[
  {
    "path": ".editorconfig",
    "content": "# http://editorconfig.org\n\nroot = true\n\n[*]\nindent_style = space\nindent_size = 4\ntrim_trailing_whitespace = true\ninsert_final_newline = true\ncharset = utf-8\nend_of_line = lf\n\n[*.bat]\nindent_style = tab\nend_of_line = crlf\n\n[LICENSE]\ninsert_final_newline = false\n\n[Makefile]\nindent_style = tab\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE.md",
    "content": "* inferno version:\n* Python version:\n* Operating System:\n\n### Description\n\nDescribe what you were trying to get done.\nTell us what happened, what went wrong, and what you expected to happen.\n\n### What I Did\n\n```\nPaste the command(s) you ran and the output.\nIf there was a crash, please include the traceback here.\n```\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*,cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# pyenv python configuration file\n.python-version\n"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\n\ndist: xenial\n\npython:\n  - 3.7\n\nenv:\n  - PYTORCH_CONDA=\"pytorch\" TORCHVISION_CONDA=\"torchvision\" TORCHVISION_CHANNEL=pytorch\n\ninstall:\n  - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;\n  - bash miniconda.sh -b -p $HOME/miniconda\n  - export PATH=\"$HOME/miniconda/bin:$PATH\"\n  - conda config --set always_yes yes --set changeps1 no\n  - conda update -q conda\n  - conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION\n  - source activate test-environment\n  - conda install -c conda-forge networkx h5py scikit-image pyyaml dill tensorboardx\n  - conda install -c pytorch $PYTORCH_CONDA\n  - conda install -c $TORCHVISION_CHANNEL $TORCHVISION_CONDA\n\ndeploy:\n  provider: pypi\n  user: nasimrahaman\n  password:\n    secure: !!binary |\n      bWwzZitLcEpibHBaUWNhUVA4UUlGa2JsZDQxVkx3eFlkY1FiYWJqYkFvWm5pdDErRzlKRXZFM0hR\n      ZE15V0tIWm5JQlJRSGlveXdYNjAzQVc1UFV3ZjNBOG0zc21vK3RaZjVSYnM5aE5ySE93ajBXc1N4\n      akNHNGhOSnF6UnBDY2kwakxPeWhxaEwxQkR0empSaFdJbWVlOE81RDVPY2pSdGw1TDQ3QjhwVGor\n      TVREdlpSYTVFd2xNNXdadTJYWFVXL3ZQY0VLZE9xckFoVk5PSHpkTTh5MGM1S1lHaS9nNThVK2JO\n      OVp5RkFROVpuOEY3YmxPdzBQZnAvL202ZUkxamlKSmxhaE13UU4zV2tJRWRpNklVSTE0RUp1ck5s\n      Q28xL2kzNER0dGVkZzI0eVhULzcxRFl5Y0pZQWMrcWtoa1VVVUo4NEZKV3JjUjNqTnF5bVI3Ykty\n      cFJrR3JydjV0dUpGUnBhc2NIdEdKVUswMkdJWEJUc3JJWGg4bS9oRGtMaVJaMExBeitJQWR4b2tF\n      MzB0OWppZ0x5VXFSMmxnVmNvZERzRWZMRnJEMTBHeTJVS2FueVhlYmpsck9qK3V5S1dtZm5UTXg4\n      bGNzN09HWEZiUmo2K0ZuYTg5a00xN3poSXhzc3pSMnRGSVJwamV4a0gzZUpyZlpYY1daTFZ3QnV0\n      clUwZW10VEsxeGFmOGFjNTd3Wll1R3JXNEZJT1h2bmxoeS9pV0FMVlE4YnVFZFFjQnJ5YWFiRjUy\n      RkZvZk1SUnp3aDFhZ3Q3cUxVa0FIbXVuZ1NYQWZxMUlOTkVNYXRTcFVJUURJM3huWmNPeTNhSWFP\n      YkVpSlFHY1lrWlhXZ1Z2cVdvcktPOW53a29Hem5BSm1HRVZHYU11dDYwaGg2SGU1MVJPTll3WHc9\n  on:\n    all_branches: false\n    tags: true\n\nscript: \n  - source activate test-environment\n  - python setup.py install\n  - python -m unittest discover -s tests -v\n"
  },
  {
    "path": "AUTHORS.rst",
    "content": "=======\nCredits\n=======\n\nDevelopment Lead\n----------------\n\n* `Nasim Rahaman <https://github.com/nasimrahaman>`_  @ `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ , `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n\n\nContributors\n------------\n\nIn no particular order,\n  *   `Steffen Wolf <https://github.com/Steffen-Wolf>`_  @ \n      `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,\n      `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n  *   `Maurice Weiler <https://github.com/mauriceweiler>`_  @ \n      `Amsterdam Machine Learning Lab <http://amlab.science.uva.nl/>`_ ,\n      `University of Amsterdam <http://www.uva.nl/en/home>`_ ,   \n  *   `Constantin Pape <https://github.com/constantinpape>`_  @ \n      `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,\n      `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n  *   `Sven Peter <https://github.com/svenpeter42>`_  @ \n      `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,\n      `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n  *   `Manuel Haussmann <https://github.com/manuelhaussmann>`_  @ \n      `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,\n      `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n  *   `Thorsten Beier <https://github.com/DerThorsten>`_  @ \n      `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,\n      `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,\n  *   `Benjamin Striner <https://github.com/bstriner>`_  @ \n      `Machine Learning Department <https://www.ml.cmu.edu/>`_ ,\n      `Carnegie Mellon University <https://www.cmu.edu/>`_ ,\n"
  },
  {
    "path": "CONTRIBUTING.rst",
    "content": ".. highlight:: shell\n\n============\nContributing\n============\n\nContributions are welcome, and they are greatly appreciated! Every\nlittle bit helps, and credit will always be given.\n\nYou can contribute in many ways:\n\nTypes of Contributions\n----------------------\n\nReport Bugs\n~~~~~~~~~~~\n\nReport bugs at https://github.com/nasimrahaman/inferno/issues.\n\nIf you are reporting a bug, please include:\n\n* Your operating system name and version.\n* Any details about your local setup that might be helpful in troubleshooting.\n* Detailed steps to reproduce the bug.\n\nFix Bugs\n~~~~~~~~\n\nLook through the GitHub issues for bugs. Anything tagged with \"bug\"\nand \"help wanted\" is open to whoever wants to implement it.\n\nImplement Features\n~~~~~~~~~~~~~~~~~~\n\nLook through the GitHub issues for features. Anything tagged with \"enhancement\"\nand \"help wanted\" is open to whoever wants to implement it.\n\nWrite Documentation\n~~~~~~~~~~~~~~~~~~~\n\ninferno could always use more documentation, whether as part of the\nofficial inferno docs, in docstrings, or even on the web in blog posts,\narticles, and such.\n\nSubmit Feedback\n~~~~~~~~~~~~~~~\n\nThe best way to send feedback is to file an issue at https://github.com/nasimrahaman/inferno/issues.\n\nIf you are proposing a feature:\n\n* Explain in detail how it would work.\n* Keep the scope as narrow as possible, to make it easier to implement.\n* Remember that this is a volunteer-driven project, and that contributions\n  are welcome :)\n\nGet Started!\n------------\n\nReady to contribute? Here's how to set up `inferno` for local development.\n\n1. Fork the `inferno` repo on GitHub.\n2. Clone your fork locally::\n\n    $ git clone git@github.com:your_name_here/inferno.git\n\n3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::\n\n    $ mkvirtualenv inferno\n    $ cd inferno/\n    $ python setup.py develop\n\n4. Create a branch for local development::\n\n    $ git checkout -b name-of-your-bugfix-or-feature\n\n   Now you can make your changes locally.\n\n5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox::\n\n    $ flake8 inferno tests\n    $ python setup.py test or py.test\n    $ tox\n\n   To get flake8 and tox, just pip install them into your virtualenv.\n\n6. Commit your changes and push your branch to GitHub::\n\n    $ git add .\n    $ git commit -m \"Your detailed description of your changes.\"\n    $ git push origin name-of-your-bugfix-or-feature\n\n7. Submit a pull request through the GitHub website.\n\nPull Request Guidelines\n-----------------------\n\nBefore you submit a pull request, check that it meets these guidelines:\n\n1. The pull request should include tests.\n2. If the pull request adds functionality, the docs should be updated. Put\n   your new functionality into a function with a docstring, and add the\n   feature to the list in README.rst.\n3. The pull request should work for Python  3.5 and 3.6. Check\n   https://travis-ci.org/nasimrahaman/inferno/pull_requests\n   and make sure that the tests pass for all supported Python versions.\n\nTips\n----\n\nTo run a subset of tests::\n\n    $ python -m unittest tests.test_inferno\n\n\n\nSphinx Apidoc\n--------------\nbefore building the documentation\none needs to generate the auto-generated\nsphinxs api documentation.\nThese files need to be in the github repository.\n\n.. code:: bash\n  \n    cd docs\n    sphinx-apidoc -o inferno-apidoc ../inferno\n\n.. warning::\n\n    Do not make changes to `inferno/docs/inferno-apidoc` This folder is auto-generated\n    by the above mentioned command.\n\nThe following combines all the commands necessary to build the html documentation:\n\n.. code:: bash\n  \n    ./build_docs.sh\n\n"
  },
  {
    "path": "HISTORY.rst",
    "content": "=======\nHistory\n=======\n\n0.1.0 (2017-08-24)\n------------------\n\n* First early release on PyPI\n\n0.1.1 (2017-08-24)\n------------------\n\n* Version Increment\n    \n0.1.2 (2017-08-24)\n------------------\n\n* Version Increment\n\n\n0.1.3 (2017-08-24)\n------------------\n\n* Updated Documentation\n\n0.1.4 (2017-08-24)\n------------------\n\n* travis auto-deployment on pypi\n\n\n0.1.5 (2017-08-24)\n------------------\n\n* travis changes to run unittest\n\n\n0.1.6 (2017-08-24)\n------------------\n\n* travis missing packages for unittesting\n* fixed inconsistent version numbers\n\n0.1.7 (2017-08-25)\n------------------\n\n* setup.py critical bugix in install procedure\n\n\n\nCURRENT CHANGES\n-----------------\n* Flexible Unet\n"
  },
  {
    "path": "LICENSE",
    "content": "\nApache Software License 2.0\n\nCopyright (c) 2017, Inferno Developers\n\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\nhttp://www.apache.org/licenses/LICENSE-2.0\n\nUnless required by applicable law or agreed to in writing, software\ndistributed under the License is distributed on an \"AS IS\" BASIS,\nWITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\nSee the License for the specific language governing permissions and\nlimitations under the License.\n\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include AUTHORS.rst\ninclude CONTRIBUTING.rst\ninclude HISTORY.rst\ninclude LICENSE\ninclude README.rst\n\nrecursive-include tests *\nrecursive-exclude * __pycache__\nrecursive-exclude * *.py[co]\n\nrecursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif\n"
  },
  {
    "path": "Makefile",
    "content": ".PHONY: clean clean-test clean-pyc clean-build docs help\n.DEFAULT_GOAL := help\ndefine BROWSER_PYSCRIPT\nimport os, webbrowser, sys\ntry:\n\tfrom urllib import pathname2url\nexcept:\n\tfrom urllib.request import pathname2url\n\nwebbrowser.open(\"file://\" + pathname2url(os.path.abspath(sys.argv[1])))\nendef\nexport BROWSER_PYSCRIPT\n\ndefine PRINT_HELP_PYSCRIPT\nimport re, sys\n\nfor line in sys.stdin:\n\tmatch = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)\n\tif match:\n\t\ttarget, help = match.groups()\n\t\tprint(\"%-20s %s\" % (target, help))\nendef\nexport PRINT_HELP_PYSCRIPT\nBROWSER := python -c \"$$BROWSER_PYSCRIPT\"\n\nhelp:\n\t@python -c \"$$PRINT_HELP_PYSCRIPT\" < $(MAKEFILE_LIST)\n\nclean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts\n\n\nclean-build: ## remove build artifacts\n\trm -fr build/\n\trm -fr dist/\n\trm -fr .eggs/\n\tfind . -name '*.egg-info' -exec rm -fr {} +\n\tfind . -name '*.egg' -exec rm -f {} +\n\nclean-pyc: ## remove Python file artifacts\n\tfind . -name '*.pyc' -exec rm -f {} +\n\tfind . -name '*.pyo' -exec rm -f {} +\n\tfind . -name '*~' -exec rm -f {} +\n\tfind . -name '__pycache__' -exec rm -fr {} +\n\nclean-test: ## remove test and coverage artifacts\n\trm -fr .tox/\n\trm -f .coverage\n\trm -fr htmlcov/\n\nlint: ## check style with flake8\n\tflake8 inferno tests\n\ntest: ## run tests quickly with the default Python\n\t\n\t\tpython setup.py test\n\ntest-all: ## run tests on every Python version with tox\n\ttox\n\ncoverage: ## check code coverage quickly with the default Python\n\tcoverage run --source inferno setup.py test\n\tcoverage report -m\n\tcoverage html\n\t$(BROWSER) htmlcov/index.html\n\ndocs: ## generate Sphinx HTML documentation, including API docs\n\trm -f docs/inferno.rst\n\trm -f docs/modules.rst\n\tsphinx-apidoc -o docs/ inferno\n\t$(MAKE) -C docs clean\n\t$(MAKE) -C docs html\n\t$(BROWSER) docs/_build/html/index.html\n\nservedocs: docs ## compile the docs watching for changes\n\twatchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .\n\nrelease: clean ## package and upload a release\n\tpython setup.py sdist upload\n\tpython setup.py bdist_wheel upload\n\ndist: clean ## builds source and wheel package\n\tpython setup.py sdist\n\tpython setup.py bdist_wheel\n\tls -l dist\n\ninstall: clean ## install the package to the active Python's site-packages\n\tpython setup.py install\n"
  },
  {
    "path": "README.rst",
    "content": "\n=======\nInferno\n=======\n\n.. image:: https://anaconda.org/conda-forge/inferno/badges/version.svg   \n        :target: https://anaconda.org/conda-forge/inferno\n\n.. image:: https://travis-ci.org/inferno-pytorch/inferno.svg?branch=master\n        :target: https://travis-ci.org/inferno-pytorch/inferno\n\n..\n  TODO new docs shield goes here, see https://github.com/inferno-pytorch/inferno/issues/139\n  .. image:: https://readthedocs.org/projects/inferno-pytorch/badge/?version=latest\n          :target: http://inferno-pytorch.readthedocs.io/en/latest/?badge=latest\n          :alt: Documentation Status\n\n\n.. image:: http://svgshare.com/i/2j7.svg\n\n\n\n\n\nInferno is a little library providing utilities and convenience functions/classes around \n`PyTorch <https://github.com/pytorch/pytorch>`_. \nIt's a work-in-progress, but the releases from v0.4 on should be fairly stable! \n\n\n\n* Free software: Apache Software License 2.0\n* Documentation: http://inferno-pytorch.readthedocs.io (Work in Progress).\n\n\nFeatures\n--------\n\nCurrent features include: \n  *   a basic \n      `Trainer class <https://github.com/nasimrahaman/inferno/tree/master/docs#preparing-the-trainer>`_ \n      to encapsulate the training boilerplate (iteration/epoch loops, validation and checkpoint creation),\n  *   a `graph API <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/containers/graph.py>`_ for building models with complex architectures, powered by `networkx <https://github.com/networkx/networkx>`_. \n  *   `easy data-parallelism <https://github.com/nasimrahaman/inferno/tree/master/docs#using-gpus>`_ over multiple GPUs, \n  *   `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/initializers>`_ for `torch.nn.Module`-level parameter initialization,\n  *   `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/transform>`_ for data preprocessing / transforms,\n  *   `support <https://github.com/nasimrahaman/inferno/tree/master/docs#using-tensorboard>`_ for `Tensorboard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`_ (best with atleast `tensorflow-cpu <https://github.com/tensorflow/tensorflow>`_ installed)\n  *   `a callback API <https://github.com/nasimrahaman/inferno/tree/master/docs#setting-up-callbacks>`_ to enable flexible interaction with the trainer,\n  *   `various utility layers <https://github.com/nasimrahaman/inferno/tree/master/inferno/extensions/layers>`_ with more underway,\n  *   `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/volumetric>`_ for volumetric datasets, and more!\n\n  \n\n\n\n.. code:: python\n\n  import torch.nn as nn\n  from inferno.io.box.cifar import get_cifar10_loaders\n  from inferno.trainers.basic import Trainer\n  from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger\n  from inferno.extensions.layers.convolutional import ConvELU2D\n  from inferno.extensions.layers.reshape import Flatten\n\n  # Fill these in:\n  LOG_DIRECTORY = '...'\n  SAVE_DIRECTORY = '...'\n  DATASET_DIRECTORY = '...'\n  DOWNLOAD_CIFAR = True\n  USE_CUDA = True\n\n  # Build torch model\n  model = nn.Sequential(\n      ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),\n      nn.MaxPool2d(kernel_size=2, stride=2),\n      ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n      nn.MaxPool2d(kernel_size=2, stride=2),\n      ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n      nn.MaxPool2d(kernel_size=2, stride=2),\n      Flatten(),\n      nn.Linear(in_features=(256 * 4 * 4), out_features=10),\n      nn.LogSoftmax(dim=1)\n  )\n\n  # Load loaders\n  train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,\n                                                      download=DOWNLOAD_CIFAR)\n\n  # Build trainer\n  trainer = Trainer(model) \\\n    .build_criterion('NLLLoss') \\\n    .build_metric('CategoricalError') \\\n    .build_optimizer('Adam') \\\n    .validate_every((2, 'epochs')) \\\n    .save_every((5, 'epochs')) \\\n    .save_to_directory(SAVE_DIRECTORY) \\\n    .set_max_num_epochs(10) \\\n    .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),\n                                    log_images_every='never'),\n                  log_directory=LOG_DIRECTORY)\n\n  # Bind loaders\n  trainer \\\n      .bind_loader('train', train_loader) \\\n      .bind_loader('validate', validate_loader)\n\n  if USE_CUDA:\n    trainer.cuda()\n\n  # Go!\n  trainer.fit()\n\n\n\n\nTo visualize the training progress, navigate to `LOG_DIRECTORY` and fire up tensorboard with \n\n.. code:: bash\n\n  $ tensorboard --logdir=${PWD} --port=6007\n\n\nand navigate to `localhost:6007` with your browser.\n\n\n\nInstallation\n------------------------\n\nConda packages for python >= 3.6 for all distributions are availaible on conda-forge:\n\n.. code:: bash\n\n  $ conda install -c pytorch -c conda-forge inferno\n\n\n\nFuture Features: \n------------------------\nPlanned features include: \n  *   a class to encapsulate Hogwild! training over multiple GPUs, \n  *   minimal shape inference with a dry-run,\n  *   proper packaging and documentation,\n  *   cutting-edge fresh-off-the-press implementations of what the future has in store. :)\n\n\n\nCredits\n---------\nAll contributors are listed here_. \n.. _here: https://inferno-pytorch.github.io/inferno/html/authors.html\n\nThis package was partially generated with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template + lots of work by Thorsten. \n\n.. _Cookiecutter: https://github.com/audreyr/cookiecutter\n.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage\n\n"
  },
  {
    "path": "add2path.sh",
    "content": "#!/usr/bin/env bash\n# Run this script from within the directory.\nexport PYTHONPATH=${PYTHONPATH}:${PWD}"
  },
  {
    "path": "build_docs.sh",
    "content": "#!/bin/bash\ncd docs\nrm -r -f inferno-apidoc\nsphinx-apidoc -o inferno-apidoc ../inferno\nmake html\ncd .."
  },
  {
    "path": "conda-recipe/build.sh",
    "content": "PY_VER=$(python -c \"import sys; print('{}.{}'.format(*sys.version_info[:2]))\")\n\n# Install python modules\nmkdir -p ${PREFIX}/inferno\ncp -r inferno/* ${PREFIX}/inferno\necho \"${PREFIX}\" > ${PREFIX}/lib/python${PY_VER}/site-packages/inferno.pth\npython -m compileall ${PREFIX}/inferno\n"
  },
  {
    "path": "conda-recipe/meta.yaml",
    "content": "package:\n    name: inferno\n\n    {% set tagged_version = GIT_DESCRIBE_TAG|replace(\"v\",\"\")|replace(\"-\", \".\") %}\n\n    # If we're using a non-tagged revision, append '.postN' to the version\n    {% if GIT_DESCRIBE_NUMBER|int != 0 %}\n        {% set tagged_version = tagged_version + '.post' + GIT_DESCRIBE_NUMBER %}\n    {% endif %}\n\n    version: {{tagged_version}}\n\nsource:\n    path: ..\n\nbuild:\n    number: 1\n    string: py_{{PKG_BUILDNUM}}_g{{GIT_FULL_HASH[:7]}}\n\nrequirements:\n    build:\n        - python {{PY_VER}}*\n    run:\n        - python {{PY_VER}}*\n        - pytorch\n        - torchvision\n        - pyyaml\n        - scipy\n        - scikit-image\n        - scikit-learn\n        - h5py\n        - dill\n        - networkx 1.11\n        - tensorboardx\n        - sphinx_rtd_theme\n\n\ntest:\n    imports:\n        - inferno\n\nabout:\n    license: Apache License 2.0\n    summary: A utility library around PyTorch\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "/inferno.rst\n/inferno.*.rst\n/modules.rst\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nPAPER         =\nBUILDDIR      = _build\n\n# User-friendly check for sphinx-build\nifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)\n$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)\nendif\n\n# Internal variables.\nPAPEROPT_a4     = -D latex_paper_size=a4\nPAPEROPT_letter = -D latex_paper_size=letter\nALLSPHINXOPTS   = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .\n# the i18n builder cannot share the environment and doctrees with the others\nI18NSPHINXOPTS  = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .\n\n.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext\n\nhelp:\n\t@echo \"Please use \\`make <target>' where <target> is one of\"\n\t@echo \"  html       to make standalone HTML files\"\n\t@echo \"  dirhtml    to make HTML files named index.html in directories\"\n\t@echo \"  singlehtml to make a single large HTML file\"\n\t@echo \"  pickle     to make pickle files\"\n\t@echo \"  json       to make JSON files\"\n\t@echo \"  htmlhelp   to make HTML files and a HTML help project\"\n\t@echo \"  qthelp     to make HTML files and a qthelp project\"\n\t@echo \"  devhelp    to make HTML files and a Devhelp project\"\n\t@echo \"  epub       to make an epub\"\n\t@echo \"  latex      to make LaTeX files, you can set PAPER=a4 or PAPER=letter\"\n\t@echo \"  latexpdf   to make LaTeX files and run them through pdflatex\"\n\t@echo \"  latexpdfja to make LaTeX files and run them through platex/dvipdfmx\"\n\t@echo \"  text       to make text files\"\n\t@echo \"  man        to make manual pages\"\n\t@echo \"  texinfo    to make Texinfo files\"\n\t@echo \"  info       to make Texinfo files and run them through makeinfo\"\n\t@echo \"  gettext    to make PO message catalogs\"\n\t@echo \"  changes    to make an overview of all changed/added/deprecated items\"\n\t@echo \"  xml        to make Docutils-native XML files\"\n\t@echo \"  pseudoxml  to make pseudoxml-XML files for display purposes\"\n\t@echo \"  linkcheck  to check all external links for integrity\"\n\t@echo \"  doctest    to run all doctests embedded in the documentation (if enabled)\"\n\nclean:\n\trm -rf $(BUILDDIR)/*\n\nhtml:\n\t$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html\n\t@echo\n\t@echo \"Build finished. The HTML pages are in $(BUILDDIR)/html.\"\n\ndirhtml:\n\t$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml\n\t@echo\n\t@echo \"Build finished. The HTML pages are in $(BUILDDIR)/dirhtml.\"\n\nsinglehtml:\n\t$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml\n\t@echo\n\t@echo \"Build finished. The HTML page is in $(BUILDDIR)/singlehtml.\"\n\npickle:\n\t$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle\n\t@echo\n\t@echo \"Build finished; now you can process the pickle files.\"\n\njson:\n\t$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json\n\t@echo\n\t@echo \"Build finished; now you can process the JSON files.\"\n\nhtmlhelp:\n\t$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp\n\t@echo\n\t@echo \"Build finished; now you can run HTML Help Workshop with the\" \\\n\t      \".hhp project file in $(BUILDDIR)/htmlhelp.\"\n\nqthelp:\n\t$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp\n\t@echo\n\t@echo \"Build finished; now you can run \"qcollectiongenerator\" with the\" \\\n\t      \".qhcp project file in $(BUILDDIR)/qthelp, like this:\"\n\t@echo \"# qcollectiongenerator $(BUILDDIR)/qthelp/inferno.qhcp\"\n\t@echo \"To view the help file:\"\n\t@echo \"# assistant -collectionFile $(BUILDDIR)/qthelp/inferno.qhc\"\n\ndevhelp:\n\t$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp\n\t@echo\n\t@echo \"Build finished.\"\n\t@echo \"To view the help file:\"\n\t@echo \"# mkdir -p $$HOME/.local/share/devhelp/inferno\"\n\t@echo \"# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/inferno\"\n\t@echo \"# devhelp\"\n\nepub:\n\t$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub\n\t@echo\n\t@echo \"Build finished. The epub file is in $(BUILDDIR)/epub.\"\n\nlatex:\n\t$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex\n\t@echo\n\t@echo \"Build finished; the LaTeX files are in $(BUILDDIR)/latex.\"\n\t@echo \"Run \\`make' in that directory to run these through (pdf)latex\" \\\n\t      \"(use \\`make latexpdf' here to do that automatically).\"\n\nlatexpdf:\n\t$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex\n\t@echo \"Running LaTeX files through pdflatex...\"\n\t$(MAKE) -C $(BUILDDIR)/latex all-pdf\n\t@echo \"pdflatex finished; the PDF files are in $(BUILDDIR)/latex.\"\n\nlatexpdfja:\n\t$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex\n\t@echo \"Running LaTeX files through platex and dvipdfmx...\"\n\t$(MAKE) -C $(BUILDDIR)/latex all-pdf-ja\n\t@echo \"pdflatex finished; the PDF files are in $(BUILDDIR)/latex.\"\n\ntext:\n\t$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text\n\t@echo\n\t@echo \"Build finished. The text files are in $(BUILDDIR)/text.\"\n\nman:\n\t$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man\n\t@echo\n\t@echo \"Build finished. The manual pages are in $(BUILDDIR)/man.\"\n\ntexinfo:\n\t$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo\n\t@echo\n\t@echo \"Build finished. The Texinfo files are in $(BUILDDIR)/texinfo.\"\n\t@echo \"Run \\`make' in that directory to run these through makeinfo\" \\\n\t      \"(use \\`make info' here to do that automatically).\"\n\ninfo:\n\t$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo\n\t@echo \"Running Texinfo files through makeinfo...\"\n\tmake -C $(BUILDDIR)/texinfo info\n\t@echo \"makeinfo finished; the Info files are in $(BUILDDIR)/texinfo.\"\n\ngettext:\n\t$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale\n\t@echo\n\t@echo \"Build finished. The message catalogs are in $(BUILDDIR)/locale.\"\n\nchanges:\n\t$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes\n\t@echo\n\t@echo \"The overview file is in $(BUILDDIR)/changes.\"\n\nlinkcheck:\n\t$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck\n\t@echo\n\t@echo \"Link check complete; look for any errors in the above output \" \\\n\t      \"or in $(BUILDDIR)/linkcheck/output.txt.\"\n\ndoctest:\n\t$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest\n\t@echo \"Testing of doctests in the sources finished, look at the \" \\\n\t      \"results in $(BUILDDIR)/doctest/output.txt.\"\n\nxml:\n\t$(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml\n\t@echo\n\t@echo \"Build finished. The XML files are in $(BUILDDIR)/xml.\"\n\npseudoxml:\n\t$(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml\n\t@echo\n\t@echo \"Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml.\"\n"
  },
  {
    "path": "docs/_templates/layout.html",
    "content": "{# layout.html #}\n{# Import the theme's layout. #}\n{% extends \"!layout.html\" %}\n\n{% set css_files = css_files + ['_static/pygments.css'] %}"
  },
  {
    "path": "docs/_templates/template_module.rst",
    "content": "{{ fullname }}\n{{ underline }}\n\n.. automodule:: {{ fullname }}\n    \n   {% block functions %}\n   {% if functions %}\n\n   Functions\n   ==================\n\n   {% for item in functions %}\n\n   .. autofunction:: {{ item }}\n\n   .. include:: backreferences/{{fullname}}.{{item}}.examples\n\n   .. raw:: html\n\n           <div style='clear:both'></div>\n\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block classes %}\n   {% if classes %}\n\n   Classes\n   -------\n\n   {% for item in classes %}\n   .. autoclass:: {{ item }}\n      :members:\n\n   .. include:: backreferences/{{fullname}}.{{item}}.examples\n   \n   .. raw:: html\n\n           <div style='clear:both'></div>\n\n\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}\n\n   {% block exceptions %}\n   {% if exceptions %}\n\n   Exceptions\n   ----------\n\n   .. autosummary::\n   {% for item in exceptions %}\n      {{ item }}\n   {%- endfor %}\n   {% endif %}\n   {% endblock %}"
  },
  {
    "path": "docs/authors.rst",
    "content": ".. include:: ../AUTHORS.rst\n"
  },
  {
    "path": "docs/conf.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# inferno documentation build configuration file, created by\n# sphinx-quickstart on Tue Jul  9 22:26:36 2013.\n#\n# This file is execfile()d with the current directory set to its\n# containing dir.\n#\n# Note that not all possible configuration values are present in this\n# autogenerated file.\n#\n# All configuration values have a default; values that are commented out\n# serve to show the default.\nimport matplotlib\nmatplotlib.use('Agg')\nimport sphinx_gallery\n\n\nimport sys\n\nfrom unittest.mock import MagicMock\n\n\nclass Mock(MagicMock):\n    @classmethod\n    def __getattr__(cls, name):\n            return MagicMock()\n\n# MOCK_MODULES = ['pygtk',\n#                 'hdf5',\n#                 'skimage',\n#                 'argparse',\n#                 'pandas',\n#                 'torch',\n#                 'torch.nn', 'torch.nn.init', 'torch.nn.functional',\n#                 'torch.nn.parallel', 'torch.nn.parallel.data_parallel',\n#                 'torch.multiprocessing', 'torch.autograd',\n#                 'torch.utils', 'torch.utils.data',\n#                 'torch.optim', 'torch.sparse', 'torch.cuda']\n# sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)\n\n\n\nimport os\n\n# If extensions (or modules to document with autodoc) are in another\n# directory, add these directories to sys.path here. If the directory is\n# relative to the documentation root, use os.path.abspath to make it\n# absolute, like shown here.\n#sys.path.insert(0, os.path.abspath('.'))\n\n# Get the project root dir, which is the parent dir of this\ncwd = os.getcwd()\nproject_root = os.path.dirname(cwd)\n\n# Insert the project root dir as the first element in the PYTHONPATH.\n# This lets us ensure that the source package is imported, and that its\n# version is used.\nsys.path.insert(0, project_root)\n\n\nimport inferno\nimport inferno.extensions\nimport inferno.extensions.layers\nfrom inferno.extensions.layers import *\nfrom inferno.extensions.layers.reshape import *\n\n# -- General configuration ---------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.\nextensions = [\n    'sphinx.ext.autodoc', \n    'sphinx.ext.viewcode',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.doctest', \n    'sphinx.ext.todo', \n    'sphinx.ext.ifconfig',\n    'sphinx.ext.mathjax',\n    'sphinx.ext.graphviz',\n    'sphinx_gallery.gen_gallery',\n    'sphinxcontrib.bibtex',\n    'sphinx.ext.napoleon',\n    'sphinxcontrib.inlinesyntaxhighlight'\n]\n\n\n\nsphinx_gallery_conf = {\n    # path to your examples scripts\n    'examples_dirs' :\n            '../examples',\n    # path where to save gallery generated examples\n    'gallery_dirs'  :\n            'auto_examples',\n    'backreferences_dir' :\n            'gen_modules/backreferences',\n    'scan_used_functions':\n        True,\n    'doc_module' :\n        ('inferno','inferno.extensions','inferno.extensions.layers','inferno.extensions.layers.convolutional'),\n\n    'docs_resolv': True,\n\n    'parallel_read_safe': True,\n\n    'reference_url':  {\n             # The module you locally document uses a None\n            'inferno': None,\n\n            # External python modules use their documentation websites\n            #'matplotlib': 'http://matplotlib.org',\n            'numpy': 'http://docs.scipy.org/doc/numpy-1.13.0'}\n}\n\n\n\n\n# Napoleon settings\nnapoleon_google_docstring = True\nnapoleon_numpy_docstring = True\nnapoleon_include_init_with_doc = False\nnapoleon_include_private_with_doc = False\nnapoleon_include_special_with_doc = True\nnapoleon_use_admonition_for_examples = False\nnapoleon_use_admonition_for_notes = False\nnapoleon_use_admonition_for_references = False\nnapoleon_use_ivar = False\nnapoleon_use_param = True\nnapoleon_use_rtype = True\n\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# generate autosummary even if no references\nautosummary_generate = True\n\n\n# The suffix of source filenames.\nsource_suffix = '.rst'\n\n# The encoding of source files.\n#source_encoding = 'utf-8-sig'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# General information about the project.\nproject = u'inferno'\ncopyright = u\"2018, f\"\n\n# The version info for the project you're documenting, acts as replacement\n# for |version| and |release|, also used in various other places throughout\n# the built documents.\n#\n# The short X.Y version.\nversion = inferno.__version__\n# The full version, including alpha/beta/rc tags.\nrelease = inferno.__version__\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#language = None\n\n# There are two options for replacing |today|: either, you set today to\n# some non-false value, then it is used:\n#today = ''\n# Else, today_fmt is used as the format for a strftime call.\n#today_fmt = '%B %d, %Y'\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\nexclude_patterns = ['_build']\n\n# The reST default role (used for this markup: `text`) to use for all\n# documents.\n#default_role = None\n\n# If true, '()' will be appended to :func: etc. cross-reference text.\n#add_function_parentheses = True\n\n# If true, the current module name will be prepended to all description\n# unit titles (such as .. function::).\n#add_module_names = True\n\n# If true, sectionauthor and moduleauthor directives will be shown in the\n# output. They are ignored by default.\n#show_authors = False\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n# A list of ignored prefixes for module index sorting.\n#modindex_common_prefix = []\n\n# If true, keep warnings as \"system message\" paragraphs in the built\n# documents.\n#keep_warnings = False\n\n\n# -- Options for HTML output -------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\nhtml_theme = 'sphinx_rtd_theme'\n\n# Theme options are theme-specific and customize the look and feel of a\n# theme further.  For a list of options available for each theme, see the\n# documentation.\n#html_theme_options = {}\n\n# Add any paths that contain custom themes here, relative to this directory.\n#html_theme_path = []\n\n# The name for this set of Sphinx documents.  If None, it defaults to\n# \"<project> v<release> documentation\".\n#html_title = None\n\n# A shorter title for the navigation bar.  Default is the same as\n# html_title.\n#html_short_title = None\n\n# The name of an image file (relative to this directory) to place at the\n# top of the sidebar.\n#html_logo = None\n\n# The name of an image file (within the static path) to use as favicon\n# of the docs.  This file should be a Windows icon file (.ico) being\n# 16x16 or 32x32 pixels large.\n#html_favicon = None\n\n# Add any paths that contain custom static files (such as style sheets)\n# here, relative to this directory. They are copied after the builtin\n# static files, so a file named \"default.css\" will overwrite the builtin\n# \"default.css\".\nhtml_static_path = ['_static']\n\n# If not '', a 'Last updated on:' timestamp is inserted at every page\n# bottom, using the given strftime format.\n#html_last_updated_fmt = '%b %d, %Y'\n\n# If true, SmartyPants will be used to convert quotes and dashes to\n# typographically correct entities.\n#html_use_smartypants = True\n\n# Custom sidebar templates, maps document names to template names.\n#html_sidebars = {}\n\n# Additional templates that should be rendered to pages, maps page names\n# to template names.\n#html_additional_pages = {}\n\n# If false, no module index is generated.\n#html_domain_indices = True\n\n# If false, no index is generated.\n#html_use_index = True\n\n# If true, the index is split into individual pages for each letter.\n#html_split_index = False\n\n# If true, links to the reST sources are added to the pages.\n#html_show_sourcelink = True\n\n# If true, \"Created using Sphinx\" is shown in the HTML footer.\n# Default is True.\n#html_show_sphinx = True\n\n# If true, \"(C) Copyright ...\" is shown in the HTML footer.\n# Default is True.\n#html_show_copyright = True\n\n# If true, an OpenSearch description file will be output, and all pages\n# will contain a <link> tag referring to it.  The value of this option\n# must be the base URL from which the finished HTML is served.\n#html_use_opensearch = ''\n\n# This is the file name suffix for HTML files (e.g. \".xhtml\").\n#html_file_suffix = None\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'infernodoc'\n\n\n# -- Options for LaTeX output ------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #'preamble': '',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title, author, documentclass\n# [howto/manual]).\nlatex_documents = [\n    ('index', 'inferno.tex',\n     u'inferno Documentation',\n     u'Inferno Team', 'manual'),\n]\n\n# The name of an image file (relative to this directory) to place at\n# the top of the title page.\n#latex_logo = None\n\n# For \"manual\" documents, if this is true, then toplevel headings\n# are parts, not chapters.\n#latex_use_parts = False\n\n# If true, show page references after internal links.\n#latex_show_pagerefs = False\n\n# If true, show URL addresses after external links.\n#latex_show_urls = False\n\n# Documents to append as an appendix to all manuals.\n#latex_appendices = []\n\n# If false, no module index is generated.\n#latex_domain_indices = True\n\n\n# -- Options for manual page output ------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    ('index', 'inferno',\n     u'inferno Documentation',\n     [u'Inferno Team'], 1)\n]\n\n# If true, show URL addresses after external links.\n#man_show_urls = False\n\n\n# -- Options for Texinfo output ----------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    ('index', 'inferno',\n     u'inferno Documentation',\n     u'Inferno Team',\n     'inferno',\n     'One line description of project.',\n     'Miscellaneous'),\n]\n\n# Documents to append as an appendix to all manuals.\n#texinfo_appendices = []\n\n# If false, no module index is generated.\n#texinfo_domain_indices = True\n\n# How to display URL addresses: 'footnote', 'no', or 'inline'.\n#texinfo_show_urls = 'footnote'\n\n# If true, do not generate a @detailmenu in the \"Top\" node's menu.\n#texinfo_no_detailmenu = False\n"
  },
  {
    "path": "docs/contributing.rst",
    "content": ".. include:: ../CONTRIBUTING.rst\n"
  },
  {
    "path": "docs/environment.yml",
    "content": "name: inferno_docs\n\nchannels:\n  - soumith\n  - anaconda\n\ndependencies:\n  - python==3.5\n  - pytorch>=0.1.12\n  - torchvision\n  - scikit-image \n  - pip:\n     - scipy>=0.13.0\n     - h5py\n     - scikit-image\n     - pyyaml\n     - dill\n     - sphinx-gallery\n     - sphinxcontrib-napoleon\n     - sphinxcontrib-bibtex\n     - sphinxcontrib-inlinesyntaxhighlight\n"
  },
  {
    "path": "docs/examples.rst",
    "content": ".. _inferno_examples_gallery:\n\nInferno Examples Gallery\n============================\n\n\n.. toctree::\n    :maxdepth: 5\n\n    ../auto_examples/index\n\n"
  },
  {
    "path": "docs/history.rst",
    "content": ".. include:: ../HISTORY.rst\n"
  },
  {
    "path": "docs/index.rst",
    "content": "Welcome to inferno's documentation!\n======================================\n\nContents:\n\n.. toctree::\n   :maxdepth: 1\n\n   readme\n   installation\n   usage\n   examples\n   contributing\n   inferno-apidoc/modules\n   authors\n   history\n   zbibliography\n\n.. automodule:: inferno\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.containers.rst",
    "content": "inferno.extensions.containers package\n=====================================\n\nSubmodules\n----------\n\ninferno.extensions.containers.graph module\n------------------------------------------\n\n.. automodule:: inferno.extensions.containers.graph\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.containers.sequential module\n-----------------------------------------------\n\n.. automodule:: inferno.extensions.containers.sequential\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.containers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.criteria.rst",
    "content": "inferno.extensions.criteria package\n===================================\n\nSubmodules\n----------\n\ninferno.extensions.criteria.core module\n---------------------------------------\n\n.. automodule:: inferno.extensions.criteria.core\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.criteria.elementwise\\_measures module\n--------------------------------------------------------\n\n.. automodule:: inferno.extensions.criteria.elementwise_measures\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.criteria.regularized module\n----------------------------------------------\n\n.. automodule:: inferno.extensions.criteria.regularized\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.criteria.set\\_similarity\\_measures module\n------------------------------------------------------------\n\n.. automodule:: inferno.extensions.criteria.set_similarity_measures\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.criteria\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.initializers.rst",
    "content": "inferno.extensions.initializers package\n=======================================\n\nSubmodules\n----------\n\ninferno.extensions.initializers.base module\n-------------------------------------------\n\n.. automodule:: inferno.extensions.initializers.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.initializers.presets module\n----------------------------------------------\n\n.. automodule:: inferno.extensions.initializers.presets\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.initializers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.layers.rst",
    "content": "inferno.extensions.layers package\n=================================\n\nSubmodules\n----------\n\ninferno.extensions.layers.activations module\n--------------------------------------------\n\n.. automodule:: inferno.extensions.layers.activations\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.building\\_blocks module\n-------------------------------------------------\n\n.. automodule:: inferno.extensions.layers.building_blocks\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.convolutional module\n----------------------------------------------\n\n.. automodule:: inferno.extensions.layers.convolutional\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.device module\n---------------------------------------\n\n.. automodule:: inferno.extensions.layers.device\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.identity module\n-----------------------------------------\n\n.. automodule:: inferno.extensions.layers.identity\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.prefab module\n---------------------------------------\n\n.. automodule:: inferno.extensions.layers.prefab\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.res\\_unet module\n------------------------------------------\n\n.. automodule:: inferno.extensions.layers.res_unet\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.reshape module\n----------------------------------------\n\n.. automodule:: inferno.extensions.layers.reshape\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.sampling module\n-----------------------------------------\n\n.. automodule:: inferno.extensions.layers.sampling\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.layers.unet\\_base module\n-------------------------------------------\n\n.. automodule:: inferno.extensions.layers.unet_base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.layers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.metrics.rst",
    "content": "inferno.extensions.metrics package\n==================================\n\nSubmodules\n----------\n\ninferno.extensions.metrics.arand module\n---------------------------------------\n\n.. automodule:: inferno.extensions.metrics.arand\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.metrics.base module\n--------------------------------------\n\n.. automodule:: inferno.extensions.metrics.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.metrics.categorical module\n---------------------------------------------\n\n.. automodule:: inferno.extensions.metrics.categorical\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.metrics.cremi\\_score module\n----------------------------------------------\n\n.. automodule:: inferno.extensions.metrics.cremi_score\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.metrics.voi module\n-------------------------------------\n\n.. automodule:: inferno.extensions.metrics.voi\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.metrics\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.optimizers.rst",
    "content": "inferno.extensions.optimizers package\n=====================================\n\nSubmodules\n----------\n\ninferno.extensions.optimizers.adam module\n-----------------------------------------\n\n.. automodule:: inferno.extensions.optimizers.adam\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.extensions.optimizers.annealed\\_adam module\n---------------------------------------------------\n\n.. automodule:: inferno.extensions.optimizers.annealed_adam\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions.optimizers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.extensions.rst",
    "content": "inferno.extensions package\n==========================\n\nSubpackages\n-----------\n\n.. toctree::\n\n    inferno.extensions.containers\n    inferno.extensions.criteria\n    inferno.extensions.initializers\n    inferno.extensions.layers\n    inferno.extensions.metrics\n    inferno.extensions.optimizers\n\nModule contents\n---------------\n\n.. automodule:: inferno.extensions\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.io.box.rst",
    "content": "inferno.io.box package\n======================\n\nSubmodules\n----------\n\ninferno.io.box.binary\\_blobs module\n-----------------------------------\n\n.. automodule:: inferno.io.box.binary_blobs\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.box.camvid module\n----------------------------\n\n.. automodule:: inferno.io.box.camvid\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.box.cifar module\n---------------------------\n\n.. automodule:: inferno.io.box.cifar\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.box.cityscapes module\n--------------------------------\n\n.. automodule:: inferno.io.box.cityscapes\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.io.box\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.io.core.rst",
    "content": "inferno.io.core package\n=======================\n\nSubmodules\n----------\n\ninferno.io.core.base module\n---------------------------\n\n.. automodule:: inferno.io.core.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.core.concatenate module\n----------------------------------\n\n.. automodule:: inferno.io.core.concatenate\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.core.data\\_utils module\n----------------------------------\n\n.. automodule:: inferno.io.core.data_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.core.zip module\n--------------------------\n\n.. automodule:: inferno.io.core.zip\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.io.core\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.io.rst",
    "content": "inferno.io package\n==================\n\nSubpackages\n-----------\n\n.. toctree::\n\n    inferno.io.box\n    inferno.io.core\n    inferno.io.transform\n    inferno.io.volumetric\n\nModule contents\n---------------\n\n.. automodule:: inferno.io\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.io.transform.rst",
    "content": "inferno.io.transform package\n============================\n\nSubmodules\n----------\n\ninferno.io.transform.base module\n--------------------------------\n\n.. automodule:: inferno.io.transform.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.transform.generic module\n-----------------------------------\n\n.. automodule:: inferno.io.transform.generic\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.transform.image module\n---------------------------------\n\n.. automodule:: inferno.io.transform.image\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.transform.volume module\n----------------------------------\n\n.. automodule:: inferno.io.transform.volume\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.io.transform\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.io.volumetric.rst",
    "content": "inferno.io.volumetric package\n=============================\n\nSubmodules\n----------\n\ninferno.io.volumetric.lazy\\_volume\\_loader module\n-------------------------------------------------\n\n.. automodule:: inferno.io.volumetric.lazy_volume_loader\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.volumetric.volume module\n-----------------------------------\n\n.. automodule:: inferno.io.volumetric.volume\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.io.volumetric.volumetric\\_utils module\n----------------------------------------------\n\n.. automodule:: inferno.io.volumetric.volumetric_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.io.volumetric\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.rst",
    "content": "inferno package\n===============\n\nSubpackages\n-----------\n\n.. toctree::\n\n    inferno.extensions\n    inferno.io\n    inferno.trainers\n    inferno.utils\n\nSubmodules\n----------\n\ninferno.inferno module\n----------------------\n\n.. automodule:: inferno.inferno\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst",
    "content": "inferno.trainers.callbacks.logging package\n==========================================\n\nSubmodules\n----------\n\ninferno.trainers.callbacks.logging.base module\n----------------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.logging.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.logging.tensorboard module\n-----------------------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.logging.tensorboard\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.trainers.callbacks.logging\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.trainers.callbacks.rst",
    "content": "inferno.trainers.callbacks package\n==================================\n\nSubpackages\n-----------\n\n.. toctree::\n\n    inferno.trainers.callbacks.logging\n\nSubmodules\n----------\n\ninferno.trainers.callbacks.base module\n--------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.base\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.console module\n-----------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.console\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.essentials module\n--------------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.essentials\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.scheduling module\n--------------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.scheduling\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.tqdm module\n--------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.tqdm\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.trainers.callbacks.tqdmstub module\n------------------------------------------\n\n.. automodule:: inferno.trainers.callbacks.tqdmstub\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.trainers.callbacks\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.trainers.rst",
    "content": "inferno.trainers package\n========================\n\nSubpackages\n-----------\n\n.. toctree::\n\n    inferno.trainers.callbacks\n\nSubmodules\n----------\n\ninferno.trainers.basic module\n-----------------------------\n\n.. automodule:: inferno.trainers.basic\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.trainers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/inferno.utils.rst",
    "content": "inferno.utils package\n=====================\n\nSubmodules\n----------\n\ninferno.utils.exceptions module\n-------------------------------\n\n.. automodule:: inferno.utils.exceptions\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.io\\_utils module\n------------------------------\n\n.. automodule:: inferno.utils.io_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.math\\_utils module\n--------------------------------\n\n.. automodule:: inferno.utils.math_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.model\\_utils module\n---------------------------------\n\n.. automodule:: inferno.utils.model_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.python\\_utils module\n----------------------------------\n\n.. automodule:: inferno.utils.python_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.test\\_utils module\n--------------------------------\n\n.. automodule:: inferno.utils.test_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.torch\\_utils module\n---------------------------------\n\n.. automodule:: inferno.utils.torch_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\ninferno.utils.train\\_utils module\n---------------------------------\n\n.. automodule:: inferno.utils.train_utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\nModule contents\n---------------\n\n.. automodule:: inferno.utils\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/inferno-apidoc/modules.rst",
    "content": "inferno\n=======\n\n.. toctree::\n   :maxdepth: 4\n\n   inferno\n"
  },
  {
    "path": "docs/installation.rst",
    "content": ".. highlight:: shell\n\n==================================\nInstallation\n==================================\n\nInstall on Linux and OSX\n------------------------\n\nDevelopers\n~~~~~~~~~~~~~~~~~~~~~~\n\nFirst, make sure `you have Pytorch installed <http://pytorch.org/>`_. \n\nThen, clone this repository with: \n\n.. code:: python\n\n  $ git clone https://github.com/nasimrahaman/inferno.git\n\n\nNext, install the dependencies.\n\n.. code:: python\n\n  $ cd inferno\n  $ pip install -r requirements.txt\n\n\nIf you use python from the shell: \n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nFinally, add *inferno* to your `PYTHONPATH` with:\n\n.. code:: python\n\n  source add2path.sh\n\nIf you use PyCharm:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\nRefer to this `QA <https://askubuntu.com/questions/684550/importing-a-python-module-works-from-command-line-but-not-from-pycharm>`_ about setting up paths with Pycharm.\n\n\n\n\n\n======================================================\nInstallation via PyPi / pip / setup.py(Experimental)\n======================================================\n\nYou need to install pytorch via pip before installing\ninferno.  Follow the `pytorch installation guide`_.\n\nStable release\n--------------\n\nTo install inferno, run this command in your terminal:\n\n.. code-block:: console\n\n    $ pip install inferno-pytorch\n\nThis is the preferred method to install inferno, as it will always install the most recent stable release. \n\nIf you don't have `pip`_ installed, this `Python installation guide`_ can guide\nyou through the process.\n\n.. _pip: https://pip.pypa.io\n.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/\n.. _pytorch installation guide: http://pytorch.org/\n\nFrom sources\n------------------------\nFirst, make sure `you have Pytorch installed <http://pytorch.org/>`_. \nThe sources for inferno can be downloaded from the `Github repo`_.\nYou can either clone the public repository:\n\n.. code-block:: console\n\n    $ git clone git://github.com/nasimrahaman/inferno\n\nOr download the `tarball`_:\n\n.. code-block:: console\n\n    $ curl  -OL https://github.com/nasimrahaman/inferno/tarball/master\n\nOnce you have a copy of the source, you can install it with:\n\n.. code-block:: console\n\n    $ python setup.py install\n\n\n.. _Github repo: https://github.com/nasimrahaman/inferno\n.. _tarball: https://github.com/nasimrahaman/inferno/tarball/master\n"
  },
  {
    "path": "docs/make.bat",
    "content": "@ECHO OFF\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset BUILDDIR=_build\nset ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .\nset I18NSPHINXOPTS=%SPHINXOPTS% .\nif NOT \"%PAPER%\" == \"\" (\n\tset ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%\n\tset I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%\n)\n\nif \"%1\" == \"\" goto help\n\nif \"%1\" == \"help\" (\n\t:help\n\techo.Please use `make ^<target^>` where ^<target^> is one of\n\techo.  html       to make standalone HTML files\n\techo.  dirhtml    to make HTML files named index.html in directories\n\techo.  singlehtml to make a single large HTML file\n\techo.  pickle     to make pickle files\n\techo.  json       to make JSON files\n\techo.  htmlhelp   to make HTML files and a HTML help project\n\techo.  qthelp     to make HTML files and a qthelp project\n\techo.  devhelp    to make HTML files and a Devhelp project\n\techo.  epub       to make an epub\n\techo.  latex      to make LaTeX files, you can set PAPER=a4 or PAPER=letter\n\techo.  text       to make text files\n\techo.  man        to make manual pages\n\techo.  texinfo    to make Texinfo files\n\techo.  gettext    to make PO message catalogs\n\techo.  changes    to make an overview over all changed/added/deprecated items\n\techo.  xml        to make Docutils-native XML files\n\techo.  pseudoxml  to make pseudoxml-XML files for display purposes\n\techo.  linkcheck  to check all external links for integrity\n\techo.  doctest    to run all doctests embedded in the documentation if enabled\n\tgoto end\n)\n\nif \"%1\" == \"clean\" (\n\tfor /d %%i in (%BUILDDIR%\\*) do rmdir /q /s %%i\n\tdel /q /s %BUILDDIR%\\*\n\tgoto end\n)\n\n\n%SPHINXBUILD% 2> nul\nif errorlevel 9009 (\n\techo.\n\techo.The 'sphinx-build' command was not found. Make sure you have Sphinx\n\techo.installed, then set the SPHINXBUILD environment variable to point\n\techo.to the full path of the 'sphinx-build' executable. Alternatively you\n\techo.may add the Sphinx directory to PATH.\n\techo.\n\techo.If you don't have Sphinx installed, grab it from\n\techo.http://sphinx-doc.org/\n\texit /b 1\n)\n\nif \"%1\" == \"html\" (\n\t%SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The HTML pages are in %BUILDDIR%/html.\n\tgoto end\n)\n\nif \"%1\" == \"dirhtml\" (\n\t%SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.\n\tgoto end\n)\n\nif \"%1\" == \"singlehtml\" (\n\t%SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.\n\tgoto end\n)\n\nif \"%1\" == \"pickle\" (\n\t%SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished; now you can process the pickle files.\n\tgoto end\n)\n\nif \"%1\" == \"json\" (\n\t%SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished; now you can process the JSON files.\n\tgoto end\n)\n\nif \"%1\" == \"htmlhelp\" (\n\t%SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished; now you can run HTML Help Workshop with the ^\n.hhp project file in %BUILDDIR%/htmlhelp.\n\tgoto end\n)\n\nif \"%1\" == \"qthelp\" (\n\t%SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished; now you can run \"qcollectiongenerator\" with the ^\n.qhcp project file in %BUILDDIR%/qthelp, like this:\n\techo.^> qcollectiongenerator %BUILDDIR%\\qthelp\\inferno.qhcp\n\techo.To view the help file:\n\techo.^> assistant -collectionFile %BUILDDIR%\\qthelp\\inferno.ghc\n\tgoto end\n)\n\nif \"%1\" == \"devhelp\" (\n\t%SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished.\n\tgoto end\n)\n\nif \"%1\" == \"epub\" (\n\t%SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The epub file is in %BUILDDIR%/epub.\n\tgoto end\n)\n\nif \"%1\" == \"latex\" (\n\t%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished; the LaTeX files are in %BUILDDIR%/latex.\n\tgoto end\n)\n\nif \"%1\" == \"latexpdf\" (\n\t%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex\n\tcd %BUILDDIR%/latex\n\tmake all-pdf\n\tcd %BUILDDIR%/..\n\techo.\n\techo.Build finished; the PDF files are in %BUILDDIR%/latex.\n\tgoto end\n)\n\nif \"%1\" == \"latexpdfja\" (\n\t%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex\n\tcd %BUILDDIR%/latex\n\tmake all-pdf-ja\n\tcd %BUILDDIR%/..\n\techo.\n\techo.Build finished; the PDF files are in %BUILDDIR%/latex.\n\tgoto end\n)\n\nif \"%1\" == \"text\" (\n\t%SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The text files are in %BUILDDIR%/text.\n\tgoto end\n)\n\nif \"%1\" == \"man\" (\n\t%SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The manual pages are in %BUILDDIR%/man.\n\tgoto end\n)\n\nif \"%1\" == \"texinfo\" (\n\t%SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.\n\tgoto end\n)\n\nif \"%1\" == \"gettext\" (\n\t%SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The message catalogs are in %BUILDDIR%/locale.\n\tgoto end\n)\n\nif \"%1\" == \"changes\" (\n\t%SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.The overview file is in %BUILDDIR%/changes.\n\tgoto end\n)\n\nif \"%1\" == \"linkcheck\" (\n\t%SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Link check complete; look for any errors in the above output ^\nor in %BUILDDIR%/linkcheck/output.txt.\n\tgoto end\n)\n\nif \"%1\" == \"doctest\" (\n\t%SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Testing of doctests in the sources finished, look at the ^\nresults in %BUILDDIR%/doctest/output.txt.\n\tgoto end\n)\n\nif \"%1\" == \"xml\" (\n\t%SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The XML files are in %BUILDDIR%/xml.\n\tgoto end\n)\n\nif \"%1\" == \"pseudoxml\" (\n\t%SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml\n\tif errorlevel 1 exit /b 1\n\techo.\n\techo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.\n\tgoto end\n)\n\n:end\n"
  },
  {
    "path": "docs/readme.rst",
    "content": ".. include:: ../README.rst\n"
  },
  {
    "path": "docs/refs.bib",
    "content": "\n@inproceedings{alush_2013_simbad,\ntitle={Break and Conquer: Efficient Correlation Clustering for Image Segmentation},\nauthor={Alush, Amir and Goldberger, Jacob},\nbooktitle={2nd International Workshop on Similarity-Based Pattern Analysis and Recognition},\nyear={2013}\n}\n"
  },
  {
    "path": "docs/usage.rst",
    "content": "=====\nUsage\n=====\n\n\nInferno is a utility library built around [PyTorch](http://pytorch.org/), designed to help you train and even build complex pytorch models. And in this tutorial, we'll see how! If you're new to PyTorch, I highly recommended you work through the [Pytorch tutorials](http://pytorch.org/tutorials/) first.\n\nBuilding a PyTorch Model\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nInferno's training machinery works with just about any valid [Pytorch module](http://pytorch.org/docs/master/nn.html#torch.nn.Module). However, to make things even easier, we also provide pre-configured layers that work out-of-the-box. Let's use them to build a convolutional neural network for Cifar-10.\n\n.. code:: python\n\n    import torch.nn as nn\n    from inferno.extensions.layers.convolutional import ConvELU2D\n    from inferno.extensions.layers.reshape import Flatten\n\n`ConvELU2D` is a 2-dimensional convolutional layer with orthogonal weight initialization and [ELU](http://pytorch.org/docs/master/nn.html#torch.nn.ELU) activation. `Flatten` reshapes the 4 dimensional activation tensor to a matrix. Let's use the Sequential container to chain together a bunch of convolutional and pooling layers, followed by a linear and softmax layer. \n\n\n.. code:: python\n\n    model = nn.Sequential(\n        ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),\n        nn.MaxPool2d(kernel_size=2, stride=2),\n        ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n        nn.MaxPool2d(kernel_size=2, stride=2),\n        ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n        nn.MaxPool2d(kernel_size=2, stride=2),\n        Flatten(),\n        nn.Linear(in_features=(256 * 4 * 4), out_features=10),\n        nn.Softmax()\n    )\n\nModels this size don't win competitions anymore, but it'll do for our purpose. \n\nData Logistics \n**************************\n\nWith our model built, it's time to worry about the data generators. Or is it? \n\n.. code:: python\n\n    from inferno.io.box.cifar import get_cifar10_loaders\n    train_loader, validate_loader = get_cifar10_loaders('path/to/cifar10',\n                                                        download=True,\n                                                        train_batch_size=128,\n                                                        test_batch_size=100)\n\nCIFAR-10 works out-of-the-`box` (pun very much intended) with all the fancy data-augmentation and normalization. Of course, it's perfectly fine if you have your own [`DataLoader`](http://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader).\n\n\nPreparing the Trainer\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nWith our model and data loaders good to go, it's finally time to build the trainer. To start, let's initialize one. \n\n.. code:: python\n\n    from inferno.trainers.basic import Trainer\n\n    trainer = Trainer(model)\n    # Tell trainer about the data loaders\n    trainer.bind_loader('train', train_loader).bind_loader('validate', validate_loader)\n\n\nNow to the things we could do with it. \n\nSetting up Checkpointing\n***************************************\nWhen training a model for days, it's usually a good idea to store the current training state to disk every once in a while. To set this up, we tell `trainer` where to store these *checkpoints* and how often. \n\n.. code:: python\n\n    trainer.save_to_directory('path/to/save/directory').save_every((25, 'epochs'))\n\nSo we're saving once every 25 epochs. But what if an epoch takes forever, and you don't wish to wait that long? \n\n.. code:: python\n\n    trainer.save_every((1000, 'iterations'))\n\nIn this setting, you're saving once every 1000 iterations (= batches). But we might also want to create a checkpoint when the validation score is the best. Easy as 1, 2,\n\n.. code:: python\n\n    trainer.save_at_best_validation_score()\n\nRemember that a checkpoint contains the entire training state, and not just the model. Everything is included in the checkpoint file, including optimizer, criterion, and callbacks but __not the data loaders__. \n\nSetting up Validation\n**************************\nLet's say you wish to validate once every 2 epochs.\n\n.. code:: python\n\n    trainer.validate_every((2, 'epochs'))\n\n\nTo be able to validate, you'll need to specify a validation metric.\n\n.. code:: python\n\n    trainer.build_metric('CategoricalError')\n\nInferno looks for a metric `'CategoricalError'` in `inferno.extensions.metrics`. To specify your own metric, subclass `inferno.extensions.metrics.base.Metric` and implement the `forward` method. With that done, you could:\n\n.. code:: python\n\n    trainer.build_metric(MyMetric)\n\nor \n\n.. code:: python\n\n    trainer.build_metric(MyMetric, **my_metric_kwargs)\n\n\nA metric might be way too expensive to evaluate every training iteration without slowing down the training. If this is the case and you'd like to evaluate the metric every (say) 10 *training* iterations:\n\n.. code:: python\n\n    trainer.evaluate_metric_every((10, 'iterations'))\n\nHowever, while validating, the metric is evaluated once every iteration.\n\nSetting up the Criterion and Optimizer\n***************************************\nWith that out of the way, let's set up a training criterion and an optimizer. \n\n.. code:: python\n\n    # set up the criterion\n    trainer.build_criterion('CrossEntropyLoss')\n\nThe `trainer` looks for a `'CrossEntropyLoss'` in `torch.nn`, which it finds. But any of the following would have worked: \n\n.. code:: python\n\n    trainer.build_criterion(nn.CrossEntropyLoss)\n\nor \n\n.. code:: python\n\n    trainer.build_criterion(nn.CrossEntropyLoss())\n\nWhat this means is that if you have your own loss criterion that has the same API as any of the criteria found in `torch.nn`, you should be fine by just plugging it in. \n\nThe same holds for the optimizer:\n\n.. code:: python\n\n    trainer.build_optimizer('Adam', weight_decay=0.0005)\n\nLike for criteria, the `trainer` looks for a `'Adam'` in `torch.optim` (among other places), and initializes it with `model`'s parameters. Any keywords you might use for `torch.optim.Adam`, you could pass them to the `build_optimizer` method. \n\nOr alternatively, you could use:\n\n.. code:: python\n\n    from torch.optim import Adam\n\n    trainer.build_optimizer(Adam, weight_decay=0.0005)\n\n\nIf you implemented your own optimizer (by subclassing `torch.optim.Optimizer`), you should be able to use it instead of `Adam`. Alternatively, if you already have an optimizer *instance*, you could do:\n\n.. code:: python\n\n    optimizer = MyOptimizer(model.parameters(), **optimizer_kwargs)\n    trainer.build_optimizer(optimizer)\n\n\nSetting up Training Duration\n********************************\nYou probably don't want to train forever, in which case you must specify: \n\n.. code:: python\n\n    trainer.set_max_num_epochs(100)\n\nor \n\n.. code:: python\n\n    trainer.set_max_num_iterations(10000)\n\n\nIf you like to train indefinitely (or until you're happy with the results), use:\n\n.. code:: python\n\n    trainer.set_max_num_iterations('inf')\n\nIn this case, you'll need to interrupt the training manually with a `KeyboardInterrupt`. \n\nSetting up Callbacks\n*********************\nCallbacks are pretty handy when it comes to interacting with the `Trainer`. More precisely: `Trainer` defines a number of events as 'triggers' for callbacks. Currently, these are: \n\n.. code:: python\n\n    BEGIN_OF_FIT,\n    END_OF_FIT,\n    BEGIN_OF_TRAINING_RUN,\n    END_OF_TRAINING_RUN,\n    BEGIN_OF_EPOCH,\n    END_OF_EPOCH,\n    BEGIN_OF_TRAINING_ITERATION,\n    END_OF_TRAINING_ITERATION,\n    BEGIN_OF_VALIDATION_RUN,\n    END_OF_VALIDATION_RUN,\n    BEGIN_OF_VALIDATION_ITERATION,\n    END_OF_VALIDATION_ITERATION,\n    BEGIN_OF_SAVE,\n    END_OF_SAVE\n\n\nAs an example, let's build a simple callback to interrupt the training on NaNs. We check at the end of every training iteration whether the training loss is NaN, and accordingly raise a `RuntimeError`. \n\n.. code:: python\n\n    import numpy as np\n    from inferno.trainers.callbacks.base import Callback\n\n    class NaNDetector(Callback):\n        def end_of_training_iteration(self, **_):\n            # The callback object has the trainer as an attribute. \n            # The trainer populates its 'states' with torch tensors (NOT VARIABLES!)\n            training_loss = self.trainer.get_state('training_loss')\n            # Extract float from torch tensor\n            training_loss = training_loss[0]\n            if np.isnan(training_loss):\n                raise RuntimeError(\"NaNs detected!\")\n\n\nWith the callback defined, all we need to do is register it with the trainer:\n\n.. code:: python\n\n    trainer.register_callback(NaNDetector())\n\n\nSo the next time you get `RuntimeError: \"NaNs detected!`, you know the drill. \n\nUsing Tensorboard\n**************************\n\nInferno supports logging scalars and images to Tensorboard out-of-the-box, though this requires you have at least [tensorflow-cpu](https://github.com/tensorflow/tensorflow) installed. Let's say you want to log scalars every iteration and images every 20 iterations:\n\n.. code:: python\n\n    from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger\n\n    trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),\n                                           log_images_every=(20, 'iterations')),\n                         log_directory='/path/to/log/directory')\n\n\nAfter you've started training, use a bash shell to fire up tensorboard with:\n\n.. code:: bash\n\n    $ tensorboard --logdir=/path/to/log/directory --port=6007\n    \nand navigate to `localhost:6007` with your favorite browser.\n\nFine print: missing the `log_images_every` keyword argument to `TensorboardLogger` will result in images being logged every iteration. If you don't have a fast hard drive, this might actually slow down the training. To not log images, just use `log_images_every='never'`. \n\nUsing GPUs\n*************\n\nTo use just one GPU: \n\n.. code:: python\n\n    trainer.cuda()\n\n\nFor multi-GPU data-parallel training, simply pass `trainer.cuda` a list of devices: \n\n.. code:: python\n\n    trainer.cuda(devices=[0, 1, 2, 3])\n\n\n__Pro-tip__: Say you only want to use GPUs 0, 3, 5 and 7 (your colleagues might love you for this). Before running your training script, simply: \n\n.. code:: bash\n\n    $ export CUDA_VISIBLE_DEVICES=0,3,5,7\n    $ python train.py\n\nThis maps device 0 to 0, 3 to 1, 5 to 2 and 7 to 3. \n\nOne more thing\n**************************\n\n\nOnce you have everything configured, use \n\n.. code:: python\n\n    trainer.fit()\n\nto commence training! This last step is kinda important. :wink:\n\nCherries:\n~~~~~~~~~~~~~~~~~~~~~~\n\n\nBuilding Complex Models with the Graph API\n****************************************************\n\n\n\nWork in Progress:\n\n\nParameter Initialization\n**************************\n\nWork in Progress:\n\n\nSupport\n*************\nWork in Progress:\n\n"
  },
  {
    "path": "docs/zbibliography.rst",
    "content": ".. _inferno_bibliography:\n\nBibliography\n============================\n\nThe bibliography: \n\n.. bibliography:: refs.bib\n    :style: alpha"
  },
  {
    "path": "examples/README.txt",
    "content": "\n.. _examples-index:\n\nGallery of Examples\n===================\n\n"
  },
  {
    "path": "examples/plot_cheap_unet.py",
    "content": "\"\"\"\nUNet Tutorial\n================================\nA unet example which can be run without a gpu\n\"\"\"\n\n##############################################################################\n# Preface\n# --------------\n# We start with some unspectacular multi purpose imports needed for this example\nimport matplotlib.pyplot as plt\nimport torch\nfrom torch import nn\nimport numpy\n\n\n##############################################################################\n\n# determine whether we have a gpu\n# and should use cuda\nUSE_CUDA = torch.cuda.is_available()\n\n\n##############################################################################\n# Dataset\n# --------------\n# For simplicity we will use a toy dataset where we need to perform\n# a binary segmentation task.\nfrom inferno.io.box.binary_blobs import get_binary_blob_loaders\n\n# convert labels from long to float as needed by\n# binary cross entropy loss\ndef label_transform(x):\n    return torch.from_numpy(x).float()\n#label_transform = lambda x : torch.from_numpy(x).float()\n\ntrain_loader, test_loader, validate_loader = get_binary_blob_loaders(\n    size=8, # how many images per {train,test,validate}\n    train_batch_size=2,\n    length=256, # <= size of the images\n    gaussian_noise_sigma=1.4, # <= how noise are the images\n    train_label_transform = label_transform,\n    validate_label_transform = label_transform\n)\n\nimage_channels = 1   # <-- number of channels of the image\npred_channels = 1  # <-- number of channels needed for the prediction\n\nif False:\n    ##############################################################################\n    # Visualize Dataset\n    # ~~~~~~~~~~~~~~~~~~~~~~\n    fig = plt.figure()\n\n    for i,(image, target) in enumerate(train_loader):\n        ax = fig.add_subplot(1, 2, 1)\n        ax.imshow(image[0,0,...])\n        ax.set_title('raw data')\n        ax = fig.add_subplot(1, 2, 2)\n        ax.imshow(target[0,...])\n        ax.set_title('ground truth')\n        break\n    fig.tight_layout()\n    plt.show()\n\n\n\n\n##############################################################################\n# Training\n# ----------------------------\n# To train the unet, we use the infernos Trainer class of inferno.\n# Since we train many models later on in this example we encapsulate\n# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for\n# an example dedicated to the trainer itself).\nfrom inferno.trainers import Trainer\nfrom inferno.utils.python_utils import ensure_dir\n\ndef train_model(model, loaders, **kwargs):\n\n    trainer = Trainer(model)\n    trainer.build_criterion('BCEWithLogitsLoss')\n    trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))\n    #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))\n    #trainer.save_every((kwargs.get('save_every', 10), 'epochs'))\n    #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))\n    trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20))\n\n    # bind the loaders\n    trainer.bind_loader('train', loaders[0])\n    trainer.bind_loader('validate', loaders[1])\n\n    if USE_CUDA:\n        trainer.cuda()\n\n    # do the training\n    trainer.fit()\n\n    return trainer\n\n\n\n\n##############################################################################\n# Prediction\n# ----------------------------\n# The trainer contains the trained model and we can do predictions.\n# We use :code:`unwrap` to convert the results to numpy arrays.\n# Since we want to do many prediction we encapsulate the\n# the prediction in a function\nfrom inferno.utils.torch_utils import unwrap\n\ndef predict(trainer, test_loader,  save_dir=None):\n\n\n    trainer.eval_mode()\n    for image, target in test_loader:\n\n        # transfer image to gpu\n        image = image.cuda() if USE_CUDA else image\n\n        # get batch size from image\n        batch_size = image.size()[0]\n\n        for b in range(batch_size):\n            prediction = trainer.apply_model(image)\n            prediction = torch.nn.functional.sigmoid(prediction)\n\n            image = unwrap(image,      as_numpy=True, to_cpu=True)\n            prediction = unwrap(prediction, as_numpy=True, to_cpu=True)\n            target = unwrap(target, as_numpy=True, to_cpu=True)\n\n            fig = plt.figure()\n\n            ax = fig.add_subplot(2, 2, 1)\n            ax.imshow(image[b,0,...])\n            ax.set_title('raw data')\n\n            ax = fig.add_subplot(2, 2, 2)\n            ax.imshow(target[b,...])\n            ax.set_title('ground truth')\n\n            ax = fig.add_subplot(2, 2, 4)\n            ax.imshow(prediction[b,...])\n            ax.set_title('prediction')\n\n            fig.tight_layout()\n            plt.show()\n\n\n\n##############################################################################\n# Custom UNet\n# ----------------------------\n# Often one needs to have a UNet with custom layers.\n# Here we show how to implement such a customized UNet.\n# To this end we derive from :code:`UNetBase`.\n# For the sake of this example we will create\n# a Unet which uses depthwise convolutions and might be trained on a CPU\nfrom inferno.extensions.models import UNetBase\nfrom inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D,ConvActivation\n\n\nclass CheapConv(nn.Module):\n    def __init__(self, in_channels, out_channels, activated):\n        super(CheapConv, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        if activated:\n            self.convs = torch.nn.Sequential(\n                ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),\n                ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))\n            )\n        else:\n            self.convs = torch.nn.Sequential(\n                ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),\n                Conv2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))\n            )\n    def forward(self, x):\n        assert x.shape[1] == self.in_channels,\"input has wrong number of channels\"\n        x =  self.convs(x)\n        assert x.shape[1] == self.out_channels,\"output has wrong number of channels\"\n        return x \n\n\nclass CheapConvBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, activated):\n        super(CheapConvBlock, self).__init__()\n        self.activated = activated\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        if(in_channels != out_channels):\n            self.start = ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))\n        else:\n            self.start = None\n        self.conv_a = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=True)\n        self.conv_b = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=False)\n        self.activation = torch.nn.ReLU()\n    def forward(self, x):\n        x_input = x\n        if self.start is not None:\n            x_input = self.start(x_input)\n\n        x = self.conv_a(x_input)\n        x = self.conv_b(x)\n\n        x = x + x_input\n\n        if self.activated:\n            x = self.activation(x)\n        return x\n\nclass MySimple2DCpUnet(UNetBase):\n    def __init__(self, in_channels, out_channels, depth=3, residual=False, **kwargs):\n        super(MySimple2DCpUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,\n                                             dim=2, depth=depth, **kwargs)\n\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n\n        # last? \n        last = part == 'up' and index==0\n        return CheapConvBlock(in_channels=in_channels, out_channels=out_channels, activated=not last),False\n\n\n\nfrom inferno.extensions.layers import RemoveSingletonDimension\nmodel_b = torch.nn.Sequential(\n    CheapConv(in_channels=image_channels, out_channels=4, activated=True),\n    MySimple2DCpUnet(in_channels=4, out_channels=pred_channels) ,\n    RemoveSingletonDimension(dim=1)\n)\n\n\n###################################################\n# do the training (with the same functions as before)\ntrainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)\n\n###################################################\n# do the training (with the same functions as before)1\npredict(trainer=trainer, test_loader=test_loader)\n\n"
  },
  {
    "path": "examples/plot_train_side_loss_unet.py",
    "content": "\"\"\"\nTrain Side Loss UNet Example\n================================\n\nIn this example a UNet with side supervision\nand auxiliary loss  implemented\n\n\"\"\"\n\n##############################################################################\n# Imports needed for this example\nimport torch\nimport torch.nn as nn\nfrom inferno.io.box.binary_blobs import get_binary_blob_loaders\nfrom inferno.trainers.basic import Trainer\n\nfrom inferno.extensions.layers.convolutional import  Conv2D\nfrom inferno.extensions.models.res_unet import _ResBlock as ResBlock\nfrom inferno.extensions.models import ResBlockUNet\nfrom inferno.utils.torch_utils import unwrap\nfrom inferno.utils.python_utils import ensure_dir\nimport pylab\n\n\n##############################################################################\n# To create a UNet with side loss we create a new nn.Module class\n# which has a ResBlockUNet as member.\n# The ResBlockUNet is configured such that the results of the\n# bottom convolution and all the results of the up-stream\n# convolutions are returned as (side)-output.\n# a 1x1 convolutions is used to give the side outputs\n# the right number of out_channels and UpSampling is\n# used to resize all side-outputs to the full resolution\n# of the input. These side `side-predictions` are\n# returned by our MySideLossUNet.\n# Furthermore, all  `side-predictions` are concatenated\n# and feed trough another two residual blocks to make\n# the final prediction.\nclass MySideLossUNet(nn.Module):\n    def __init__(self, in_channels, out_channels, depth=3):\n        super(MySideLossUNet, self).__init__()\n\n        self.depth = depth\n        self.unet = ResBlockUNet(in_channels=in_channels, out_channels=in_channels*2,\n                                 dim=2, unet_kwargs=dict(depth=depth),\n                                 side_out_parts=['bottom', 'up'])\n\n        # number of out channels\n        self.n_channels_per_output = self.unet.n_channels_per_output\n\n        # 1x1 conv to give the side outs of the unet\n        # the right number of channels\n        # and a Upsampling to give the right shape\n        upscale_factor = 2**self.depth\n        conv_and_scale = []\n        for n_channels in self.n_channels_per_output:\n\n            # conv blocks\n            conv = Conv2D(in_channels=n_channels, out_channels=out_channels, kernel_size=1)\n            if upscale_factor > 1:\n                upsample = nn.Upsample(scale_factor=upscale_factor)\n                conv_and_scale.append(nn.Sequential(conv, upsample))\n            else:\n                conv_and_scale.append(conv)\n\n            upscale_factor //= 2\n\n        self.conv_and_scale = nn.ModuleList(conv_and_scale)\n\n        # combined number of channels after concat\n        # concat side output predictions with main output of unet\n        self.n_channels_combined = (self.depth + 1)* out_channels + in_channels*2\n\n        self.final_block = nn.Sequential(\n            ResBlock(dim=2,in_channels=self.n_channels_combined, out_channels=self.n_channels_combined),\n            ResBlock(in_channels=self.n_channels_combined, out_channels=out_channels,\n                    dim=2, activated=False),\n        )\n\n    def forward(self, input):\n        outs = self.unet(input)\n        assert len(outs) == len(self.n_channels_per_output)\n\n        # convert the unet output into the right number of\n        preds = [None] * len(outs)\n        for i,out in enumerate(outs):\n            preds[i] = self.conv_and_scale[i](out)\n\n        # this is the side output\n        preds =  tuple(preds)\n\n        # concat side output predictions with main output of unet\n        combined = torch.cat(preds + (outs[-1],), 1)\n\n        final_res = self.final_block(combined)\n\n        # return everything\n        return preds + (final_res,)\n\n##############################################################################\n# We use a custom loss functions which applied CrossEntropyLoss\n# to all side outputs.\n# The side outputs are weighted in a quadratic fashion and added up\n# into a single value\nclass MySideLoss(nn.Module):\n    \"\"\"Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.\n    \"\"\"\n\n    def __init__(self):\n        super(MySideLoss, self).__init__()\n        self.criterion = nn.CrossEntropyLoss(reduce=True)\n\n        w = 1.0\n        l = None\n\n    def forward(self, predictions, target):\n        w = 1.0\n        l = None\n        for p in predictions:\n            ll = self.criterion(p, target)*w\n            if l is None:\n                l = ll\n            else:\n                l += ll\n            w *= 2\n        return l\n\n\n\n##############################################################################\n# Training boilerplate (see :ref:`sphx_glr_auto_examples_trainer.py`)\nLOG_DIRECTORY = ensure_dir('log')\nSAVE_DIRECTORY = ensure_dir('save')\nDATASET_DIRECTORY = ensure_dir('dataset')\n\n\nUSE_CUDA = torch.cuda.is_available()\n\n# Build a residual unet where the last layer is not activated\nsl_unet = MySideLossUNet(in_channels=5, out_channels=2)\n\nmodel = nn.Sequential(\n    ResBlock(dim=2, in_channels=1, out_channels=5),\n    sl_unet\n)\ntrain_loader, test_loader, validate_loader = get_binary_blob_loaders(\n    train_batch_size=3,\n    length=512, # <= size of the images\n    gaussian_noise_sigma=1.5 # <= how noise are the images\n)\n\n# Build trainer\ntrainer = Trainer(model)\ntrainer.build_criterion(MySideLoss())\ntrainer.build_optimizer('Adam')\ntrainer.validate_every((10, 'epochs'))\n#trainer.save_every((10, 'epochs'))\n#trainer.save_to_directory(SAVE_DIRECTORY)\ntrainer.set_max_num_epochs(40)\n\n# Bind loaders\ntrainer \\\n    .bind_loader('train', train_loader)\\\n    .bind_loader('validate', validate_loader)\n\nif USE_CUDA:\n    trainer.cuda()\n\n# Go!\ntrainer.fit()\n\n\n##############################################################################\n# Predict with the trained network\n# and visualize the results\n\n# predict:\n#trainer.load(best=True)\ntrainer.bind_loader('train', train_loader)\ntrainer.bind_loader('validate', validate_loader)\ntrainer.eval_mode()\n\nif USE_CUDA:\n    trainer.cuda()\n\n# look at an example\nfor img,target in test_loader:\n    if USE_CUDA:\n        img = img.cuda()\n\n    # softmax on each of the prediction\n    preds = trainer.apply_model(img)\n    preds = [nn.functional.softmax(pred,dim=1)        for pred in preds]\n    preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]\n    img    = unwrap(img,  as_numpy=True, to_cpu=True)\n    target  = unwrap(target, as_numpy=True, to_cpu=True)\n\n    n_plots = len(preds) + 2\n    batch_size = preds[0].shape[0]\n\n    for b in range(batch_size):\n\n        fig = pylab.figure()\n\n        ax1 = fig.add_subplot(2,4,1)\n        ax1.set_title('image')\n        ax1.imshow(img[b,0,...])\n\n        ax2 = fig.add_subplot(2,4,2)\n        ax2.set_title('ground truth')\n        ax2.imshow(target[b,...])\n\n        for i,pred in enumerate(preds):\n            axn = fig.add_subplot(2,4, 3+i)\n            axn.imshow(pred[b,1,...])\n\n            if i + 1 < len(preds):\n                axn.set_title('side prediction %d'%i)\n            else:\n                axn.set_title('combined prediction')\n\n        pylab.show()\n\n    break\n"
  },
  {
    "path": "examples/plot_unet_tutorial.py",
    "content": "\"\"\"\nUNet Tutorial\n================================\nA tentative tutorial on the usage\nof the unet framework in inferno\n\"\"\"\n\n##############################################################################\n# Preface\n# --------------\n# We start with some unspectacular multi purpose imports needed for this example\nimport matplotlib.pyplot as plt\nimport torch\nimport numpy\n\n##############################################################################\n\n# determine whether we have a gpu\n# and should use cuda\nUSE_CUDA = torch.cuda.is_available()\n\n\n##############################################################################\n# Dataset\n# --------------\n# For simplicity we will use a toy dataset where we need to perform\n# a binary segmentation task.\nfrom inferno.io.box.binary_blobs import get_binary_blob_loaders\n\n# convert labels from long to float as needed by\n# binary cross entropy loss\ndef label_transform(x):\n    return torch.from_numpy(x).float()\n#label_transform = lambda x : torch.from_numpy(x).float()\n\ntrain_loader, test_loader, validate_loader = get_binary_blob_loaders(\n    size=8, # how many images per {train,test,validate}\n    train_batch_size=2,\n    length=256, # <= size of the images\n    gaussian_noise_sigma=1.4, # <= how noise are the images\n    train_label_transform = label_transform,\n    validate_label_transform = label_transform\n)\n\nimage_channels = 1   # <-- number of channels of the image\npred_channels = 1  # <-- number of channels needed for the prediction\n\n##############################################################################\n# Visualize Dataset\n# ~~~~~~~~~~~~~~~~~~~~~~\nfig = plt.figure()\n\nfor i,(image, target) in enumerate(train_loader):\n    ax = fig.add_subplot(1, 2, 1)\n    ax.imshow(image[0,0,...])\n    ax.set_title('raw data')\n    ax = fig.add_subplot(1, 2, 2)\n    ax.imshow(target[0,...])\n    ax.set_title('ground truth')\n    break\nfig.tight_layout()\nplt.show()\n\n\n##############################################################################\n# Simple UNet\n# ----------------------------\n# We start with a very simple predefined\n# res block UNet. By default, this UNet uses  ReLUs (in conjunction with batchnorm) as nonlinearities\n# With :code:`activated=False` we make sure that the last layer\n# is not activated since we chain the UNet with a sigmoid\n# activation function.\nfrom inferno.extensions.models import ResBlockUNet\nfrom inferno.extensions.layers import RemoveSingletonDimension\n\nmodel = torch.nn.Sequential(\n    ResBlockUNet(dim=2, in_channels=image_channels, out_channels=pred_channels,  activated=False),\n    RemoveSingletonDimension(dim=1),\n    torch.nn.Sigmoid()\n)\n\n##############################################################################\n# while the model above will work in principal, it has some drawbacks.\n# Within the UNet, the number of features is increased by a multiplicative\n# factor while going down, the so-called gain. The default value for the gain is 2.\n# Since we start with only a single channel we could either increase the gain,\n# or use a some convolutions to increase the number of channels\n# before the the UNet.\nfrom inferno.extensions.layers import ConvReLU2D\nmodel_a = torch.nn.Sequential(\n    ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),\n    ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels,  activated=False,\n        res_block_kwargs=dict(batchnorm=True,size=2)) ,\n    RemoveSingletonDimension(dim=1)\n    # torch.nn.Sigmoid()\n)\n\n\n\n\n\n##############################################################################\n# Training\n# ----------------------------\n# To train the unet, we use the infernos Trainer class of inferno.\n# Since we train many models later on in this example we encapsulate\n# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for\n# an example dedicated to the trainer itself).\nfrom inferno.trainers import Trainer\nfrom inferno.utils.python_utils import ensure_dir\n\ndef train_model(model, loaders, **kwargs):\n\n    trainer = Trainer(model)\n    trainer.build_criterion('BCEWithLogitsLoss')\n    trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))\n    #trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))\n    #trainer.save_every((kwargs.get('save_every', 10), 'epochs'))\n    #trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))\n    trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 200))\n\n    # bind the loaders\n    trainer.bind_loader('train', loaders[0])\n    trainer.bind_loader('validate', loaders[1])\n\n    if USE_CUDA:\n        trainer.cuda()\n\n    # do the training\n    trainer.fit()\n\n    return trainer\n\n\ntrainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a', lr=0.01)\n\n\n\n##############################################################################\n# Prediction\n# ----------------------------\n# The trainer contains the trained model and we can do predictions.\n# We use :code:`unwrap` to convert the results to numpy arrays.\n# Since we want to do many prediction we encapsulate the\n# the prediction in a function\nfrom inferno.utils.torch_utils import unwrap\n\ndef predict(trainer, test_loader,  save_dir=None):\n\n\n    trainer.eval_mode()\n    for image, target in test_loader:\n\n        # transfer image to gpu\n        image = image.cuda() if USE_CUDA else image\n\n        # get batch size from image\n        batch_size = image.size()[0]\n\n        for b in range(batch_size):\n            prediction = trainer.apply_model(image)\n            prediction = torch.nn.functional.sigmoid(prediction)\n\n            image = unwrap(image,      as_numpy=True, to_cpu=True)\n            prediction = unwrap(prediction, as_numpy=True, to_cpu=True)\n            target = unwrap(target, as_numpy=True, to_cpu=True)\n\n            fig = plt.figure()\n\n            ax = fig.add_subplot(2, 2, 1)\n            ax.imshow(image[b,0,...])\n            ax.set_title('raw data')\n\n            ax = fig.add_subplot(2, 2, 2)\n            ax.imshow(target[b,...])\n            ax.set_title('ground truth')\n\n            ax = fig.add_subplot(2, 2, 4)\n            ax.imshow(prediction[b,...])\n            ax.set_title('prediction')\n\n            fig.tight_layout()\n            plt.show()\n\n###################################################\n# do the prediction\npredict(trainer=trainer, test_loader=test_loader)\n\n\n\n\n##############################################################################\n# Custom UNet\n# ----------------------------\n# Often one needs to have a UNet with custom layers.\n# Here we show how to implement such a customized UNet.\n# To this end we derive from :code:`UNetBase`.\n# For the sake of this example we will create\n# a rather exotic UNet which uses different types\n# of convolutions/non-linearities in the different branches\n# of the unet\nfrom inferno.extensions.models import UNetBase\nfrom inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D\nfrom inferno.extensions.layers.sampling import Upsample\n\nclass MySimple2DUnet(UNetBase):\n    def __init__(self, in_channels, out_channels, depth=3, **kwargs):\n        super(MySimple2DUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,\n                                             dim=2, depth=depth, **kwargs)\n\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n\n        if part == 'down':\n            return torch.nn.Sequential(\n                ConvELU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),\n                ConvELU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)\n            ), False\n        elif part == 'bottom':\n            return torch.nn.Sequential(\n                ConvReLU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),\n                ConvReLU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3),\n            ), False\n        elif part == 'up':\n            # are we in the very last block?\n            if index  == 0:\n                return torch.nn.Sequential(\n                    ConvELU2D(in_channels=in_channels,  out_channels=out_channels, kernel_size=3),\n                    Conv2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)\n                ), False\n            else:\n                return torch.nn.Sequential(\n                    ConvELU2D(in_channels=in_channels,   out_channels=out_channels, kernel_size=3),\n                    ConvReLU2D(in_channels=out_channels,  out_channels=out_channels, kernel_size=3)\n                ), False\n        else:\n            raise RuntimeError(\"something is wrong\")\n\n\n\n\n    # this function CAN be implemented, if not, MaxPooling is used by default\n    def downsample_op_factory(self, index):\n        return torch.nn.MaxPool2d(kernel_size=2, stride=2)\n\n    # this function CAN be implemented, if not, Upsampling is used by default\n    def upsample_op_factory(self, index):\n        return Upsample(mode='bilinear', align_corners=False,scale_factor=2)\n\nmodel_b = torch.nn.Sequential(\n    ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),\n    MySimple2DUnet(in_channels=5, out_channels=pred_channels) ,\n    RemoveSingletonDimension(dim=1)\n)\n\n\n###################################################\n# do the training (with the same functions as before)\ntrainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)\n\n###################################################\n# do the training (with the same functions as before)\npredict(trainer=trainer, test_loader=test_loader)\n\n"
  },
  {
    "path": "examples/regularized_mnist.py",
    "content": "\"\"\"\nRegularized MNIST Example\n================================\n\nThis example demonstrates adding and logging arbitrary regularization losses, in this case,\nL2 activity regularization and L1 weight regularization.\n\n- Add a `_losses` dictionary to any module containing loss names and values\n- Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses\n- Call `Trainer.observe_training_and_validation_states` to log the losses as well\n\"\"\"\n\nimport argparse\nimport sys\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import datasets, transforms\n\nfrom inferno.extensions.layers.reshape import Flatten\nfrom inferno.trainers.basic import Trainer\nfrom inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger\n\n\nclass RegularizedLinear(nn.Linear):\n    def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs):\n        super(RegularizedLinear, self).__init__(*args, **kwargs)\n        self.ar_weight = ar_weight\n        self.l1_weight = l1_weight\n        self._losses = {}\n\n    def forward(self, input):\n        output = super(RegularizedLinear, self).forward(input)\n        self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight\n        self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight\n        return output\n\n\ndef model_fn():\n    return nn.Sequential(\n        Flatten(),\n        RegularizedLinear(in_features=784, out_features=256),\n        nn.LeakyReLU(),\n        RegularizedLinear(in_features=256, out_features=128),\n        nn.LeakyReLU(),\n        RegularizedLinear(in_features=128, out_features=10)\n    )\n\n\ndef mnist_data_loaders(args):\n    kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}\n    train_loader = torch.utils.data.DataLoader(\n        datasets.MNIST('./data', train=True, download=True,\n                       transform=transforms.Compose([\n                           transforms.ToTensor(),\n                           transforms.Normalize((0.1307,), (0.3081,))\n                       ])),\n        batch_size=args.batch_size, shuffle=True, **kwargs)\n    test_loader = torch.utils.data.DataLoader(\n        datasets.MNIST('./data', train=False, transform=transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize((0.1307,), (0.3081,))\n        ])),\n        batch_size=args.test_batch_size, shuffle=True, **kwargs)\n    return train_loader, test_loader\n\n\ndef train_model(args):\n    model = model_fn()\n    train_loader, validate_loader = mnist_data_loaders(args)\n\n    # Build trainer\n    trainer = Trainer(model) \\\n        .build_criterion('RegularizedCrossEntropyLoss') \\\n        .build_metric('CategoricalError') \\\n        .build_optimizer('Adam') \\\n        .validate_every((1, 'epochs')) \\\n        .save_every((1, 'epochs')) \\\n        .save_to_directory(args.save_directory) \\\n        .set_max_num_epochs(args.epochs) \\\n        .build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),\n                                        log_images_every='never'),\n                      log_directory=args.save_directory)\n\n    # Record regularization losses\n    trainer.logger.observe_training_and_validation_states([\n        'main_loss',\n        'total_regularization_loss',\n        'activity_regularization',\n        'l1_weight_regularization'\n    ])\n\n    # Bind loaders\n    trainer \\\n        .bind_loader('train', train_loader) \\\n        .bind_loader('validate', validate_loader)\n\n    if args.cuda:\n        trainer.cuda()\n\n    # Go!\n    trainer.fit()\n\n\ndef main(argv):\n    # Training settings\n    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')\n    parser.add_argument('--batch-size', type=int, default=64, metavar='N',\n                        help='input batch size for training (default: 64)')\n    parser.add_argument('--save-directory', type=str, default='output/mnist/v1',\n                        help='output directory')\n    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',\n                        help='input batch size for testing (default: 1000)')\n    parser.add_argument('--epochs', type=int, default=20, metavar='N',\n                        help='number of epochs to train (default: 20)')\n    parser.add_argument('--no-cuda', action='store_true', default=False,\n                        help='disables CUDA training')\n    args = parser.parse_args(argv)\n    args.cuda = not args.no_cuda and torch.cuda.is_available()\n    train_model(args)\n\n\nif __name__ == '__main__':\n    main(sys.argv[1:])\n"
  },
  {
    "path": "examples/trainer.py",
    "content": "\"\"\"\nTrainer Example\n================================\n\nThis example should illustrate how to use the trainer class.\n\n\"\"\"\n\nimport torch.nn as nn\nfrom inferno.io.box.cifar import get_cifar10_loaders\nfrom inferno.trainers.basic import Trainer\nfrom inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger\nfrom inferno.extensions.layers import ConvELU2D\nfrom inferno.extensions.layers import Flatten\nfrom inferno.utils.python_utils import ensure_dir\n\nfrom inferno.extensions.layers import SELU\n\n##################################################\n# change directories to your needs\nLOG_DIRECTORY = ensure_dir('log')\nSAVE_DIRECTORY = ensure_dir('save')\nDATASET_DIRECTORY = ensure_dir('dataset')\n\n##################################################\n# shall models be downloaded\nDOWNLOAD_CIFAR = True\nUSE_CUDA = True\n\n##################################################\n# Build torch model\nmodel = nn.Sequential(\n    ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),\n    nn.MaxPool2d(kernel_size=2, stride=2),\n    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n    nn.MaxPool2d(kernel_size=2, stride=2),\n    ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),\n    nn.MaxPool2d(kernel_size=2, stride=2),\n    Flatten(),\n    nn.Linear(in_features=(256 * 4 * 4), out_features=10),\n    nn.Softmax()\n)\n\n##################################################\n# data loaders\ntrain_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,\n                                        download=DOWNLOAD_CIFAR)\n\n##################################################\n# Build trainer\ntrainer = Trainer(model)\ntrainer.build_criterion('CrossEntropyLoss')\ntrainer.build_metric('CategoricalError')\ntrainer.build_optimizer('Adam')\ntrainer.validate_every((2, 'epochs'))\ntrainer.save_every((5, 'epochs'))\ntrainer.save_to_directory(SAVE_DIRECTORY)\ntrainer.set_max_num_epochs(10)\ntrainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),\n                                log_images_every='never'), \n              log_directory=LOG_DIRECTORY)\n\n##################################################\n# Bind loaders\ntrainer.bind_loader('train', train_loader)\ntrainer.bind_loader('validate', validate_loader)\n\n##################################################\n# activate cuda\nif USE_CUDA:\n    trainer.cuda()\n\n##################################################\n# fit\ntrainer.fit()\n"
  },
  {
    "path": "inferno/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Top-level package for inferno.\"\"\"\n\nfrom . import extensions\nfrom . import io\nfrom . import trainers\nfrom . import utils\nfrom .version import __version__\n\n\n__all__ = ['extensions', 'io', 'trainers', 'utils']\n\n__author__ = \"\"\"Nasim Rahaman\"\"\"\n__email__ = 'nasim.rahaman@iwr.uni-heidelberg.de'\n"
  },
  {
    "path": "inferno/extensions/__init__.py",
    "content": "from . import containers\nfrom . import criteria\nfrom . import initializers\nfrom . import layers\nfrom . import metrics\nfrom . import optimizers\nfrom . import models\n# Backward support\nfrom . import models as model\n\n__all__ = ['containers', 'criteria', 'initializers', 'layers', 'metrics', 'optimizers',\n           'models', 'model']"
  },
  {
    "path": "inferno/extensions/containers/__init__.py",
    "content": "from .graph import *\nfrom .sequential import *\n"
  },
  {
    "path": "inferno/extensions/containers/graph.py",
    "content": "from collections import OrderedDict\nimport sys\nimport threading\nimport multiprocessing as mp\nimport copy\nimport gc\n\nimport networkx as nx\nfrom networkx import is_directed_acyclic_graph, topological_sort\nfrom torch import nn as nn\n\nfrom ...utils import python_utils as pyu\nfrom ...utils.exceptions import assert_\nfrom ..layers.device import OnDevice\nfrom ..layers.identity import Identity\n\n__all__ = ['NNGraph', 'Graph']\n\n\nclass NNGraph(nx.DiGraph):\n    \"\"\"A NetworkX DiGraph, except that node and edge ordering matters.\"\"\"\n    # We don't copy torch tensors, only to have them deleted.\n    ATTRIBUTES_TO_NOT_COPY = {'payload'}\n    node_dict_factory = OrderedDict\n    adjlist_dict_factory = OrderedDict\n\n    def copy(self, **init_kwargs):\n        new = type(self)(**init_kwargs)\n        # Remove all attributes and copy only the graph structure\n        for source, target in self.edges_iter():\n            # Add new nodes\n            new.add_node(source)\n            new.add_node(target)\n            # Copy attributes\n            new.node[source].update(copy.deepcopy({key: value\n                                                   for key, value in self.node[source].items()\n                                                   if key not in self.ATTRIBUTES_TO_NOT_COPY}))\n            new.node[target].update(copy.deepcopy({key: value\n                                                   for key, value in self.node[target].items()\n                                                   if key not in self.ATTRIBUTES_TO_NOT_COPY}))\n            # Add new edge\n            new.add_edge(copy.deepcopy(source), copy.deepcopy(target))\n            old_edge_attributes = self[source][target]\n            new_edge_attributes = {key: value for key, value in old_edge_attributes.items()\n                                   if key not in self.ATTRIBUTES_TO_NOT_COPY}\n            new_edge_attributes = copy.deepcopy(new_edge_attributes)\n            new[source][target].update(new_edge_attributes)\n        return new\n\n\nclass Graph(nn.Module):\n    \"\"\"\n    A graph structure to build networks with complex architectures. The resulting graph model\n    can be used like any other `torch.nn.Module`. The graph structure used behind the scenes\n    is a `networkx.DiGraph`. This internal graph is exposed by the `apply_on_graph` method,\n    which can be used with any NetworkX function (e.g. for plotting with matplotlib or GraphViz).\n\n    Examples\n    --------\n    The naive inception module (without the max-pooling for simplicity) with ELU-layers of 64 units\n    can be built as following, (assuming 64 input channels):\n\n        >>> from inferno.extensions.layers.reshape import Concatenate\n        >>> from inferno.extensions.layers.convolutional import ConvELU2D\n        >>> import torch\n        >>> # Build the model\n        >>> inception_module = Graph()\n        >>> inception_module.add_input_node('input')\n        >>> inception_module.add_node('conv1x1', ConvELU2D(64, 64, 3), previous='input')\n        >>> inception_module.add_node('conv3x3', ConvELU2D(64, 64, 3), previous='input')\n        >>> inception_module.add_node('conv5x5', ConvELU2D(64, 64, 3), previous='input')\n        >>> inception_module.add_node('cat', Concatenate(),\n        >>>                           previous=['conv1x1', 'conv3x3', 'conv5x5'])\n        >>> inception_module.add_output_node('output', 'cat')\n        >>> # Build dummy variable\n        >>> input = torch.rand(1, 64, 100, 100)\n        >>> # Get output\n        >>> output = inception_module(input)\n\n    \"\"\"\n    def __init__(self, graph=None):\n        \"\"\"\n        Construct the graph object.\n\n        Parameters\n        ----------\n            graph : networkx.DiGraph or NNGraph\n                Graph to build the object from (optional).\n        \"\"\"\n        super(Graph, self).__init__()\n        # Privates\n        self._thread_to_graph_mapping = {}\n        self._creator_thread = threading.get_ident()\n        self._creator_pid = mp.current_process().pid\n        # Publics\n        if graph is not None:\n            self.graph = graph\n        else:\n            self.graph = NNGraph()\n\n    @property\n    def graph(self):\n        # `graph` needs to be different for every thread, because torch.nn.parallel.replicate does\n        # not make a copy.\n        graph = self._thread_to_graph_mapping.get(threading.get_ident())\n        if graph is None:\n            creator_thread_graph = self._thread_to_graph_mapping.get(self._creator_thread)\n            assert creator_thread_graph is not None\n            graph = creator_thread_graph.copy()\n            # We don't need to clear payloads because the copy method of NNGraph copies only the\n            # graph structure and not the attributes\n            self._thread_to_graph_mapping.update({threading.get_ident(): graph})\n        return graph\n\n    @graph.setter\n    def graph(self, value):\n        assert_(isinstance(value, NNGraph), exception_type=TypeError)\n        self._thread_to_graph_mapping.update({threading.get_ident(): value})\n\n    def is_node_in_graph(self, name):\n        \"\"\"\n        Checks whether a node is in the graph.\n\n        Parameters\n        ----------\n        name : str\n            Name of the node.\n\n        Returns\n        -------\n        bool\n        \"\"\"\n        return name in self.graph.nodes\n\n    def is_source_node(self, name):\n        \"\"\"\n        Checks whether a given node (by name) is a source node.\n        A source node has no incoming edges.\n\n        Parameters\n        ----------\n        name : str\n            Name of the node.\n\n        Returns\n        -------\n        bool\n\n        Raises\n        ------\n        AssertionError\n            if node is not found in the graph.\n        \"\"\"\n        assert self.is_node_in_graph(name)\n        return self.graph.in_degree(name) == 0\n\n    def is_sink_node(self, name):\n        \"\"\"\n        Checks whether a given node (by name) is a sink node.\n        A sink node has no outgoing edges.\n\n        Parameters\n        ----------\n        name : str\n            Name of the node.\n\n        Returns\n        -------\n        bool\n\n        Raises\n        ------\n        AssertionError\n            if node is not found in the graph.\n        \"\"\"\n        assert self.is_node_in_graph(name)\n        return self.graph.out_degree(name) == 0\n\n    @property\n    def output_nodes(self):\n        \"\"\"\n        Gets a list of output nodes. The order is relevant and is the same as that\n        in which the forward method returns its outputs.\n\n        Returns\n        -------\n        list\n            A list of names (str) of the output nodes.\n        \"\"\"\n        return [name for name, node_attributes in self.graph.nodes.items()\n                if node_attributes.get('is_output_node', False)]\n\n    @property\n    def input_nodes(self):\n        \"\"\"\n        Gets a list of input nodes. The order is relevant and is the same as that\n        in which the forward method accepts its inputs.\n\n        Returns\n        -------\n        list\n            A list of names (str) of the input nodes.\n        \"\"\"\n        return [name for name, node_attributes in self.graph.nodes.items()\n                if node_attributes.get('is_input_node', False)]\n\n    @property\n    def graph_is_valid(self):\n        \"\"\"Checks if the graph is valid.\"\"\"\n        # Check if the graph is a DAG\n        is_dag = is_directed_acyclic_graph(self.graph)\n        # Check if output nodes are sinks\n        output_nodes_are_sinks = all([self.is_sink_node(name) for name in self.output_nodes])\n        # Check inf input nodes are sources\n        input_nodes_are_sources = all([self.is_source_node(name) for name in self.input_nodes])\n        # TODO Check whether only input nodes are sources and only output nodes are sinks\n        # Conclude\n        is_valid = is_dag and output_nodes_are_sinks and input_nodes_are_sources\n        return is_valid\n\n    def assert_graph_is_valid(self):\n        \"\"\"Asserts that the graph is valid.\"\"\"\n        assert is_directed_acyclic_graph(self.graph), \"Graph is not a DAG.\"\n        for name in self.output_nodes:\n            assert self.is_sink_node(name), \"Output node {} is not a sink.\".format(name)\n            assert not self.is_source_node(name), \"Output node {} is a source node. \" \\\n                                                  \"Make sure it's connected.\".format(name)\n        for name in self.input_nodes:\n            assert self.is_source_node(name), \"Input node {} is not a source.\".format(name)\n            assert not self.is_sink_node(name), \"Input node {} is a sink node. \" \\\n                                                \"Make sure it's connected.\".format(name)\n\n    def add_node(self, name, module, previous=None):\n        \"\"\"\n        Add a node to the graph.\n\n        Parameters\n        ----------\n        name : str\n            Name of the node. Nodes are identified by their names.\n\n        module : torch.nn.Module\n            Torch module for this node.\n\n        previous : str or list of str\n            (List of) name(s) of the previous node(s).\n\n        Returns\n        -------\n        Graph\n            self\n        \"\"\"\n        assert isinstance(module, nn.Module)\n        self.add_module(name, module)\n        self.graph.add_node(name)\n        if previous is not None:\n            for _previous in pyu.to_iterable(previous):\n                self.add_edge(_previous, name)\n        return self\n\n    def add_input_node(self, name):\n        \"\"\"\n        Add an input to the graph. The order in which input nodes are added is the\n        order in which the forward method accepts its inputs.\n\n        Parameters\n        ----------\n        name : str\n            Name of the input node.\n\n        Returns\n        -------\n        Graph\n            self\n        \"\"\"\n        self.add_module(name, Identity())\n        self.graph.add_node(name, is_input_node=True)\n        return self\n\n    def add_output_node(self, name, previous=None):\n        \"\"\"\n        Add an output to the graph. The order in which output nodes are added is the\n        order in which the forward method returns its outputs.\n\n        Parameters\n        ----------\n        name : str\n            Name of the output node.\n\n        Returns\n        -------\n        Graph\n            self\n        \"\"\"\n        self.graph.add_node(name, is_output_node=True)\n        if previous is not None:\n            for _previous in pyu.to_iterable(previous):\n                self.add_edge(_previous, name)\n        return self\n\n    def add_edge(self, from_node, to_node):\n        \"\"\"\n        Add an edge between two nodes.\n\n        Parameters\n        ----------\n        from_node : str\n            Name of the source node.\n        to_node : str\n            Name of the target node.\n\n        Returns\n        -------\n        Graph\n            self\n\n        Raises\n        ------\n        AssertionError\n            if either of the two nodes is not in the graph,\n            or if the edge is not 'legal'.\n        \"\"\"\n        assert self.is_node_in_graph(from_node)\n        assert self.is_node_in_graph(to_node)\n        self.graph.add_edge(from_node, to_node)\n        assert self.graph_is_valid\n        return self\n\n    def apply_on_graph(self, function, *args, **kwargs):\n        \"\"\"Applies a `function` on the internal graph.\"\"\"\n        return function(self, *args, **kwargs)\n\n    def get_module_for_nodes(self, names):\n        \"\"\"\n        Gets the `torch.nn.Module` object for nodes corresponding to `names`.\n\n        Parameters\n        ----------\n        names : str or list of str\n            Names of the nodes to fetch the modules of.\n\n        Returns\n        -------\n        list or torch.nn.Module\n            Module or a list of modules corresponding to `names`.\n\n        \"\"\"\n        names = pyu.to_iterable(names)\n        modules = []\n        for name in names:\n            assert self.is_node_in_graph(name), \"Node '{}' is not in graph.\".format(name)\n            module = getattr(self, name, None)\n            assert module is not None, \"Node '{}' is in the graph but could not find a module \" \\\n                                       \"corresponding to it.\".format(name)\n            modules.append(module)\n        return pyu.from_iterable(modules)\n\n    def to_device(self, names, target_device, device_ordinal=None, asynchronous=False):\n        \"\"\"Transfer nodes in the network to a specified device.\"\"\"\n        names = pyu.to_iterable(names)\n        for name in names:\n            assert self.is_node_in_graph(name), \"Node '{}' is not in graph.\".format(name)\n            module = getattr(self, name, None)\n            assert module is not None, \"Node '{}' is in the graph but could not find a module \" \\\n                                       \"corresponding to it.\".format(name)\n            # Transfer\n            module_on_device = OnDevice(module, target_device,\n                                        device_ordinal=device_ordinal,\n                                        asynchronous=asynchronous)\n            setattr(self, name, module_on_device)\n        return self\n\n    def get_parameters_for_nodes(self, names, named=False):\n        \"\"\"Get parameters of all nodes listed in `names`.\"\"\"\n        if not named:\n            parameters = (parameter\n                          for module in pyu.to_iterable(self.get_module_for_nodes(names))\n                          for parameter in module.parameters())\n        else:\n            parameters = ((name, parameter)\n                          for module in pyu.to_iterable(self.get_module_for_nodes(names))\n                          for name, parameter in module.named_parameters())\n        return parameters\n\n    def clear_payloads(self, graph=None):\n        graph = self.graph if graph is None else graph\n        for edge in list(graph.edges(data=True)):\n            source, target, _ = edge\n            if 'payload' in graph[source][target]:\n                del graph[source][target]['payload']\n\n    def forward_through_node(self, name, input=None):\n        # If input is a tuple/list, it will NOT be unpacked.\n        # Make sure the node is in the graph\n        if input is None:\n            # Make sure the node is not a source node\n            assert not self.is_source_node(name), \\\n                \"Node '{}' did not get an input but is a source node.\".format(name)\n            # Get input from payload\n            incoming_edges = self.graph.in_edges(name)\n            input = []\n            for incoming, this in incoming_edges:\n                # Append to input\n                input.append(self.graph[incoming][this]['payload'])\n                # Clear reference for the garbage collector to do its thing\n                del self.graph[incoming][this]['payload']\n        else:\n            assert self.is_node_in_graph(name)\n            # Convert input to list\n            input = [input]\n        # Get outputs\n        try:\n            outputs = pyu.to_iterable(getattr(self, name)(*input))\n        except Exception as e:\n            input_spec_string = \"\\n\".join([\"--[{}]-{}-->[{}]\".format(incoming,\n                                                                     tuple(_input.size()),\n                                                                     this)\n                                           for (incoming, this), _input in\n                                           zip(self.graph.in_edges(name), input)])\n\n            message = \"In node '{}': {}\\n\" \\\n                      \"Inputs to this node were:\\n{}\"\\\n                .format(name, str(e), input_spec_string)\n            raise type(e)(message).with_traceback(sys.exc_info()[2])\n        # Distribute outputs to outgoing payloads if required\n        if not self.is_sink_node(name):\n            outgoing_edges = self.graph.out_edges(name)\n            if len(outputs) == 1:\n                # Support for replication\n                outputs *= len(outgoing_edges)\n            # Make sure the number of outputs check out\n            assert len(outputs) == len(outgoing_edges), \\\n                \"Number of outputs from the model ({}) does not match the number \" \\\n                \"of out-edges ({}) in the graph for this node ('{}').\".format(len(outputs),\n                                                                              len(outgoing_edges),\n                                                                              name)\n            for (this, outgoing), output in zip(outgoing_edges, outputs):\n                self.graph[this][outgoing].update({'payload': output})\n        # Collect garbage to free some GPU memory?\n        del input\n        gc.collect()\n        # Return outputs\n        return pyu.from_iterable(outputs)\n\n    def forward(self, *inputs):\n        self.assert_graph_is_valid()\n        input_nodes = self.input_nodes\n        output_nodes = self.output_nodes\n        assert len(inputs) == len(input_nodes), \"Was expecting {} \" \\\n                                                \"arguments for as many input nodes, got {}.\"\\\n            .format(len(input_nodes), len(inputs))\n        # Unpack inputs to input nodes\n        for input, input_node in zip(inputs, input_nodes):\n            self.forward_through_node(input_node, input=input)\n        # Toposort the graph\n        toposorted = topological_sort(self.graph)\n        # Remove all input and output nodes\n        toposorted = [name for name in toposorted\n                      if name not in input_nodes and name not in output_nodes]\n        # Since we'll be clearing payloads anyway, it makes no sense whatsoever\n        # to evaluate sink nodes\n        toposorted = [name for name in toposorted if not self.is_sink_node(name)]\n        # Forward\n        for node in toposorted:\n            self.forward_through_node(node)\n        # Read outputs from output nodes\n        outputs = []\n        for output_node in output_nodes:\n            # Get all incoming edges to output node\n            outputs_from_node = [self.graph[incoming][this]['payload']\n                                 for incoming, this in self.graph.in_edges(output_node)]\n            outputs.append(pyu.from_iterable(outputs_from_node))\n        # Clear payloads for next pass\n        self.clear_payloads()\n        # Done.\n        return pyu.from_iterable(outputs)\n"
  },
  {
    "path": "inferno/extensions/containers/sequential.py",
    "content": "import torch.nn as nn\nfrom ...utils import python_utils as pyu\n\n\n__all__ = ['Sequential1', 'Sequential2']\n\n\nclass Sequential1(nn.Sequential):\n    \"\"\"Like torch.nn.Sequential, but with a few extra methods.\"\"\"\n    def __len__(self):\n        return len(self._modules.values())\n\n\nclass Sequential2(Sequential1):\n    \"\"\"Another sequential container.\n    Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and\n    accept multiple inputs.\n    \"\"\"\n    def forward(self, *input):\n        for module in self._modules.values():\n            input = pyu.to_iterable(module(*pyu.to_iterable(input)))\n        return pyu.from_iterable(input)\n"
  },
  {
    "path": "inferno/extensions/criteria/__init__.py",
    "content": "from .set_similarity_measures import *\nfrom .elementwise_measures import *\nfrom .core import *\nfrom .regularized import *\n\n__all__ = ['set_similarity_measures', 'elementwise_measures','core','regularized']"
  },
  {
    "path": "inferno/extensions/criteria/core.py",
    "content": "import torch.nn as nn\nfrom functools import reduce\nfrom ...utils.exceptions import assert_, ShapeError, NotTorchModuleError\n\n\n__all__ = ['Criteria', 'As2DCriterion']\n\n\nclass Criteria(nn.Module):\n    \"\"\"Aggregate multiple criteria to one.\"\"\"\n    def __init__(self, *criteria):\n        super(Criteria, self).__init__()\n        if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)):\n            criteria = list(criteria[0])\n        else:\n            criteria = list(criteria)\n        # Validate criteria\n        assert all([isinstance(criterion, nn.Module) for criterion in criteria]), \\\n            \"Criterion must be a torch module.\"\n        self.criteria = criteria\n\n    def forward(self, prediction, target):\n        assert isinstance(prediction, (list, tuple)), \\\n            \"`prediction` must be a list or a tuple, got {} instead.\"\\\n                .format(type(prediction).__name__)\n        assert isinstance(target, (list, tuple)), \\\n            \"`prediction` must be a list or a tuple, got {} instead.\" \\\n                .format(type(target).__name__)\n        assert len(prediction) == len(target), \\\n            \"Number of predictions must equal the number of targets. \" \\\n            \"Got {} predictions but {} targets.\".format(len(prediction), len(target))\n        # Compute losses\n        losses = [criterion(prediction, target)\n                  for _prediction, _target, criterion in zip(prediction, target, self.criteria)]\n        # Aggegate losses\n        loss = reduce(lambda x, y: x + y, losses)\n        # Done\n        return loss\n\n\nclass As2DCriterion(nn.Module):\n    \"\"\"\n    Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors,\n    if they're applicable to (N, C) prediction and (N,) target tensors .\n    \"\"\"\n    def __init__(self, criterion):\n        super(As2DCriterion, self).__init__()\n        assert_(isinstance(criterion, nn.Module),\n                \"Criterion must be a module, got a {} instead.\"\n                .format(type(criterion).__name__),\n                NotTorchModuleError)\n        self.criterion = criterion\n\n    def forward(self, prediction, target):\n        # Validate input\n        assert_(prediction.dim() == 4, \"`prediction` is expected to be a 4D tensor of shape \"\n                                       \"(N, C, H, W), got a {}D \"\n                                       \"tensor instead.\".format(prediction.dim()),\n                ShapeError)\n        assert_(target.dim() == 3, \"`target` is expected to be a 3D tensor of shape \"\n                                   \"(N, H, W), got a {}D \"\n                                   \"tensor instead.\".format(target.dim()),\n                ShapeError)\n        # prediction is assumed to be NCHW, and target NHW.\n        # this makes target (NHW,)\n        target = target.contiguous().view(-1)\n        # This makes prediction (N, H, W, C) --> (NHW, C)\n        num_channels = prediction.size(1)\n        prediction = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_channels)\n        # Now, the criterion should be applicable as is\n        loss = self.criterion(prediction, target)\n        return loss\n"
  },
  {
    "path": "inferno/extensions/criteria/elementwise_measures.py",
    "content": "import torch.nn as nn\nfrom ...utils.exceptions import assert_\n\n\nclass WeightedMSELoss(nn.Module):\n    NEGATIVE_CLASS_WEIGHT = 1.\n\n    def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True):\n        super(WeightedMSELoss, self).__init__()\n        assert_(positive_class_weight >= 0,\n                \"Positive class weight can't be less than zero, got {}.\"\n                .format(positive_class_weight),\n                ValueError)\n        self.mse = nn.MSELoss(size_average=size_average)\n        self.positive_class_weight = positive_class_weight\n        self.positive_class_value = positive_class_value\n\n    def forward(self, input, target):\n        # Get a mask\n        positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data)\n        # Get differential weights (positive_weight - negative_weight,\n        # i.e. subtract 1, assuming the negative weight is gauged at 1)\n        weight_differential = (positive_class_mask\n                               .mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT))\n        # Get final weight by adding weight differential to a tensor with negative weights\n        weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT)\n        # `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with.\n        sqrt_weights = weights.sqrt_()\n        return self.mse(input * sqrt_weights, target * sqrt_weights)\n"
  },
  {
    "path": "inferno/extensions/criteria/regularized.py",
    "content": "import warnings\n\nimport torch\nfrom torch import nn\n\nfrom . import set_similarity_measures, core\n\n__all__ = [\n    'RegularizedLoss',\n    'RegularizedCrossEntropyLoss',\n    'RegularizedBCEWithLogitsLoss',\n    'RegularizedBCELoss',\n    'RegularizedMSELoss',\n    'RegularizedNLLLoss'\n]\n\n\ndef collect_losses(module):\n    \"\"\"Collect `_losses` dictionaries from module and children\n\n    :param module: a Module to be searched for losses\n    :return: dictionary of loss names to values\n    \"\"\"\n    losses = {}\n\n    def _collect(m):\n        if hasattr(m, '_losses'):\n            for k, v in m._losses.items():\n                if k in losses:\n                    losses[k] = losses[k] + v\n                else:\n                    losses[k] = v\n\n    module.apply(_collect)\n    return losses\n\n\ndef build_criterion(criterion, *args, **kwargs):\n    \"\"\"Build a criterion\n\n    :param criterion: criterion class, name of criterion class, or instance of criterion\n    :param args: args for constructor\n    :param kwargs: kwargs for constructor\n    :return: instance of criterion\n    \"\"\"\n    if isinstance(criterion, str):\n        for module in [nn, core, set_similarity_measures]:\n            criterion_class = getattr(module, criterion, None)\n            if criterion_class is not None:\n                break\n        assert criterion_class is not None, \"Criterion {} not found.\".format(criterion)\n    elif callable(criterion) and isinstance(criterion, type):\n        criterion_class = criterion\n    elif isinstance(criterion, torch.nn.Module):\n        return criterion\n    else:\n        raise NotImplementedError\n    return criterion_class(*args, **kwargs)\n\n\nclass RegularizedLoss(nn.Module):\n    \"\"\"Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.\n    \"\"\"\n\n    def __init__(self, criterion, *args, **kwargs):\n        super(RegularizedLoss, self).__init__()\n        self.criterion = build_criterion(criterion, *args, **kwargs)\n\n    def forward(self, *args, trainer=None, model=None, **kwargs):\n        # calculate wrapped loss\n        main_loss = self.criterion(*args, **kwargs)\n\n        # If no trainer, we cannot record states\n        if trainer is None:\n            warnings.warn('No trainer parameter provided. Not logging regularization losses.')\n        elif model is None:\n            model = trainer.model\n\n        # If no model or trainer, we cannot record states or collect losses\n        if model is None:\n            warnings.warn('No model or trainer parameter provided. Not calculating regularization losses.')\n            regularization_losses = {}\n            total_regularization_loss = None\n            total_loss = main_loss\n        else:\n            regularization_losses = collect_losses(model)\n            total_regularization_loss = sum(regularization_losses.values())\n            total_loss = main_loss + total_regularization_loss\n\n        # Record losses if trainer provided\n        if trainer is not None:\n            # prefix depending on mode\n            if self.training:\n                prefix = 'training'\n            else:\n                prefix = 'validation'\n            # main loss\n            updates = {'{}_main_loss'.format(prefix): main_loss}\n            # total regulariztion loss\n            if total_regularization_loss is not None:\n                updates['{}_total_regularization_loss'.format(prefix)] = total_regularization_loss\n            # detailed regularization losses\n            for k, v in regularization_losses.items():\n                updates['{}_{}'.format(prefix, k)] = v\n            # record state\n            trainer.update_state_from_dictionary(updates)\n\n        return total_loss\n\n\n# Convenience wrappers for common losses\nclass RegularizedCrossEntropyLoss(RegularizedLoss):\n    def __init__(self, *args, **kwargs):\n        super(RegularizedCrossEntropyLoss, self).__init__(nn.CrossEntropyLoss, *args, **kwargs)\n\n\nclass RegularizedBCEWithLogitsLoss(RegularizedLoss):\n    def __init__(self, *args, **kwargs):\n        super(RegularizedBCEWithLogitsLoss, self).__init__(nn.BCEWithLogitsLoss, *args, **kwargs)\n\n\nclass RegularizedBCELoss(RegularizedLoss):\n    def __init__(self, *args, **kwargs):\n        super(RegularizedBCELoss, self).__init__(nn.BCELoss, *args, **kwargs)\n\n\nclass RegularizedMSELoss(RegularizedLoss):\n    def __init__(self, *args, **kwargs):\n        super(RegularizedMSELoss, self).__init__(nn.MSELoss, *args, **kwargs)\n\n\nclass RegularizedNLLLoss(RegularizedLoss):\n    def __init__(self, *args, **kwargs):\n        super(RegularizedNLLLoss, self).__init__(nn.NLLLoss, *args, **kwargs)\n"
  },
  {
    "path": "inferno/extensions/criteria/set_similarity_measures.py",
    "content": "import torch.nn as nn\nfrom ...utils.torch_utils import flatten_samples\n\n__all__ = ['SorensenDiceLoss', 'GeneralizedDiceLoss']\n\n\nclass SorensenDiceLoss(nn.Module):\n    \"\"\"\n    Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity\n    between the input and the target. For both inputs and targets it must be the case that\n    `input_or_target.size(1) = num_channels`.\n    \"\"\"\n    def __init__(self, weight=None, channelwise=True, eps=1e-6):\n        \"\"\"\n        Parameters\n        ----------\n        weight : torch.FloatTensor or torch.cuda.FloatTensor\n            Class weights. Applies only if `channelwise = True`.\n        channelwise : bool\n            Whether to apply the loss channelwise and sum the results (True)\n            or to apply it on all channels jointly (False).\n        \"\"\"\n        super(SorensenDiceLoss, self).__init__()\n        self.register_buffer('weight', weight)\n        self.channelwise = channelwise\n        self.eps = eps\n\n    def forward(self, input, target):\n        \"\"\"\n        input:      torch.FloatTensor or torch.cuda.FloatTensor\n        target:     torch.FloatTensor or torch.cuda.FloatTensor\n\n        Expected shape of the inputs: (batch_size, nb_channels, ...)\n        \"\"\"\n        assert input.size() == target.size()\n        if not self.channelwise:\n            numerator = (input * target).sum()\n            denominator = (input * input).sum() + (target * target).sum()\n            loss = -2. * (numerator / denominator.clamp(min=self.eps))\n        else:\n            # TODO This should be compatible with Pytorch 0.2, but check\n            # Flatten input and target to have the shape (C, N),\n            # where N is the number of samples\n            input = flatten_samples(input)\n            target = flatten_samples(target)\n            # Compute numerator and denominator (by summing over samples and\n            # leaving the channels intact)\n            numerator = (input * target).sum(-1)\n            denominator = (input * input).sum(-1) + (target * target).sum(-1)\n            channelwise_loss = -2 * (numerator / denominator.clamp(min=self.eps))\n            if self.weight is not None:\n                # With pytorch < 0.2, channelwise_loss.size = (C, 1).\n                if channelwise_loss.dim() == 2:\n                    channelwise_loss = channelwise_loss.squeeze(1)\n                assert self.weight.size() == channelwise_loss.size()\n                # Apply weight\n                channelwise_loss = self.weight * channelwise_loss\n            # Sum over the channels to compute the total loss\n            loss = channelwise_loss.sum()\n        return loss\n\n\nclass GeneralizedDiceLoss(nn.Module):\n    \"\"\"\n    Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237\n\n    This version works for multiple classes and expects predictions for every class (e.g. softmax output) and\n    one-hot targets for every class.\n    \"\"\"\n    def __init__(self, weight=None, channelwise=False, eps=1e-6):\n        super(GeneralizedDiceLoss, self).__init__()\n        self.register_buffer('weight', weight)\n        self.channelwise = channelwise\n        self.eps = eps\n\n    def forward(self, input, target):\n        \"\"\"\n        input: torch.FloatTensor or torch.cuda.FloatTensor\n        target:     torch.FloatTensor or torch.cuda.FloatTensor\n\n        Expected shape of the inputs:\n            - if not channelwise: (batch_size, nb_classes, ...)\n            - if channelwise:     (batch_size, nb_channels, nb_classes, ...)\n        \"\"\"\n        assert input.size() == target.size()\n        if not self.channelwise:\n            # Flatten input and target to have the shape (nb_classes, N),\n            # where N is the number of samples\n            input = flatten_samples(input)\n            target = flatten_samples(target)\n\n            # Find classes weights:\n            sum_targets = target.sum(-1)\n            class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)\n\n            # Compute generalized Dice loss:\n            numer = ((input * target).sum(-1) * class_weigths).sum()\n            denom = ((input + target).sum(-1) * class_weigths).sum()\n\n            loss = 1. - 2. * numer / denom.clamp(min=self.eps)\n        else:\n            def flatten_and_preserve_channels(tensor):\n                tensor_dim = tensor.dim()\n                assert tensor_dim >= 3\n                num_channels = tensor.size(1)\n                num_classes = tensor.size(2)\n                # Permute the channel axis to first\n                permute_axes = list(range(tensor_dim))\n                permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0]\n                permuted = tensor.permute(*permute_axes).contiguous()\n                flattened = permuted.view(num_channels, num_classes, -1)\n                return flattened\n\n            # Flatten input and target to have the shape (nb_channels, nb_classes, N)\n            input = flatten_and_preserve_channels(input)\n            target = flatten_and_preserve_channels(target)\n\n            # Find classes weights:\n            sum_targets = target.sum(-1)\n            class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)\n\n            # Compute generalized Dice loss:\n            numer = ((input * target).sum(-1) * class_weigths).sum(-1)\n            denom = ((input + target).sum(-1) * class_weigths).sum(-1)\n\n            channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps)\n\n            if self.weight is not None:\n                if channelwise_loss.dim() == 2:\n                    channelwise_loss = channelwise_loss.squeeze(1)\n                assert self.weight.size() == channelwise_loss.size(),\\\n                    \"\"\"`weight` should have shape (nb_channels, ),\n                       `target` should have shape (batch_size, nb_channels, nb_classes, ...)\"\"\"\n                # Apply channel weights:\n                channelwise_loss = self.weight * channelwise_loss\n\n            loss = channelwise_loss.sum()\n\n        return loss\n"
  },
  {
    "path": "inferno/extensions/initializers/__init__.py",
    "content": "from .base import *\nfrom .presets import *\n\n"
  },
  {
    "path": "inferno/extensions/initializers/base.py",
    "content": "import torch.nn.init as init\n\n\n__all__ = ['Initializer',\n           'Initialization',\n           'WeightInitFunction',\n           'BiasInitFunction',\n           'TensorInitFunction']\n\n\nclass Initializer(object):\n    \"\"\"\n    Base class for all initializers.\n    \"\"\"\n\n    # TODO Support LSTMs and GRUs\n    VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d',\n                    'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',\n                    'Linear', 'Bilinear',\n                    'Embedding'}\n\n    def __call__(self, module):\n        module_class_name = module.__class__.__name__\n        if module_class_name in self.VALID_LAYERS:\n            # Apply to weight and bias\n            try:\n                if hasattr(module, 'weight'):\n                    self.call_on_weight(module.weight.data)\n            except NotImplementedError:\n                # Don't cry if it's not implemented\n                pass\n\n            try:\n                if hasattr(module, 'bias'):\n                    self.call_on_bias(module.bias.data)\n            except NotImplementedError:\n                pass\n\n        return module\n\n    def call_on_bias(self, tensor):\n        return self.call_on_tensor(tensor)\n\n    def call_on_weight(self, tensor):\n        return self.call_on_tensor(tensor)\n\n    def call_on_tensor(self, tensor):\n        raise NotImplementedError\n\n    @classmethod\n    def initializes_weight(cls):\n        return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__\n\n    @classmethod\n    def initializes_bias(cls):\n        return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__\n\n\nclass Initialization(Initializer):\n    def __init__(self, weight_initializer=None, bias_initializer=None):\n        if weight_initializer is None:\n            self.weight_initializer = Initializer()\n        else:\n            if isinstance(weight_initializer, Initializer):\n                assert weight_initializer.initializes_weight()\n                self.weight_initializer = weight_initializer\n            elif isinstance(weight_initializer, str):\n                init_function = getattr(init, weight_initializer, None)\n                assert init_function is not None\n                self.weight_initializer = WeightInitFunction(init_function=init_function)\n            else:\n                # Provison for weight_initializer to be a function\n                assert callable(weight_initializer)\n                self.weight_initializer = WeightInitFunction(init_function=weight_initializer)\n\n        if bias_initializer is None:\n            self.bias_initializer = Initializer()\n        else:\n            if isinstance(bias_initializer, Initializer):\n                assert bias_initializer.initializes_bias\n                self.bias_initializer = bias_initializer\n            elif isinstance(bias_initializer, str):\n                init_function = getattr(init, bias_initializer, None)\n                assert init_function is not None\n                self.bias_initializer = BiasInitFunction(init_function=init_function)\n            else:\n                assert callable(bias_initializer)\n                self.bias_initializer = BiasInitFunction(init_function=bias_initializer)\n\n    def call_on_weight(self, tensor):\n        return self.weight_initializer.call_on_weight(tensor)\n\n    def call_on_bias(self, tensor):\n        return self.bias_initializer.call_on_bias(tensor)\n\n\nclass WeightInitFunction(Initializer):\n    def __init__(self, init_function, *init_function_args, **init_function_kwargs):\n        super(WeightInitFunction, self).__init__()\n        assert callable(init_function)\n        self.init_function = init_function\n        self.init_function_args = init_function_args\n        self.init_function_kwargs = init_function_kwargs\n\n    def call_on_weight(self, tensor):\n        return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)\n\n\nclass BiasInitFunction(Initializer):\n    def __init__(self, init_function, *init_function_args, **init_function_kwargs):\n        super(BiasInitFunction, self).__init__()\n        assert callable(init_function)\n        self.init_function = init_function\n        self.init_function_args = init_function_args\n        self.init_function_kwargs = init_function_kwargs\n\n    def call_on_bias(self, tensor):\n        return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)\n\n\nclass TensorInitFunction(Initializer):\n    def __init__(self, init_function, *init_function_args, **init_function_kwargs):\n        super(TensorInitFunction, self).__init__()\n        assert callable(init_function)\n        self.init_function = init_function\n        self.init_function_args = init_function_args\n        self.init_function_kwargs = init_function_kwargs\n\n    def call_on_tensor(self, tensor):\n        return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)\n\n"
  },
  {
    "path": "inferno/extensions/initializers/presets.py",
    "content": "import numpy as np\nimport torch.nn.init as init\nfrom functools import partial\n\nfrom .base import Initialization, Initializer\n\n\n__all__ = ['Constant', 'NormalWeights',\n           'SELUWeightsZeroBias',\n           'ELUWeightsZeroBias',\n           'OrthogonalWeightsZeroBias',\n           'KaimingNormalWeightsZeroBias']\n\n\nclass Constant(Initializer):\n    \"\"\"Initialize with a constant.\"\"\"\n    def __init__(self, constant):\n        self.constant = constant\n\n    def call_on_tensor(self, tensor):\n        tensor.fill_(self.constant)\n        return tensor\n\n\nclass NormalWeights(Initializer):\n    \"\"\"\n    Initialize weights with random numbers drawn from the normal distribution at\n    `mean` and `stddev`.\n    \"\"\"\n    def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None):\n        self.mean = mean\n        self.stddev = stddev\n        self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in\n\n    def compute_fan_in(self, tensor):\n        if tensor.dim() == 2:\n            return tensor.size(1)\n        else:\n            return np.prod(list(tensor.size())[1:])\n\n    def call_on_weight(self, tensor):\n        # Compute stddev if required\n        if self.sqrt_gain_over_fan_in is not None:\n            stddev = self.stddev * \\\n                     np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor))\n        else:\n            stddev = self.stddev\n        # Init\n        tensor.normal_(self.mean, stddev)\n\n\nclass OrthogonalWeightsZeroBias(Initialization):\n    def __init__(self, orthogonal_gain=1.):\n        # This prevents a deprecated warning in Pytorch 0.4+\n        orthogonal = getattr(init, 'orthogonal_', init.orthogonal)\n        super(OrthogonalWeightsZeroBias, self)\\\n            .__init__(weight_initializer=partial(orthogonal, gain=orthogonal_gain),\n                      bias_initializer=Constant(0.))\n\n\nclass KaimingNormalWeightsZeroBias(Initialization):\n    def __init__(self, relu_leakage=0):\n        # This prevents a deprecated warning in Pytorch 0.4+\n        kaiming_normal = getattr(init, 'kaiming_normal_', init.kaiming_normal)\n        super(KaimingNormalWeightsZeroBias, self)\\\n            .__init__(weight_initializer=partial(kaiming_normal, a=relu_leakage),\n                      bias_initializer=Constant(0.))\n\n\nclass SELUWeightsZeroBias(Initialization):\n    def __init__(self):\n        super(SELUWeightsZeroBias, self)\\\n            .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.),\n                      bias_initializer=Constant(0.))\n\n\nclass ELUWeightsZeroBias(Initialization):\n    def __init__(self):\n        super(ELUWeightsZeroBias, self)\\\n            .__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),\n                      bias_initializer=Constant(0.))\n"
  },
  {
    "path": "inferno/extensions/layers/__init__.py",
    "content": "__all__ = []\nfrom .activations import *\nfrom .convolutional import *\nfrom .device import *\nfrom .reshape import *\nfrom .convolutional_blocks import *\n\n#######################################################\n# the following is to make the sphinx example\n# gallery makes proper cross-references\nfrom .activations       import _all as _activations_all\nfrom .convolutional     import _all as _convolutional_all\nfrom .device            import _all as _device_all\nfrom .reshape           import _all as _reshape_all\nfrom .convolutional_blocks   import _all as _convolutional_blocks_all\nfrom .identity          import _all as _identity_all\n\n__all__.extend(_activations_all)\n__all__.extend(_convolutional_all)\n__all__.extend(_device_all)\n__all__.extend(_reshape_all)\n__all__.extend(_convolutional_blocks_all)\n__all__.extend(_identity_all)\n\n_all = __all__\n"
  },
  {
    "path": "inferno/extensions/layers/activations.py",
    "content": "import torch.nn.functional as F\nimport torch.nn as nn\nfrom ...utils.torch_utils import where\n\n__all__ = ['SELU']\n_all = __all__\n\nclass SELU(nn.Module):\n    def forward(self, input):\n        return self.selu(input)\n\n    @staticmethod\n    def selu(x):\n        alpha = 1.6732632423543772848170429916717\n        scale = 1.0507009873554804934193349852946\n        # noinspection PyTypeChecker\n        return scale * where(x >= 0, x, alpha * F.elu(x))"
  },
  {
    "path": "inferno/extensions/layers/convolutional.py",
    "content": "import torch.nn as nn\nimport sys\nimport functools\nfrom ..initializers import (\n    OrthogonalWeightsZeroBias,\n    KaimingNormalWeightsZeroBias,\n    SELUWeightsZeroBias,\n)\nfrom ..initializers import Initializer\nfrom .normalization import BatchNormND\nfrom .activations import SELU\nfrom ...utils.exceptions import assert_, ShapeError\nfrom ...utils.partial_cls import register_partial_cls\n\n# we append to this later on\n__all__ = [\n    \"GlobalConv2D\",\n]\n_all = __all__\n\nregister_partial_cls_here = functools.partial(register_partial_cls, module=__name__)\n\n\nclass ConvActivation(nn.Module):\n    \"\"\"Convolutional layer with 'SAME' padding by default followed by an activation.\"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        dim,\n        activation,\n        stride=1,\n        dilation=1,\n        groups=None,\n        depthwise=False,\n        bias=True,\n        deconv=False,\n        initialization=None,\n        valid_conv=False,\n    ):\n        super(ConvActivation, self).__init__()\n        # Validate dim\n        assert_(\n            dim in [1, 2, 3],\n            \"`dim` must be one of [1, 2, 3], got {}.\".format(dim),\n            ShapeError,\n        )\n        self.dim = dim\n        # Check if depthwise\n        if depthwise:\n\n            # We know that in_channels == out_channels, but we also want a consistent API.\n            # As a compromise, we allow that out_channels be None or 'auto'.\n            out_channels = in_channels if out_channels in [None, \"auto\"] else out_channel\n            assert_(\n                in_channels == out_channels,\n                \"For depthwise convolutions, number of input channels (given: {}) \"\n                \"must equal the number of output channels (given {}).\".format(\n                    in_channels, out_channels\n                ),\n                ValueError,\n            )\n            assert_(\n                groups is None or groups == in_channels,\n                \"For depthwise convolutions, groups (given: {}) must \"\n                \"equal the number of channels (given: {}).\".format(groups, in_channels),\n            )\n            groups = in_channels\n        else:\n            groups = 1 if groups is None else groups\n        self.depthwise = depthwise\n        if valid_conv:\n            self.conv = getattr(nn, \"Conv{}d\".format(self.dim))(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n            )\n        elif not deconv:\n            # Get padding\n            padding = self.get_padding(kernel_size, dilation)\n            self.conv = getattr(nn, \"Conv{}d\".format(self.dim))(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                padding=padding,\n                stride=stride,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n            )\n        else:\n            self.conv = getattr(nn, \"ConvTranspose{}d\".format(self.dim))(\n                in_channels=in_channels,\n                out_channels=out_channels,\n                kernel_size=kernel_size,\n                stride=stride,\n                dilation=dilation,\n                groups=groups,\n                bias=bias,\n            )\n        if initialization is None:\n            pass\n        elif isinstance(initialization, Initializer):\n            self.conv.apply(initialization)\n        else:\n            raise NotImplementedError\n\n        if isinstance(activation, str):\n            self.activation = getattr(nn, activation)()\n        elif isinstance(activation, nn.Module):\n            self.activation = activation\n        elif activation is None:\n            self.activation = None\n        else:\n            raise NotImplementedError\n\n    def forward(self, input):\n        conved = self.conv(input)\n        if self.activation is not None:\n            activated = self.activation(conved)\n        else:\n            # No activation\n            activated = conved\n        return activated\n\n    def _pair_or_triplet(self, object_):\n        if isinstance(object_, (list, tuple)):\n            assert len(object_) == self.dim\n            return object_\n        else:\n            object_ = [object_] * self.dim\n            return object_\n\n    def _get_padding(self, _kernel_size, _dilation):\n        assert isinstance(_kernel_size, int)\n        assert isinstance(_dilation, int)\n        assert _kernel_size % 2 == 1\n        return ((_kernel_size - 1) // 2) * _dilation\n\n    def get_padding(self, kernel_size, dilation):\n        kernel_size = self._pair_or_triplet(kernel_size)\n        dilation = self._pair_or_triplet(dilation)\n        padding = [\n            self._get_padding(_kernel_size, _dilation)\n            for _kernel_size, _dilation in zip(kernel_size, dilation)\n        ]\n        return tuple(padding)\n\n# for consistency\nConvActivationND = ConvActivation\n\n\n# noinspection PyUnresolvedReferences\nclass _BNReLUSomeConv(object):\n    def forward(self, input):\n        normed = self.batchnorm(input)\n        activated = self.activation(normed)\n        conved = self.conv(activated)\n        return conved\n\nclass BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation):\n    def __init__(self, in_channels, out_channels, kernel_size, dim, stride=1, dilation=1, deconv=False):\n\n        super(BNReLUConvBaseND, self).__init__(\n            in_channels=in_channels,\n            out_channels=out_channels,\n            kernel_size=kernel_size,\n            dim=dim,\n            stride=stride,\n            activation=nn.ReLU(inplace=True),\n            dilation=dilation,\n            deconv=deconv,\n            initialization=KaimingNormalWeightsZeroBias(0),\n        )\n        self.batchnorm = BatchNormND(dim, in_channels)\n\n\ndef _register_conv_cls(conv_name,  fix=None, default=None):\n    if fix is None:\n        fix = {}\n    if default is None:\n        default = {}\n\n    # simple conv activation\n    activations = [\"ReLU\", \"ELU\", \"Sigmoid\", \"SELU\", \"\"]\n    init_map = {\n        \"ReLU\": KaimingNormalWeightsZeroBias,\n        \"SELU\": SELUWeightsZeroBias\n    }\n    for activation_str in activations:\n        cls_name = cls_name = \"{}{}ND\".format(conv_name,activation_str)\n        __all__.append(cls_name)\n        initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias)\n        if activation_str == \"\":\n            activation = None\n            _fix = {**fix}\n            _default = {'activation':None}\n        elif activation_str == \"SELU\":\n            activation = nn.SELU(inplace=True)\n            _fix={**fix, 'activation':activation}\n            _default = {**default}\n        else:\n            activation = activation_str\n            _fix={**fix, 'activation':activation}\n            _default = {**default}\n\n        register_partial_cls_here(ConvActivation, cls_name,\n            fix=_fix,\n            default={**_default, 'initialization':initialization_cls()}\n        )\n        for dim in [1, 2, 3]:\n            cls_name = \"{}{}{}D\".format(conv_name,activation_str, dim)\n            __all__.append(cls_name)\n            register_partial_cls_here(ConvActivation, cls_name,\n                fix={**_fix, 'dim':dim},\n                default={**_default, 'initialization':initialization_cls()}\n            )\n\ndef _register_bnr_conv_cls(conv_name,  fix=None, default=None):\n    if fix is None:\n        fix = {}\n    if default is None:\n        default = {}\n    for dim in [1, 2, 3]:\n\n        cls_name = \"BNReLU{}ND\".format(conv_name)\n        __all__.append(cls_name)\n        register_partial_cls_here(BNReLUConvBaseND, cls_name,fix=fix,default=default)\n\n        for dim in [1, 2, 3]:\n            cls_name = \"BNReLU{}{}D\".format(conv_name, dim)\n            __all__.append(cls_name)\n            register_partial_cls_here(BNReLUConvBaseND, cls_name,\n                fix={**fix, 'dim':dim},\n                default=default)\n\n# conv classes\n_register_conv_cls(\"Conv\")\n_register_conv_cls(\"ValidConv\",  fix=dict(valid_conv=True))\n_register_conv_cls(\"Deconv\", fix=dict(deconv=True), default=dict(kernel_size=2, stride=2))\n_register_conv_cls(\"StridedConv\",  default=dict(stride=2))\n_register_conv_cls(\"DilatedConv\",  fix=dict(dilation=2))\n_register_conv_cls(\"DepthwiseConv\", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto'))\n\n# BatchNormRelu classes\n_register_bnr_conv_cls(\"Conv\", fix=dict(deconv=False))\n_register_bnr_conv_cls(\"Deconv\", fix=dict(deconv=True))\n_register_bnr_conv_cls(\"StridedConv\",  default=dict(stride=2))\n_register_bnr_conv_cls(\"DilatedConv\",  default=dict(dilation=2))\n_register_bnr_conv_cls(\"DepthwiseConv\", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto'))\n\ndel _register_conv_cls\ndel _register_bnr_conv_cls\n\n\n\n\nclass GlobalConv2D(nn.Module):\n    \"\"\"From https://arxiv.org/pdf/1703.02719.pdf\n    Main idea: we can have a bigger kernel size computationally acceptable\n    if we separate 2D-conv in 2 1D-convs \"\"\"\n\n    def __init__(\n        self,\n        in_channels,\n        out_channels,\n        kernel_size,\n        local_conv_type,\n        activation=None,\n        use_BN=False,\n        **kwargs\n    ):\n        super(GlobalConv2D, self).__init__()\n        self.in_channels = in_channels\n        self.out_channels = out_channels\n        self.kernel_size = kernel_size\n        assert isinstance(kernel_size, (int, list, tuple))\n        if isinstance(kernel_size, int):\n            kernel_size = (kernel_size,) * 2\n        self.kwargs = kwargs\n        self.conv1a = local_conv_type(\n            in_channels=self.in_channels,\n            out_channels=self.out_channels,\n            kernel_size=(kernel_size[0], 1),\n            **kwargs\n        )\n        self.conv1b = local_conv_type(\n            in_channels=self.out_channels,\n            out_channels=self.out_channels,\n            kernel_size=(1, kernel_size[1]),\n            **kwargs\n        )\n        self.conv2a = local_conv_type(\n            in_channels=self.in_channels,\n            out_channels=self.out_channels,\n            kernel_size=(1, kernel_size[1]),\n            **kwargs\n        )\n        self.conv2b = local_conv_type(\n            in_channels=self.out_channels,\n            out_channels=self.out_channels,\n            kernel_size=(kernel_size[0], 1),\n            **kwargs\n        )\n        if use_BN:\n            self.batchnorm = nn.BatchNorm2d(self.out_channels)\n        else:\n            self.batchnorm = None\n        self.activation = activation\n\n    def forward(self, input_):\n        out1 = self.conv1a(input_)\n        out1 = self.conv1b(out1)\n        out2 = self.conv2a(input_)\n        out2 = self.conv2b(out2)\n        out = out1.add(1, out2)\n        if self.activation is not None:\n            out = self.activation(out)\n        if self.batchnorm is not None:\n            out = self.batchnorm(out)\n        return out\n"
  },
  {
    "path": "inferno/extensions/layers/convolutional_blocks.py",
    "content": "import torch.nn as nn\nfrom .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D\nfrom ...utils import python_utils as pyu\nfrom ...utils.exceptions import assert_\n\n__all__ = ['ResidualBlock', 'PreActSimpleResidualBlock']\n_all = __all__\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, layers, resample=None):\n        super(ResidualBlock, self).__init__()\n        assert pyu.is_listlike(layers)\n        self.layers = nn.Sequential(*layers)\n        self.resample = resample\n\n    def forward(self, input):\n        preaddition = self.layers(input)\n        if self.resample is not None:\n            skip = self.resample(input)\n        else:\n            skip = input\n        output = preaddition + skip\n        return output\n\n\nclass PreActSimpleResidualBlock(ResidualBlock):\n    def __init__(self, in_channels, num_hidden_channels, upsample=False, downsample=False):\n        layers = []\n        if downsample:\n            assert_(not upsample, \"Both downsample and upsample is set to true.\", ValueError)\n            layers.append(BNReLUConv2D(in_channels=in_channels,\n                                       out_channels=num_hidden_channels,\n                                       kernel_size=3,\n                                       stride=2))\n            resample = nn.Sequential(Conv2D(in_channels=in_channels,\n                                            out_channels=in_channels,\n                                            kernel_size=1, stride=2),\n                                     nn.BatchNorm2d(in_channels))\n        elif upsample:\n            layers.append(BNReLUDeconv2D(in_channels=in_channels,\n                                         out_channels=num_hidden_channels,\n                                         kernel_size=2,\n                                         stride=2))\n            resample = nn.Sequential(Deconv2D(in_channels=in_channels,\n                                              out_channels=in_channels,\n                                              kernel_size=2, stride=2),\n                                     nn.BatchNorm2d(in_channels))\n        else:\n            layers.append(BNReLUConv2D(in_channels=in_channels,\n                                       out_channels=num_hidden_channels,\n                                       kernel_size=3))\n            resample = None\n        layers.append(BNReLUConv2D(in_channels=num_hidden_channels,\n                                   out_channels=in_channels,\n                                   kernel_size=3))\n        super(PreActSimpleResidualBlock, self).__init__(layers, resample)\n\n\n# TODO PreActBottleneckResidualBlock\n"
  },
  {
    "path": "inferno/extensions/layers/device.py",
    "content": "import torch.nn as nn\nfrom ...utils.python_utils import from_iterable, to_iterable\nfrom ...utils.exceptions import assert_, DeviceError\n\n__all__ = ['DeviceTransfer', 'OnDevice']\n_all = __all__\n\n\nclass DeviceTransfer(nn.Module):\n    \"\"\"Layer to transfer variables to a specified device.\"\"\"\n    def __init__(self, target_device, device_ordinal=None, asynchronous=False):\n        \"\"\"\n        Parameters\n        ----------\n        target_device : {'cpu', 'cuda'}\n            Device to transfer to.\n        device_ordinal : int\n            Device ordinal if target_device == 'cuda'.\n        asynchronous : bool\n            Whether to use asynchronous transfers.\n        \"\"\"\n        super(DeviceTransfer, self).__init__()\n        # Validate arguments\n        assert_(target_device in ['cpu', 'cuda'],\n                \"Target device must either be 'cpu' or 'cuda'.\",\n                DeviceError)\n        if target_device == 'cpu':\n            assert_(device_ordinal is None,\n                    \"'device_ordinal' must be None if target_device is 'cpu'.\",\n                    DeviceError)\n        self.target_device = target_device\n        self.device_ordinal = device_ordinal\n\n    def forward(self, *inputs):\n        if self.target_device == 'cuda':\n            transferred = tuple(input_.cuda(device=self.device_ordinal,\n                                            non_blocking=self.asynchronous)\n                                for input_ in inputs)\n        elif self.target_device == 'cpu':\n            transferred = tuple(input_.cpu() for input_ in inputs)\n        else:\n            raise NotImplementedError\n        return from_iterable(transferred)\n\n\nclass OnDevice(nn.Module):\n    \"\"\"\n    Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is\n    that the inputs are transferred to the same device as the module, enabling easy model\n    parallelism.\n    \"\"\"\n    def __init__(self, module, target_device, device_ordinal=None, asynchronous=False):\n        \"\"\"\n        Parameters\n        ----------\n        module : torch.nn.Module\n            Module to transfer to device.\n        target_device : {'cuda', 'cpu'}\n            The device to move `module` to. Must be either 'cuda' or 'cpu'.\n        device_ordinal : int\n            Ordinal of the GPU device if `target_device = 'cuda'`.\n        asynchronous : bool\n            Whether to use asynchronous transfers.\n        \"\"\"\n        super(OnDevice, self).__init__()\n        # Validate arguments\n        assert_(target_device in ['cpu', 'cuda'],\n                \"Target device must either be 'cpu' or 'cuda'.\",\n                DeviceError)\n        if target_device == 'cpu':\n            assert_(device_ordinal is None,\n                    \"'device_ordinal' must be None if target_device is 'cpu'.\",\n                    DeviceError)\n        self.target_device = target_device\n        self.device_ordinal = device_ordinal\n        self.asynchronous = asynchronous\n        # This is a no-op if module is already in the right device\n        self.device_transfer = DeviceTransfer(self.target_device,\n                                              device_ordinal=self.device_ordinal,\n                                              asynchronous=self.asynchronous)\n\n        self.module = self.transfer_module(module)\n\n    def transfer_module(self, module):\n        if self.target_device == 'cuda':\n            return module.cuda(device_id=self.device_ordinal)\n        elif self.target_device == 'cpu':\n            return module.cpu()\n        else:\n            raise NotImplementedError\n\n    def forward(self, *inputs):\n        # Transfer inputs (no-op if they're already on the right device)\n        transferred = to_iterable(self.device_transfer(*inputs))\n        output = self.module(*transferred)\n        return output\n"
  },
  {
    "path": "inferno/extensions/layers/identity.py",
    "content": "import torch.nn as nn\n__all__ = ['identity']\n_all = __all__\n\nclass Identity(nn.Module):  \n    def __init__(self):\n        super(Identity, self).__init__()\n\n    def forward(self, x):\n        return x"
  },
  {
    "path": "inferno/extensions/layers/normalization.py",
    "content": "import torch.nn as nn\n\n\nclass BatchNormND(nn.Module):\n    def __init__(self, dim, num_features, \n                 eps=1e-5, momentum=0.1, \n                 affine=True,track_running_stats=True):\n        super(BatchNormND, self).__init__()\n        assert dim in [1, 2, 3]\n        self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features,\n            eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)\n\n    def forward(self, x):\n        return self.bn(x)"
  },
  {
    "path": "inferno/extensions/layers/reshape.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ...utils.exceptions import assert_, ShapeError\nfrom ...utils import python_utils as pyu\n\n\n__all__ = ['View', 'AsMatrix', 'Flatten',\n           'As3D', 'As2D',\n           'Concatenate', 'Cat',\n           'ResizeAndConcatenate', 'PoolCat',\n           'GlobalMeanPooling', 'GlobalMaxPooling',\n           'Sum', 'SplitChannels','Squeeze', 'RemoveSingletonDimension']\n_all = __all__\n\nclass View(nn.Module):\n    def __init__(self, as_shape):\n        super(View, self).__init__()\n        self.as_shape = self.validate_as_shape(as_shape)\n\n    def validate_as_shape(self, as_shape):\n        assert all([isinstance(_s, int) or _s == 'x' for _s in as_shape])\n\n        all_int_indices = [_n for _n, _s in enumerate(as_shape) if isinstance(_s, int)]\n        if all_int_indices:\n            first_int_at_index = all_int_indices[0]\n            assert all([isinstance(_s, int) for _s in as_shape[first_int_at_index:]])\n        return as_shape\n\n    def forward(self, input):\n        input_shape = list(input.size())\n        reshaped_shape = [_s if isinstance(_s, int) else input_shape[_n]\n                          for _n, _s in enumerate(self.as_shape)]\n        output = input.view(*reshaped_shape)\n        return output\n\n\nclass AsMatrix(View):\n    def __init__(self):\n        super(AsMatrix, self).__init__(as_shape=['x', 'x'])\n\n\nclass Flatten(View):\n    def __init__(self):\n        super(Flatten, self).__init__(as_shape=['x', -1])\n\n\nclass As3D(nn.Module):\n    def __init__(self, channel_as_z=False, num_channels_or_num_z_slices=1):\n        super(As3D, self).__init__()\n        self.channel_as_z = channel_as_z\n        self.num_channels_or_num_z_slices = num_channels_or_num_z_slices\n\n    def forward(self, input):\n        if input.dim() == 5:\n            # If input is a batch of 3D volumes - return as is\n            return input\n        elif input.dim() == 4:\n            # If input is a batch of 2D images, reshape\n            b, c, _0, _1 = list(input.size())\n            assert_(c % self.num_channels_or_num_z_slices == 0,\n                    \"Number of channels of the 4D image tensor (= {}) must be \"\n                    \"divisible by the set number of channels or number of z slices \"\n                    \"of the 5D volume tensor (= {}).\"\n                    .format(c, self.num_channels_or_num_z_slices),\n                    ShapeError)\n            c //= self.num_channels_or_num_z_slices\n            if self.channel_as_z:\n                # Move channel axis to z\n                return input.view(b, self.num_channels_or_num_z_slices, c, _0, _1)\n            else:\n                # Keep channel axis where it is, but add a singleton dimension for z\n                return input.view(b, c, self.num_channels_or_num_z_slices, _0, _1)\n        elif input.dim() == 2:\n            # We have a matrix which we wish to turn to a 3D batch\n            b, c = list(input.size())\n            return input.view(b, c, 1, 1, 1)\n        else:\n            raise NotImplementedError\n\n\nclass As2D(nn.Module):\n    def __init__(self, z_as_channel=True):\n        super(As2D, self).__init__()\n        self.z_as_channel = z_as_channel\n\n    def forward(self, input):\n        if input.dim() == 5:\n            b, c, _0, _1, _2 = list(input.size())\n            if not self.z_as_channel:\n                assert _0 == 1\n            # Reshape\n            return input.view(b, c * _0, _1, _2)\n        elif input.dim() == 4:\n            # Nothing to do here - input is already 2D\n            return input\n        elif input.dim() == 2:\n            # We make singleton dimensions\n            b, c = list(input.size())\n            return input.view(b, c, 1, 1)\n\n\nclass Concatenate(nn.Module):\n    \"\"\"Concatenate input tensors along a specified dimension.\"\"\"\n    def __init__(self, dim=1):\n        super(Concatenate, self).__init__()\n        self.dim = dim\n\n    def forward(self, *inputs):\n        return torch.cat(inputs, dim=self.dim)\n\n\nclass ResizeAndConcatenate(nn.Module):\n    \"\"\"\n    Resize input tensors spatially (to a specified target size) before concatenating\n    them along the a given dim (channel, i.e. 1 by default). The down-sampling mode can\n    be specified ('average' or 'max'), but the up-sampling is always 'nearest'.\n    \"\"\"\n\n    POOL_MODE_MAPPING = {'avg': 'avg',\n                         'average': 'avg',\n                         'mean': 'avg',\n                         'max': 'max'}\n\n    def __init__(self, target_size, pool_mode='average', dim=1):\n        super(ResizeAndConcatenate, self).__init__()\n        self.target_size = target_size\n        assert_(pool_mode in self.POOL_MODE_MAPPING.keys(),\n                \"`pool_mode` must be one of {}, got {} instead.\"\n                .format(self.POOL_MODE_MAPPING.keys(), pool_mode),\n                ValueError)\n        self.pool_mode = self.POOL_MODE_MAPPING.get(pool_mode)\n        self.dim = dim\n\n    def forward(self, *inputs):\n        dim = inputs[0].dim()\n        assert_(dim in [4, 5],\n                'Input tensors must either be 4 or 5 '\n                'dimensional, but inputs[0] is {}D.'.format(dim),\n                ShapeError)\n        # Get resize function\n        spatial_dim = {4: 2, 5: 3}[dim]\n        resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode,\n                                                                  spatial_dim))\n        target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim)\n        # Do the resizing\n        resized_inputs = []\n        for input_num, input in enumerate(inputs):\n            # Make sure the dim checks out\n            assert_(input.dim() == dim,\n                    \"Expected inputs[{}] to be a {}D tensor, got a {}D \"\n                    \"tensor instead.\".format(input_num, dim, input.dim()),\n                    ShapeError)\n            resized_inputs.append(resize_function(input, target_size))\n        # Concatenate along the channel axis\n        if len(resized_inputs) > 1:\n            concatenated = torch.cat(tuple(resized_inputs), self.dim)\n        else:\n            concatenated = resized_inputs[0]\n        # Done\n        return concatenated\n\n\nclass Cat(Concatenate):\n    \"\"\"An alias for `Concatenate`. Hey, everyone knows who Cat is.\"\"\"\n    pass\n\n\nclass PoolCat(ResizeAndConcatenate):\n    \"\"\"Alias for `ResizeAndConcatenate`, just to annoy snarky web developers.\"\"\"\n    pass\n\n\nclass GlobalMeanPooling(ResizeAndConcatenate):\n    \"\"\"Global mean pooling layer.\"\"\"\n    def __init__(self):\n        super(GlobalMeanPooling, self).__init__((1, 1), 'average')\n\n\nclass GlobalMaxPooling(ResizeAndConcatenate):\n    \"\"\"Global max pooling layer.\"\"\"\n    def __init__(self):\n        super(GlobalMaxPooling, self).__init__((1, 1), 'max')\n\n\nclass Sum(nn.Module):\n    \"\"\"Sum all inputs.\"\"\"\n    def forward(self, *inputs):\n        return torch.stack(inputs, dim=0).sum(0)\n\n\nclass SplitChannels(nn.Module):\n    \"\"\"Split input at a given index along the channel axis.\"\"\"\n    def __init__(self, channel_index):\n        super(SplitChannels, self).__init__()\n        self.channel_index = channel_index\n\n    def forward(self, input):\n        if isinstance(self.channel_index, int):\n            split_location = self.channel_index\n        elif self.channel_index == 'half':\n            split_location = input.size(1) // 2\n        else:\n            raise NotImplementedError\n        assert split_location < input.size(1)\n        split_0 = input[:, 0:split_location, ...]\n        split_1 = input[:, split_location:, ...]\n        return split_0, split_1\n\n\n\nclass Squeeze(nn.Module):\n    def __init__(self):\n        super(Squeeze, self).__init__()\n    def  forward(self, x):\n        return x.squeeze()\n\nclass RemoveSingletonDimension(nn.Module):\n    def __init__(self, dim=1):\n        super(RemoveSingletonDimension, self).__init__()\n        self.dim = 1\n    def  forward(self, x):\n        size = list(x.size())\n        if size[self.dim] != 1:\n            raise RuntimeError(\"RemoveSingletonDimension expects a single channel at dim %d, shape=%s\"%(self.dim,str(size)))\n\n        slicing = []\n        for s in size:\n            slicing.append(slice(0, s))\n\n        slicing[self.dim] = 0\n\n        return x[slicing]"
  },
  {
    "path": "inferno/extensions/layers/sampling.py",
    "content": "import torch.nn as nn\n\n__all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'AnisotropicPool2D']\n\n\n# torch is deprecating nn.Upsample in favor of nn.functional.interpolate\n# we wrap interpolate here to still use Upsample as class\nclass Upsample(nn.Module):\n    def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):\n        self.size = size\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n        super(Upsample, self).__init__()\n        # interpolate was only introduced in torch 0.4.1 for backward compatibility\n        # we check if we have the attribute here and fall back to Upsample otherwise\n        if hasattr(nn.functional, 'interpolate'):\n            self.have_interpolate = True\n        else:\n            self.have_interpolate = False\n            self.sampler = nn.Upsample(size=size, scale_factor=scale_factor,\n                                       mode=mode, align_corners=align_corners)\n\n    def forward(self, input):\n        if self.have_interpolate:\n            return nn.functional.interpolate(input, self.size, self.scale_factor,\n                                             self.mode, self.align_corners)\n        else:\n            return self.sampler(input)\n\n\nclass AnisotropicUpsample(nn.Module):\n    def __init__(self, scale_factor):\n        super(AnisotropicUpsample, self).__init__()\n        self.upsampler = Upsample(scale_factor=scale_factor)\n\n    def forward(self, input):\n        # input is 3D of shape NCDHW\n        N, C, D, H, W = input.size()\n        # Fold C and D axes in one\n        folded = input.view(N, C * D, H, W)\n        # Upsample\n        upsampled = self.upsampler(folded)\n        # Unfold out the C and D axes\n        unfolded = upsampled.view(N, C, D,\n                                  self.upsampler.scale_factor * H,\n                                  self.upsampler.scale_factor * W)\n        # Done\n        return unfolded\n\n\nclass AnisotropicPool(nn.MaxPool3d):\n    def __init__(self, downscale_factor):\n        ds = downscale_factor\n        super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1),\n                                              stride=(1, ds, ds),\n                                              padding=(0, 1, 1))\n\nclass AnisotropicUpsample2D(nn.Module):\n    def __init__(self, scale_factor):\n        super(AnisotropicUpsample2D, self).__init__()\n        self.upsampler = nn.Upsample(scale_factor=scale_factor)\n\n    def forward(self, input):\n        # input is 2D of shape NCDW (or NCDH, egal)\n        N, C, D, W = input.size()\n        # Fold C and D axes in one\n        folded = input.view(N, C * D, W)\n        # Upsample\n        upsampled = self.upsampler(folded)\n        # Unfold out the C and D axes\n        unfolded = upsampled.view(N, C, D,\n                                  self.upsampler.scale_factor * W)\n        # Done\n        return unfolded\n\n\nclass AnisotropicPool2D(nn.MaxPool2d):\n    def __init__(self, downscale_factor):\n        ds = downscale_factor\n        super(AnisotropicPool2D, self).__init__(kernel_size=(1, ds + 1),\n                                              stride=(1, ds),\n                                              padding=(0, 1))\n\n"
  },
  {
    "path": "inferno/extensions/metrics/__init__.py",
    "content": "from .categorical import *\nfrom .arand import *\n"
  },
  {
    "path": "inferno/extensions/metrics/arand.py",
    "content": "from .base import Metric\nimport numpy as np\nimport scipy.sparse as sparse\nimport logging\n\n\nclass ArandScore(Metric):\n    \"\"\"Arand Score, as defined in [1].\n\n    References\n    ----------\n    [1]: http://journal.frontiersin.org/article/10.3389/fnana.2015.00142/full#h3\n    \"\"\"\n    def __init__(self, average_slices=True):\n        self.average_slices = average_slices\n\n    # compute the arand score for a prediction target pair\n    def _arand_for_tensor(self, prediction, target):\n        # check if we need to average over slices\n        average_slices = self.average_slices and prediction.ndim == 3\n        score_is_invalid = False\n\n        # average the rand score over 3d slices\n        if average_slices:\n            # average the arand values over the 3d slices\n            evaluation_values = [adapted_rand(pred, targ) for pred, targ in zip(prediction, target)]\n            # check if the score is invalid\n            if all(ev_val is None for ev_val in evaluation_values):\n                score_is_invalid = True\n                score = 0\n            else:\n                score = np.mean([eval_val[0] for eval_val in evaluation_values if eval_val is not None])\n\n        # compute rand score on whole image / volume\n        else:\n            score = adapted_rand(prediction, target)\n            # check if the score is invalid\n            if score is None:\n                score_is_invalid = True\n                score = 0\n            else:\n                score = score[0]\n\n        if score_is_invalid:\n            logger = logging.getLogger(__name__)\n            logger.warning(\"All slices were invalid, returning worst possible score\")\n        return score\n\n    def forward(self, prediction, target):\n        assert(prediction.shape == target.shape), \"%s, %s\" % (str(prediction.shape),\n                                                              str(target.shape))\n        assert prediction.shape[1] == 1, \"Expect singleton channel axis\"\n        prediction = prediction.cpu().numpy()\n        target = target.cpu().numpy()\n\n        ndim = prediction.ndim\n        assert ndim in (4, 5), \"Expect 2 or 3d input with additional batch and channel axis\"\n\n        # return the average arand error over the batches\n        return np.mean([self._arand_for_tensor(pred[0], targ[0])\n                        for pred, targ in zip(prediction, target)])\n\n\nclass ArandError(ArandScore):\n    \"\"\"Arand Error = 1 - <arand score>\"\"\"\n    def __init__(self, **super_kwargs):\n        super(ArandError, self).__init__(**super_kwargs)\n\n    def forward(self, prediction, target):\n        return 1. - super(ArandError, self).forward(prediction, target)\n\n\n# Evaluation code courtesy of Juan Nunez-Iglesias, taken from\n# https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py\ndef adapted_rand(seg, gt):\n    \"\"\"Compute Adapted Rand error as defined by the SNEMI3D contest [1]\n    Formula is given as 1 - the maximal F-score of the Rand index\n    (excluding the zero component of the original labels). Adapted\n    from the SNEMI3D MATLAB script, hence the strange style.\n\n    Parameters\n    ----------\n    seg : np.ndarray\n        the segmentation to score, where each value is the label at that point\n    gt : np.ndarray, same shape as seg\n        the groundtruth to score against, where each value is a label\n\n    Returns\n    -------\n    are : float\n        The adapted Rand error; equal to $1 - \\frac{2pr}{p + r}$,\n        where $p$ and $r$ are the precision and recall described below.\n    prec : float, optional\n        The adapted Rand precision.\n    rec : float, optional\n        The adapted Rand recall.\n\n    References\n    ----------\n    [1]: http://brainiac2.mit.edu/SNEMI3D/evaluation\n    \"\"\"\n    assert seg.shape == gt.shape, \"%s, %s\" % (str(seg.shape), str(gt.shape))\n    logger = logging.getLogger(__name__)\n\n    if np.any(seg == 0):\n        logger.debug(\"Zeros in segmentation, treating as background.\")\n    if np.any(gt == 0):\n        logger.debug(\"Zeros in ground truth, 0's will be ignored.\")\n\n    seg_zeros = np.all(seg == 0)\n    gt_zeros = np.all(gt == 0)\n    # return None if either gt or segmentation are all zeros\n    logger.debug(\"Either segmentation or groundtruth are all zeros, returning None.\")\n    if  seg_zeros or gt_zeros:\n        return None\n\n    # segA is truth, segB is query\n    segA = np.ravel(gt)\n    segB = np.ravel(seg)\n\n    # mask to foreground in A\n    mask = (segA > 0)\n    segA = segA[mask]\n    segB = segB[mask]\n\n    # number of nonzero pixels in original segA\n    n = segA.size\n    n_labels_A = int(np.amax(segA)) + 1\n    n_labels_B = int(np.amax(segB)) + 1\n\n    ones_data = np.ones(n)\n    p_ij = sparse.csr_matrix((ones_data, (segA.ravel(), segB.ravel())),\n                             shape=(n_labels_A, n_labels_B),\n                             dtype=np.uint64)\n\n    # In the paper where adapted rand is proposed, they treat each background\n    # pixel in segB as a different value (i.e., unique label for each pixel).\n    # To do this, we sum them differently than others\n\n    # ind (label_gt, label_seg), so ignore 0 seg labels\n    B_nonzero = p_ij[:, 1:]\n    B_zero = p_ij[:, 0]\n\n    # this is a count\n    num_B_zero = B_zero.sum()\n\n    # sum of the joint distribution\n    #   separate sum of B>0 and B=0 parts\n    sum_p_ij = (B_nonzero).power(2).sum() + num_B_zero\n\n    # these are marginal probabilities\n    # sum over all seg labels overlapping one gt label (except 0 labels)\n    a_i = p_ij.sum(1)\n    b_i = B_nonzero.sum(0)\n\n    sum_a = np.power(a_i, 2).sum()\n    sum_b = np.power(b_i, 2).sum() + num_B_zero\n\n    precision = float(sum_p_ij) / sum_b\n    recall = float(sum_p_ij) / sum_a\n    f_score = 2.0 * precision * recall / (precision + recall)\n    return f_score, precision, recall\n"
  },
  {
    "path": "inferno/extensions/metrics/base.py",
    "content": "\n\nclass Metric(object):\n\n    def forward(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def __call__(self, prediction, target, **kwargs):\n        # We might have listlike predictions (e.g. multi-scale)\n        # If so, we evaluate the metric on the first prediction,\n        # which should be at the original scale\n        if isinstance(prediction, (list, tuple)):\n            prediction = prediction[0]\n        # same is true for the target\n        if isinstance(target, (list, tuple)):\n            target = target[0]\n        # Make sure prediction and target live on the same device.\n        # If they don't, move target to the right device.\n        if not prediction.is_cuda:\n            # Move to CPU\n            target = target.cpu()\n        else:\n            # Find device to move to\n            device_ordinal = prediction.get_device()\n            target = target.cuda(device_ordinal)\n        return self.forward(prediction, target, **kwargs)\n"
  },
  {
    "path": "inferno/extensions/metrics/categorical.py",
    "content": "import torch\nfrom .base import Metric\nfrom ...utils.torch_utils import flatten_samples, is_label_tensor\nfrom ...utils.exceptions import assert_, DTypeError, ShapeError\n\n\nclass CategoricalError(Metric):\n    \"\"\"Categorical error.\"\"\"\n    def __init__(self, aggregation_mode='mean'):\n        assert aggregation_mode in ['mean', 'sum']\n        self.aggregation_mode = aggregation_mode\n\n    def forward(self, prediction, target):\n        # Check if prediction is binary or not\n        is_binary = len(prediction.size()) == 1 or prediction.size(1) == 1\n\n        if len(target.size()) > 1:\n            target = target.squeeze(1)\n        assert len(target.size()) == 1\n\n        if is_binary:\n            # Binary classification\n            prediction = prediction > 0.5\n            incorrect = prediction.type_as(target).ne(target).float()\n            if self.aggregation_mode == 'mean':\n                return incorrect.mean()\n            else:\n                return incorrect.sum()\n        else:\n            # Multiclass classificiation\n            _, predicted_class = torch.max(prediction, 1)\n            if predicted_class.dim() == prediction.dim():\n                # Support for Pytorch 0.1.12\n                predicted_class = predicted_class.squeeze(1)\n            incorrect = predicted_class.type_as(target).ne(target).float()\n            if self.aggregation_mode == 'mean':\n                return incorrect.mean()\n            else:\n                return incorrect.sum()\n\n\nclass IOU(Metric):\n    \"\"\"Intersection over Union. \"\"\"\n    def __init__(self, ignore_class=None, sharpen_prediction=False, eps=1e-6):\n        super(IOU, self).__init__()\n        self.eps = eps\n        self.ignore_class = ignore_class\n        self.sharpen_prediction = sharpen_prediction\n\n    def forward(self, prediction, target):\n        # Assume that is one of:\n        #   prediction.shape = (N, C, H, W)\n        #   prediction.shape = (N, C, D, H, W)\n        #   prediction.shape = (N, C)\n        # The corresponding target shapes are either:\n        #   target.shape = (N, H, W)\n        #   target.shape = (N, D, H, W)\n        #   target.shape = (N,)\n        # Or:\n        #   target.shape = (N, C, H, W)\n        #   target.shape = (N, C, D, H, W)\n        #   target.shape = (N, C)\n        # First, reshape prediction to (C, -1)\n        flattened_prediction = flatten_samples(prediction)\n        # Take measurements\n        num_classes, num_samples = flattened_prediction.size()\n        # We need to figure out if the target is a int label tensor or a onehot tensor.\n        # The former always has one dimension less, so\n        if target.dim() == (prediction.dim() - 1):\n            # Labels, we need to go one hot\n            # Make sure it's a label\n            assert_(is_label_tensor(target),\n                    \"Target must be a label tensor (of dtype long) if it has one \"\n                    \"dimension less than the prediction.\",\n                    DTypeError)\n            # Reshape target to (1, -1) for it to work with scatter\n            flattened_target = target.view(1, -1)\n            # Convert target to onehot with shape (C, -1)\n            # Make sure the target is consistent\n            assert_(target.max() < num_classes)\n            onehot_targets = flattened_prediction \\\n                .new(num_classes, num_samples) \\\n                .zero_() \\\n                .scatter_(0, flattened_target, 1)\n        elif target.dim() == prediction.dim():\n            # Onehot, nothing to do except flatten\n            onehot_targets = flatten_samples(target)\n        else:\n            raise ShapeError(\"Target must have the same number of dimensions as the \"\n                             \"prediction, or one less. Got target.dim() = {} but \"\n                             \"prediction.dim() = {}.\".format(target.dim(), prediction.dim()))\n        # Cast onehot_targets to float if required (this is a no-op if it's already float)\n        onehot_targets = onehot_targets.float()\n        # Sharpen prediction if required to. Sharpening in this sense means to replace\n        # the max predicted probability with 1.\n        if self.sharpen_prediction:\n            _, predicted_classes = torch.max(flattened_prediction, 0)\n            # Case for pytorch 0.2, where predicted_classes is (N,) instead of (1, N)\n            if predicted_classes.dim() == 1:\n                predicted_classes = predicted_classes.view(1, -1)\n            # Scatter\n            flattened_prediction = flattened_prediction\\\n                .new(num_classes, num_samples).zero_().scatter_(0, predicted_classes, 1)\n        # Now to compute the IOU = (a * b).sum()/(a**2 + b**2 - a * b).sum()\n        # We sum over all samples to obtain a classwise iou\n        numerator = (flattened_prediction * onehot_targets).sum(-1)\n        denominator = \\\n            flattened_prediction.sub_(onehot_targets).pow_(2).clamp_(min=self.eps).sum(-1) + \\\n            numerator\n        classwise_iou = numerator.div_(denominator)\n        # If we're ignoring a class, don't count its contribution to the mean\n        if self.ignore_class is not None:\n            ignore_class = self.ignore_class \\\n                if self.ignore_class != -1 else onehot_targets.size(0) - 1\n            assert_(ignore_class < onehot_targets.size(0),\n                    \"`ignore_class` = {} must be at least one less than the number \"\n                    \"of classes = {}.\".format(ignore_class, onehot_targets.size(0)),\n                    ValueError)\n            num_classes = onehot_targets.size(0)\n            dont_ignore_class = list(range(num_classes))\n            dont_ignore_class.pop(ignore_class)\n            if classwise_iou.is_cuda:\n                dont_ignore_class = \\\n                    torch.LongTensor(dont_ignore_class).cuda(classwise_iou.get_device())\n            else:\n                dont_ignore_class = torch.LongTensor(dont_ignore_class)\n            iou = classwise_iou[dont_ignore_class].mean()\n        else:\n            iou = classwise_iou.mean()\n        return iou\n\n\nclass NegativeIOU(IOU):\n    def forward(self, prediction, target):\n        return -1 * super(NegativeIOU, self).forward(prediction, target)\n"
  },
  {
    "path": "inferno/extensions/metrics/cremi_score.py",
    "content": "import numpy as np\nfrom .voi import voi\nfrom .arand import adapted_rand\n\n\n# TODO build metrics object\n\n\ndef cremi_metrics(seg, gt, no_seg_ignore=True):\n    if no_seg_ignore:\n        if 0  in seg:\n            seg += 1\n    vi_s, vi_m = voi(seg, gt)\n    rand = 1. - adapted_rand(seg, gt)[0]\n    cs = np.sqrt((vi_s + vi_m) * rand)\n    return cs, vi_s, vi_m, rand\n"
  },
  {
    "path": "inferno/extensions/metrics/voi.py",
    "content": "from .base import Metric\n\nimport numpy as np\nimport scipy.sparse as sparse\n\n\nclass VoiScore(Metric):\n    \"\"\"\n    Computes a score based on the variation of information according to [1].\n    References\n    ----------\n    [1] Meila, M. (2007). Comparing clusterings - an information based\n    distance. Journal of Multivariate Analysis 98, 873-895.\n    \"\"\"\n    def forward(self, prediction, target):\n        assert(len(prediction) == len(target))\n        segmentation = prediction.cpu().numpy()\n        target = target.cpu().numpy()\n        return np.mean([sum(voi(segmentation[i], target[i]))\n                        for i in range(len(prediction))])\n\n\n# Copied from `cremi-python`\n# https://github.com/cremi/cremi_python/blob/master/cremi/evaluation/voi.py\n\n# Evaluation code courtesy of Juan Nunez-Iglesias, taken from\n# https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py\n\ndef voi(seg, gt, ignore_reconstruction=[], ignore_groundtruth=[0]):\n    \"\"\"Return the conditional entropies of the variation of information metric. [1]\n\n    Let X be a seg, and Y a ground truth labelling. The variation of\n    information between the two is the sum of two conditional entropies:\n\n        VI(X, Y) = H(X|Y) + H(Y|X).\n\n    The first one, H(X|Y), is a measure of oversegmentation, the second one,\n    H(Y|X), a measure of undersegmentation. These measures are referred to as\n    the variation of information split or merge error, respectively.\n\n    Parameters\n    ----------\n    seg : np.ndarray, int type, arbitrary shape\n        A candidate segmentation.\n    gt : np.ndarray, int type, same shape as `seg`\n        The ground truth segmentation.\n    ignore_seg, ignore_gt : list of int, optional\n        Any points having a label in this list are ignored in the evaluation.\n        By default, only the label 0 in the ground truth will be ignored.\n\n    Returns\n    -------\n    (split, merge) : float\n        The variation of information split and merge error, i.e., H(X|Y) and H(Y|X)\n\n    References\n    ----------\n    [1] Meila, M. (2007). Comparing clusterings - an information based\n    distance. Journal of Multivariate Analysis 98, 873-895.\n    \"\"\"\n    hyxg, hxgy = split_vi(seg, gt, ignore_reconstruction, ignore_groundtruth)\n    return hxgy, hyxg\n\n\ndef split_vi(x, y=None, ignore_x=[0], ignore_y=[0]):\n    \"\"\"Return the symmetric conditional entropies associated with the VI.\n\n    The variation of information is defined as VI(X,Y) = H(X|Y) + H(Y|X).\n    If Y is the ground-truth segmentation, then H(Y|X) can be interpreted\n    as the amount of under-segmentation of Y and H(X|Y) is then the amount\n    of over-segmentation.  In other words, a perfect over-segmentation\n    will have H(Y|X)=0 and a perfect under-segmentation will have H(X|Y)=0.\n\n    If y is None, x is assumed to be a contingency table.\n\n    Parameters\n    ----------\n    x : np.ndarray\n        Label field (int type) or contingency table (float). `x` is\n        interpreted as a contingency table (summing to 1.0) if and only if `y`\n        is not provided.\n    y : np.ndarray of int, same shape as x, optional\n        A label field to compare to `x`.\n    ignore_x, ignore_y : list of int, optional\n        Any points having a label in this list are ignored in the evaluation.\n        Ignore 0-labeled points by default.\n\n    Returns\n    -------\n    sv : np.ndarray of float, shape (2,)\n        The conditional entropies of Y|X and X|Y.\n\n    See Also\n    --------\n    vi\n    \"\"\"\n    _, _, _, hxgy, hygx, _, _ = vi_tables(x, y, ignore_x, ignore_y)\n    # false merges, false splits\n    return np.array([hygx.sum(), hxgy.sum()])\n\n\ndef vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]):\n    \"\"\"Return probability tables used for calculating VI.\n\n    If y is None, x is assumed to be a contingency table.\n\n    Parameters\n    ----------\n    x, y : np.ndarray\n        Either x and y are provided as equal-shaped np.ndarray label fields\n        (int type), or y is not provided and x is a contingency table\n        (sparse.csc_matrix) that may or may not sum to 1.\n    ignore_x, ignore_y : list of int, optional\n        Rows and columns (respectively) to ignore in the contingency table.\n        These are labels that are not counted when evaluating VI.\n\n    Returns\n    -------\n    pxy : sparse.csc_matrix of float\n        The normalized contingency table.\n    px, py, hxgy, hygx, lpygx, lpxgy : np.ndarray of float\n        The proportions of each label in `x` and `y` (`px`, `py`), the\n        per-segment conditional entropies of `x` given `y` and vice-versa, the\n        per-segment conditional probability p log p.\n    \"\"\"\n    if y is not None:\n        pxy = contingency_table(x, y, ignore_x, ignore_y)\n    else:\n        cont = x\n        total = float(cont.sum())\n        # normalize, since it is an identity op if already done\n        pxy = cont / total\n\n    # Calculate probabilities\n    px = np.array(pxy.sum(axis=1)).ravel()\n    py = np.array(pxy.sum(axis=0)).ravel()\n    # Remove zero rows/cols\n    nzx = px.nonzero()[0]\n    nzy = py.nonzero()[0]\n    nzpx = px[nzx]\n    nzpy = py[nzy]\n    nzpxy = pxy[nzx, :][:, nzy]\n\n    # Calculate log conditional probabilities and entropies\n    lpygx = np.zeros(np.shape(px))\n    lpygx[nzx] = xlogx(divide_rows(nzpxy, nzpx)).sum(axis=1).squeeze()  # \\sum_x{p_{y|x} \\log{p_{y|x}}}\n    hygx = -(px * lpygx)  # \\sum_x{p_x H(Y|X=x)} = H(Y|X)\n\n    lpxgy = np.zeros(np.shape(py))\n    lpxgy[nzy] = xlogx(divide_columns(nzpxy, nzpy)).sum(axis=0)\n    hxgy = -(py * lpxgy)\n\n    return [pxy] + list(map(np.asarray, [px, py, hxgy, hygx, lpygx, lpxgy]))\n\n\ndef contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True):\n    \"\"\"Return the contingency table for all regions in matched segmentations.\n\n    Parameters\n    ----------\n    seg : np.ndarray, int type, arbitrary shape\n        A candidate segmentation.\n    gt : np.ndarray, int type, same shape as `seg`\n        The ground truth segmentation.\n    ignore_seg : list of int, optional\n        Values to ignore in `seg`. Voxels in `seg` having a value in this list\n        will not contribute to the contingency table. (default: [0])\n    ignore_gt : list of int, optional\n        Values to ignore in `gt`. Voxels in `gt` having a value in this list\n        will not contribute to the contingency table. (default: [0])\n    norm : bool, optional\n        Whether to normalize the table so that it sums to 1.\n\n    Returns\n    -------\n    cont : scipy.sparse.csc_matrix\n        A contingency table. `cont[i, j]` will equal the number of voxels\n        labeled `i` in `seg` and `j` in `gt`. (Or the proportion of such voxels\n        if `norm=True`.)\n    \"\"\"\n    segr = seg.ravel()\n    gtr = gt.ravel()\n    ignored = np.zeros(segr.shape, np.bool)\n    data = np.ones(len(gtr))\n    for i in ignore_seg:\n        ignored[segr == i] = True\n    for j in ignore_gt:\n        ignored[gtr == j] = True\n    data[ignored] = 0\n    cont = sparse.coo_matrix((data, (segr, gtr))).tocsc()\n    if norm:\n        cont /= float(cont.sum())\n    return cont\n\n\ndef divide_columns(matrix, row, in_place=False):\n    \"\"\"Divide each column of `matrix` by the corresponding element in `row`.\n\n    The result is as follows: out[i, j] = matrix[i, j] / row[j]\n\n    Parameters\n    ----------\n    matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)\n        The input matrix.\n    column : a 1D np.ndarray, shape (N,)\n        The row dividing `matrix`.\n    in_place : bool (optional, default False)\n        Do the computation in-place.\n\n    Returns\n    -------\n    out : same type as `matrix`\n        The result of the row-wise division.\n    \"\"\"\n    if in_place:\n        out = matrix\n    else:\n        out = matrix.copy()\n    if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:\n        if type(out) == sparse.csc_matrix:\n            convert_to_csc = True\n            out = out.tocsr()\n        else:\n            convert_to_csc = False\n        row_repeated = np.take(row, out.indices)\n        nz = out.data.nonzero()\n        out.data[nz] /= row_repeated[nz]\n        if convert_to_csc:\n            out = out.tocsc()\n    else:\n        out /= row[np.newaxis, :]\n    return out\n\n\ndef divide_rows(matrix, column, in_place=False):\n    \"\"\"Divide each row of `matrix` by the corresponding element in `column`.\n\n    The result is as follows: out[i, j] = matrix[i, j] / column[i]\n\n    Parameters\n    ----------\n    matrix : np.ndarray, scipy.sparse.csc_matrix or csr_matrix, shape (M, N)\n        The input matrix.\n    column : a 1D np.ndarray, shape (M,)\n        The column dividing `matrix`.\n    in_place : bool (optional, default False)\n        Do the computation in-place.\n\n    Returns\n    -------\n    out : same type as `matrix`\n        The result of the row-wise division.\n    \"\"\"\n    if in_place:\n        out = matrix\n    else:\n        out = matrix.copy()\n    if type(out) in [sparse.csc_matrix, sparse.csr_matrix]:\n        if type(out) == sparse.csr_matrix:\n            convert_to_csr = True\n            out = out.tocsc()\n        else:\n            convert_to_csr = False\n        column_repeated = np.take(column, out.indices)\n        nz = out.data.nonzero()\n        out.data[nz] /= column_repeated[nz]\n        if convert_to_csr:\n            out = out.tocsr()\n    else:\n        out /= column[:, np.newaxis]\n    return out\n\n\ndef xlogx(x, out=None, in_place=False):\n    \"\"\"Compute x * log_2(x).\n\n    We define 0 * log_2(0) = 0\n\n    Parameters\n    ----------\n    x : np.ndarray or scipy.sparse.csc_matrix or csr_matrix\n        The input array.\n    out : same type as x (optional)\n        If provided, use this array/matrix for the result.\n    in_place : bool (optional, default False)\n        Operate directly on x.\n\n    Returns\n    -------\n    y : same type as x\n        Result of x * log_2(x).\n    \"\"\"\n    if in_place:\n        y = x\n    elif out is None:\n        y = x.copy()\n    else:\n        y = out\n    if type(y) in [sparse.csc_matrix, sparse.csr_matrix]:\n        z = y.data\n    else:\n        z = y\n    nz = z.nonzero()\n    z[nz] *= np.log2(z[nz])\n    return y\n"
  },
  {
    "path": "inferno/extensions/models/__init__.py",
    "content": "from .unet import UNet, UNetBase\nfrom .res_unet import ResBlockUNet\n"
  },
  {
    "path": "inferno/extensions/models/res_unet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom ..layers.convolutional import ConvActivation\nfrom .unet import UNetBase\nfrom ...utils.python_utils import require_dict_kwargs\n\n__all__ = ['ResBlockUNet']\n_all = __all__\n\n\n\n# We only use this for the u-net implementation here\n# in favor of less code duplication it might be a\n# good ideat to replace this with 'ResidualBlock' from layers.convolutional_blocks\nclass _ResBlockBase(nn.Module):\n    def __init__(self, in_channels, out_channels, dim,\n                 size=2, force_skip_op=False, activated=True):\n        super(_ResBlockBase, self).__init__()\n        self.in_channels = int(in_channels)\n        self.out_channels = int(out_channels)\n        self.size = int(size)\n        self.activated = bool(activated)\n        self.force_skip_op = bool(force_skip_op)\n        self.dim = int(dim)\n\n        if self.in_channels != self.out_channels or self.force_skip_op:\n            self.activated_skip_op = self.activated_skip_op_factory(in_channels=self.in_channels,\n                                                                    out_channels=self.out_channels)\n\n        conv_ops = []\n        activation_ops = []\n        for i in range(self.size):\n\n            # the convolutions\n            if i == 0:\n                op = self.nonactivated_conv_op_factory(in_channels=self.out_channels,\n                                                       out_channels=self.out_channels, index=i)\n            else:\n                op = self.nonactivated_conv_op_factory(in_channels=self.out_channels,\n                                                       out_channels=self.out_channels, index=i)\n            conv_ops.append(op)\n\n            # the activations\n            if i < self.size or self.activated:\n                activation_ops.append(self.activation_op_factory(index=i))\n\n        self.conv_ops = nn.ModuleList(conv_ops)\n        self.activation_ops = nn.ModuleList(activation_ops)\n\n    def activated_skip_op_factory(self, in_channels, out_channels):\n        raise NotImplementedError(\"activated_skip_op_factory need to be implemented by deriving class\")\n\n    def nonactivated_conv_op_factory(self, in_channels, out_channels, index):\n        raise NotImplementedError(\"conv_op_factory need to be implemented by deriving class\")\n\n    def activation_op_factory(self, index):\n        return nn.ReLU()\n\n    def forward(self, input):\n\n        if input.size(1) != self.in_channels:\n            raise RuntimeError(\"wrong number of channels: expected %d, got %d\"%\n                (self.in_channels, input.size(1)))\n\n        if input.dim() != self.dim + 2:\n            raise RuntimeError(\"wrong number of dim: expected %d, got %d\"%\n                (self.dim+2, input.dim()))\n\n        if self.in_channels != self.out_channels or self.force_skip_op:\n            skip_res = self.activated_skip_op(input)\n        else:\n            skip_res = input\n\n        assert skip_res.size(1) == self.out_channels\n\n        res = skip_res\n        for i in  range(self.size):\n            res = self.conv_ops[i](res)\n            assert res.size(1)  == self.out_channels\n            if i + 1 < self.size:\n                res = self.activation_ops[i](res)\n\n        non_activated = skip_res + res\n        if self.activated:\n            return self.activation_ops[-1](non_activated)\n        else:\n            return non_activated\n\n\nclass _ResBlock(_ResBlockBase):\n    def __init__(self, in_channels, out_channels, dim, size=2, activated=True,\n                 activation='ReLU', batchnorm=True, force_skip_op=False, conv_kwargs=None):\n\n        # trick to store  nn-module before call of super\n        # => we put it in a list\n        if isinstance(activation, str):\n            self.activation_op = [getattr(torch.nn, activation)()]\n        elif isinstance(activation, nn.Module):\n            self.activation_op = [activation]\n        else:\n            raise RuntimeError(\"activation must be a striong or a torch.nn.Module\")\n\n        # keywords for conv\n        if conv_kwargs is None:\n            conv_kwargs = dict(\n                 kernel_size=3, dim=dim, activation=None,\n                 stride=1, dilation=1, groups=None, depthwise=False, bias=True,\n                 deconv=False, initialization=None\n            )\n        elif isinstance(conv_kwargs, dict):\n            conv_kwargs['activation'] = None\n        else:\n            raise RuntimeError(\"conv_kwargs must be either None or a dict\")\n        self.conv_kwargs = conv_kwargs\n\n        self.dim = dim\n        self.batchnorm = batchnorm\n\n        self.conv_1x1_kwargs = dict(kernel_size=1, dim=dim, activation=None,\n                                    stride=1, dilation=1, groups=None,\n                                    depthwise=False, bias=True, deconv=False,\n                                    initialization=None)\n\n        super(_ResBlock, self).__init__(in_channels=in_channels,\n                                        out_channels=out_channels,\n                                        dim=dim, size=size,\n                                        force_skip_op=force_skip_op,\n                                        activated=activated)\n\n    def activated_skip_op_factory(self, in_channels, out_channels):\n        conv_op = ConvActivation(in_channels=in_channels,\n                                 out_channels=out_channels, **self.conv_1x1_kwargs)\n        if self.batchnorm:\n            batchnorm_op = self.batchnorm_op_factory(in_channels=out_channels)\n            return torch.nn.Sequential(conv_op, batchnorm_op, self.activation_op[0])\n        else:\n            return torch.nn.Sequential(conv_op, self.activation_op[0])\n\n    def nonactivated_conv_op_factory(self, in_channels, out_channels, index):\n        conv_op = ConvActivation(in_channels=in_channels,\n                                 out_channels=out_channels, **self.conv_kwargs)\n        if self.batchnorm:\n            batchnorm_op = self.batchnorm_op_factory(in_channels=out_channels)\n            return torch.nn.Sequential(conv_op, batchnorm_op)\n        else:\n            return conv_op\n\n    def activation_op_factory(self, index):\n        return self.activation_op[0]\n\n    def batchnorm_op_factory(self, in_channels):\n        bn_cls_name = 'BatchNorm{}d'.format(int(self.dim))\n        bn_op_cls = getattr(torch.nn, bn_cls_name)\n        return bn_op_cls(in_channels)\n\n\n# TODO not sure how to handle out-channels properly.\n# For now, we just force the corrcect number in the last decoder layer\nclass ResBlockUNet(UNetBase):\n    \"\"\"TODO.\n\n        ACCC\n\n    Attributes:\n        activated (TYPE): Description\n        dim (TYPE): Description\n        res_block_kwargs (TYPE): Description\n        side_out_parts (TYPE): Description\n        unet_kwargs (TYPE): Description\n    \"\"\"\n    def __init__(self, in_channels, dim, out_channels, unet_kwargs=None,\n                 res_block_kwargs=None, activated=True,\n                 side_out_parts=None):\n\n        self.dim = dim\n        self.unet_kwargs = require_dict_kwargs(unet_kwargs, \"unet_kwargs must be a dict or None\")\n        self.res_block_kwargs = require_dict_kwargs(res_block_kwargs,\n                                                    \"res_block_kwargs must be a dict or None\")\n        self.activated = activated\n        if isinstance(side_out_parts, str):\n            self.side_out_parts = set([side_out_parts])\n        elif isinstance(side_out_parts, (tuple,list)):\n            self.side_out_parts = set(side_out_parts)\n        else:\n            self.side_out_parts = set()\n\n        super(ResBlockUNet, self).__init__(in_channels=in_channels,\n                                           out_channels=out_channels,\n                                           dim=dim,\n                                           **self.unet_kwargs)\n\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n\n        # is this the very last convolutional block?\n        very_last = (part == 'up' and index == 0)\n\n        # should the residual block be activated?\n        activated = not very_last or self.activated\n\n        # should the output be part of the overall\n        # return-list in the forward pass of the UNet\n        use_as_output = part in self.side_out_parts\n\n        # residual block used within the UNet\n        return _ResBlock(in_channels=in_channels, out_channels=out_channels,\n                         dim=self.dim, activated=activated,\n                         **self.res_block_kwargs), use_as_output\n"
  },
  {
    "path": "inferno/extensions/models/unet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom ..layers.identity import Identity\nfrom ..layers.convolutional import ConvELU2D, ConvELU3D, Conv2D, Conv3D\nfrom ..layers.sampling import Upsample as InfernoUpsample\nfrom ...utils.math_utils import max_allowed_ds_steps\n\n\n__all__ = ['UNetBase', 'UNet', 'ResBlockUNet']\n_all = __all__\n\n\nclass UNetBase(nn.Module):\n\n    \"\"\" Base class for implementing UNets.\n        The depth and dimension of the UNet is flexible.\n        The deriving classes must implement\n        `conv_op_factory` and can implement\n        `upsample_op_factory` and\n        `downsample_op_factory`.\n\n    Attributes:\n        in_channels (int): Number of input channels.\n        dim (int): Spatial dimension of data (must be 2 or 3).\n        out_channels (int): Number of output channels. Set to None by default,\n            which sets the number of out channels to the number of input channels\n            to preserve symmetry of feature channels (default: None).\n        depth (int): How many down-sampling / up-sampling steps\n            shall be performed (default: 3).\n        gain (int): Multiplicative increase of channels while going down in the UNet.\n            The same factor is used to decrease the number of channels while\n            going up in the UNet (default: 2).\n        residual (bool): If residual is true, the output of the down-streams\n            are added to the up-stream results.\n            Otherwise the results are concatenated (default: False).\n    \"\"\"\n\n    def __init__(self, in_channels, dim, out_channels=None, depth=3,\n                 gain=2, residual=False, upsample_mode=None, p_dropout=None):\n\n        super(UNetBase, self).__init__()\n\n        # early sanity check\n        if dim not in [2, 3]:\n            raise RuntimeError(\"UNetBase is only implemented for 2D and 3D\")\n\n        # settings related members\n        self.in_channels  = int(in_channels)\n        self.dim          = int(dim)\n        self.out_channels = self.in_channels if out_channels is\\\n            None else int(out_channels)\n        self.depth        = int(depth)\n        self.gain         = int(gain)\n        self.residual     = bool(residual)\n        self.p_dropout = p_dropout\n\n        # members to remember what to store as side output\n        self._store_conv_down = []\n        self._store_conv_bottom = False\n        self._store_conv_up = []\n\n        # number of channels per side output\n        self.n_channels_per_output = []\n\n        # members to hold actual nn.Modules / nn.ModuleLists\n        self._pre_conv_down_ops  = None\n        self._post_conv_down_ops = None\n        self._conv_down_ops  = None\n\n        self._pre_conv_up_ops  = None\n        self._post_conv_up_ops = None\n        self._conv_up_ops = None\n\n        self._upsample_ops = None\n        self._downsample_ops = None\n\n        self._pre_conv_bottom_ops  = None\n        self._post_conv_bottom_ops = None\n        self._conv_bottom_op = None\n\n        # upsample kwargs\n        self._upsample_kwargs = self._make_upsample_kwargs(upsample_mode=upsample_mode)\n\n        ########################################\n        # default dropout\n        ########################################\n        if self.p_dropout is not None:\n            self.use_dropout = True\n            if self.dim == 2 :\n                self._channel_dropout_op = self.torch.nn.Dropout2d(p=float(self.p_dropout),\n                                                                   inplace=False)\n            else:\n                self._channel_dropout_op = self.torch.nn.Dropout3d(p=float(self.p_dropout),\n                                                                   inplace=False)\n        else:\n            self.use_dropout = False\n\n        # down-stream convolution blocks\n        self._init__downstream()\n\n        # pooling / downsample operators\n        self._downsample_ops = nn.ModuleList([\n            self.downsample_op_factory(i) for i in range(depth)\n        ])\n\n        # upsample operators\n        # we flip the index that is given as argument to index consistently in up and\n        # downstream sampling factories\n        self._upsample_ops = nn.ModuleList([\n            self.upsample_op_factory(depth - i - 1) for i in range(depth)\n        ])\n\n        # bottom block of the unet\n        self._init__bottom()\n\n        # up-stream convolution blocks\n        self._init__upstream()\n\n        assert len(self.n_channels_per_output) == self._store_conv_down.count(True) + \\\n            self._store_conv_up.count(True) + int(self._store_conv_bottom)\n\n    def _get_num_channels(self, depth):\n        assert depth > 0\n        return self.in_channels * self.gain**depth\n\n    def _init__downstream(self):\n        conv_down_ops = []\n        self._store_conv_down = []\n\n        current_in_channels = self.in_channels\n\n        for i in range(self.depth):\n            out_channels = self._get_num_channels(i + 1)\n            op, return_op_res = self.conv_op_factory(in_channels=current_in_channels,\n                                                     out_channels=out_channels,\n                                                     part='down', index=i)\n            conv_down_ops.append(op)\n            if return_op_res:\n                self.n_channels_per_output.append(out_channels)\n                self._store_conv_down.append(True)\n            else:\n                self._store_conv_down.append(False)\n\n            # increase the number of channels\n            current_in_channels = out_channels\n\n        # store as proper torch ModuleList\n        self._conv_down_ops = nn.ModuleList(conv_down_ops)\n\n        return current_in_channels\n\n    def _init__bottom(self):\n\n        current_in_channels = self._get_num_channels(self.depth)\n\n        factory_res = self.conv_op_factory(in_channels=current_in_channels,\n            out_channels=current_in_channels, part='bottom', index=0)\n        if isinstance(factory_res, tuple):\n            self._conv_bottom_op, self._store_conv_bottom = factory_res\n            if self._store_conv_bottom:\n                self.n_channels_per_output.append(current_in_channels)\n        else:\n            self._conv_bottom_op = factory_res\n            self._store_conv_bottom = False\n\n    def _init__upstream(self):\n        conv_up_ops = []\n        current_in_channels = self._get_num_channels(self.depth)\n\n        for i in range(self.depth):\n            # the number of out channels (set to self.out_channels for last decoder)\n            out_channels = self.out_channels if i + 1 == self.depth else \\\n                self._get_num_channels(self.depth - i - 1)\n\n            # if not residual we concat which needs twice as many channels\n            fac = 1 if self.residual else 2\n\n            # we flip the index that is given as argument to index consistently in up and\n            # downstream conv factories\n            op, return_op_res = self.conv_op_factory(in_channels=fac*current_in_channels,\n                                                     out_channels=out_channels,\n                                                     part='up', index=self.depth - i - 1)\n            conv_up_ops.append(op)\n            if return_op_res:\n                self.n_channels_per_output.append(out_channels)\n                self._store_conv_up.append(True)\n            else:\n                self._store_conv_up.append(False)\n\n            # decrease the number of input_channels\n            current_in_channels = out_channels\n\n        # store as proper torch ModuleLis\n        self._conv_up_ops = nn.ModuleList(conv_up_ops)\n\n        # the last block needs to be stored in any case\n        if not self._store_conv_up[-1]:\n            self._store_conv_up[-1] = True\n            self.n_channels_per_output.append(out_channels)\n\n    def _make_upsample_kwargs(self, upsample_mode):\n        \"\"\"To avoid some waring from pytorch, and some missing implementations\n        for the arguments need to be handle carefully in this helper functions\n\n        Args:\n            upsample_mode (str): users choice for upsampling  interpolation style.\n        \"\"\"\n        if upsample_mode is None:\n            if self.dim == 2:\n                upsample_mode = 'bilinear'\n            elif self.dim == 3:\n                # upsample_mode = 'nearest'\n                upsample_mode = 'trilinear'\n\n        upsample_kwargs = dict(scale_factor=2, mode=upsample_mode)\n        if upsample_mode in ('bilinear', 'trilinear'):\n            upsample_kwargs['align_corners'] = False\n        return upsample_kwargs\n\n    def _forward_sanity_check(self, input):\n        if isinstance(input, tuple):\n            raise RuntimeError(\"tuples of tensors are not supported\")\n        shape = input.shape\n\n        if shape[1] != self.in_channels:\n            raise RuntimeError(\"wrong number of channels: expected %d, got %d\"%\n                (self.in_channels, input.size(1)))\n\n        if input.dim() != self.dim + 2:\n            raise RuntimeError(\"wrong number of dim: expected %d, got %d\"%\n                (self.dim+2, input.dim()))\n        self._check_scaling(input)\n\n    # override if model has different scaling\n    def _check_scaling(self, input):\n        shape = input.shape\n        mx = max_allowed_ds_steps(shape=shape[2:2+self.dim], factor=2)\n        if mx < self.depth:\n            raise RuntimeError(\"cannot downsample %d times, with shape %s\"%\n                (self.depth, str(input.size())) )\n\n    def forward(self, input):\n\n        # check if input is suitable\n        self._forward_sanity_check(input=input)\n\n        # collect all desired outputs\n        side_out = []\n\n        # remember all conv-block results of the downward part\n        # of the UNet\n        down_res = []\n\n        #################################\n        # downwards part\n        #################################\n        out = input\n        for d in range(self.depth):\n\n            out = self._conv_down_ops[d](out)\n            #out = self.dropout\n\n            down_res.append(out)\n\n            if self._store_conv_down[d]:\n                side_out.append(out)\n\n            out = self._downsample_ops[d](out)\n\n        #################################\n        # bottom part\n        #################################\n        out = self._conv_bottom_op(out)\n        if self._store_conv_bottom:\n            side_out.append(out)\n\n        #################################\n        # upward part\n        #################################\n        down_res = list(reversed(down_res)) # <- eases indexing\n        for d in range(self.depth):\n\n            # upsample\n            out = self._upsample_ops[d](out)\n\n            # the result of the downward part\n            a = down_res[d]\n\n            # add or concat?\n            if self.residual:\n                out = a + out\n            else:\n                out = torch.cat([a, out], 1)\n\n            # the convolutional block\n            out = self._conv_up_ops[d](out)\n\n            if self._store_conv_up[d]:\n                side_out.append(out)\n\n        # if  len(side_out) == 1 we actually have no side output\n        # just the main output\n        if len(side_out) == 1:\n            return side_out[0]\n        else:\n            return tuple(side_out)\n\n    def downsample_op_factory(self, index):\n        C = nn.MaxPool2d if self.dim == 2 else nn.MaxPool3d\n        return C(kernel_size=2, stride=2)\n\n    def upsample_op_factory(self, index):\\\n        return InfernoUpsample(**self._upsample_kwargs)\n        #return nn.Upsample(**self._upsample_kwargs)\n\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n        raise NotImplementedError(\"conv_op_factory need to be implemented by deriving class\")\n\n    def _dropout(self, x):\n        if self.use_dropout:\n            return self._channel_dropout_op(x)\n        else:\n            return x\n\n\n# TODO implement function to load a pretrained unet\nclass UNet(UNetBase):\n    \"\"\"\n    Default 2d / 3d U-Net implementation following:\n    https://arxiv.org/abs/1505.04597\n    \"\"\"\n    def __init__(self, in_channels, out_channels, dim,\n                 depth=4, initial_features=64, gain=2,\n                 final_activation=None, p_dropout=None):\n        # convolutional types for inner convolutions and output convolutions\n        self.default_conv = ConvELU2D if dim == 2 else ConvELU3D\n        last_conv = Conv2D if dim == 2 else Conv3D\n\n        # init the base class\n        super(UNet, self).__init__(in_channels=initial_features, dim=dim,\n                                   depth=depth, gain=gain, p_dropout=p_dropout)\n        # initial conv layer to go from the number of input channels, which are defined by the data\n        # (usually 1 or 3) to the initial number of feature maps\n        self._initial_conv = self.default_conv(in_channels, initial_features, 3)\n\n        # get the final output and activation activation\n        if isinstance(final_activation, str):\n            activation = getattr(nn, final_activation)()\n        elif isinstance(final_activation, nn.Module):\n            activation = final_activation\n        elif final_activation is None:\n            activation = None\n        else:\n            raise NotImplementedError(\"Activation of type %s is not supported\" % type(final_activation))\n\n        # override the unet base attributes for out_channels\n        self.out_channels = int(out_channels)\n        if activation is None:\n            self._output = last_conv(initial_features, self.out_channels, 1)\n        else:\n            self._output = nn.Sequential(last_conv(initial_features, self.out_channels, 1),\n                                         activation)\n\n    def forward(self, input):\n        # TODO implement 2d from 3d input (see neurofire)\n        x = self._initial_conv(input)\n        x = super(UNet, self).forward(x)\n        return self._output(x)\n\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n\n        # is this the first convolutional block?\n        first = (part == 'down' and index == 0)\n\n        # if this is the first conv block, we just need\n        # a single convolution, because we have the `_initial_conv` already\n        if first:\n            conv = self.default_conv(in_channels, out_channels, 3)\n        else:\n            conv = nn.Sequential(self.default_conv(in_channels, out_channels, 3),\n                                 self.default_conv(out_channels, out_channels, 3))\n        return conv, False\n"
  },
  {
    "path": "inferno/extensions/optimizers/__init__.py",
    "content": "from .adam import Adam\nfrom .annealed_adam import AnnealedAdam\nfrom .ranger import Ranger, RangerQH, RangerVA\n"
  },
  {
    "path": "inferno/extensions/optimizers/adam.py",
    "content": "import math\nfrom torch.optim import Optimizer\n\n\nclass Adam(Optimizer):\n    \"\"\"Implements Adam algorithm with the option of adding a L1 penalty.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)\n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,\n                 lambda_l1=0, weight_decay=0, **kwargs):\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        lambda_l1=lambda_l1, weight_decay=weight_decay,\n                        **kwargs)\n        super(Adam, self).__init__(params, defaults)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = grad.new().resize_as_(grad).zero_()\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                if group['lambda_l1'] != 0:\n                    grad.add_(group['lambda_l1'], p.data.sign())\n                if group['weight_decay'] != 0:\n                    grad.add_(group['weight_decay'], p.data)\n                \n                # Decay the first and second moment running average coefficient\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n\n                denom = exp_avg_sq.sqrt().add_(group['eps'])\n\n                bias_correction1 = 1 - beta1 ** state['step']\n                bias_correction2 = 1 - beta2 ** state['step']\n                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1\n\n                p.data.addcdiv_(-step_size, exp_avg, denom)\n\n        return loss\n"
  },
  {
    "path": "inferno/extensions/optimizers/annealed_adam.py",
    "content": "from .adam import Adam\n\n\nclass AnnealedAdam(Adam):\n    \"\"\"Implements Adam algorithm with learning rate annealing and optional L1 penalty.\n\n    It has been proposed in `Adam: A Method for Stochastic Optimization`_.\n\n    Arguments:\n        params (iterable): iterable of parameters to optimize or dicts defining\n            parameter groups\n        lr (float, optional): learning rate (default: 1e-3)\n        betas (Tuple[float, float], optional): coefficients used for computing\n            running averages of gradient and its square (default: (0.9, 0.999))\n        eps (float, optional): term added to the denominator to improve\n            numerical stability (default: 1e-8)\n        lambda_l1 (float, optional): L1 penalty (default: 0)\n        weight_decay (float, optional): L2 penalty (weight decay) (default: 0)\n        lr_decay(float, optional): decay learning rate by this factor after every step\n            (default: 1.)\n\n    .. _Adam\\: A Method for Stochastic Optimization:\n        https://arxiv.org/abs/1412.6980\n    \"\"\"\n\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,\n                 lambda_l1=0, weight_decay=0, lr_decay=1.):\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        lambda_l1=lambda_l1, weight_decay=weight_decay,\n                        lr_decay=lr_decay)\n        super(AnnealedAdam, self).__init__(params, **defaults)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        # Do an optimization step\n        super(AnnealedAdam, self).step(closure=closure)\n        # Update learning rate\n        for group in self.param_groups:\n            group['lr'] *= group['lr_decay']\n"
  },
  {
    "path": "inferno/extensions/optimizers/ranger.py",
    "content": "# easy support for additional ranger optimizers from\n# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer\ntry:\n    from ranger import Ranger, RangerVA, RangerQH\nexcept ImportError:\n    Ranger = None\n    RangerVA = None\n    RangerQH = None\n"
  },
  {
    "path": "inferno/inferno.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Main module.\"\"\"\n"
  },
  {
    "path": "inferno/io/__init__.py",
    "content": "from . import box\nfrom . import core\nfrom . import transform\nfrom . import volumetric"
  },
  {
    "path": "inferno/io/box/__init__.py",
    "content": "\"\"\"Things that work out of the box. ;)\"\"\"\n\nfrom .camvid import CamVid, get_camvid_loaders\nfrom .cityscapes import Cityscapes, get_cityscapes_loaders\nfrom .cifar import get_cifar10_loaders, get_cifar100_loaders\n\n\n__all__ = [\n    'CamVid','get_camvid_loaders', 'Cityscapes', 'get_cityscapes_loaders',\n    'get_cifar10_loaders','get_cifar100_loaders'\n]"
  },
  {
    "path": "inferno/io/box/binary_blobs.py",
    "content": "import torch.utils.data as data\nimport skimage.data\nimport numpy\nfrom operator import mul\nfrom functools import reduce\n\nclass BinaryBlobs(data.Dataset):\n\n\n    def __init__(self, size=20, length=512, blob_size_fraction=0.1,\n                 n_dim=2, volume_fraction=0.5,split='train', \n                 uniform_noise_range=(-1.2, 1.2),\n                 gaussian_noise_sigma=1.2,\n                 noise_scale_factor=8,\n                 image_transform=None, \n                 label_transform=None, \n                 joint_transform=None):\n        # how many images are in the dataset\n        self.size = size\n\n        # blob related members\n        self.length             = length\n        self.blob_size_fraction = blob_size_fraction\n        self.n_dim              = n_dim\n        self.volume_fraction    = volume_fraction\n\n        # which split {'train', 'test', 'validate'}\n        self.split              = split\n\n        # noise related members\n        self.uniform_noise_range = uniform_noise_range\n        self.gaussian_noise_sigma = float(gaussian_noise_sigma)\n        self.noise_scale_factor = noise_scale_factor\n\n        # transforms\n        self.image_transform = image_transform\n        self.label_transform = label_transform\n        self.joint_transform = joint_transform\n\n        # internal\n        split_to_seed = dict(train=0, test=1, validate=2)\n        self.master_seed  = split_to_seed[self.split]*self.size\n\n    def __getitem__(self, index):\n\n        # generate the labels\n        label = skimage.data.binary_blobs(\n            length=self.length, \n            blob_size_fraction=self.blob_size_fraction, \n            n_dim=self.n_dim, \n            volume_fraction=self.volume_fraction,\n            seed=self.master_seed + index)\n\n        # make the raw image [-1,1]\n        image  = label.astype('float32')*2\n        image -= 1\n\n\n        # add uniform noise \n        low, high = self.uniform_noise_range\n        uniform_noise   = numpy.random.uniform(low=low, high=high, \n                                               size=image.size)\n        image += uniform_noise.reshape(image.shape)\n\n        # add gaussian noise\n        gaussian_noise   = numpy.random.normal(scale=self.gaussian_noise_sigma, \n                                              size=image.size)\n        image += gaussian_noise.reshape(image.shape)\n\n\n        # generate noise at lower scales\n        small_shape = [s//self.noise_scale_factor for s in label.shape]\n        small_size = reduce(mul, small_shape, 1)\n        small_noise_img   = numpy.random.uniform(low=low, high=high, \n                                               size=small_size)\n        small_noise_img   = small_noise_img.reshape(small_shape)\n\n        gaussian_noise   = numpy.random.normal(scale=self.gaussian_noise_sigma, \n                                              size=small_size)\n        small_noise_img += gaussian_noise.reshape(small_shape)\n\n        noise_img = skimage.transform.resize(image = small_noise_img, \n            output_shape=image.shape,  mode='reflect')\n\n\n        image += noise_img\n\n        image -= image.mean()\n        image /= image.std()\n        \n        label = label.astype('long')\n        try:\n            # Apply transforms\n            if self.image_transform is not None:\n                image = self.image_transform(image)\n            if self.label_transform is not None:\n                label = self.label_transform(label)\n            if self.joint_transform is not None:\n                image, label = self.joint_transform(image, label)\n        except Exception:\n            print(\"[!] An Exception occurred while applying the transforms at \"\n                  \"index {} of split '{}'.\".format(index, self.split))\n            raise\n\n        image = image[None,...]\n        return image, label\n\n    def __len__(self):\n        return self.size\n\n\ndef get_binary_blob_loaders(train_batch_size=1, test_batch_size=1,\n                            num_workers=1,\n                            train_image_transform=None,\n                            train_label_transform=None,\n                            train_joint_transform=None,\n                            validate_image_transform=None,\n                            validate_label_transform=None,\n                            validate_joint_transform=None,\n                            test_image_transform=None,\n                            test_label_transform=None,\n                            test_joint_transform=None,\n                            **kwargs):\n    \n    trainset = BinaryBlobs(split='train',   image_transform=train_image_transform, \n        label_transform=train_label_transform, joint_transform=train_joint_transform, **kwargs)\n    testset  = BinaryBlobs(split='test',    image_transform=test_image_transform,\n        label_transform=test_label_transform, joint_transform=test_joint_transform, **kwargs)\n    validset = BinaryBlobs(split='validate',image_transform=validate_image_transform, \n        label_transform=validate_label_transform, joint_transform=validate_joint_transform, **kwargs)\n\n\n    trainloader = data.DataLoader(trainset, batch_size=train_batch_size,\n                                            num_workers=num_workers)\n\n    testloader = data.DataLoader(testset, batch_size=test_batch_size,\n                                            num_workers=num_workers)\n\n    validloader = data.DataLoader(validset, batch_size=test_batch_size,\n                                            num_workers=num_workers)\n\n    return trainloader, testloader, validloader\n\nif __name__ == \"__main__\":\n    ds = BinaryBlobs()\n    ds[0]"
  },
  {
    "path": "inferno/io/box/camvid.py",
    "content": "# Adapted from felixgwu's PR here:\n# https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/torchvision/datasets/camvid.py\n\nimport os\nimport torch\nimport torch.utils.data as data\nimport numpy as np\nfrom PIL import Image\nfrom torchvision.datasets.folder import default_loader\nfrom ...utils.exceptions import assert_\nfrom ..transform.base import Compose\nfrom ..transform.generic import Normalize, NormalizeRange, Cast, AsTorchBatch, Label2OneHot\nfrom ..transform.image import \\\n    RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray\n\ntry:\n    from torchvision.datasets.folder import is_image_file\nexcept ImportError:\n    from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension\n\n\n    def is_image_file(filename):\n        return has_file_allowed_extension(filename, IMG_EXTENSIONS)\n\nCAMVID_CLASSES = ['Sky',\n                  'Building',\n                  'Column-Pole',\n                  'Road',\n                  'Sidewalk',\n                  'Tree',\n                  'Sign-Symbol',\n                  'Fence',\n                  'Car',\n                  'Pedestrain',\n                  'Bicyclist',\n                  'Void']\n\n# weights when using median frequency balancing used in SegNet paper\n# https://arxiv.org/pdf/1511.00561.pdf\n# The numbers were generated by:\n# https://github.com/yandex/segnet-torch/blob/master/datasets/camvid-gen.lua\nCAMVID_CLASS_WEIGHTS = [0.58872014284134,\n                        0.51052379608154,\n                        2.6966278553009,\n                        0.45021694898605,\n                        1.1785038709641,\n                        0.77028578519821,\n                        2.4782588481903,\n                        2.5273461341858,\n                        1.0122526884079,\n                        3.2375309467316,\n                        4.1312313079834,\n                        0]\n# mean and std\nCAMVID_MEAN = [0.41189489566336, 0.4251328133025, 0.4326707089857]\nCAMVID_STD = [0.27413549931506, 0.28506257482912, 0.28284674400252]\n\nCAMVID_CLASS_COLORS = [\n    (128, 128, 128),\n    (128, 0, 0),\n    (192, 192, 128),\n    (128, 64, 128),\n    (0, 0, 192),\n    (128, 128, 0),\n    (192, 128, 128),\n    (64, 64, 128),\n    (64, 0, 128),\n    (64, 64, 0),\n    (0, 128, 192),\n    (0, 0, 0),\n]\n\n\ndef make_dataset(dir):\n    images = []\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                item = path\n                images.append(item)\n    return images\n\n\ndef label_to_long_tensor(pic):\n    label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))\n    label = label.view(pic.size[1], pic.size[0], 1)\n    label = label.transpose(0, 1).transpose(0, 2).squeeze().contiguous().long()\n    return label\n\n\ndef label_to_pil_image(label):\n    label = label.unsqueeze(0)\n    colored_label = torch.zeros(3, label.size(1), label.size(2)).byte()\n    for i, color in enumerate(CAMVID_CLASS_COLORS):\n        mask = label.eq(i)\n        for j in range(3):\n            colored_label[j].masked_fill_(mask, color[j])\n    npimg = colored_label.numpy()\n    npimg = np.transpose(npimg, (1, 2, 0))\n    mode = None\n    if npimg.shape[2] == 1:\n        npimg = npimg[:, :, 0]\n        mode = \"L\"\n\n    return Image.fromarray(npimg, mode=mode)\n\n\nclass CamVid(data.Dataset):\n    SPLIT_NAME_MAPPING = {'train': 'train',\n                          'training': 'train',\n                          'validate': 'val',\n                          'val': 'val',\n                          'validation': 'val',\n                          'test': 'test',\n                          'testing': 'test'}\n    # Dataset statistics\n    CLASS_WEIGHTS = CAMVID_CLASS_WEIGHTS\n    CLASSES = CAMVID_CLASSES\n    MEAN = CAMVID_MEAN\n    STD = CAMVID_STD\n\n    def __init__(self, root, split='train',\n                 image_transform=None, label_transform=None, joint_transform=None,\n                 download=False, loader=default_loader):\n        # Validate\n        assert_(split in self.SPLIT_NAME_MAPPING.keys(),\n                \"`split` must be one of {}\".format(set(self.SPLIT_NAME_MAPPING.keys())),\n                KeyError)\n        # Root directory and split\n        self.root_directory = root\n        self.split = self.SPLIT_NAME_MAPPING.get(split)\n        # Utils\n        self.image_loader = loader\n        # Transforms\n        self.image_transform = image_transform\n        self.label_transform = label_transform\n        self.joint_transform = joint_transform\n        # For when we implement download:\n        if download:\n            self.download()\n        # Make dataset with paths to the image\n        self.image_paths = make_dataset(os.path.join(self.root_directory, self.split))\n\n    def __getitem__(self, index):\n        path = self.image_paths[index]\n        image = self.image_loader(path)\n        label = Image.open(path.replace(self.split, self.split + 'annot'))\n        # Apply transforms\n        if self.image_transform is not None:\n            image = self.image_transform(image)\n        if self.label_transform is not None:\n            label = self.label_transform(label)\n        if self.joint_transform is not None:\n            image, label = self.joint_transform(image, label)\n        return image, label\n\n    def __len__(self):\n        return len(self.image_paths)\n\n    def download(self):\n        # TODO: please download the dataset from\n        # https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid\n        raise NotImplementedError\n\n\n# noinspection PyTypeChecker\ndef get_camvid_loaders(root_directory, image_shape=(360, 480), labels_as_onehot=False,\n                       train_batch_size=1, validate_batch_size=1, test_batch_size=1,\n                       num_workers=2):\n    # Make transforms\n    image_transforms = Compose(PILImage2NumPyArray(),\n                               NormalizeRange(),\n                               RandomGammaCorrection(),\n                               Normalize(mean=CAMVID_MEAN, std=CAMVID_STD))\n    label_transforms = PILImage2NumPyArray()\n    joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0),\n                                               preserve_aspect_ratio=True),\n                               # Scale raw image back to the original shape\n                               Scale(output_image_shape=image_shape,\n                                     interpolation_order=3, apply_to=[0]),\n                               # Scale segmentation back to the original shape\n                               # (without interpolation)\n                               Scale(output_image_shape=image_shape,\n                                     interpolation_order=0, apply_to=[1]),\n                               RandomFlip(allow_ud_flips=False),\n                               # Cast raw image to float\n                               Cast('float', apply_to=[0]))\n    if labels_as_onehot:\n        # See cityscapes loader to understand why this is here.\n        joint_transforms\\\n            .add(Label2OneHot(num_classes=len(CAMVID_CLASS_WEIGHTS), dtype='bool',\n                              apply_to=[1]))\\\n            .add(Cast('float', apply_to=[1]))\n    else:\n        # Cast label image to long\n        joint_transforms.add(Cast('long', apply_to=[1]))\n    # Batchify\n    joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False))\n    # Build datasets\n    train_dataset = CamVid(root_directory, split='train',\n                           image_transform=image_transforms,\n                           label_transform=label_transforms,\n                           joint_transform=joint_transforms)\n    validate_dataset = CamVid(root_directory, split='validate',\n                              image_transform=image_transforms,\n                              label_transform=label_transforms,\n                              joint_transform=joint_transforms)\n    test_dataset = CamVid(root_directory, split='test',\n                          image_transform=image_transforms,\n                          label_transform=label_transforms,\n                          joint_transform=joint_transforms)\n    # Build loaders\n    train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size,\n                                   shuffle=True, num_workers=num_workers, pin_memory=True)\n    validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size,\n                                      shuffle=True, num_workers=num_workers, pin_memory=True)\n    test_loader = data.DataLoader(test_dataset, batch_size=test_batch_size,\n                                  shuffle=True, num_workers=num_workers, pin_memory=True)\n    return train_loader, validate_loader, test_loader\n"
  },
  {
    "path": "inferno/io/box/cifar.py",
    "content": "import os\nimport torch\nimport torchvision\nimport torchvision.transforms as transforms\nfrom torch.utils.data.sampler import SubsetRandomSampler\n\n\ndef get_cifar10_loaders(root_directory, train_batch_size=128, test_batch_size=256,\n                        download=False, augment=False, validation_dataset_size=None):\n    # Data preparation for CIFAR10.\n    if augment:\n        transform_train = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),\n        ])\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),\n        ])\n    else:\n        transform_train = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),\n        ])\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.247, 0.2435, 0.2616)),\n        ])\n\n    trainset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'),\n                                            train=True, download=download,\n                                            transform=transform_train)\n    if validation_dataset_size:\n        indices = torch.randperm(len(trainset))\n        train_indices = indices[:(len(indices) - validation_dataset_size)]\n        valid_indices = indices[(len(indices) - validation_dataset_size):]\n        validset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'),\n                                                train=True, download=download,\n                                                transform=transform_test)\n        trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,\n                                                  pin_memory=True, num_workers=1,\n                                                  sampler=SubsetRandomSampler(train_indices))\n        validloader = torch.utils.data.DataLoader(validset, batch_size=test_batch_size,\n                                                  pin_memory=True, num_workers=1,\n                                                  sampler=SubsetRandomSampler(valid_indices))\n    else:\n        trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,\n                                                  shuffle=True, pin_memory=True,  num_workers=1)\n\n    testset = torchvision.datasets.CIFAR10(root=os.path.join(root_directory, 'data'),\n                                           train=False, download=download,\n                                           transform=transform_test)\n    testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,\n                                             shuffle=False, pin_memory=True,  num_workers=1)\n\n    if validation_dataset_size:\n        return trainloader, validloader, testloader\n    else:\n        return trainloader, testloader\n\n\ndef get_cifar100_loaders(root_directory, train_batch_size=128, test_batch_size=100,\n                         download=False, augment=False, validation_dataset_size=None):\n    # Data preparation for CIFAR100. Adapted from\n    # https://github.com/kuangliu/pytorch-cifar/blob/master/main.py\n    if augment:\n        transform_train = transforms.Compose([\n            transforms.RandomCrop(32, padding=4),\n            transforms.RandomHorizontalFlip(),\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)),\n        ])\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)),\n        ])\n    else:\n        transform_train = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)),\n        ])\n        transform_test = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize(mean=(0.5071, 0.4865, 0.4409), std=(0.2673, 0.2564, 0.2762)),\n        ])\n\n    trainset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'),\n                                             train=True, download=download,\n                                             transform=transform_train)\n    if validation_dataset_size:\n        indices = torch.randperm(len(trainset))\n        train_indices = indices[:(len(indices) - validation_dataset_size)]\n        valid_indices = indices[(len(indices) - validation_dataset_size):]\n        validset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'),\n                                                 train=True, download=download,\n                                                 transform=transform_test)\n        trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,\n                                                  pin_memory=True,  num_workers=1,\n                                                  sampler=SubsetRandomSampler(train_indices))\n        validloader = torch.utils.data.DataLoader(validset, batch_size=test_batch_size,\n                                                  pin_memory=True, num_workers=1,\n                                                  sampler=SubsetRandomSampler(valid_indices))\n    else:\n        trainloader = torch.utils.data.DataLoader(trainset, batch_size=train_batch_size,\n                                                  shuffle=True, pin_memory=True, num_workers=1)\n\n    testset = torchvision.datasets.CIFAR100(root=os.path.join(root_directory, 'data'),\n                                            train=False, download=download,\n                                            transform=transform_test)\n    testloader = torch.utils.data.DataLoader(testset, batch_size=test_batch_size,\n                                             shuffle=False, pin_memory=True, num_workers=1)\n\n    if validation_dataset_size:\n        return trainloader, validloader, testloader\n    else:\n        return trainloader, testloader\n"
  },
  {
    "path": "inferno/io/box/cityscapes.py",
    "content": "import zipfile\nimport io\nimport os\nimport torch.utils.data as data\nfrom PIL import Image\nfrom os.path import join, relpath, abspath\nfrom ...utils.exceptions import assert_\nfrom ..transform.base import Compose\nfrom ..transform.generic import \\\n    Normalize, NormalizeRange, Cast, AsTorchBatch, Project, Label2OneHot\nfrom ..transform.image import \\\n    RandomSizedCrop, RandomGammaCorrection, RandomFlip, Scale, PILImage2NumPyArray\nfrom ..core import Concatenate\n\n\nCITYSCAPES_CLASSES = {\n    0: 'unlabeled',\n    1: 'ego vehicle',\n    2: 'rectification border',\n    3: 'out of roi',\n    4: 'static',\n    5: 'dynamic',\n    6: 'ground',\n    7: 'road',\n    8: 'sidewalk',\n    9: 'parking',\n    10: 'rail track',\n    11: 'building',\n    12: 'wall',\n    13: 'fence',\n    14: 'guard rail',\n    15: 'bridge',\n    16: 'tunnel',\n    17: 'pole',\n    18: 'polegroup',\n    19: 'traffic light',\n    20: 'traffic sign',\n    21: 'vegetation',\n    22: 'terrain',\n    23: 'sky',\n    24: 'person',\n    25: 'rider',\n    26: 'car',\n    27: 'truck',\n    28: 'bus',\n    29: 'caravan',\n    30: 'trailer',\n    31: 'train',\n    32: 'motorcycle',\n    33: 'bicycle',\n    -1: 'license plate'\n}\n\nIGNORE_CLASS_LABEL = 19\n\n# Class labels to use for training, found here:\n# https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py#L61\nCITYSCAPES_CLASSES_TO_LABELS = {\n    0: IGNORE_CLASS_LABEL,\n    1: IGNORE_CLASS_LABEL,\n    2: IGNORE_CLASS_LABEL,\n    3: IGNORE_CLASS_LABEL,\n    4: IGNORE_CLASS_LABEL,\n    5: IGNORE_CLASS_LABEL,\n    6: IGNORE_CLASS_LABEL,\n    7: 0,\n    8: 1,\n    9: IGNORE_CLASS_LABEL,\n    10: IGNORE_CLASS_LABEL,\n    11: 2,\n    12: 3,\n    13: 4,\n    14: IGNORE_CLASS_LABEL,\n    15: IGNORE_CLASS_LABEL,\n    16: IGNORE_CLASS_LABEL,\n    17: 5,\n    18: IGNORE_CLASS_LABEL,\n    19: 6,\n    20: 7,\n    21: 8,\n    22: 9,\n    23: 10,\n    24: 11,\n    25: 12,\n    26: 13,\n    27: 14,\n    28: 15,\n    29: IGNORE_CLASS_LABEL,\n    30: IGNORE_CLASS_LABEL,\n    31: 16,\n    32: 17,\n    33: 18,\n    -1: IGNORE_CLASS_LABEL\n}\n\n# Map classes to official cityscapes colors\nCITYSCAPES_CLASS_COLOR_MAPPING = {\n    0: (0, 0, 0),\n    1: (0, 0, 0),\n    2: (0, 0, 0),\n    3: (0, 0, 0),\n    4: (0, 0, 0),\n    5: (111, 74, 0),\n    6: (81, 0, 81),\n    7: (128, 64, 128),\n    8: (244, 35, 232),\n    9: (250, 170, 160),\n    10: (230, 150, 140),\n    11: (70, 70, 70),\n    12: (102, 102, 156),\n    13: (190, 153, 153),\n    14: (180, 165, 180),\n    15: (150, 100, 100),\n    16: (150, 120, 90),\n    17: (153, 153, 153),\n    18: (153, 153, 153),\n    19: (250, 170, 30),\n    20: (220, 220, 0),\n    21: (107, 142, 35),\n    22: (152, 251, 152),\n    23: (70, 130, 180),\n    24: (220, 20, 60),\n    25: (255, 0, 0),\n    26: (0, 0, 142),\n    27: (0, 0, 70),\n    28: (0, 60, 100),\n    29: (0, 0, 90),\n    30: (0, 0, 110),\n    31: (0, 80, 100),\n    32: (0, 0, 230),\n    33: (119, 11, 32),\n    -1: (0, 0, 142),\n}\n\n# Weights corresponding to the outputs\nCITYSCAPES_LABEL_WEIGHTS = {\n    0: 1.,\n    1: 1.,\n    2: 1.,\n    3: 1.,\n    4: 1.,\n    5: 1.,\n    6: 1.,\n    7: 1.,\n    8: 1.,\n    9: 1.,\n    10: 1.,\n    11: 1.,\n    12: 1.,\n    13: 1.,\n    14: 1.,\n    15: 1.,\n    16: 1.,\n    17: 1.,\n    18: 1.,\n    19: 0.\n}\n\n# 0:void 1:flat  2:construction  3:object  4:nature  5:sky  6:human  7:vehicle\nCITYSCAPES_CATEGORIES = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2,\n                         3, 3, 3, 3, 4, 4, 5, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7]\n\nCITYSCAPES_IGNORE_IN_EVAL = [True, True, True, True, True, True, True, False, False, True, True,\n                             False, False, False, True, True, True, False, True, False, False,\n                             False, False, False, False,\n                             False, False, False, False, True, True, False, False, False, True]\n\n# mean and std\nCITYSCAPES_MEAN = [0.28689554, 0.32513303, 0.28389177]\nCITYSCAPES_STD = [0.18696375, 0.19017339, 0.18720214]\n\n\ndef get_matching_labelimage_file(f, groundtruth):\n    fs = f.split('/')\n    fs[0] = groundtruth\n    fs[-1] = str.replace(fs[-1], 'leftImg8bit', groundtruth + '_labelIds')\n    return '/'.join(fs)\n\n\ndef get_filelist(path):\n    if path.endswith('.zip'):\n        return zipfile.ZipFile(path, 'r').filelist\n    elif os.path.isdir(path):\n        return [relpath(join(root, filename), abspath(join(path, '..')))\n                for root, _, filenames in os.walk(path) for filename in filenames]\n    else:\n        raise NotImplementedError(\"Path must be a zip archive or a directory.\")\n\n\ndef make_dataset(path, split):\n    images = []\n    for f in get_filelist(path):\n        if isinstance(f, str):\n            fn = f\n            fns = f.split('/')\n        else:\n            fn = f.filename\n            fns = f.filename.split('/')\n        if fns[-1].endswith('.png') and fns[1] == split:\n            # use first folder name to identify train/val/test images\n            if split == 'train_extra':\n                groundtruth = 'gtCoarse'\n            else:\n                groundtruth = 'gtFine'\n\n            fl = get_matching_labelimage_file(fn, groundtruth)\n            images.append((f, fl))\n    return images\n\n\ndef extract_image(path, image_path):\n    if path.endswith('.zip'):\n        # read image directly from zipfile if path is a zip\n        return Image.open(io.BytesIO(zipfile.ZipFile(path, 'r').read(image_path)))\n    else:\n        return Image.open(join(abspath(join(path, '..')), image_path), 'r')\n\n\nclass Cityscapes(data.Dataset):\n    SPLIT_NAME_MAPPING = {'train': 'train',\n                          'training': 'train',\n                          'validate': 'val',\n                          'val': 'val',\n                          'validation': 'val',\n                          'test': 'test',\n                          'testing': 'test',\n                          'training_extra': 'train_extra',\n                          'train_extra': 'train_extra'}\n\n    # Dataset statistics\n    CLASSES = CITYSCAPES_CLASSES\n    MEAN = CITYSCAPES_MEAN\n    STD = CITYSCAPES_STD\n\n    BLACKLIST = ['leftImg8bit/train_extra/troisdorf/troisdorf_000000_000073_leftImg8bit.png']\n\n    def __init__(self, root_folder, split='train', read_from_zip_archive=True,\n                 image_transform=None, label_transform=None, joint_transform=None):\n        \"\"\"\n        Parameters:\n        root_folder: folder that contains both leftImg8bit_trainvaltest.zip and\n               gtFine_trainvaltest.zip archives.\n        split: name of dataset spilt (i.e. 'train_extra', 'train', 'val' or 'test') \n        \"\"\"\n\n        assert_(split in self.SPLIT_NAME_MAPPING.keys(),\n                \"`split` must be one of {}\".format(set(self.SPLIT_NAME_MAPPING.keys())),\n                KeyError)\n        self.split = self.SPLIT_NAME_MAPPING.get(split)\n        self.read_from_zip_archive = read_from_zip_archive\n\n        # Get roots\n        self.image_root, self.label_root = [join(root_folder, groot)\n                                            for groot in self.get_image_and_label_roots()]\n\n        # Transforms\n        self.image_transform = image_transform\n        self.label_transform = label_transform\n        self.joint_transform = joint_transform\n        # Make list with paths to the images\n        self.image_paths = make_dataset(self.image_root, self.split)\n\n    def __getitem__(self, index):\n        pi, pl = self.image_paths[index]\n        if pi in self.BLACKLIST:\n            # Select the next image if the current image is bad\n            return self[index + 1]\n        image = extract_image(self.image_root, pi)\n        label = extract_image(self.label_root, pl)\n        try:\n            # Apply transforms\n            if self.image_transform is not None:\n                image = self.image_transform(image)\n            if self.label_transform is not None:\n                label = self.label_transform(label)\n            if self.joint_transform is not None:\n                image, label = self.joint_transform(image, label)\n        except Exception:\n            print(\"[!] An Exception occurred while applying the transforms at \"\n                  \"index {} of split '{}'.\".format(index, self.split))\n            raise\n        return image, label\n\n    def __len__(self):\n        return len(self.image_paths)\n\n    def download(self):\n        # TODO: please download the dataset from\n        # https://www.cityscapes-dataset.com/\n        raise NotImplementedError\n\n    def get_image_and_label_roots(self):\n        all_roots = {\n            'zipped':\n                {\n                    'train': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'),\n                    'val': ('leftImg8bit_trainvaltest.zip', 'gtFine_trainvaltest.zip'),\n                    'train_extra': ('leftImg8bit_trainextra.zip', 'gtCoarse.zip')\n                },\n            'unzipped':\n                {\n                    'train': ('leftImg8bit', 'gtFine'),\n                    'val': ('leftImg8bit', 'gtFine'),\n                    'train_extra': ('leftImg8bit', 'gtCoarse')\n                }\n        }\n        image_and_label_roots = all_roots\\\n            .get('zipped' if self.read_from_zip_archive else 'unzipped').get(self.split)\n        return image_and_label_roots\n\n\ndef make_transforms(image_shape, labels_as_onehot):\n    # Make transforms\n    image_transforms = Compose(PILImage2NumPyArray(),\n                               NormalizeRange(),\n                               RandomGammaCorrection(),\n                               Normalize(mean=CITYSCAPES_MEAN, std=CITYSCAPES_STD))\n    label_transforms = Compose(PILImage2NumPyArray(),\n                               Project(projection=CITYSCAPES_CLASSES_TO_LABELS))\n    joint_transforms = Compose(RandomSizedCrop(ratio_between=(0.6, 1.0),\n                                               preserve_aspect_ratio=True),\n                               # Scale raw image back to the original shape\n                               Scale(output_image_shape=image_shape,\n                                     interpolation_order=3, apply_to=[0]),\n                               # Scale segmentation back to the original shape\n                               # (without interpolation)\n                               Scale(output_image_shape=image_shape,\n                                     interpolation_order=0, apply_to=[1]),\n                               RandomFlip(allow_ud_flips=False),\n                               # Cast raw image to float\n                               Cast('float', apply_to=[0]))\n    if labels_as_onehot:\n        # Applying Label2OneHot on the full label image makes it unnecessarily expensive,\n        # because we're throwing it away with RandomSizedCrop and Scale. Tests show that it's\n        # ~1 sec faster per image.\n        joint_transforms \\\n            .add(Label2OneHot(num_classes=len(CITYSCAPES_LABEL_WEIGHTS), dtype='bool',\n                              apply_to=[1])) \\\n            .add(Cast('float', apply_to=[1]))\n    else:\n        # Cast label image to long\n        joint_transforms.add(Cast('long', apply_to=[1]))\n    # Batchify\n    joint_transforms.add(AsTorchBatch(2, add_channel_axis_if_necessary=False))\n    # Return as kwargs\n    return {'image_transform': image_transforms,\n            'label_transform': label_transforms,\n            'joint_transform': joint_transforms}\n\n\ndef get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), labels_as_onehot=False,\n                           include_coarse_dataset=False, read_from_zip_archive=True,\n                           train_batch_size=1, validate_batch_size=1, num_workers=2):\n    # Build datasets\n    train_dataset = Cityscapes(root_directory, split='train',\n                               read_from_zip_archive=read_from_zip_archive,\n                               **make_transforms(image_shape, labels_as_onehot))\n    if include_coarse_dataset:\n        # Build coarse dataset\n        coarse_dataset = Cityscapes(root_directory, split='train_extra',\n                                    read_from_zip_archive=read_from_zip_archive,\n                                    **make_transforms(image_shape, labels_as_onehot))\n        # ... and concatenate with train_dataset\n        train_dataset = Concatenate(coarse_dataset, train_dataset)\n    validate_dataset = Cityscapes(root_directory, split='validate',\n                                  read_from_zip_archive=read_from_zip_archive,\n                                  **make_transforms(image_shape, labels_as_onehot))\n\n    # Build loaders\n    train_loader = data.DataLoader(train_dataset, batch_size=train_batch_size,\n                                   shuffle=True, num_workers=num_workers, pin_memory=True)\n    validate_loader = data.DataLoader(validate_dataset, batch_size=validate_batch_size,\n                                      shuffle=True, num_workers=num_workers, pin_memory=True)\n    return train_loader, validate_loader\n"
  },
  {
    "path": "inferno/io/core/__init__.py",
    "content": "from .base import SyncableDataset\nfrom .zip import Zip, ZipReject\nfrom .concatenate import Concatenate\n"
  },
  {
    "path": "inferno/io/core/base.py",
    "content": "from torch.utils.data.dataset import Dataset\n\n\nclass SyncableDataset(Dataset):\n    def __init__(self, base_sequence=None):\n        self.base_sequence = base_sequence\n\n    def sync_with(self, dataset):\n        if hasattr(dataset, 'base_sequence'):\n            self.base_sequence = dataset.base_sequence\n        return self\n\n    def __len__(self):\n        if self.base_sequence is None:\n            raise RuntimeError(\"Class {} does not specify a base sequence. Either specify \"\n                               \"one by assigning to self.base_sequence or override the \"\n                               \"__len__ method.\".format(self.__class__.__name__))\n        else:\n            return len(self.base_sequence)\n\n\nclass IndexSpec(object):\n    \"\"\"\n    Class to wrap any extra index information a `Dataset` object might want to send back.\n    This could be useful in (say) inference, where we would wish to (asynchronously) know\n    more about the current input.\n    \"\"\"\n    def __init__(self, index=None, base_sequence_at_index=None):\n        self.index = index\n        self.base_sequence_at_index = base_sequence_at_index\n\n    def __int__(self):\n        return int(self.index)\n"
  },
  {
    "path": "inferno/io/core/concatenate.py",
    "content": "import numpy as np\nfrom torch.utils.data.dataset import Dataset\nfrom ...utils import python_utils as pyu\n\n\nclass Concatenate(Dataset):\n    \"\"\"\n    Concatenates mutliple datasets to one. This class does not implement\n    synchronization primitives.\n    \"\"\"\n    def __init__(self, *datasets, transforms=None):\n        assert all([isinstance(dataset, Dataset) for dataset in datasets])\n        assert len(datasets) >= 1\n        assert transforms is None or callable(transforms)\n        self.datasets = datasets\n        self.transforms = transforms\n\n    def map_index(self, index):\n        # Get a list of lengths of all datasets. Say the answer is [4, 3, 3],\n        # and we're looking for index = 5.\n        len_list = list(map(len, self.datasets))\n        # Cumulate to a numpy array. The answer is [4, 7, 10]\n        cumulative_len_list = np.cumsum(len_list)\n        # When the index is subtracted, we get [-1, 2, 5]. We're looking for the (index\n        # of the) first cumulated len which is larger than the index (in this case,\n        # 7 (index 1)).\n        offset_cumulative_len_list = cumulative_len_list - index\n        dataset_index = np.argmax(offset_cumulative_len_list > 0)\n        # With the dataset index, we figure out the index in dataset\n        if dataset_index == 0:\n            # First dataset - index corresponds to index_in_dataset\n            index_in_dataset = index\n        else:\n            # Get cumulated length up to the current dataset\n            len_up_to_dataset = cumulative_len_list[dataset_index - 1]\n            # Compute index_in_dataset as that what's left\n            index_in_dataset = index - len_up_to_dataset\n        return dataset_index, index_in_dataset\n\n    def __getitem__(self, index):\n        assert index < len(self)\n        dataset_index, index_in_dataset = self.map_index(index)\n        fetched = self.datasets[dataset_index][index_in_dataset]\n        if self.transforms is None:\n            return fetched\n        elif callable(self.transforms):\n            return self.transforms(*pyu.to_iterable(fetched))\n        else:\n            raise NotImplementedError\n\n    def __len__(self):\n        return sum([len(dataset) for dataset in self.datasets])\n\n    def __repr__(self):\n        if len(self.datasets) < 3:\n            return \"Concatenate(\" + \\\n                   \", \".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + \", \" + \\\n                   self.datasets[-1].__repr__() + \\\n                   \")\"\n        else:\n            return \"Concatenate({}xDatasets)\".format(len(self.datasets))\n"
  },
  {
    "path": "inferno/io/core/data_utils.py",
    "content": "\ndef implements_sync_primitives(dataset):\n    return hasattr(dataset, 'sync_with') and callable(getattr(dataset, 'sync_with'))\n\n\ndef defines_base_sequence(dataset):\n    return hasattr(dataset, 'base_sequence') and dataset.base_sequence is not None\n"
  },
  {
    "path": "inferno/io/core/zip.py",
    "content": "from torch.utils.data.dataset import Dataset\nimport torch.multiprocessing as mp\nimport numpy as np\nfrom . import data_utils as du\nfrom .base import SyncableDataset\nfrom ...utils.exceptions import assert_\nfrom ...utils import python_utils as pyu\nimport random\n\n\nclass Zip(SyncableDataset):\n    \"\"\"\n    Zip two or more datasets to one dataset. If the datasets implement synchronization primitives,\n    they are all synchronized with the first dataset.\n    \"\"\"\n\n    def __init__(self, *datasets, sync=False, transforms=None):\n        super(Zip, self).__init__()\n        assert_(len(datasets) >= 1, \"Expecting one or more datasets, got none.\", ValueError)\n        for dataset_index, dataset in enumerate(datasets):\n            assert_(isinstance(dataset, Dataset),\n                    \"Object at position {} of type {} is not a subclass of \"\n                    \"`torch.utils.data.dataset.Dataset`\"\n                    .format(dataset_index, type(dataset).__name__),\n                    TypeError)\n        assert_(transforms is None or callable(transforms),\n                \"Given `transforms` is not callable.\",\n                TypeError)\n        self.datasets = datasets\n        self.sync = sync\n        self.transforms = transforms\n        if self.sync:\n            self.sync_datasets()\n        # Inherit base sequence if sync'ing\n        if self.sync and all([du.defines_base_sequence(dataset) for dataset in self.datasets]):\n            self.base_sequence = list(zip(*[dataset.base_sequence for dataset in self.datasets]))\n        else:\n            self.base_sequence = None\n\n    def sync_datasets(self):\n        master_dataset = self.datasets[0]\n        for dataset in self.datasets[1:]:\n            if du.implements_sync_primitives(dataset):\n                dataset.sync_with(master_dataset)\n\n    def sync_with(self, dataset):\n        master_dataset = self.datasets[0]\n        if du.implements_sync_primitives(master_dataset):\n            master_dataset.sync_with(dataset)\n        # Sync all other datasets\n        self.sync_datasets()\n\n    def __getitem__(self, index):\n        assert_(index < len(self), exception_type=IndexError)\n        fetched = [dataset[index] for dataset in self.datasets]\n        if self.transforms is None:\n            return fetched\n        elif callable(self.transforms):\n            return self.transforms(*fetched)\n        else:\n            raise RuntimeError\n\n    def __len__(self):\n        if du.defines_base_sequence(self):\n            return super(Zip, self).__len__()\n        else:\n            return min([len(dataset) for dataset in self.datasets])\n\n    def __repr__(self):\n        if len(self.datasets) > 3:\n            return \"{}({}xDatasets)\".format(type(self).__name__, len(self.datasets))\n        else:\n            return \"{}(\".format(type(self).__name__) + \\\n                   \", \".join([dataset.__repr__() for dataset in self.datasets[:-1]]) + \", \" + \\\n                   self.datasets[-1].__repr__() + \\\n                   ')'\n\n\nclass ZipReject(Zip):\n    \"\"\"\n    Extends `Zip` by the functionality of rejecting samples that don't fulfill\n    a specified rejection criterion.\n    \"\"\"\n\n    def __init__(self, *datasets, sync=False, transforms=None,\n                 rejection_dataset_indices, rejection_criterion,\n                 random_jump_after_reject=True):\n        \"\"\"\n        Parameters\n        ----------\n        datasets : list or tuple\n            Datasets to zip.\n        sync : bool\n            Whether to synchronize zipped datasets if a synchronization primitive is available.\n        transforms : callable\n            Transforms to apply on the fetched batch.\n        rejection_dataset_indices : int or list or tuple\n            Indices (or index) corresponding to the datasets which are used to determine whether\n            a batch should be rejected.\n        rejection_criterion : callable\n            Criterion for rejection of batch. Must be a callable that accepts one or more\n            arrays / tensors and returns True if the corresponding batch should be rejected,\n            False otherwise. Should accept as many inputs as the number of elements in\n            `rejection_dataset_indices` if the latter is a list, and 1 otherwise. Note that\n            the order of the inputs to the `rejection_criterion` is the same as the order of\n            the indices in `rejection_dataset_indices`.\n        random_jump_after_reject: bool\n            Whether to try a random index or the rejected index incremented by one after rejection.\n        \"\"\"\n        super(ZipReject, self).__init__(*datasets, sync=sync, transforms=transforms)\n        for rejection_dataset_index in pyu.to_iterable(rejection_dataset_indices):\n            assert_(rejection_dataset_index < len(datasets),\n                    \"Index of the dataset to be used for rejection (= {}) is larger \"\n                    \"than the number of datasets (= {}) minus one.\"\n                    .format(rejection_dataset_index, len(datasets)),\n                    IndexError)\n        self.rejection_dataset_indices = pyu.to_iterable(rejection_dataset_indices)\n        assert_(callable(rejection_criterion),\n                \"Rejection criterion is not callable as it should be.\",\n                TypeError)\n        # return true if fetched should be rejected\n        self.rejection_criterion = rejection_criterion\n        # Array shared over processes to keep track of which indices have been rejected\n        self.rejected = mp.Array('b', len(self))\n        self.available_indices = None\n        # optional index mapping to exclude rejected indices, reducing dataset size (see remove_rejected())\n        self.index_mapping = None\n\n        self.random_jump_after_reject = random_jump_after_reject\n\n    def remove_rejected(self):\n        # remove the indices belonging to samples that were rejected from the dataset\n        # this changes the length of the dataset\n        rejected = np.array(self.rejected[:])\n        self.index_mapping = np.argwhere(1 - rejected)[:, 0]\n        self.rejected = mp.Array('b', len(self))\n        # just in case of num_workers == 0\n        self.available_indices = None\n\n    def __len__(self):\n        if hasattr(self, 'index_mapping') and self.index_mapping is not None:\n            return len(self.index_mapping)\n        else:\n\n            return super(ZipReject, self).__len__()\n\n    def next_index_to_try(self, index):\n        if self.random_jump_after_reject:\n            return np.random.randint(len(self))\n        else:\n            return (index + 1) % len(self)\n\n    def fetch_from_rejection_datasets(self, index):\n        rejection_fetched = [self.datasets[rejection_dataset_index][index]\n                             for rejection_dataset_index in self.rejection_dataset_indices]\n        return rejection_fetched\n\n    def __getitem__(self, index):\n        # we increase the index until a valid batch of 'rejection_dataset' is found\n        assert_(index < len(self), exception_type=IndexError)\n        index_ = index\n        # if we have a rejection dataset, check if the rejection criterion is fulfilled\n        # and update the index\n        if self.rejection_dataset_indices is not None:\n            # at the start of each epoch, compute the available indices from the shared variable\n            if self.available_indices is None:\n                self.available_indices = set(np.argwhere(1 - np.array(self.rejected[:]))[:, 0])\n\n            reject = True\n            while reject:\n                # check if there are no potentially valid indices left\n                if not self.available_indices:\n                    raise RuntimeError(\"ZipReject: No valid batch was found!\")\n\n                # check if this index was marked as rejected before\n                if index_ not in self.available_indices:\n                    index_ = self.next_index_to_try(index_)\n                    continue\n                # check if this index was marked as rejected in any process\n                if self.rejected[index_]:\n                    self.available_indices.remove(index_)\n                    continue\n\n                # map the index, if an index_mapping has been defined (see remove_rejected())\n                mapped_index_ = index_ if self.index_mapping is None else self.index_mapping[index_]\n                # we only fetch the dataset which has the rejection criterion\n                # and only fetch all datasets when a valid index is found\n                rejection_fetched = self.fetch_from_rejection_datasets(mapped_index_)\n                # check if this batch is to be rejected\n                reject = self.rejection_criterion(*rejection_fetched)\n\n                # if so, increase the index and add it\n                if reject:\n                    self.rejected[index_] = True\n                    self.available_indices.remove(index_)\n\n            # fetch all other datasets and concatenate them with the valid rejection_fetch\n            fetched = []\n            for dataset_index, dataset in enumerate(self.datasets):\n                if dataset_index in self.rejection_dataset_indices:\n                    # Find the index in `rejection_fetched` corresponding to this dataset_index\n                    index_in_rejection_fetched = self.rejection_dataset_indices.index(dataset_index)\n                    # ... and append to fetched\n                    fetched.append(rejection_fetched[index_in_rejection_fetched])\n                else:\n                    # Fetch and append to fetched\n                    fetched.append(dataset[mapped_index_])\n        else:\n            # map the index, if an index_mapping has been defined (see remove_rejected())\n            mapped_index_ = index_ if self.index_mapping is None else self.index_mapping[index_]\n            fetched = [dataset[mapped_index_] for dataset in self.datasets]\n        # apply transforms if present\n        if self.transforms is not None:\n            assert_(callable(self.transforms), \"`self.transforms` is not callable.\", TypeError)\n            fetched = self.transforms(*fetched)\n        return fetched\n"
  },
  {
    "path": "inferno/io/transform/__init__.py",
    "content": "from .base import Transform, Compose\nfrom . import generic\nfrom . import image\nfrom . import volume\n"
  },
  {
    "path": "inferno/io/transform/base.py",
    "content": "from ...utils import python_utils as pyu\nimport numpy as np\n\n\nclass Transform(object):\n    \"\"\"\n    Base class for a Transform. The argument `apply_to` (list) specifies the indices of\n    the tensors this transform will be applied to.\n\n    The following methods are recognized (in order of descending priority):\n        - `batch_function`: Applies to all tensors in a batch simultaneously\n        - `tensor_function`: Applies to just __one__ tensor at a time.\n        - `volume_function`: For 3D volumes, applies to just __one__ volume at a time.\n        - `image_function`: For 2D or 3D volumes, applies to just __one__ image at a time.\n\n    For example, if both `volume_function` and `image_function` are defined, this means that\n    only the former will be called. If the inputs are therefore not 5D batch-tensors of 3D\n    volumes, a `NotImplementedError` is raised.\n    \"\"\"\n    def __init__(self, apply_to=None):\n        \"\"\"\n        Parameters\n        ----------\n        apply_to : list or tuple\n            Indices of tensors to apply this transform to. The indices are with respect\n            to the list of arguments this object is called with.\n        \"\"\"\n        self._random_variables = {}\n        self._apply_to = list(apply_to) if apply_to is not None else None\n\n    def build_random_variables(self, **kwargs):\n        pass\n\n    def clear_random_variables(self):\n        self._random_variables = {}\n\n    def get_random_variable(self, key, default=None, build=True,\n                            **random_variable_building_kwargs):\n        if key in self._random_variables:\n            return self._random_variables.get(key, default)\n        else:\n            if not build:\n                return default\n            else:\n                self.build_random_variables(**random_variable_building_kwargs)\n                return self.get_random_variable(key, default, build=False)\n\n    def set_random_variable(self, key, value):\n        self._random_variables.update({key: value})\n\n    def __call__(self, *tensors, **transform_function_kwargs):\n        tensors = pyu.to_iterable(tensors)\n        # Get the list of the indices of the tensors to which we're going to apply the transform\n        apply_to = list(range(len(tensors))) if self._apply_to is None else self._apply_to\n        # Flush random variables and assume they're built by image_function\n        self.clear_random_variables()\n        if hasattr(self, 'batch_function'):\n            transformed = self.batch_function(tensors, **transform_function_kwargs)\n            return pyu.from_iterable(transformed)\n        elif hasattr(self, 'tensor_function'):\n            transformed = [self._apply_tensor_function(tensor, **transform_function_kwargs)\n                           if tensor_index in apply_to else tensor\n                           for tensor_index, tensor in enumerate(tensors)]\n            return pyu.from_iterable(transformed)\n        elif hasattr(self, 'volume_function'):\n            # Loop over all tensors\n            transformed = [self._apply_volume_function(tensor, **transform_function_kwargs)\n                           if tensor_index in apply_to else tensor\n                           for tensor_index, tensor in enumerate(tensors)]\n            return pyu.from_iterable(transformed)\n        elif hasattr(self, 'image_function'):\n            # Loop over all tensors\n            transformed = [self._apply_image_function(tensor, **transform_function_kwargs)\n                           if tensor_index in apply_to else tensor\n                           for tensor_index, tensor in enumerate(tensors)]\n            return pyu.from_iterable(transformed)\n        else:\n            raise NotImplementedError\n\n    # noinspection PyUnresolvedReferences\n    def _apply_tensor_function(self, tensor, **transform_function_kwargs):\n        if isinstance(tensor, list):\n            return [self._apply_tensor_function(tens) for tens in tensor]\n        return self.tensor_function(tensor)\n\n    # noinspection PyUnresolvedReferences\n    def _apply_image_function(self, tensor, **transform_function_kwargs):\n        assert pyu.has_callable_attr(self, 'image_function')\n        if isinstance(tensor, list):\n            return [self._apply_image_function(tens) for tens in tensor]\n        # 2D case\n        if tensor.ndim == 4:\n            return np.array([np.array([self.image_function(image, **transform_function_kwargs)\n                                       for image in channel_image])\n                             for channel_image in tensor])\n        # 3D case\n        elif tensor.ndim == 5:\n            return np.array([np.array([np.array([self.image_function(image,\n                                                                     **transform_function_kwargs)\n                                                 for image in volume])\n                                       for volume in channel_volume])\n                             for channel_volume in tensor])\n        elif tensor.ndim == 3:\n            # Assume we have a 3D volume (signature zyx) and apply the image function\n            # on all yx slices.\n            return np.array([self.image_function(image, **transform_function_kwargs)\n                             for image in tensor])\n        elif tensor.ndim == 2:\n            # Assume we really do have an image.\n            return self.image_function(tensor, **transform_function_kwargs)\n        else:\n            raise NotImplementedError\n\n    # noinspection PyUnresolvedReferences\n    def _apply_volume_function(self, tensor, **transform_function_kwargs):\n        assert pyu.has_callable_attr(self, 'volume_function')\n        if isinstance(tensor, list):\n            return [self._apply_volume_function(tens) for tens in tensor]\n        # 3D case\n        if tensor.ndim == 5:\n            # tensor is bczyx\n            # volume function is applied to zyx, i.e. loop over b and c\n            # FIXME This loops one time too many\n            return np.array([np.array([np.array([self.volume_function(volume,\n                                                                      **transform_function_kwargs)\n                                                 for volume in channel_volume])\n                                       for channel_volume in batch])\n                             for batch in tensor])\n        elif tensor.ndim == 4:\n            # We're applying the volume function on a czyx tensor, i.e. we loop over c and apply\n            # volume function to (zyx)\n            return np.array([self.volume_function(volume, **transform_function_kwargs)\n                             for volume in tensor])\n        elif tensor.ndim == 3:\n            # We're applying the volume function on the volume itself\n            return self.volume_function(tensor, **transform_function_kwargs)\n        else:\n            cname = self.__class__.__name__\n            raise NotImplementedError(\"Volume function not implemented for ndim %i called in %s\" % (tensor.ndim, cname))\n\n\nclass Compose(object):\n    \"\"\"Composes multiple callables (including but not limited to `Transform` objects).\"\"\"\n    def __init__(self, *transforms):\n        \"\"\"\n        Parameters\n        ----------\n        transforms : list of callable or tuple of callable\n            Transforms to compose.\n        \"\"\"\n        assert all([callable(transform) for transform in transforms])\n        self.transforms = list(transforms)\n\n    def add(self, transform):\n        assert callable(transform)\n        self.transforms.append(transform)\n        return self\n\n    def remove(self, name):\n        transform_idx = None\n        for idx, transform in enumerate(self.transforms):\n            if type(transform).__name__ == name:\n                transform_idx = idx\n                break\n        if transform_idx is not None:\n            self.transforms.pop(transform_idx)\n        return self\n\n    def __call__(self, *tensors):\n        intermediate = tensors\n        for transform in self.transforms:\n            intermediate = pyu.to_iterable(transform(*intermediate))\n        return pyu.from_iterable(intermediate)\n\n\nclass DTypeMapping(object):\n    DTYPE_MAPPING = {'float32': 'float32',\n                     'float': 'float32',\n                     'double': 'float64',\n                     'float64': 'float64',\n                     'half': 'float16',\n                     'float16': 'float16',\n                     'long': 'int64',\n                     'int64': 'int64',\n                     'byte': 'uint8',\n                     'uint8': 'uint8',\n                     'int': 'int32',\n                     'int32': 'int32'}\n"
  },
  {
    "path": "inferno/io/transform/generic.py",
    "content": "import numpy as np\nimport torch\nfrom .base import Transform, DTypeMapping\nfrom ...utils.exceptions import assert_, DTypeError\n\n\nclass Normalize(Transform):\n    \"\"\"Normalizes input to zero mean unit variance.\"\"\"\n    def __init__(self, eps=1e-4, mean=None, std=None, ignore_value=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        eps : float\n            A small epsilon for numerical stability.\n        mean : list or float or numpy.ndarray\n            Global dataset mean for all channels.\n        std : list or float or numpy.ndarray\n            Global dataset std for all channels.\n        super_kwargs : dict\n            Kwargs to the superclass `inferno.io.transform.base.Transform`.\n        \"\"\"\n        super(Normalize, self).__init__(**super_kwargs)\n        self.eps = eps\n        self.mean = np.asarray(mean) if mean is not None else None\n        self.std = np.asarray(std) if std is not None else None\n        self.ignore_value = ignore_value\n\n    def tensor_function(self, tensor):\n        # if we have a background value that we don't want to normalize\n        mask = None if self.ignore_value is None else (tensor != self.ignore_value)\n        if mask is None:\n            mean = np.asarray(tensor.mean()) if self.mean is None else self.mean\n            std = np.asarray(tensor.std()) if self.std is None else self.std\n        else:\n            mean = np.asarray(tensor[mask].mean()) if self.mean is None else self.mean\n            std = np.asarray(tensor[mask].std()) if self.std is None else self.std\n        # Figure out how to reshape mean and std\n        reshape_as = [-1] + [1] * (tensor.ndim - 1)\n        # Normalize\n        if mask is None:\n            tensor = (tensor - mean.reshape(*reshape_as)) / (std.reshape(*reshape_as) + self.eps)\n        else:\n            # if tensor is int, the normalized tensor will be in int as well\n            tensor = tensor.astype('float64')\n            tensor[mask] = ((tensor - mean.reshape(*reshape_as)) \\\n                            / (std.reshape(*reshape_as) + self.eps))[mask]\n        return tensor\n\n\nclass NormalizeRange(Transform):\n    \"\"\"Normalizes input by a constant.\"\"\"\n    def __init__(self, normalize_by=255., **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        normalize_by : float or int\n            Scalar to normalize by.\n        super_kwargs : dict\n            Kwargs to the superclass `inferno.io.transform.base.Transform`.\n        \"\"\"\n        super(NormalizeRange, self).__init__(**super_kwargs)\n        self.normalize_by = float(normalize_by)\n\n    def tensor_function(self, tensor):\n        return tensor / self.normalize_by\n\n\nclass Project(Transform):\n    \"\"\"\n    Given a projection mapping (i.e. a dict) and an input tensor, this transform replaces\n    all values in the tensor that equal a key in the mapping with the value corresponding to\n    the key.\n    \"\"\"\n    def __init__(self, projection, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        projection : dict\n            The projection mapping.\n        super_kwargs : dict\n            Keywords to the super class.\n        \"\"\"\n        super(Project, self).__init__(**super_kwargs)\n        self.projection = dict(projection)\n\n    def tensor_function(self, tensor):\n        output = np.zeros_like(tensor)\n        for source, target in self.projection.items():\n            output[tensor == source] = target\n        return output\n\n\nclass Label2OneHot(Transform, DTypeMapping):\n    \"\"\"Convert integer labels to one-hot vectors for arbitrary dimensional data.\"\"\"\n    def __init__(self, num_classes, dtype='float', **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        num_classes : int\n            Number of classes.\n        dtype : str\n            Datatype of the output.\n        super_kwargs : dict\n            Keyword arguments to the superclass.\n        \"\"\"\n        super(Label2OneHot, self).__init__(**super_kwargs)\n        self.num_classes = num_classes\n        self.dtype = self.DTYPE_MAPPING.get(dtype)\n\n    def tensor_function(self, tensor):\n        reshaped_arange = np.arange(self.num_classes).reshape(-1, *(1,)*tensor.ndim)\n        output = np.equal(reshaped_arange, tensor).astype(self.dtype)\n        # output = np.zeros(shape=(self.num_classes,) + tensor.shape, dtype=self.dtype)\n        # # Optimizing for simplicity and memory efficiency, because one would usually\n        # # spawn multiple workers\n        # for class_num in range(self.num_classes):\n        #     output[class_num] = tensor == class_num\n        return output\n\n\nclass Cast(Transform, DTypeMapping):\n    \"\"\"Casts inputs to a specified datatype.\"\"\"\n    def __init__(self, dtype='float', **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        dtype : {'float16', 'float32', 'float64', 'half', 'float', 'double'}\n            Datatype to cast to.\n        super_kwargs : dict\n            Kwargs to the superclass `inferno.io.transform.base.Transform`.\n        \"\"\"\n        super(Cast, self).__init__(**super_kwargs)\n        assert dtype in self.DTYPE_MAPPING.keys()\n        self.dtype = self.DTYPE_MAPPING.get(dtype)\n\n    def tensor_function(self, tensor):\n        return getattr(np, self.dtype)(tensor)\n\n\nclass AsTorchBatch(Transform):\n    \"\"\"Converts a given numpy array to a torch batch tensor.\n\n    The result is a torch tensor __without__ the leading batch axis. For example,\n    if the input is an image of shape `(100, 100)`, the output is a batch of shape\n    `(1, 100, 100)`. The collate function will add the leading batch axis to obtain\n    a tensor of shape `(N, 1, 100, 100)`, where `N` is the batch-size.\n    \"\"\"\n    def __init__(self, dimensionality, add_channel_axis_if_necessary=True, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        dimensionality : {1, 2, 3}\n            Dimensionality of the data: 1 if vector, 2 if image, 3 if volume.\n        add_channel_axis_if_necessary : bool\n            Whether to add a channel axis where necessary. For example, if `dimensionality = 2`\n            and the input temperature has 2 dimensions (i.e. an image), setting\n            `add_channel_axis_if_necessary` to True results in the output being a 3 dimensional\n            tensor, where the leading dimension is a singleton and corresponds to `channel`.\n        super_kwargs : dict\n            Kwargs to the superclass `inferno.io.transform.base.Transform`.\n        \"\"\"\n        super(AsTorchBatch, self).__init__(**super_kwargs)\n        assert dimensionality in [1, 2, 3]\n        self.dimensionality = dimensionality\n        self.add_channel_axis_if_necessary = bool(add_channel_axis_if_necessary)\n\n    def _to_batch(self, tensor):\n        assert_(isinstance(tensor, np.ndarray),\n                \"Expected numpy array, got %s\" % type(tensor),\n                DTypeError)\n        if self.dimensionality == 3:\n            # We're dealing with a volume. tensor can either be 3D or 4D\n            assert tensor.ndim in [3, 4]\n            if tensor.ndim == 3 and self.add_channel_axis_if_necessary:\n                # Add channel axis\n                return torch.from_numpy(tensor[None, ...])\n            else:\n                # Channel axis is in already\n                return torch.from_numpy(tensor)\n        elif self.dimensionality == 2:\n            # We're dealing with an image. tensor can either be 2D or 3D\n            assert tensor.ndim in [2, 3]\n            if tensor.ndim == 2 and self.add_channel_axis_if_necessary:\n                # Add channel axis\n                return torch.from_numpy(tensor[None, ...])\n            else:\n                # Channel axis is in already\n                return torch.from_numpy(tensor)\n        elif self.dimensionality == 1:\n            # We're dealing with a vector - it has to be 1D\n            assert tensor.ndim == 1\n            return torch.from_numpy(tensor)\n        else:\n            raise NotImplementedError\n\n    def tensor_function(self, tensor):\n        assert_(isinstance(tensor, (list, np.ndarray)),\n                \"Expected numpy array or list, got %s\" % type(tensor),\n                DTypeError)\n        if isinstance(tensor, np.ndarray):\n            return self._to_batch(tensor)\n        else:\n            return [self._to_batch(elem) for elem in tensor]\n"
  },
  {
    "path": "inferno/io/transform/image.py",
    "content": "import numpy as np\nfrom scipy.ndimage import zoom\nfrom scipy.ndimage.filters import gaussian_filter\nfrom scipy.ndimage.interpolation import map_coordinates, rotate\nfrom scipy.ndimage.morphology import binary_dilation, binary_erosion\nfrom skimage.exposure import adjust_gamma\nfrom warnings import catch_warnings, simplefilter\n\nfrom .base import Transform\nfrom ...utils.exceptions import assert_, ShapeError\n\n\nclass PILImage2NumPyArray(Transform):\n    \"\"\"Convert a PIL Image object to a numpy array.\n\n    For images with multiple channels (say RGB), the channel axis is moved to front. Therefore,\n    a (100, 100, 3) RGB image becomes an array of shape (3, 100, 100).\n    \"\"\"\n    def tensor_function(self, tensor):\n        tensor = np.asarray(tensor)\n        if tensor.ndim == 3:\n            # There's a channel axis - we move it to front\n            tensor = np.moveaxis(tensor, source=-1, destination=0)\n        elif tensor.ndim == 2:\n            pass\n        else:\n            raise NotImplementedError(\"Expected tensor to be a 2D or 3D \"\n                                      \"numpy array, got a {}D array instead.\"\n                                      .format(tensor.ndim))\n        return tensor\n\n\nclass Scale(Transform):\n    \"\"\"Scales an image to a given size with spline interpolation of requested order.\n\n    Unlike torchvision.transforms.Scale, this does not depend on PIL and therefore works\n    with numpy arrays. If you do have a PIL image and wish to use this transform, consider\n    applying `PILImage2NumPyArray` first.\n\n    Warnings\n    --------\n    This transform uses `scipy.ndimage.zoom` and requires scipy >= 0.13.0 to work correctly.\n    \"\"\"\n    def __init__(self, output_image_shape, interpolation_order=3, zoom_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        output_image_shape : list or tuple or int or None\n            Target size of the output image. Aspect ratio may not be preserved.\n            If output_image_shape is None, image input size will be preserved\n        interpolation_order : int\n            Interpolation order for the spline interpolation.\n        zoom_kwargs : dict\n            Keyword arguments for `scipy.ndimage.zoom`.\n        super_kwargs : dict\n            Keyword arguments for the superclass.\n        \"\"\"\n        super(Scale, self).__init__(**super_kwargs)\n        if output_image_shape is not None:\n            output_image_shape = (output_image_shape, output_image_shape) \\\n                if isinstance(output_image_shape, int) else tuple(output_image_shape)\n            assert_(len(output_image_shape) == 2,\n                    \"`output_image_shape` must be an integer or a tuple of length 2.\",\n                    ValueError)\n        self.output_image_shape = output_image_shape\n        self.interpolation_order = interpolation_order\n        self.zoom_kwargs = {} if zoom_kwargs is None else dict(zoom_kwargs)\n\n    def image_function(self, image):\n        source_height, source_width = image.shape\n        target_height, target_width = self.output_image_shape\n        # We're on Python 3 - take a deep breath and relax.\n        zoom_height, zoom_width = (target_height / source_height), (target_width / source_width)\n        with catch_warnings():\n            # Ignore warning that scipy should be > 0.13 (it's 0.19 these days)\n            simplefilter('ignore')\n            rescaled_image = zoom(image, (zoom_height, zoom_width),\n                                  order=self.interpolation_order, **self.zoom_kwargs)\n        # This should never happen\n        assert_(rescaled_image.shape == (target_height, target_width),\n                \"Shape mismatch that shouldn't have happened if you were on scipy > 0.13.0. \"\n                \"Are you on scipy > 0.13.0?\",\n                ShapeError)\n        return rescaled_image\n\n\nclass RandomCrop(Transform):\n    \"\"\"Crop input to a given size.\n\n    This is similar to torchvision.transforms.RandomCrop, except that it operates on\n    numpy arrays instead of PIL images. If you do have a PIL image and wish to use this\n    transform, consider applying `PILImage2NumPyArray` first.\n\n    Warnings\n    --------\n    If `output_image_shape` is larger than the image itself, the image is not cropped\n    (along the relevant dimensions).\n    \"\"\"\n    def __init__(self, output_image_shape, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        output_image_shape : tuple or list or int\n            Expected shape of the output image. Could be an integer, (say) 100, in\n            which case it's interpreted as `(100, 100)`. Note that if the image shape\n            along some (or all) dimension is smaller, say `(50, 200)`, the resulting\n            output images will have the shape `(50, 100)`.\n        super_kwargs : dict\n            Keywords to the super class.\n        \"\"\"\n        super(RandomCrop, self).__init__(**super_kwargs)\n        # Privates\n        self._image_shape_cache = None\n        # Publics\n        output_image_shape = (output_image_shape, output_image_shape) \\\n            if isinstance(output_image_shape, int) else tuple(output_image_shape)\n        assert_(len(output_image_shape) == 2,\n                \"`output_image_shape` must be an integer or a tuple of length 2.\",\n                ValueError)\n        self.output_image_shape = output_image_shape\n\n    def clear_random_variables(self):\n        self._image_shape_cache = None\n        super(RandomCrop, self).clear_random_variables()\n\n    def build_random_variables(self, height_leeway, width_leeway):\n        if height_leeway > 0:\n            self.set_random_variable('height_location',\n                                     np.random.randint(low=0, high=height_leeway + 1))\n        if width_leeway > 0:\n            self.set_random_variable('width_location',\n                                     np.random.randint(low=0, high=width_leeway + 1))\n\n    def image_function(self, image):\n        # Validate image shape\n        if self._image_shape_cache is not None:\n            assert_(self._image_shape_cache == image.shape,\n                    \"RandomCrop works on multiple images simultaneously only \"\n                    \"if they have the same shape. Was expecting an image of \"\n                    \"shape {}, got one of shape {} instead.\"\n                    .format(self._image_shape_cache, image.shape),\n                    ShapeError)\n        else:\n            self._image_shape_cache = image.shape\n        source_height, source_width = image.shape\n        crop_height, crop_width = self.output_image_shape\n        height_leeway = source_height - crop_height\n        width_leeway = source_width - crop_width\n        if height_leeway > 0:\n            # Crop height\n            height_location = self.get_random_variable('height_location',\n                                                       height_leeway=height_leeway,\n                                                       width_leeway=width_leeway)\n            cropped = image[height_location:(height_location + crop_height), :]\n            assert cropped.shape[0] == self.output_image_shape[0], \"Well, shit.\"\n        else:\n            cropped = image\n        if width_leeway > 0:\n            # Crop width\n            width_location = self.get_random_variable('width_location',\n                                                      height_leeway=height_leeway,\n                                                      width_leeway=width_leeway)\n            cropped = cropped[:, width_location:(width_location + crop_width)]\n            assert cropped.shape[1] == self.output_image_shape[1], \"Well, shit.\"\n        return cropped\n\n\nclass RandomSizedCrop(Transform):\n    \"\"\"Extract a randomly sized crop from the image.\n\n    The ratio of the sizes of the cropped and the original image can be limited within\n    specified bounds along both axes. To resize back to a constant sized image, compose\n    with `Scale`.\n    \"\"\"\n    def __init__(self, ratio_between=None, height_ratio_between=None, width_ratio_between=None,\n                 preserve_aspect_ratio=False, relative_target_aspect_ratio=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        ratio_between : tuple\n            Specify the bounds between which to sample the crop ratio. This applies to\n            both height and width if not overriden. Can be None if both height and width\n            ratios are specified individually.\n        height_ratio_between : tuple\n            Specify the bounds between which to sample the vertical crop ratio.\n            Can be None if `ratio_between` is not None.\n        width_ratio_between : tuple\n            Specify the bounds between which to sample the horizontal crop ratio.\n            Can be None if `ratio_between` is not None.\n        preserve_aspect_ratio : bool\n            Whether to preserve aspect ratio. If both `height_ratio_between`\n            and `width_ratio_between` are specified, the former is used if this\n            is set to True.\n        relative_target_aspect_ratio : float\n            Specify the target aspect ratio (W x H) relative to the input image\n            (i.e. by mapping the input image ratio to 1:1). For instance, if an image\n            has the size 1024 (H) x 2048 (W), a relative target aspect ratio of 0.5\n            might yield images of size 1024 x 1024. Note that this only applies if\n            `preserve_aspect_ratio` is set to False.\n        super_kwargs : dict\n            Keyword arguments for the super class.\n        \"\"\"\n        super(RandomSizedCrop, self).__init__(**super_kwargs)\n        # Privates\n        self._image_shape_cache = None\n        # Publics\n        height_ratio_between = tuple(height_ratio_between) \\\n            if height_ratio_between is not None else tuple(ratio_between)\n        width_ratio_between = tuple(width_ratio_between) \\\n            if width_ratio_between is not None else tuple(ratio_between)\n        assert_(height_ratio_between is not None,\n                \"`height_ratio_between` is not specified.\",\n                ValueError)\n        assert_(width_ratio_between is not None,\n                \"`width_ratio_between` is not specified.\",\n                ValueError)\n        self.height_ratio_between = height_ratio_between\n        self.width_ratio_between = width_ratio_between\n        self.preserve_aspect_ratio = preserve_aspect_ratio\n        self.relative_target_aspect_ratio = relative_target_aspect_ratio\n\n    def build_random_variables(self, image_shape):\n        # Seed RNG\n        np.random.seed()\n        # Compute random variables\n        source_height, source_width = image_shape\n        height_ratio = np.random.uniform(low=self.height_ratio_between[0],\n                                         high=self.height_ratio_between[1])\n        if self.preserve_aspect_ratio:\n            width_ratio = height_ratio\n        elif self.relative_target_aspect_ratio is not None:\n            width_ratio = height_ratio * self.relative_target_aspect_ratio\n        else:\n            width_ratio = np.random.uniform(low=self.width_ratio_between[0],\n                                            high=self.width_ratio_between[1])\n        crop_height = int(np.round(height_ratio * source_height))\n        crop_width = int(np.round(width_ratio * source_width))\n        height_leeway = source_height - crop_height\n        width_leeway = source_width - crop_width\n        # Set random variables\n        if height_leeway > 0:\n            self.set_random_variable('height_location',\n                                     np.random.randint(low=0, high=height_leeway + 1))\n        if width_leeway > 0:\n            self.set_random_variable('width_location',\n                                     np.random.randint(low=0, high=width_leeway + 1))\n        self.set_random_variable('crop_height', crop_height)\n        self.set_random_variable('crop_width', crop_width)\n        self.set_random_variable('height_leeway', height_leeway)\n        self.set_random_variable('width_leeway', width_leeway)\n\n    def image_function(self, image):\n        # Validate image shape\n        if self._image_shape_cache is not None:\n            assert_(self._image_shape_cache == image.shape,\n                    \"RandomCrop works on multiple images simultaneously only \"\n                    \"if they have the same shape. Was expecting an image of \"\n                    \"shape {}, got one of shape {} instead.\"\n                    .format(self._image_shape_cache, image.shape),\n                    ShapeError)\n        else:\n            self._image_shape_cache = image.shape\n        height_leeway = self.get_random_variable('height_leeway', image_shape=image.shape)\n        width_leeway = self.get_random_variable('width_leeway', image_shape=image.shape)\n        if height_leeway > 0:\n            height_location = self.get_random_variable('height_location',\n                                                       image_shape=image.shape)\n            crop_height = self.get_random_variable('crop_height',\n                                                   image_shape=image.shape)\n            cropped = image[height_location:(height_location + crop_height), :]\n        else:\n            cropped = image\n        if width_leeway > 0:\n            width_location = self.get_random_variable('width_location',\n                                                      image_shape=image.shape)\n            crop_width = self.get_random_variable('crop_width',\n                                                  image_shape=image.shape)\n            cropped = cropped[:, width_location:(width_location + crop_width)]\n        return cropped\n\n\nclass RandomGammaCorrection(Transform):\n    \"\"\"Applies gamma correction [1] with a random gamma.\n\n    This transform uses `skimage.exposure.adjust_gamma`, which requires the input be positive.\n\n    References\n    ----------\n    [1] https://en.wikipedia.org/wiki/Gamma_correction\n    \"\"\"\n    def __init__(self, gamma_between=(0.5, 2.), gain=1, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        gamma_between : tuple or list\n            Specifies the range within which to sample gamma (uniformly).\n        gain : int or float\n            The resulting gamma corrected image is multiplied by this `gain`.\n        super_kwargs : dict\n            Keyword arguments for the superclass.\n        \"\"\"\n        super(RandomGammaCorrection, self).__init__(**super_kwargs)\n        self.gamma_between = list(gamma_between)\n        self.gain = gain\n\n    def build_random_variables(self):\n        np.random.seed()\n        self.set_random_variable('gamma',\n                                 np.random.uniform(low=self.gamma_between[0],\n                                                   high=self.gamma_between[1]))\n\n    def image_function(self, image):\n        gamma_adjusted = adjust_gamma(image,\n                                      gamma=self.get_random_variable('gamma'),\n                                      gain=self.gain)\n        return gamma_adjusted\n\n\nclass ElasticTransform(Transform):\n    \"\"\"Random Elastic Transformation.\"\"\"\n    NATIVE_DTYPES = {'float32', 'float64'}\n    PREFERRED_DTYPE = 'float32'\n\n    def __init__(self, alpha, sigma, order=1, invert=False, **super_kwargs):\n        self._initial_dtype = None\n        super(ElasticTransform, self).__init__(**super_kwargs)\n        self.alpha = alpha\n        self.sigma = sigma\n        self.order = order\n        self.invert = invert\n\n    def build_random_variables(self, **kwargs):\n        # All this is done just once per batch (i.e. until `clear_random_variables` is called)\n        np.random.seed()\n        imshape = kwargs.get('imshape')\n        # Build and scale random fields\n        random_field_x = np.random.uniform(-1, 1, imshape) * self.alpha\n        random_field_y = np.random.uniform(-1, 1, imshape) * self.alpha\n        # Smooth random field (this has to be done just once per reset)\n        sdx = gaussian_filter(random_field_x, self.sigma, mode='reflect')\n        sdy = gaussian_filter(random_field_y, self.sigma, mode='reflect')\n        # Make meshgrid\n        x, y = np.meshgrid(np.arange(imshape[1]), np.arange(imshape[0]))\n        # Make inversion coefficient\n        _inverter = 1. if not self.invert else -1.\n        # Distort meshgrid indices (invert if required)\n        flow_y, flow_x = (y + _inverter * sdy).reshape(-1, 1), (x + _inverter * sdx).reshape(-1, 1)\n        # Set random states\n        self.set_random_variable('flow_x', flow_x)\n        self.set_random_variable('flow_y', flow_y)\n\n    def cast(self, image):\n        if image.dtype not in self.NATIVE_DTYPES:\n            self._initial_dtype = image.dtype\n            image = image.astype(self.PREFERRED_DTYPE)\n        return image\n\n    def uncast(self, image):\n        if self._initial_dtype is not None:\n            image = image.astype(self._initial_dtype)\n        self._initial_dtype = None\n        return image\n\n    def image_function(self, image):\n        # Cast image to one of the native dtypes (one which that is supported by scipy)\n        image = self.cast(image)\n        # Take measurements\n        imshape = image.shape\n        # Obtain flows\n        flows = self.get_random_variable('flow_y', imshape=imshape), \\\n                self.get_random_variable('flow_x', imshape=imshape)\n        # Map cooordinates from image to distorted index set\n        transformed_image = map_coordinates(image, flows,\n                                            mode='reflect', order=self.order).reshape(imshape)\n        # Uncast image to the original dtype\n        transformed_image = self.uncast(transformed_image)\n        return transformed_image\n\n\nclass AdditiveGaussianNoise(Transform):\n    \"\"\"Add gaussian noise to the input.\"\"\"\n    def __init__(self, sigma, **super_kwargs):\n        super(AdditiveGaussianNoise, self).__init__(**super_kwargs)\n        self.sigma = sigma\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('noise', np.random.normal(loc=0, scale=self.sigma,\n                                                           size=kwargs.get('imshape')))\n\n    def image_function(self, image):\n        image = image + self.get_random_variable('noise', imshape=image.shape)\n        return image\n\n\nclass RandomRotate(Transform):\n    \"\"\"Random 90-degree rotations.\"\"\"\n    def __init__(self, **super_kwargs):\n        super(RandomRotate, self).__init__(**super_kwargs)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('k', np.random.randint(0, 4))\n\n    def image_function(self, image):\n        return np.rot90(image, k=self.get_random_variable('k'))\n\n\nclass RandomTranspose(Transform):\n    \"\"\"Random 2d transpose.\"\"\"\n    def __init__(self, **super_kwargs):\n        super(RandomTranspose, self).__init__(**super_kwargs)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('do_transpose', np.random.uniform() > 0.5)\n\n    def image_function(self, image):\n        if self.get_random_variable('do_transpose'):\n            image = np.transpose(image)\n        return image\n\n\nclass RandomFlip(Transform):\n    \"\"\"Random left-right or up-down flips.\"\"\"\n    def __init__(self, allow_lr_flips=True, allow_ud_flips=True, **super_kwargs):\n        super(RandomFlip, self).__init__(**super_kwargs)\n        self.allow_lr_flips = allow_lr_flips\n        self.allow_ud_flips = allow_ud_flips\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('flip_lr', np.random.uniform() > 0.5)\n        self.set_random_variable('flip_ud', np.random.uniform() > 0.5)\n\n    def image_function(self, image):\n        if self.allow_lr_flips and self.get_random_variable('flip_lr'):\n            image = np.fliplr(image)\n        if self.allow_ud_flips and self.get_random_variable('flip_ud'):\n            image = np.flipud(image)\n        return image\n\n\nclass CenterCrop(Transform):\n    \"\"\" Crop patch of size `size` from the center of the image \"\"\"\n    def __init__(self, size, **super_kwargs):\n        super(CenterCrop, self).__init__(**super_kwargs)\n        assert isinstance(size, (int, tuple))\n        self.size = (size, size) if isinstance(size, int) else size\n\n    def image_function(self, image):\n        h, w = image.shape\n        th, tw = self.size\n        if h > th:\n            y1 = int(round((h - th) / 2.))\n            image = image[y1:y1 + th, :]\n        if w > tw:\n            x1 = int(round((w - tw) / 2.))\n            image = image[:, x1:x1 + tw]\n        return image\n\n\nclass BinaryMorphology(Transform):\n    \"\"\"\n    Apply a binary morphology operation on an image. Supported operations are dilation\n    and erosion.\n    \"\"\"\n    def __init__(self, mode, num_iterations=1, morphology_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        mode : {'dilate', 'erode'}\n            Whether to dilate or erode.\n        num_iterations : int\n            Number of iterations to apply the operation for.\n        morphology_kwargs: dict\n            Keyword arguments to the morphology function\n            (i.e. `scipy.ndimage.morphology.binary_erosion` or\n            `scipy.ndimage.morphology.binary_erosion`)\n        super_kwargs : dict\n            Keyword arguments to the superclass.\n        \"\"\"\n        super(BinaryMorphology, self).__init__(**super_kwargs)\n        # Validate and assign mode\n        assert_(mode in ['dilate', 'erode'],\n                \"Mode must be one of ['dilate', 'erode']. Got {} instead.\".format(mode),\n                ValueError)\n        self.mode = mode\n        self.num_iterations = num_iterations\n        self.morphology_kwargs = {} if morphology_kwargs is None else dict(morphology_kwargs)\n\n    def image_function(self, image):\n        if self.mode == 'dilate':\n            transformed_image = binary_dilation(image, iterations=self.num_iterations,\n                                                **self.morphology_kwargs)\n        elif self.mode == 'erode':\n            transformed_image = binary_erosion(image, iterations=self.num_iterations,\n                                               **self.morphology_kwargs)\n        else:\n            raise ValueError\n        # Cast transformed image to the right dtype and return\n        return transformed_image.astype(image.dtype)\n\n\nclass BinaryDilation(BinaryMorphology):\n    \"\"\"Apply a binary dilation operation on an image.\"\"\"\n    def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs):\n        super(BinaryDilation, self).__init__(mode='dilate', num_iterations=num_iterations,\n                                             morphology_kwargs=morphology_kwargs,\n                                             **super_kwargs)\n\n\nclass BinaryErosion(BinaryMorphology):\n    \"\"\"Apply a binary erosion operation on an image.\"\"\"\n    def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs):\n        super(BinaryErosion, self).__init__(mode='erode', num_iterations=num_iterations,\n                                            morphology_kwargs=morphology_kwargs,\n                                            **super_kwargs)\n\n\nclass FineRandomRotations(Transform):\n    \"\"\" Random Rotation with random uniform angle distribution\n        batch_function applies to rotation of input and label image\n\n        Parameters\n        ----------\n        angle_range : int\n                      maximum angle of rotation\n        axes        : tuple, default (1,2) assuming that channel axis is 0\n                      pair of axis that define the 2d-plane of rotation\n        mask_label  : constant value that is used to pad the label images\n    \"\"\"\n    def __init__(self, angle_range, axes=(1,2), mask_label=0, **super_kwargs):\n        super(FineRandomRotations, self).__init__(**super_kwargs)\n        self.angle_range = angle_range\n        self.axes = axes\n        self.ml = mask_label\n\n    def build_random_variables(self):\n        np.random.seed()\n        self.set_random_variable('angle',\n                 np.random.uniform(low=-self.angle_range,\n                                   high=self.angle_range))\n\n    def batch_function(self, image):\n        angle = self.get_random_variable('angle')\n        return rotate(image[0], angle, axes=self.axes, reshape=False), \\\n               rotate(image[1], angle, axes=self.axes, order=0, cval=self.ml, reshape=False)\n\n\nclass RandomScaleSegmentation(Transform):\n    \"\"\" Random Scale input and label image\n\n        Parameters\n        ----------\n        scale_range : tuple of floats defining (min, max) scales\n                      maximum angle of rotation\n        resize  : if True, image is cropped or padded to the original size\n        pad_const: value used for constant padding\n    \"\"\"\n    def __init__(self, scale_range, resize=True, pad_const=0, **super_kwargs):\n        super(RandomScaleSegmentation, self).__init__(**super_kwargs)\n        self.scale_range = scale_range\n        self.resize = resize\n        self.pad_const = pad_const\n\n    def build_random_variables(self):\n        np.random.seed()\n        self.set_random_variable('seg_scale',\n                 np.random.uniform(low=self.scale_range[0],\n                                   high=self.scale_range[1]))\n\n    def batch_function(self, image):\n        scale = self.get_random_variable('seg_scale')\n        input_image, segmentation = image\n        image_shape = np.array(input_image.shape[1:])\n        if input_image.ndim == segmentation.ndim + 1:\n            segmentation = segmentation[None]\n        with catch_warnings():\n            simplefilter('ignore')\n            img = np.stack([zoom(x, scale, order=3) for x in input_image])\n            seg = np.stack([zoom(x, scale, order=0) for x in segmentation])\n        new_shape = np.array(img.shape[1:])\n        if self.resize:\n            if scale > 1.:\n                # pad image to original size\n                crop_l = (new_shape - image_shape) // 2\n                crop_r = new_shape - image_shape - crop_l\n                cropping = [slice(None)] + [slice(c[0] if c[0] > 0 else None,\n                                                 -c[1] if c[1] > 0 else None) for c in zip(crop_l, crop_r)]\n                img = img[cropping]\n                seg = seg[cropping]\n            else:\n                # crop image to original size\n                pad_l = (image_shape - new_shape) // 2\n                pad_r = image_shape - new_shape - pad_l\n                padding = [(0,0)] + list(zip(pad_l, pad_r))\n                img = np.pad(img, padding, 'constant', constant_values=self.pad_const)\n                seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)\n        return img, seg\n"
  },
  {
    "path": "inferno/io/transform/volume.py",
    "content": "import numpy as np\nimport scipy\nfrom scipy.ndimage import zoom\nfrom scipy.ndimage.morphology import binary_dilation, binary_erosion\nfrom .base import Transform\nfrom ...utils.exceptions import assert_\n\n\nclass RandomFlip3D(Transform):\n    def __init__(self, **super_kwargs):\n        super(RandomFlip3D, self).__init__(**super_kwargs)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('flip_lr', np.random.uniform() > 0.5)\n        self.set_random_variable('flip_ud', np.random.uniform() > 0.5)\n        self.set_random_variable('flip_z', np.random.uniform() > 0.5)\n\n    def volume_function(self, volume):\n        if self.get_random_variable('flip_lr'):\n            volume = volume[:, :, ::-1].copy()\n        if self.get_random_variable('flip_ud'):\n            volume = volume[:, ::-1, :].copy()\n        if self.get_random_variable('flip_z'):\n            volume = volume[::-1, :, :].copy()\n        return volume\n\n\nclass RandomRot3D(Transform):\n    def __init__(self, rot_range, p=0.125, reshape=False, order=0, mode='nearest', **super_kwargs):\n        super(RandomRot3D, self).__init__(**super_kwargs)\n        self.rot_range = rot_range\n        self.p = p\n        self.reshape = reshape\n        self.order = order\n        self.mode = mode\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n\n        self.set_random_variable('do_z', np.random.uniform() < self.p)\n        self.set_random_variable('do_y', np.random.uniform() < self.p)\n        self.set_random_variable('do_x', np.random.uniform() < self.p)\n\n        self.set_random_variable('angle_z', np.random.uniform(-self.rot_range, self.rot_range))\n        self.set_random_variable('angle_y', np.random.uniform(-self.rot_range, self.rot_range))\n        self.set_random_variable('angle_x', np.random.uniform(-self.rot_range, self.rot_range))\n\n    def volume_function(self, volume):\n        angle_z = self.get_random_variable('angle_z')\n        angle_y = self.get_random_variable('angle_y')\n        angle_x = self.get_random_variable('angle_x')\n\n        # rotate along z-axis\n        if self.get_random_variable('do_z'):\n            volume = scipy.ndimage.interpolation.rotate(volume, angle_z,\n                                                        order=self.order, mode=self.mode,\n                                                        axes=(0, 1), reshape=self.reshape)\n        # rotate along y-axis\n        if self.get_random_variable('do_y'):\n            volume = scipy.ndimage.interpolation.rotate(volume, angle_y,\n                                                        order=self.order, mode=self.mode,\n                                                        axes=(0, 2), reshape=self.reshape)\n        # rotate along x-axis\n        if self.get_random_variable('do_y'):\n            volume = scipy.ndimage.interpolation.rotate(volume, angle_x,\n                                                        order=self.order, mode=self.mode,\n                                                        axes=(1, 2), reshape=self.reshape)\n        return volume\n\n\n# TODO this is obsolete\nclass AdditiveRandomNoise3D(Transform):\n    \"\"\" Add gaussian noise to 3d volume\n\n    Need to know input shape before application, but can be\n    synchronized between different inputs (cf. `AdditiveNoise`)\n    Arguments:\n        shape: shape of input volumes\n        std: standard deviation of gaussian\n        super_kwargs: keyword arguments for `Transform` base class\n    \"\"\"\n    def __init__(self, shape, std, **super_kwargs):\n        super(AdditiveRandomNoise3D, self).__init__(**super_kwargs)\n        self.shape = shape\n        self.std = float(std)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('noise_vol',\n                                 np.random.normal(loc=0.0, scale=self.std, size=self.shape))\n\n    def volume_function(self, volume):\n        noise_vol = self.get_random_variable('noise_vol')\n        return volume + noise_vol\n\n\n# TODO different options than gaussian\nclass AdditiveNoise(Transform):\n    \"\"\" Add noise to 3d volume\n\n    Do NOT need to know input shape before application, but CANNOT be\n    synchronized between different inputs (cf. `AdditiveRandomNoise`)\n    Arguments:\n        sigma: sigma for noise\n        mode: mode of distribution (only gaussian supported for now)\n        super_kwargs: keyword arguments for `Transform` base class\n    \"\"\"\n    def __init__(self, sigma, mode='gaussian', **super_kwargs):\n        assert mode == 'gaussian'\n        super().__init__(**super_kwargs)\n        self.sigma = sigma\n\n    # TODO check if volume is tensor and use torch functions in that case\n    def tensor_function(self, volume):\n        volume += np.random.normal(loc=0, scale=self.sigma, size=volume.shape)\n        return volume\n\n\nclass CentralSlice(Transform):\n    def volume_function(self, volume):\n        half_z = volume.shape[0] // 2\n        return volume[half_z:half_z + 1, ...]\n\n\nclass VolumeCenterCrop(Transform):\n    \"\"\" Crop patch of size `size` from the center of the volume \"\"\"\n    def __init__(self, size, **super_kwargs):\n        super().__init__(**super_kwargs)\n        assert isinstance(size, (int, tuple))\n        self.size = (size, size, size) if isinstance(size, int) else size\n        assert len(size) == 3\n\n    def volume_function(self, volume):\n        h, w, d = volume.shape\n        th, tw, td = self.size\n        x1 = int(round((w - tw) / 2.))\n        y1 = int(round((h - th) / 2.))\n        z1 = int(round((d - td) / 2.))\n        return volume[x1:x1+tw, y1:y1+th, z1:z1+td]\n\n\nclass VolumeAsymmetricCrop(Transform):\n    \"\"\" Crop `crop_left` from the left borders and `crop_right` from the right borders \"\"\"\n    def __init__(self, crop_left, crop_right, **super_kwargs):\n        super(VolumeAsymmetricCrop, self).__init__(**super_kwargs)\n        assert isinstance(crop_left, (list, tuple))\n        assert isinstance(crop_right, (list, tuple))\n        assert len(crop_left) == 3\n        assert len(crop_right) == 3\n        self.crop_left = crop_left\n        self.crop_right = crop_right\n\n    def volume_function(self, volume):\n        x1, y1, z1 = self.crop_left\n        x2, y2, z2 = (np.array(volume.shape) - np.array(self.crop_right)).astype('uint32')\n        return volume[x1:x2, y1:y2, z1:z2]\n\n\nclass Slices2Channels(Transform):\n    \"\"\" Needed for training 2D network with slices above/below as additional channels\n        For the input data transforms one dimension (x, y or z) into channels\n        For the target data just takes the central slice and discards all the rest\"\"\"\n    def __init__(self, num_channels, downsampling=1, **super_kwargs):\n        super(Slices2Channels, self).__init__(**super_kwargs)\n        self.channels = num_channels\n        self.downsampling = downsampling\n\n    def batch_function(self, batch):\n        try:\n            axis = batch[0].shape.index(self.channels)\n        except ValueError:\n            print(\"The axis has the shape of the desired channels number!\")\n        half = int(self.channels/2)\n        new_input = np.moveaxis(batch[0], axis, 0)\n        # take every nth slice to the both directions of the central slice\n        indices = []\n        for i in range(self.channels):\n            if i % self.downsampling == half % self.downsampling:\n                indices.append(i)\n        new_input = new_input[indices]   # num_chan after - int (num_chan/(2*downsample)) * 2 + 1\n        new_target = np.moveaxis(batch[1], axis, 0)\n        new_target = new_target[half]\n        return (new_input, new_target)\n\n\nclass RandomScale3D(Transform):\n    \"\"\"Scales a volume with a random zoom factor with spline interpolation of requested order\"\"\"\n    def __init__(self, zoom_factor_range, interpolation_order=0, p=0.5,\n                 same_zoom=True, zoom_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        zoom_factor_range : list or tuple\n            The allowed range to sample zoom factors along the axes.\n        interpolation_order : int\n            Interpolation order for the spline interpolation.\n        p : float\n            Probability that the axis gets zoomed\n        same_zoom: bool\n            Apply the same zoom factor to all the axes\n        zoom_kwargs : dict\n            Keyword arguments for `scipy.ndimage.zoom`.\n        super_kwargs : dict\n            Keyword arguments for the superclass.\n        \"\"\"\n        super(RandomScale3D, self).__init__(**super_kwargs)\n        assert_(len(zoom_factor_range) == 2,\n                \"`zoom_factor_range` must be a list or a tuple of length 2.\",\n                ValueError)\n        self.min = zoom_factor_range[0]\n        self.max = zoom_factor_range[1]\n        self.interpolation_order = interpolation_order\n        self.p = p\n        self.same_zoom = same_zoom\n        self.zoom_kwargs = {} if zoom_kwargs is None else dict(zoom_kwargs)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('do_z', np.random.uniform() < self.p)\n        self.set_random_variable('do_y', np.random.uniform() < self.p)\n        self.set_random_variable('do_x', np.random.uniform() < self.p)\n        self.set_random_variable('zoom_z', np.random.uniform(self.min, self.max))\n        self.set_random_variable('zoom_y', np.random.uniform(self.min, self.max))\n        self.set_random_variable('zoom_x', np.random.uniform(self.min, self.max))\n\n    def volume_function(self, volume):\n        zoom_z = self.get_random_variable('zoom_z') \\\n            if self.get_random_variable('do_z') else 1\n        zoom_y = self.get_random_variable('zoom_y') \\\n            if self.get_random_variable('do_y') else 1\n        zoom_x = self.get_random_variable('zoom_x') \\\n            if self.get_random_variable('do_x') else 1\n\n        if self.same_zoom:\n            zoom_y, zoom_x = zoom_z, zoom_z\n\n        zoomed_volume = zoom(volume, (zoom_z, zoom_y, zoom_x),\n                             order=self.interpolation_order, **self.zoom_kwargs)\n        return zoomed_volume\n\n\nclass RandomBinaryMorphology3D(Transform):\n    \"\"\"\n    Apply a random binary morphology operation  (dilation or erosion).\n    Allowed range of iteration number can be set.\n    \"\"\"\n    def __init__(self, p=0.5, num_iter_range=(1, 5), morphology_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        p : float\n            Probability that any operation is applied\n        num_iter_range : list or tuple\n            The allowed range of iteration number to apply the operation for.\n        morphology_kwargs: dict\n            Keyword arguments to the morphology function\n            (i.e. `scipy.ndimage.morphology.binary_erosion` or\n            `scipy.ndimage.morphology.binary_erosion`)\n        super_kwargs : dict\n            Keyword arguments to the superclass.\n        \"\"\"\n        super(RandomBinaryMorphology3D, self).__init__(**super_kwargs)\n        assert_(len(num_iter_range) == 2,\n                \"`num_iter_range` must be a list or a tuple of length 2.\",\n                ValueError)\n        self.p = p\n        self.min_iter = num_iter_range[0]\n        self.max_iter = num_iter_range[1] + 1\n        self.morphology_kwargs = {} if morphology_kwargs is None else dict(morphology_kwargs)\n\n    def build_random_variables(self, **kwargs):\n        np.random.seed()\n        self.set_random_variable('do', np.random.uniform() < self.p)\n        self.set_random_variable('erode', np.random.uniform() < 0.5)\n        self.set_random_variable('iter_num', np.random.randint(self.min_iter, self.max_iter))\n\n    def volume_function(self, volume):\n        do = self.get_random_variable('do')\n        erode_mode = self.get_random_variable('erode')\n        iter_num = self.get_random_variable('iter_num')\n\n        if do:\n            if erode_mode:\n                transformed_volume = binary_erosion(volume, iterations=iter_num,\n                                                    **self.morphology_kwargs)\n            else:\n                transformed_volume = binary_dilation(volume, iterations=iter_num,\n                                                     **self.morphology_kwargs)\n            volume = transformed_volume.astype(volume.dtype)\n\n        return volume\n\n\nclass CropPad2Divisible(Transform):\n    \"\"\"\n    Given the number, symmetrically crops/pads the volume\n    for all dimensions to be divisible by this number.\n    Used e.g. to feed input with any shape to models with pooling layers.\n    The threshold of cropping vs padding can be specified.\n    \"\"\"\n    def __init__(self, divisor=16, crop_pad_threshold=0.2,\n                 mode='constant', padding_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        divisor : int\n            A number that all dimensions should be divisible by\n        crop_pad_threshold : float\n            When \"division remainder to divisor\" ratio is lower then this number,\n            input volume will be cropped, otherwise - padded.\n            Set to 0 to only pad and 1 to only crop.\n        mode: ‘constant’, ‘edge’, ‘symmetric’, etc\n            See all the possible modes in numpy.pad doc\n        padding_kwargs: dict\n            Keyword arguments to numpy.pad\n        super_kwargs : dict\n            Keyword arguments to the superclass.\n        \"\"\"\n        super(CropPad2Divisible, self).__init__(**super_kwargs)\n        assert_(0 <= crop_pad_threshold <= 1,\n                \"threshold must be between 0 and 1 inclusive\",\n                ValueError)\n        assert_(divisor % 2 == 0, \"divisor must be an even number\", ValueError)\n        self.divisor = divisor\n        self.crop_pad_threshold = crop_pad_threshold\n        self.mode = mode\n        self.padding_kwargs = {} if padding_kwargs is None else dict(padding_kwargs)\n\n    def volume_function(self, volume):\n        half_div = int(self.divisor/2)\n        remainders = [axis % self.divisor for axis in volume.shape]\n        to_pad = [remainder/self.divisor >= self.crop_pad_threshold\n                  for remainder in remainders]\n        diffs = [(int(np.floor(remainder/2)), int(np.ceil(remainder/2)))\n                 for remainder in remainders]\n        padding = [(half_div - diff[0], half_div - diff[1])\n                   if pad else (0, 0)\n                   for diff, pad in zip(diffs, to_pad)]\n        cropping = [slice(diff[0], -diff[1])\n                    if not (pad or diff[1] == 0) else slice(None, None)\n                    for diff, pad in zip(diffs, to_pad)]\n        volume = np.pad(volume, pad_width=padding, mode=self.mode, **self.padding_kwargs)\n        volume = volume[cropping]\n\n        return volume\n\n\nclass CropPad2Size(Transform):\n    \"\"\"\n    Adjust the input volume to the given size:\n    Symmetrically crops if input > size, symmetrically pads if input < size.\n    \"\"\"\n    def __init__(self, output_size, mode='constant',\n                 padding_kwargs=None, **super_kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        output_size : int, tuple or list\n            The output size. If int, the same value is used for all axes\n        mode: `constant`, `edge`, `symmetric`, etc\n            See all the possible modes in numpy.pad doc\n        padding_kwargs: dict\n            Keyword arguments to numpy.pad\n        super_kwargs : dict\n            Keyword arguments to the superclass.\n        \"\"\"\n        super(CropPad2Size, self).__init__(**super_kwargs)\n        self.output_size = output_size if isinstance(output_size, (list, tuple)) \\\n                                       else (output_size, ) * 3\n        assert len(self.output_size) == 3, 'The size should be given for all the dimensions'\n        self.mode = mode\n        self.padding_kwargs = {} if padding_kwargs is None else dict(padding_kwargs)\n\n    def volume_function(self, volume):\n        difference = [inp - outp for inp, outp in zip(volume.shape, self.output_size)]\n        to_pad = [diff < 0 for diff in difference]\n        to_crop = [diff > 0 for diff in difference]\n        diffs = [(int(np.floor(diff/2)), int(np.ceil(diff/2)))\n                 for diff in np.abs(difference)]\n        padding = [(diff[0], diff[1]) if pad else (0, 0)\n                   for diff, pad in zip(diffs, to_pad)]\n        cropping = [slice(diff[0], -diff[1]) if crop else slice(None, None)\n                    for diff, crop in zip(diffs, to_crop)]\n        volume = np.pad(volume, pad_width=padding, mode=self.mode, **self.padding_kwargs)\n        volume = volume[cropping]\n\n        return volume\n"
  },
  {
    "path": "inferno/io/volumetric/__init__.py",
    "content": "from .volume import VolumeLoader, HDF5VolumeLoader, TIFVolumeLoader\rfrom .lazy_volume_loader import LazyHDF5VolumeLoader, LazyZarrVolumeLoader, LazyN5VolumeLoader\r"
  },
  {
    "path": "inferno/io/volumetric/lazy_volume_loader.py",
    "content": "import numpy as np\nimport os\nimport pickle\nfrom concurrent import futures\n\n# try to load io libraries (h5py and z5py)\ntry:\n    import h5py\n    WITH_H5PY = True\nexcept ImportError:\n    WITH_H5PY = False\n\ntry:\n    import z5py\n    WITH_Z5PY = True\nexcept ImportError:\n    WITH_Z5PY = False\n\nfrom ..core.base import SyncableDataset\nfrom ..core.base import IndexSpec\nfrom . import volumetric_utils as vu\nfrom ...utils import python_utils as pyu\n\n\n# TODO support h5py as well\ndef filter_base_sequence(input_path, input_key,\n                         window_size, stride,\n                         filter_function, n_threads):\n    with z5py.File(input_path, 'r') as f:\n        ds = f[input_key]\n        shape = list(ds.shape)\n        sequence = vu.slidingwindowslices(shape=shape,\n                                          window_size=window_size,\n                                          strides=stride,\n                                          shuffle=True,\n                                          add_overhanging=True)\n\n        def check_slice(slice_id, slice_):\n            print(\"Checking slice_id\", slice_id)\n            data = ds[slice_]\n            if filter_function(data):\n                return None\n            else:\n                return slice_\n\n        with futures.ThreadPoolExecutor(n_threads) as tp:\n            tasks = [tp.submit(check_slice, slice_id, slice_) for slice_id, slice_ in enumerate(sequence)]\n            filtered_sequence = [t.result() for t in tasks]\n\n        filtered_sequence = [seq for seq in filtered_sequence if seq is not None]\n        return filtered_sequence\n\n\nclass LazyVolumeLoaderBase(SyncableDataset):\n    def __init__(self, dataset, window_size, stride, downsampling_ratio=None, padding=None,\n                 padding_mode='reflect', transforms=None, return_index_spec=False, name=None,\n                 data_slice=None, base_sequence=None):\n        super(LazyVolumeLoaderBase, self).__init__()\n        assert len(window_size) == dataset.ndim, \"%i, %i\" % (len(window_size), dataset.ndim)\n        assert len(stride) == dataset.ndim\n        # Validate transforms\n        assert transforms is None or callable(transforms)\n\n        self.name = name\n        self.return_index_spec = return_index_spec\n        self.dataset = dataset\n        self.window_size = window_size\n        self.stride = stride\n        self.padding_mode = padding_mode\n        self.transforms = transforms\n        # slicing and padding\n        self.data_slice = self.normalize_slice(data_slice)\n        self.padding = padding\n        # DataloaderIter should do the shuffling\n        self.shuffle = False\n\n        # compute the shape\n        self.shape = self.get_shape()\n        self._data_shape = tuple(dsl.stop - dsl.start for dsl in self.data_slice)\\\n            if self.data_slice is not None else self.dataset.shape\n\n        if downsampling_ratio is None:\n            self.downsampling_ratio = [1] * self.dataset.ndim\n        elif isinstance(downsampling_ratio, int):\n            self.downsampling_ratio = [downsampling_ratio] * self.dataset.ndim\n        elif isinstance(downsampling_ratio, (list, tuple)):\n            assert len(downsampling_ratio) == self.dataset.ndim\n            self.downsampling_ratio = list(downsampling_ratio)\n        else:\n            raise NotImplementedError\n\n        if base_sequence is None:\n            self.base_sequence = self.make_sliding_windows()\n        else:\n            self.base_sequence = self.load_base_sequence(base_sequence)\n\n    @staticmethod\n    def load_base_sequence(base_sequence):\n        if isinstance(base_sequence, (list, tuple)):\n            return base_sequence\n        elif isinstance(base_sequence, str):\n            assert os.path.exists(base_sequence)\n            with open(base_sequence, 'rb') as f:\n                base_sequence = pickle.load(f)\n            return base_sequence\n        else:\n            raise ValueError(\"Unsupported base_sequence format, must be either listlike or str\")\n\n    def normalize_slice(self, data_slice):\n        if data_slice is None:\n            return None\n        slice_ = tuple(slice(0 if sl.start is None else sl.start,\n                             sh if sl.stop is None else sl.stop)\n                       for sl, sh in zip(data_slice, self.dataset.shape))\n        if len(slice_) < self.dataset.ndim:\n            slice_ = slice_ + tuple(slice(0, sh) for sh in self.dataset.shape[len(slice_):])\n        return slice_\n\n    # get the effective shape after slicing and / or padding\n    def get_shape(self):\n        if self.data_slice is None:\n            shape = self.dataset.shape\n        else:\n            # get the shape from the data slice (don't support ellipses)\n            shape = tuple(slice_.stop - slice_.start for slice_ in self.data_slice)\n        if self.padding is not None:\n            # TODO is this correct ???\n            shape = tuple(sh + sum(pad) for sh, pad in zip(shape, self.padding))\n        return shape\n\n    def make_sliding_windows(self):\n        return list(vu.slidingwindowslices(shape=list(self.shape),\n                                           window_size=self.window_size,\n                                           strides=self.stride,\n                                           shuffle=self.shuffle,\n                                           add_overhanging=True,\n                                           ds=self.downsampling_ratio))\n\n    def __getitem__(self, index):\n        # Casting to int would allow index to be IndexSpec objects.\n        index = int(index)\n        slices = self.base_sequence[index]\n\n        slices_ = tuple(slices)\n\n        # check if we have padding and if we need to pad\n        if self.padding is not None:\n\n            # get the start and stop positions in the dataset without padding\n            starts = [sl.start - pad[0] for sl, pad in zip(slices_, self.padding)]\n            stops = [sl.stop - pad[0] for sl, pad in zip(slices_, self.padding)]\n\n            # check if we need to pad to the left\n            pad_left = None\n            if any(start < 0 for start in starts):\n                pad_left = tuple(abs(start) if start < 0 else 0 for start in starts)\n                starts = [max(0, start) for start in starts]\n\n            # check if we need to pad to the right\n            pad_right = None\n            if any(stop > sh for stop, sh in zip(stops, self._data_shape)):\n                pad_right = tuple(stop - sh if stop > sh else 0\n                                  for stop, sh in zip(stops, self._data_shape))\n                stops = [min(sh, stop) for sh, stop in zip(self._data_shape, stops)]\n\n            # check if we need any paddingand if so calculate the padding width\n            need_padding = pad_left is not None or pad_right is not None\n            if need_padding:\n                # check the pad width (left and right) that we need for this batch\n                pad_left = (0,) * len(self.shape) if pad_left is None else pad_left\n                pad_right = (0,) * len(self.shape) if pad_right is None else pad_right\n                pad_width = tuple((pl, pr) for pl, pr in zip(pad_left, pad_right))\n\n            # update the slicing\n            slices_ = tuple(slice(start, stop) for start, stop in zip(starts, stops))\n        else:\n            need_padding = False\n\n        # if we have data-slices, we need to bring\n        # the slices back to the volume space\n        if self.data_slice is not None:\n            slices_ = tuple(slice(sl.start + dsl.start, sl.stop + dsl.start)\n                            for sl, dsl in zip(slices_, self.data_slice))\n\n        # load the slice and pad if necessary\n        sliced_volume = self.dataset[slices_]\n        if need_padding:\n            sliced_volume = np.pad(sliced_volume, pad_width=pad_width,\n                                   mode=self.padding_mode)\n\n        if self.transforms is None:\n            transformed = sliced_volume\n        else:\n            transformed = self.transforms(sliced_volume)\n        if self.return_index_spec:\n            return transformed, IndexSpec(index=index, base_sequence_at_index=slices)\n        else:\n            return transformed\n\n    def clone(self, dataset=None, transforms=None, name=None):\n        # Make sure the dataset shapes check out\n        assert dataset.shape == self.dataset.shape\n        # Make a new instance (without initializing)\n        new = type(self).__new__(type(self))\n        # Update dictionary to initialize\n        new_dict = dict(self.__dict__)\n        if dataset is not None:\n            new_dict.update({'dataset': dataset})\n        if transforms is not None:\n            new_dict.update({'transforms': transforms})\n        if name is not None:\n            new_dict.update({'name': name})\n        new.__dict__.update(new_dict)\n        return new\n\n    def __repr__(self):\n        return \"{}(shape={}, name={})\".format(type(self).__name__, self.dataset.shape, self.name)\n\n\n# baseclass for hdf5, zarr or n5 volume loaders\nclass LazyVolumeLoader(LazyVolumeLoaderBase):\n    def __init__(self, file_impl, path,\n                 path_in_file=None, data_slice=None, transforms=None,\n                 name=None, **slicing_config):\n\n        if isinstance(path, dict):\n            assert name is not None\n            assert name in path\n            self.path = path.get(name)\n        elif isinstance(path, str):\n            assert os.path.exists(path), path\n            self.path = path\n        else:\n            raise NotImplementedError(\"Not implemented for type %s\" % type(path))\n\n        if isinstance(path_in_file, dict):\n            assert name is not None\n            assert name in path_in_file\n            self.path_in_file = path_in_file.get(name)\n        elif isinstance(path_in_file, str):\n            self.path_in_file = path_in_file\n        elif path_in_file is None:\n            self.path_in_file = None\n        else:\n            raise NotImplementedError\n\n        if data_slice is None or isinstance(data_slice, (str, list, tuple)):\n            data_slice = vu.parse_data_slice(data_slice)\n        elif isinstance(data_slice, dict):\n            assert name is not None\n            assert name in data_slice\n            data_slice = vu.parse_data_slice(data_slice.get(name))\n        else:\n            raise NotImplementedError\n        self.validate_data_slice(data_slice)\n\n        slicing_config_for_name = pyu.get_config_for_name(slicing_config, name)\n\n        assert 'window_size' in slicing_config_for_name\n        assert 'stride' in slicing_config_for_name\n\n        self.file_ = file_impl(self.path, mode='r')\n        # Initialize superclass with the volume\n        super(LazyVolumeLoader, self).__init__(dataset=self.file_[self.path_in_file], name=name,\n                                               transforms=transforms, data_slice=data_slice,\n                                               **slicing_config_for_name)\n\n    # we do not support step in the dataslice\n    def validate_data_slice(self, data_slice):\n        if data_slice is not None:\n            assert all(sl.step in (None, 1) for sl in data_slice), \"Complicated step is not supported\"\n\n\nclass LazyHDF5VolumeLoader(LazyVolumeLoader):\n    def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=None,\n                 name=None, **slicing_config):\n        assert WITH_H5PY, \"Need h5py to load volume from hdf5 file.\"\n        super(LazyHDF5VolumeLoader, self).__init__(file_impl=h5py.File, path=path,\n                                                   path_in_file=path_in_h5_dataset,\n                                                   data_slice=data_slice, transforms=transforms,\n                                                   name=name, **slicing_config)\n\n    # this is not pythonic, but we need to close the h5py file\n    def __del__(self):\n        self.file_.close()\n\n\nclass LazyN5VolumeLoader(LazyVolumeLoader):\n    def __init__(self, path, path_in_file=None, data_slice=None, transforms=None,\n                 name=None, **slicing_config):\n        assert WITH_Z5PY, \"Need z5py to load volume from N5 file.\"\n        assert slicing_config.get('downsampling_ratio', None) is None,\\\n            \"Downsampling is not supported by z5py based loaderes\"\n        super(LazyN5VolumeLoader, self).__init__(file_impl=z5py.N5File, path=path,\n                                                 path_in_file=path_in_file,\n                                                 data_slice=data_slice,\n                                                 transforms=transforms,\n                                                 name=name, **slicing_config)\n\n\nclass LazyZarrVolumeLoader(LazyVolumeLoader):\n    def __init__(self, path, path_in_file=None, data_slice=None, transforms=None,\n                 name=None, **slicing_config):\n        assert WITH_Z5PY, \"Need z5py to load volume from zarr file.\"\n        assert slicing_config.get('downsampling_ratio', None) is None,\\\n            \"Downsampling is not supported by z5py based loaderes\"\n        super(LazyZarrVolumeLoader, self).__init__(file_impl=z5py.ZarrFile, path=path,\n                                                   path_in_file=path_in_file,\n                                                   data_slice=data_slice,\n                                                   transforms=transforms,\n                                                   name=name, **slicing_config)\n"
  },
  {
    "path": "inferno/io/volumetric/volume.py",
    "content": "import numpy as np\nimport os\nimport skimage.io\n\nfrom ..core.base import SyncableDataset\nfrom ..core.base import IndexSpec\nfrom . import volumetric_utils as vu\nfrom ...utils import io_utils as iou\nfrom ...utils import python_utils as pyu\nfrom ...utils.exceptions import assert_, ShapeError\n\n\nclass VolumeLoader(SyncableDataset):\n    \"\"\" Loader for in-memory volumetric data.\n\n    Parameters\n    ----------\n    volume: np.ndarray\n        the volumetric data\n    window_size: list or tuple\n        size of the (3d) sliding window used for iteration\n    stride: list or tuple\n        stride of the (3d) sliding window used for iteration\n    downsampling_ratio: list or tuple (default: None)\n        factor by which the data is downsampled (no downsapling by default)\n    padding: list (default: None)\n        padding for data, follows np.pad syntax\n    padding_mode: str (default: 'reflect')\n        padding mode as in np.pad\n    transforms: callable (default: None)\n       transforms applied on each batch loaded from volume\n    return_index_spec: bool (default: False)\n        whether to return the index spec for each batch\n    name: str (default: None)\n        name of this volume\n    is_multichannel: bool (default: False)\n        is this a multichannel volume? sliding window is NOT applied to channel dimension\n    \"\"\"\n\n    def __init__(self, volume, window_size, stride, downsampling_ratio=None, padding=None,\n                 padding_mode='reflect', transforms=None, return_index_spec=False, name=None,\n                 is_multichannel=False):\n        super(VolumeLoader, self).__init__()\n        # Validate volume\n        assert isinstance(volume, np.ndarray), str(type(volume))\n        # Validate window size and stride\n        if is_multichannel:\n            assert_(len(window_size) + 1 == volume.ndim, \"%i, %i\" % (len(window_size),\n                                                                     volume.ndim),\n                                                                    ShapeError)\n            assert_(len(stride) + 1 == volume.ndim, exception_type=ShapeError)\n            # TODO implemnent downsampling and padding for multi-channel volume\n            assert_(downsampling_ratio is None, exception_type=NotImplementedError)\n            assert_(padding is None, exception_type=NotImplementedError)\n        else:\n            assert_(len(window_size) == volume.ndim, \"%i, %i\" % (len(window_size),\n                                                                 volume.ndim),\n                                                                ShapeError)\n            assert_(len(stride) == volume.ndim, exception_type=ShapeError)\n        # Validate transforms\n        assert_(transforms is None or callable(transforms))\n\n        self.name = name\n        self.return_index_spec = return_index_spec\n        self.volume = volume\n        self.window_size = window_size\n        self.stride = stride\n        self.padding_mode = padding_mode\n        self.is_multichannel = is_multichannel\n        self.transforms = transforms\n        # DataloaderIter should do the shuffling\n        self.shuffle = False\n\n        ndim = self.volume.ndim - 1 if is_multichannel else self.volume.ndim\n\n        if downsampling_ratio is None:\n            self.downsampling_ratio = [1] * ndim\n        elif isinstance(downsampling_ratio, int):\n            self.downsampling_ratio = [downsampling_ratio] * self.volume.ndim\n        elif isinstance(downsampling_ratio, (list, tuple)):\n            assert_(len(downsampling_ratio) == self.volume.ndim, exception_type=ShapeError)\n            self.downsampling_ratio = list(downsampling_ratio)\n        else:\n            raise NotImplementedError\n\n        if padding is None:\n            self.padding = [[0, 0]] * ndim\n        else:\n            self.padding = padding\n            self.pad_volume()\n\n        self.base_sequence = self.make_sliding_windows()\n\n    def pad_volume(self, padding=None):\n        padding = self.padding if padding is None else padding\n        if padding is None:\n            return self.volume\n        else:\n            #for symmertic padding only one int can be passed for each axis\n            assert_(all(isinstance(pad, (int, tuple, list)) for pad in self.padding),\\\n                \"Expect int or iterable\", TypeError)\n            self.padding = [[pad, pad] if isinstance(pad, int) else pad for pad in self.padding]\n            self.volume = np.pad(self.volume,\n                                 pad_width=self.padding,\n                                 mode=self.padding_mode)\n            return self.volume\n\n    def make_sliding_windows(self):\n        shape = self.volume.shape[1:] if self.is_multichannel else self.volume.shape\n        return list(vu.slidingwindowslices(shape=list(shape),\n                                           window_size=self.window_size,\n                                           strides=self.stride,\n                                           shuffle=self.shuffle,\n                                           add_overhanging=True,\n                                           ds=self.downsampling_ratio))\n\n    def __getitem__(self, index):\n        # Casting to int would allow index to be IndexSpec objects.\n        index = int(index)\n        slices = self.base_sequence[index]\n        if self.is_multichannel:\n            slices = (slice(None),) + tuple(slices)\n        sliced_volume = self.volume[tuple(slices)]\n        if self.transforms is None:\n            transformed = sliced_volume\n        else:\n            transformed = self.transforms(sliced_volume)\n        if self.return_index_spec:\n            return transformed, IndexSpec(index=index, base_sequence_at_index=slices)\n        else:\n            return transformed\n\n    def clone(self, volume=None, transforms=None, name=None):\n        # Make sure the volume shapes check out\n        assert_(volume.shape == self.volume.shape, exception_type=ShapeError)\n        # Make a new instance (without initializing)\n        new = type(self).__new__(type(self))\n        # Update dictionary to initialize\n        new_dict = dict(self.__dict__)\n        if volume is not None:\n            new_dict.update({'volume': volume})\n        if transforms is not None:\n            new_dict.update({'transforms': transforms})\n        if name is not None:\n            new_dict.update({'name': name})\n        new.__dict__.update(new_dict)\n        return new\n\n    def __repr__(self):\n        return \"{}(shape={}, name={})\".format(type(self).__name__, self.volume.shape, self.name)\n\n\nclass HDF5VolumeLoader(VolumeLoader):\n    \"\"\" Loader for volumes stored in hdf5, zarr or n5.\n\n    Zarr and n5 are file formats very similar to hdf5, but use\n    the regular filesystem to store data instead of a filesystem\n    in a file as hdf5.\n    The file type will be infered from the extension:\n    .hdf5, .h5 and .hdf map to hdf5\n    .n5 maps to n5\n    .zr and .zarr map to zarr\n    It will fail for other extensions.\n\n    Parameters\n    ----------\n    path: str\n        path to file\n    path_in_h5_dataset: str (default: None)\n        path in file\n    data_slice: slice (default: None)\n        slice loaded from dataset\n    transforms: callable (default: None)\n       transforms applied on each batch loaded from volume\n    name: str (default: None)\n        name of this volume\n    slicing_config: kwargs\n        keyword arguments for base class `VolumeLoader`\n    \"\"\"\n\n    @staticmethod\n    def is_h5(file_path):\n        ext = os.path.splitext(file_path)[1].lower()\n        if ext in ('.h5', '.hdf', '.hdf5'):\n            return True\n        elif ext in ('.zarr', '.zr', '.n5'):\n            return False\n        else:\n            raise RuntimeError(\"Could not infer volume type for file extension %s\" % ext)\n\n    def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=None,\n                 name=None, **slicing_config):\n\n        if isinstance(path, dict):\n            assert name is not None\n            assert name in path\n            self.path = path.get(name)\n        elif isinstance(path, str):\n            assert os.path.exists(path), path\n            self.path = path\n        else:\n            raise NotImplementedError\n\n        if isinstance(path_in_h5_dataset, dict):\n            assert name is not None\n            assert name in path_in_h5_dataset\n            self.path_in_h5_dataset = path_in_h5_dataset.get(name)\n        elif isinstance(path_in_h5_dataset, str):\n            self.path_in_h5_dataset = path_in_h5_dataset\n        elif path_in_h5_dataset is None:\n            self.path_in_h5_dataset = None\n        else:\n            raise NotImplementedError\n\n        # get the dataslice\n        if data_slice is None or isinstance(data_slice, (str, list)):\n            self.data_slice = vu.parse_data_slice(data_slice)\n        elif isinstance(data_slice, dict):\n            assert name is not None\n            assert name in data_slice\n            self.data_slice = vu.parse_data_slice(data_slice.get(name))\n        else:\n            raise NotImplementedError\n\n        slicing_config_for_name = pyu.get_config_for_name(slicing_config, name)\n\n        # adapt data-slice if this is a multi-channel volume (slice is not applied to channel dimension)\n        if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False):\n            self.data_slice = (slice(None),) + self.data_slice\n\n        assert 'window_size' in slicing_config_for_name, str(slicing_config_for_name)\n        assert 'stride' in slicing_config_for_name\n\n        # Read in volume from file (can be hdf5, n5 or zarr)\n        if self.is_h5(self.path):\n            volume = iou.fromh5(self.path, self.path_in_h5_dataset,\n                                dataslice=self.data_slice)\n        else:\n            volume = iou.fromz5(self.path, self.path_in_h5_dataset,\n                                dataslice=self.data_slice)\n        # Initialize superclass with the volume\n        super(HDF5VolumeLoader, self).__init__(volume=volume, name=name, transforms=transforms,\n                                               **slicing_config_for_name)\n\n\nclass TIFVolumeLoader(VolumeLoader):\n    \"\"\"Loader for volumes stored in .tif files.\"\"\"\n    def __init__(self, path, data_slice=None, transforms=None, name=None, **slicing_config):\n        \"\"\"\n        Parameters\n        ----------\n        path : str\n            Path to the volume.\n        transforms : callable\n            Transforms to apply on the read volume.\n        slicing_config : dict\n            Dictionary specifying the sliding window. Must contain keys 'window_size'\n            and 'stride'.\n        \"\"\"\n        if isinstance(path, dict):\n            assert name in path.keys()\n            assert os.path.exists(path.get(name))\n            self.path = path.get(name)\n        elif isinstance(path, str):\n            assert os.path.exists(path)\n            self.path = path\n        else:\n            raise NotImplementedError\n\n        assert 'window_size' in slicing_config\n        assert 'stride' in slicing_config\n\n        if data_slice is None or isinstance(data_slice, (str, list)):\n            self.data_slice = vu.parse_data_slice(data_slice)\n        elif isinstance(data_slice, dict):\n            assert name is not None\n            assert name in data_slice\n            self.data_slice = vu.parse_data_slice(data_slice.get(name))\n        else:\n            raise NotImplementedError\n\n        # Read in volume from file\n        volume = skimage.io.imread(self.path)\n        # and slice it\n        volume = volume[self.data_slice] if self.data_slice is not None else volume\n        # Initialize superclass with the volume\n        super(TIFVolumeLoader, self).__init__(volume=volume, transforms=transforms,\n                                              **slicing_config)\n"
  },
  {
    "path": "inferno/io/volumetric/volumetric_utils.py",
    "content": "import random\nimport itertools as it\n\n\ndef slidingwindowslices(shape, window_size, strides,\n                        ds=1, shuffle=True, rngseed=None,\n                        dataslice=None, add_overhanging=True):\n    # only support lists or tuples for shape, window_size and strides\n    assert isinstance(shape, (list, tuple))\n    assert isinstance(window_size, (list, tuple)), \"%s\" % (str(type(window_size)))\n    assert isinstance(strides, (list, tuple))\n\n    dim = len(shape)\n    assert len(window_size) == dim\n    assert len(strides) == dim\n\n    # check for downsampling\n    assert isinstance(ds, (list, tuple, int))\n    if isinstance(ds, int):\n        ds = [ds] * dim\n    assert len(ds) == dim\n\n    # Seed RNG if a seed is provided\n    if rngseed is not None:\n        random.seed(rngseed)\n\n    # sliding windows in one dimenstion\n    def dimension_window(start, stop, wsize, stride, dimsize, ds_dim):\n        starts = range(start, stop + 1, stride)\n        slices = [slice(st, st + wsize, ds_dim) for st in starts if st + wsize <= dimsize]\n\n        # add an overhanging window at the end if the windoes\n        # do not fit and `add_overhanging`\n        if slices[-1].stop != dimsize and add_overhanging:\n            slices.append(slice(dimsize - wsize, dimsize, ds_dim))\n\n        if shuffle:\n            random.shuffle(slices)\n        return slices\n\n    # determine adjusted start and stop coordinates if we have a dataslice\n    # otherwise predict the whole volume\n    if dataslice is not None:\n        assert len(dataslice) == dim, \"Dataslice must be a tuple with len = data dimension.\"\n        starts = [0 if sl.start is None else sl.start for sl in dataslice]\n        stops = [sh - wsize if sl.stop is None else sl.stop - wsize\n                 for sl, wsize, sh in zip(dataslice, window_size, shape)]\n    else:\n        starts = dim * [0]\n        stops = [dimsize - wsize if wsize != dimsize else dimsize\n                 for dimsize, wsize in zip(shape, window_size)]\n\n    assert all(stp > strt for strt, stp in zip(starts, stops)),\\\n        \"%s, %s\" % (str(starts), str(stops))\n    nslices = [dimension_window(start, stop, wsize, stride, dimsize, ds_dim)\n               for start, stop, wsize, stride, dimsize, ds_dim\n               in zip(starts, stops, window_size, strides, shape, ds)]\n    return it.product(*nslices)\n\n\n# This code is legacy af, don't judge\n# Define a sliding window iterator (this time, more readable than a wannabe one-liner)\ndef slidingwindowslices_depr(shape, nhoodsize, stride=1, ds=1, window=None, ignoreborder=True,\n                             shuffle=True, rngseed=None,\n                             startmins=None, startmaxs=None, dataslice=None):\n    \"\"\"\n    Returns a generator yielding (shuffled) sliding window slice objects.\n    :type shape: int or list of int\n    :param shape: Shape of the input data\n    :type nhoodsize: int or list of int\n    :param nhoodsize: Window size of the sliding window.\n    :type stride: int or list of int\n    :param stride: Stride of the sliding window.\n    :type shuffle: bool\n    :param shuffle: Whether to shuffle the iterator.\n    \"\"\"\n\n    # Determine dimensionality of the data\n    datadim = len(shape)\n\n    # Parse window\n    if window is None:\n        window = ['x'] * datadim\n    else:\n        assert len(window) == datadim, \\\n            \"Window must have the same length as the number of data dimensions.\"\n\n    # Parse nhoodsize and stride\n    nhoodsize = [nhoodsize, ] * datadim if isinstance(nhoodsize, int) else nhoodsize\n    stride = [stride, ] * datadim if isinstance(stride, int) else stride\n    ds = [ds, ] * datadim if isinstance(ds, int) else ds\n\n    # Seed RNG if a seed is provided\n    if rngseed is not None:\n        random.seed(rngseed)\n\n    # Define a function that gets a 1D slice\n    def _1Dwindow(startmin, startmax, nhoodsize, stride, ds, seqsize, shuffle):\n        starts = range(startmin, startmax + 1, stride)\n\n        if ignoreborder:\n            slices = [slice(st, st + nhoodsize, ds) for st in starts if st + nhoodsize <= seqsize]\n        else:\n            slices = [slice(st, ((st + nhoodsize) if st + nhoodsize <= seqsize else None), ds)\n                      for st in starts]\n\n        if shuffle:\n            random.shuffle(slices)\n        return slices\n\n    # Get window start limits\n    if dataslice is None:\n        startmins = [0, ] * datadim if startmins is None else startmins\n        startmaxs = [shp - nhoodsiz for shp, nhoodsiz in zip(shape, nhoodsize)] \\\n            if startmaxs is None else startmaxs\n    else:\n        assert len(dataslice) == datadim, \\\n            \"Dataslice must be a tuple with len = data dimension.\"\n        startmins = [sl.start for sl in dataslice]\n        startmaxs = [sl.stop - nhoodsiz for sl, nhoodsiz in zip(dataslice, nhoodsize)]\n\n    def _to_list(x):\n        if not isinstance(x, (list, tuple)):\n            return list(x)\n        else:\n            return x\n\n    # The final iterator is going to be a cartesian product of the lists in nslices\n    nslices = [_1Dwindow(startmin, startmax, nhoodsiz, st, dsample, datalen, shuffle) if windowspec == 'x'\n               else [slice(ws, ws + 1) for ws in _to_list(windowspec)]\n               for startmin, startmax, datalen, nhoodsiz, st, windowspec, dsample in zip(startmins, startmaxs, shape,\n                                                                                         nhoodsize, stride, window, ds)]\n\n    return it.product(*nslices)\n\n\ndef parse_data_slice(data_slice):\n    \"\"\"Parse a dataslice as a list of slice objects.\"\"\"\n    if data_slice is None:\n        return data_slice\n    elif isinstance(data_slice, (list, tuple)) and \\\n            all([isinstance(_slice, slice) for _slice in data_slice]):\n        return tuple(data_slice)\n    else:\n        assert isinstance(data_slice, str)\n    # Get rid of whitespace\n    data_slice = data_slice.replace(' ', '')\n    # Split by commas\n    dim_slices = data_slice.split(',')\n    # Build slice objects\n    slices = []\n    for dim_slice in dim_slices:\n        indices = dim_slice.split(':')\n        if len(indices) == 2:\n            start, stop, step = indices[0], indices[1], None\n        elif len(indices) == 3:\n            start, stop, step = indices\n        else:\n            raise RuntimeError\n        # Convert to ints\n        start = int(start) if start != '' else None\n        stop = int(stop) if stop != '' else None\n        step = int(step) if step is not None and step != '' else None\n        # Build slices\n        slices.append(slice(start, stop, step))\n    return tuple(slices)\n"
  },
  {
    "path": "inferno/trainers/__init__.py",
    "content": "from . import basic\nfrom . import callbacks\nfrom . basic import Trainer\n__all__ = ['basic','callbacks','Trainer']"
  },
  {
    "path": "inferno/trainers/basic.py",
    "content": "from datetime import datetime\nfrom inspect import signature\nimport os\nimport shutil\n\n# These are fetched from globals, they're not unused\n# noinspection PyUnresolvedReferences\nimport dill\n# noinspection PyUnresolvedReferences\nimport pickle\n\n\nimport torch\nfrom numpy import inf\nfrom torch.utils.data import DataLoader\nfrom torch.nn.parallel.data_parallel import data_parallel\nfrom .callbacks.logging.base import Logger\nfrom .callbacks.logging import get_logger\n\nfrom ..utils import train_utils as tu\nfrom ..utils import python_utils as pyu\nfrom ..utils import torch_utils as thu\nfrom ..extensions import metrics\nfrom ..extensions import optimizers\nfrom ..extensions import criteria\nfrom .callbacks import CallbackEngine\nfrom .callbacks import Console\nfrom ..utils.exceptions import assert_, NotSetError, NotTorchModuleError, DeviceError\n\n# NOTE for distributed training, we might also need\n# from apex.parallel import DistributedDataParallel as DDP\n# but I don't know where exactly to put it.\ntry:\n    from apex import amp\nexcept ImportError:\n    amp = None\n\n\nclass Trainer(object):\n    \"\"\"A basic trainer.\n\n    Given a torch model, this class encapsulates the training and validation loops,\n    checkpoint creation, logging, CPU <-> GPU transfers and managing data-loaders.\n\n    In addition, this class interacts with the callback engine (found at\n    `inferno.trainers.callbacks.base.CallbackEngine`), which manages callbacks at\n    certain preset events.\n\n    Notes\n    -----\n    Logging is implemented as a special callback, in the sense that it's jointly\n    managed by the this class and the callback engine. This is primarily because\n    general callbacks are not intended to be serializable, but not being able to\n    serialize the logger is a nuisance.\n    \"\"\"\n    def __init__(self, model=None):\n        \"\"\"\n        Parameters\n        ----------\n        model : torch.nn.Module\n            Torch model to bind to.\n        \"\"\"\n        # Privates\n        # Core\n        self._model = None\n        self._optimizer = None\n        self._criterion = None\n        self._retain_graph = False\n        self._backprop_every = 1\n\n        # Metric evaluation\n        self._metric = None\n        self._evaluate_metric_every = None\n        self._metric_evaluation_externally_triggered = False\n        self._last_metric_evaluated_at_epoch = 0\n\n        # Logging\n        self._logger = None\n        self._last_logged = {}\n        self._log_directory = {}\n\n        # Data logistics\n        self._loaders = {}\n        self._loader_iters = {}\n        self._loader_specs = {}\n\n        # Iteration and epoch book-keeping\n        self._iteration_count = 0\n        self._epoch_count = 0\n        self._batch_count = 0\n        self._current_mode = 'train'\n\n        # GPU and dtype business\n        self._use_cuda = False\n        self._dtype = 'float'\n        self._devices = None\n        self._base_device_ordinal = None\n\n        # Validation\n        self._save_at_best_validation_score = False\n        self._best_validation_score = None\n        self._is_iteration_with_best_validation_score = False\n        self._validate_every = None\n        self._num_validation_iterations = None\n        self._target_batch_dim = 0\n        self._validation_criterion = None\n        # We should exclude the zero-th epoch from validation\n        self._last_validated_at_epoch = 0\n        self._last_validated_at_iteration = 0\n        # This is to allow a callback to trigger a validation by setting\n        # trainer.validate_now = True\n        self._validation_externally_triggered = False\n\n        # Checkpointing\n        self._save_every = None\n        self._save_to_directory = None\n        self._pickle_module = 'pickle'\n        # Defaults for file names\n        self._checkpoint_filename = 'checkpoint.pytorch'\n        self._best_checkpoint_filename = 'best_checkpoint.pytorch'\n\n        # Nothing to save at epoch 0\n        self._last_saved_at_epoch = 0\n        # This is to allow a callback to trigger a save by setting trainer.save_now = True\n        self._save_externally_triggered = False\n\n        # Stopping conditions\n        self._max_num_iterations = None\n        self._max_num_epochs = None\n\n        # Callbacks and states\n        self._callback_engine = CallbackEngine().bind_trainer(self)\n        self._state = {}\n\n        # Print console\n        self._console = Console()\n\n        # Train with mixed precision, only works\n        # if we have apex\n        self._mixed_precision = False\n        self._apex_opt_level = 'O1'\n\n        # Public\n        if model is not None:\n            self.model = model\n\n    @property\n    def mixed_precision(self):\n        return self._mixed_precision\n\n    # this needs to be called after model and optimizer are set\n    @mixed_precision.setter\n    def mixed_precision(self, mp):\n        if mp:\n            assert_(amp is not None, \"Cannot use mixed precision training without apex library\", RuntimeError)\n            assert_(self.model is not None and self._optimizer is not None,\n                    \"Model and optimizer need to be set before activating mixed precision\", RuntimeError)\n            # in order to support BCE loss\n            amp.register_float_function(torch, 'sigmoid')\n            # For now, we don't allow to set 'keep_batchnorm' and 'loss_scale'\n            self.model, self._optimizer = amp.initialize(self.model, self._optimizer,\n                                                         opt_level=self._apex_opt_level,\n                                                         keep_batchnorm_fp32=None)\n        self._mixed_precision = mp\n\n    @property\n    def apex_opt_level(self):\n        return self._apex_opt_level\n\n    @apex_opt_level.setter\n    def apex_opt_level(self, opt_level):\n        assert_(opt_level in ('O0', 'O1', 'O2', 'O3'),\n                \"Invalid optimization level\", ValueError)\n        self._apex_opt_level = opt_level\n\n    @property\n    def console(self):\n        \"\"\"Get the current console.\"\"\"\n        return self._console\n\n    def set_console(self, console):\n        assert_(isinstance(console, Console), \"`console` must be a Console object.\", TypeError)\n        self._console = console\n        return self\n\n    def quiet(self):\n        self.console.toggle_progress(False)\n        return self\n\n    @property\n    def callbacks(self):\n        \"\"\"Gets the callback engine.\"\"\"\n        return self._callback_engine\n\n    def register_callback(self, callback, trigger='auto', **callback_kwargs):\n        \"\"\"\n        Registers a callback with the internal callback engine.\n\n        Parameters\n        ----------\n        callback : type or callable\n            Callback to register.\n        trigger : str\n            Specify the event that triggers the callback. Leave at 'auto' to have the\n            callback-engine figure out the triggers. See\n            `inferno.training.callbacks.base.CallbackEngine` documentation for more on this.\n        callback_kwargs : dict\n            If `callback` is a type, initialize an instance with these keywords to the\n            __init__ method.\n        Returns\n        -------\n        Trainer\n            self.\n        \"\"\"\n        if isinstance(callback, type):\n            callback = callback(**callback_kwargs)\n        self._callback_engine.register_callback(callback, trigger=trigger)\n        return self\n\n    @property\n    def model(self):\n        \"\"\"Gets the model.\"\"\"\n        assert_(self._model is not None, \"Model is not defined yet.\", NotSetError)\n        return self._model\n\n    @model.setter\n    def model(self, value):\n        self.bind_model(value)\n\n    def bind_model(self, model):\n        \"\"\"\n        Binds a model to the trainer. Equivalent to setting model.\n\n        Parameters\n        ----------\n        model : torch.nn.Module\n            Model to bind.\n\n        Returns\n        -------\n        Trainer\n            self.\n        \"\"\"\n        assert_(isinstance(model, torch.nn.Module),\n                \"Model must be a torch.nn.Module.\",\n                NotTorchModuleError)\n        self._model = model\n        # Transfer model to GPU if required\n        if self._use_cuda:\n            self._model.cuda()\n        return self\n\n    @property\n    def model_is_defined(self):\n        return self._model is not None\n\n    @property\n    def retain_graph(self):\n        return self._retain_graph\n\n    @retain_graph.setter\n    def retain_graph(self, value):\n        assert isinstance(value, bool)\n        self._retain_graph = value\n\n    @property\n    def backprop_every(self):\n        return self._backprop_every\n\n    @backprop_every.setter\n    def backprop_every(self, value):\n        self.set_backprop_every(value)\n\n    def set_backprop_every(self, num_steps):\n        \"\"\"\n        Set frequency of backpropagation.\n        To use in cases of small batch sizes.\n\n        Parameters\n        ----------\n        num_steps : number of steps (iterations/batches) to backprop after\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        assert isinstance(num_steps, int)\n        self._backprop_every = num_steps\n        return self\n\n    @property\n    def optimizer(self):\n        \"\"\"Gets the optimizer.\"\"\"\n        assert_(self._optimizer is not None, \"Optimizer is not set yet.\", NotSetError)\n        return self._optimizer\n\n    @optimizer.setter\n    def optimizer(self, value):\n        if isinstance(value, str) or callable(value):\n            self.build_optimizer(value)\n        elif isinstance(value, dict):\n            self.build_optimizer(**value)\n        else:\n            raise NotImplementedError\n\n    @property\n    def optimizer_is_defined(self):\n        return self._optimizer is not None\n\n    def build_optimizer(self, method, param_groups=None, **kwargs):\n        \"\"\"\n        Builds the optimizer for training.\n\n        Parameters\n        ----------\n        method : str or callable or torch.optim.Optimizer\n            Name of the optimizer when str, handle to the optimizer class when callable,\n            or a torch.optim.Optimizer instance. If a name is provided, this method looks\n            for the optimizer in `torch.optim` module first and in\n            inferno.extensions.optimizers second.\n        param_groups : list of dict\n            Specifies the parameter group. Defaults to model.parameters() if None.\n        kwargs : dict\n            Keyword arguments to the optimizer.\n\n        Returns\n        -------\n        Trainer\n            self.\n\n        Raises\n        ------\n        AssertionError\n            if optimizer is not found\n        NotImplementedError\n            if method is not str or callable.\n        \"\"\"\n        if isinstance(method, str):\n            optimizer_class = getattr(torch.optim, method, None)\n            if optimizer_class is None:\n                # Look for optimizer in extensions\n                optimizer_class = getattr(optimizers, method, None)\n            assert optimizer_class is not None, \"Optimizer {} not found.\".format(method)\n        elif callable(method) and isinstance(method, type):\n            optimizer_class = method\n        elif isinstance(method, torch.optim.Optimizer):\n            self._optimizer = method\n            return self\n        else:\n            raise NotImplementedError\n        param_groups = self.model.parameters() if param_groups is None else param_groups\n        self._optimizer = optimizer_class(param_groups, **kwargs)\n        return self\n\n    @property\n    def criterion(self):\n        \"\"\"Gets the loss criterion.\"\"\"\n        assert_(self._criterion is not None, \"Criterion is not set yet.\", NotSetError)\n        return self._criterion\n\n    @criterion.setter\n    def criterion(self, value):\n        if isinstance(value, str) or callable(value):\n            self.build_criterion(value)\n        elif isinstance(value, dict):\n            self.build_criterion(**value)\n        else:\n            raise RuntimeError(f\"Criterion can either be set to a string, callable or a dict. \"\n                               f\"Got {type(value).__name__} instead.\")\n\n    def build_criterion(self, method, **kwargs):\n        \"\"\"\n        Builds the loss criterion for training.\n\n        Parameters\n        ----------\n        method : str or callable or torch.nn.Module\n            Name of the criterion when str, criterion class when callable, or a\n            torch.nn.Module instance. If a name is provided, this method looks\n            for the criterion in `torch.nn`.\n        kwargs : dict\n            Keyword arguments to the criterion class' constructor if applicable.\n\n        Returns\n        -------\n        Trainer\n            self.\n\n        Raises\n        ------\n        AssertionError\n            if criterion is not found.\n        NotImplementedError\n            if method is neither a str nor a callable.\n        \"\"\"\n        if isinstance(method, str):\n            # Look for criteria in torch\n            criterion_class = getattr(torch.nn, method, None)\n            if criterion_class is None:\n                # Look for it in extensions\n                criterion_class = getattr(criteria, method, None)\n            assert criterion_class is not None, \"Criterion {} not found.\".format(method)\n        elif callable(method) and isinstance(method, type):\n            criterion_class = method\n        elif isinstance(method, torch.nn.Module):\n            self._criterion = method\n            return self\n        else:\n            raise NotImplementedError\n        self._criterion = criterion_class(**kwargs)\n        # Transfer criterion to GPU if required. This is necessary for e.g. weighted loss,\n        # where the weight is registered as a buffer.\n        # The criterion is to be cuda'ed only if the model is on CUDA (self._use_cuda) and\n        # the base_device is not CPU (ordinal -1).\n        if hasattr(self, '_base_device_ordinal'):\n            # This is to not break old checkpoints\n            base_device_ordinal = self._base_device_ordinal\n        else:\n            base_device_ordinal = None\n        if self._use_cuda and base_device_ordinal != 1:\n            self._criterion.cuda()\n        return self\n\n    @property\n    def criterion_is_defined(self):\n        return self._criterion is not None\n\n    @property\n    def validation_criterion(self):\n        if self._validation_criterion is None:\n            return self.criterion\n        else:\n            return self._validation_criterion\n\n    @validation_criterion.setter\n    def validation_criterion(self, value):\n        if isinstance(value, str) or callable(value):\n            self.build_validation_criterion(value)\n        elif isinstance(value, dict):\n            self.build_validation_criterion(**value)\n        else:\n            raise RuntimeError(f\"Validation criterion can either be set to a string, callable \"\n                               f\"or a dict. Got {type(value).__name__} instead.\")\n\n    def build_validation_criterion(self, method, **kwargs):\n        \"\"\"\n        Builds the loss criterion for validation.\n\n        Parameters\n        ----------\n        method : str or callable or torch.nn.Module\n            Name of the criterion when str, criterion class when callable, or a\n            torch.nn.Module instance. If a name is provided, this method looks\n            for the criterion in `torch.nn`.\n        kwargs : dict\n            Keyword arguments to the criterion class' constructor if applicable.\n\n        Returns\n        -------\n        Trainer\n            self.\n\n        Raises\n        ------\n        AssertionError\n            if criterion is not found.\n        NotImplementedError\n            if method is neither a str nor a callable.\n        \"\"\"\n        if isinstance(method, str):\n            # Look for criteria in torch\n            criterion_class = getattr(torch.nn, method, None)\n            if criterion_class is None:\n                # Look for it in extensions\n                criterion_class = getattr(criteria, method, None)\n            assert criterion_class is not None, \"Criterion {} not found.\".format(method)\n        elif callable(method) and isinstance(method, type):\n            criterion_class = method\n        elif isinstance(method, torch.nn.Module):\n            self._validation_criterion = method\n            return self\n        else:\n            raise NotImplementedError\n        self._validation_criterion = criterion_class(**kwargs)\n        # Transfer criterion to GPU if required. This is necessary for e.g. weighted loss,\n        # where the weight is registered as a buffer.\n        # The criterion is to be cuda'ed only if the model is on CUDA (self._use_cuda) and\n        # the base_device is not CPU (ordinal -1).\n        if hasattr(self, '_base_device_ordinal'):\n            # This is to not break old checkpoints\n            base_device_ordinal = self._base_device_ordinal\n        else:\n            base_device_ordinal = None\n        if self._use_cuda and base_device_ordinal != 1:\n            self._validation_criterion.cuda()\n        return self\n\n    def validation_criterion_is_train_criterion(self, yes=True):\n        if yes:\n            # This will cause the property to return train criterion\n            self._validation_criterion = None\n        return self\n\n    @property\n    def validation_criterion_is_defined(self):\n        return self._validation_criterion is not None\n\n    @property\n    def metric(self):\n        \"\"\"Gets the evaluation metric.\"\"\"\n        assert_(self._metric is not None, \"Metric is not set yet.\", NotSetError)\n        return self._metric\n\n    @metric.setter\n    def metric(self, value):\n        if callable(value) or isinstance(value, str):\n            self.build_metric(value)\n        else:\n            raise NotImplementedError\n\n    @property\n    def evaluating_metric_every(self):\n        return self._evaluate_metric_every\n\n    def evaluate_metric_every(self, frequency):\n        \"\"\"\n        Set frequency of metric evaluation __during training__ (and not during validation).\n\n        Parameters\n        ----------\n        frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int\n            Metric evaluation frequency. If str, it could be (say) '10 iterations' or '1 epoch'.\n            If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int\n            (say 10), it's interpreted as (10, 'iterations').\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        self._evaluate_metric_every = tu.Frequency.build_from(frequency, priority='iterations')\n        assert self._evaluate_metric_every.is_consistent\n        return self\n\n    @property\n    def evaluate_metric_now(self):\n        if self._metric_evaluation_externally_triggered:\n            # Reset trigger\n            self._metric_evaluation_externally_triggered = False\n            return True\n        elif self._evaluate_metric_every is None:\n            # By default, evaluate metric every time\n            return True\n        elif self._evaluate_metric_every is not None and self._evaluate_metric_every.by_epoch:\n            # Don't evaluate if we've done so already this epoch\n            if self._last_metric_evaluated_at_epoch == self._epoch_count:\n                return False\n            else:\n                # If we haven't evaluated this epoch, check if we should\n                return self._evaluate_metric_every.match(epoch_count=self._epoch_count)\n        else:\n            # This is reached when evaluate_metric_every is defined and matching by\n            # iteration count\n            return self._evaluate_metric_every.match(iteration_count=self._iteration_count)\n\n    @evaluate_metric_now.setter\n    def evaluate_metric_now(self, value):\n        self._metric_evaluation_externally_triggered = bool(value)\n\n    def build_metric(self, method, **kwargs):\n        \"\"\"\n        Builds the metric for evaluation.\n\n        Parameters\n        ----------\n        method : callable or str\n            Name of the metric when string, metric class or a callable object\n            when callable. If a name is provided, this method looks for the metric in\n            `inferno.extensions.metrics`.\n\n        kwargs : dict\n            Keyword arguments to the metric class' constructor, if applicable.\n\n        Returns\n        -------\n        Trainer\n            self.\n\n        Raises\n        ------\n        AssertionError: if the metric is not found.\n        \"\"\"\n        if callable(method):\n            if isinstance(method, type):\n                self._metric = method(**kwargs)\n            else:\n                self._metric = method\n        elif isinstance(method, str):\n            assert hasattr(metrics, method), \\\n                \"Could not find the metric '{}'.\".format(method)\n            self._metric = getattr(metrics, method)(**kwargs)\n        else:\n            raise NotImplementedError\n        return self\n\n    @property\n    def metric_is_defined(self):\n        \"\"\"Checks if the metric is defined.\"\"\"\n        return self._metric is not None\n\n    def eval_mode(self):\n        \"\"\"Set model, criterion and metric to eval mode\"\"\"\n        self._current_mode = 'eval'\n        self.model.eval()\n        if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module):\n            self.criterion.eval()\n        if self.metric_is_defined and isinstance(self.metric, torch.nn.Module):\n            self.metric.eval()\n        return self\n\n    def train_mode(self):\n        \"\"\"Set model, criterion and metric to train mode\"\"\"\n        self._current_mode = 'train'\n        self.model.train()\n        if self.criterion_is_defined and isinstance(self.criterion, torch.nn.Module):\n            self.criterion.train()\n        if self.metric_is_defined and isinstance(self.metric, torch.nn.Module):\n            self.metric.train()\n        return self\n\n    @property\n    def train_loader(self):\n        assert self._loaders.get('train') is not None\n        return self._loaders.get('train')\n\n    @train_loader.setter\n    def train_loader(self, value):\n        assert isinstance(value, DataLoader)\n        self._loaders.update({'train': value})\n\n    @property\n    def validate_loader(self):\n        assert self._loaders.get('validate') is not None\n        return self._loaders.get('validate')\n\n    @validate_loader.setter\n    def validate_loader(self, value):\n        assert isinstance(value, DataLoader)\n        self._loaders.update({'validate': value})\n\n    @property\n    def logger(self):\n        \"\"\"Gets the logger.\"\"\"\n        return self._logger\n\n    @logger.setter\n    def logger(self, value):\n        if isinstance(value, dict):\n            self.build_logger(**value)\n        else:\n            self.build_logger(logger=value)\n\n    @property\n    def log_directory(self):\n        \"\"\"Gets the log directory.\"\"\"\n        return self._log_directory\n\n    @log_directory.setter\n    def log_directory(self, value):\n        \"\"\"Sets the log directory,\"\"\"\n        if value is not None:\n            self.set_log_directory(value)\n\n    @property\n    def pickle_module(self):\n        module_ = globals().get(self._pickle_module, None)\n        assert_(module_ is not None, \"Pickle module not found!\", ModuleNotFoundError)\n        return module_\n\n    _ALLOWED_PICKLE_MODULES = {'pickle', 'dill'}\n\n    @pickle_module.setter\n    def pickle_module(self, value):\n        assert_(isinstance(value, str), \"`pickle_module` must be set to a string.\", TypeError)\n        assert_(value in self._ALLOWED_PICKLE_MODULES,\n                f\"Pickle module must be one of {self._ALLOWED_PICKLE_MODULES}, \"\n                f\"got {value} instead.\", ValueError)\n        self._pickle_module = value\n\n    @property\n    def saving_every(self):\n        \"\"\"Gets the frequency at which checkpoints are made.\"\"\"\n        return self._save_every\n\n    def save_at_best_validation_score(self, yes=True):\n        \"\"\"Sets whether to save when the validation score is the best seen.\"\"\"\n        self._save_at_best_validation_score = yes\n        return self\n\n    @property\n    def save_now(self):\n        if self._save_externally_triggered:\n            # Reset trigger\n            self._save_externally_triggered = False\n            # Save if externally triggered\n            return True\n        elif self._save_at_best_validation_score and self._is_iteration_with_best_validation_score:\n            return True\n        else:\n            # Check if we're saving by epoch\n            if self._save_every is not None and self._save_every.by_epoch:\n                # Don't save if we've already saved once this epoch\n                if self._epoch_count == self._last_saved_at_epoch:\n                    return False\n                else:\n                    # If we haven't saved this epoch, check if we should\n                    return self._save_every.match(epoch_count=self._epoch_count)\n            else:\n                # We're saving by iterations\n                return self._save_every is not None and \\\n                   self._save_every.match(iteration_count=self._iteration_count)\n\n    @save_now.setter\n    def save_now(self, value):\n        \"\"\"Can be set to true to trigger a checkpoint creation..\"\"\"\n        self._save_externally_triggered = bool(value)\n\n    def save_every(self, frequency, to_directory=None,\n                   checkpoint_filename=None, best_checkpoint_filename=None):\n        \"\"\"\n        Set checkpoint creation frequency.\n\n        Parameters\n        ----------\n        frequency : inferno.utils.train_utils.Frequency or tuple or str\n            Checkpoint creation frequency. Examples: '100 iterations' or '1 epochs'.\n        to_directory : str\n            Directory where the checkpoints are to be created.\n        checkpoint_filename : str\n            Name of the checkpoint file.\n        best_checkpoint_filename : str\n            Name of the best checkpoint file.\n        Returns\n        -------\n        Trainer\n            self.\n        \"\"\"\n        self._save_every = tu.Frequency.build_from(frequency, priority='iterations')\n        assert self._save_every.is_consistent\n        self.save_to_directory(to_directory, checkpoint_filename, best_checkpoint_filename)\n        return self\n\n    @property\n    def save_directory(self):\n        return self._save_to_directory\n\n    def save_to_directory(self, to_directory=None, checkpoint_filename=None,\n                          best_checkpoint_filename=None):\n        if to_directory is not None:\n            assert_(isinstance(to_directory, str), exception_type=TypeError)\n            if not os.path.exists(to_directory):\n                os.makedirs(to_directory)\n            else:\n                assert os.path.isdir(to_directory)\n            self._save_to_directory = to_directory\n        if checkpoint_filename is not None:\n            assert_(isinstance(checkpoint_filename, str), exception_type=TypeError)\n            self._checkpoint_filename = checkpoint_filename\n        if best_checkpoint_filename is not None:\n            assert_(isinstance(best_checkpoint_filename, str), exception_type=TypeError)\n            self._best_checkpoint_filename = best_checkpoint_filename\n        return self\n\n    @property\n    def validating_every(self):\n        return self._validate_every\n\n    @property\n    def validate_now(self):\n        if self._validation_externally_triggered:\n            # Reset trigger\n            self._validation_externally_triggered = False\n            return True\n        elif self._validate_every is not None and self._validate_every.by_epoch:\n            # Don't validate if we've done so already this epoch\n            if self._last_validated_at_epoch == self._epoch_count:\n                return False\n            else:\n                # If we haven't validated this epoch, check if we should\n                return self._validate_every.match(epoch_count=self._epoch_count,\n                                                  match_zero=False)\n        else:\n            # Don't validate if we've done once already this iteration\n            if self._last_validated_at_iteration == self._iteration_count:\n                return False\n            else:\n                # If we haven't validated this iteration, check if we should. The `match_zero` is\n                # redundant, but we'll leave it on anyway.\n                return self._validate_every is not None and \\\n                       self._validate_every.match(iteration_count=self._iteration_count,\n                                                  match_zero=False)\n\n    @validate_now.setter\n    def validate_now(self, value):\n        self._validation_externally_triggered = bool(value)\n\n    def validate_every(self, frequency, for_num_iterations=None):\n        \"\"\"\n        Set validation frequency.\n\n        Parameters\n        ----------\n        frequency : inferno.utils.train_utils.Frequency or str or tuple or list or int\n            Validation frequency. If str, it could be (say) '10 iterations' or '1 epoch'.\n            If tuple (or list), it could be (10, 'iterations') or (1, 'epoch'). If int\n            (say 10), it's interpreted as (10, 'iterations').\n        for_num_iterations : int\n            Number of iterations to validate for. If not set, the model is validated on\n            the entire dataset (i.e. till the data loader is exhausted).\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        self._validate_every = tu.Frequency.build_from(frequency, priority='iterations')\n        assert self._validate_every.is_consistent\n        self._num_validation_iterations = for_num_iterations\n        return self\n\n    @property\n    def iteration_count(self):\n        return self._iteration_count\n\n    @property\n    def epoch_count(self):\n        return self._epoch_count\n\n    @property\n    def target_batch_dim(self):\n        return self._target_batch_dim\n\n    @target_batch_dim.setter\n    def target_batch_dim(self, value):\n        assert_(value in [0, 1],\n                \"target_batch_dim must be either 0 or 1, got {value} instead.\".format(value=value),\n                ValueError)\n        self._target_batch_dim = value\n\n    def set_target_batch_dim(self, value):\n        self.target_batch_dim = value\n        return self\n\n    def build_logger(self, logger=None, log_directory=None, **kwargs):\n        \"\"\"\n        Build the logger.\n\n        Parameters\n        ----------\n        logger : inferno.trainers.callbacks.logging.base.Logger or str or type\n            Must either be a Logger object or the name of a logger or the class of a logger.\n        log_directory : str\n            Path to the directory where the log files are to be stored.\n        kwargs : dict\n            Keyword arguments to the logger class.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        if isinstance(logger, Logger):\n            # Set logger and register with the callback engine.\n            self._logger = logger\n            self.callbacks.register_callback(self._logger)\n        elif callable(logger):\n            kwargs.update({'log_directory': log_directory})\n            self._logger = logger(**kwargs)\n            self.callbacks.register_callback(self._logger)\n        elif isinstance(logger, str):\n            self._logger = get_logger(logger)(**kwargs)\n            self.callbacks.register_callback(self._logger)\n        elif logger is None:\n            pass\n        else:\n            raise NotImplementedError\n\n        if log_directory is not None:\n            self.set_log_directory(log_directory)\n        return self\n\n    def set_log_directory(self, log_directory):\n        \"\"\"\n        Set the directory where the log files are to be stored.\n\n        Parameters\n        ----------\n        log_directory : str\n            Directory where the log files are to be stored.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        self._log_directory = log_directory\n        if self._logger is not None:\n            self._logger.set_log_directory(log_directory)\n        return self\n\n    # States that are fetched dynamically from the trainer object via properties are\n    # dynamic states. Such states can not be updated.\n    # The following dictionary maps state keys to the corresponding trainer attribute\n    DYNAMIC_STATES = {'learning_rate': 'current_learning_rate'}\n\n    def update_state(self, key, value):\n        assert key not in self.DYNAMIC_STATES, \\\n            \"State at key '{}' cannot be updated because it's dynamic.\".format(key)\n        self._state.update({key: value})\n        return self\n\n    def update_state_from_dictionary(self, dictionary):\n        # Unwrap variables (or tensors)\n        self._state.update({\n            state_key: thu.unwrap(state)\n            for state_key, state in dictionary.items()})\n\n    def update_state_from_model_state_hooks(self):\n        if hasattr(self.model, '_state_hooks'):\n            state_hooks = getattr(self.model, '_state_hooks')\n            if isinstance(state_hooks, dict):\n                self.update_state_from_dictionary(state_hooks)\n\n    def get_state(self, key, default=None):\n        if key in self.DYNAMIC_STATES:\n            return getattr(self, self.DYNAMIC_STATES.get(key), default)\n        else:\n            return self._state.get(key, default)\n\n    @property\n    def current_learning_rate(self):\n        return self.get_current_learning_rate()\n\n    def get_current_learning_rate(self):\n        \"\"\"\n        Gets the current learning rate.\n        Returns\n        -------\n        list or float\n            List of learning rates if there are multiple parameter groups, or a float\n            if there's just one.\n        \"\"\"\n        learning_rate = [param_group.get('lr', -1.)\n                         for param_group in self.optimizer.param_groups]\n        learning_rate = [_learning_rate[0] if thu.is_tensor(_learning_rate) else _learning_rate\n                         for _learning_rate in learning_rate]\n        return pyu.from_iterable(learning_rate)\n\n    def to(self, device):\n        \"\"\"\n        Send trainer to device\n        ----------\n        device : string or torch.device\n            Target device where trainer/model should be send to\n        \"\"\"\n        if device == 'cuda':\n            return self.cuda()\n        elif device == 'cpu':\n            return self.cpu()\n        elif isinstance(device, torch.torch.device):\n            self.to(device.type)\n        else:\n            raise NotImplementedError(\"Can not send trainer to device\", device)\n\n    def cuda(self, devices=None, base_device=None):\n        \"\"\"\n        Train on the GPU.\n\n        Parameters\n        ----------\n        devices : list\n            Specify the ordinals of the devices to use for dataparallel training.\n\n        base_device : {'cpu', 'cuda'}\n            When using data-parallel training, specify where the result tensors\n            are collected. If 'cuda', the results are collected in `devices[0]`.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        # Validate base_device\n        assert_(base_device in [None, 'cpu', 'cuda'],\n                \"`base_device` must either be 'cpu' or 'cuda', got {} instead.\"\n                .format(base_device),\n                DeviceError)\n        if isinstance(devices, int) or (isinstance(devices, (list, tuple)) and len(devices) == 1):\n            # No data-parallelism, make sure base_device is not CPU\n            assert_(base_device != 'cpu',\n                    \"Without dataparallelism, `base_device` cannot be 'cpu'.\",\n                    DeviceError)\n        self._base_device_ordinal = {None: None, 'cpu': -1, 'cuda': None}.get(base_device)\n        # Move model to CUDA\n        if self.model_is_defined:\n            self.model.cuda()\n        # Move criterion to cuda if base device ordinal is not -1 (i.e. CPU)\n        # (the criterion is evaluated on the base device)\n        if self.criterion_is_defined and self._base_device_ordinal != -1:\n            self.criterion.cuda()\n        elif self.criterion_is_defined and self._base_device_ordinal == -1:\n            # Criterion is evaluated on the CPU, make sure that's where it lives\n            self.criterion.cpu()\n        self._use_cuda = True\n        self._devices = devices\n        return self\n\n    def cpu(self):\n        \"\"\"\n        Train on the CPU.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        if self.model_is_defined:\n            self.model.cpu()\n        if self.criterion_is_defined:\n            self.criterion.cpu()\n        self._use_cuda = False\n        self._devices = None\n        return self\n\n    def is_cuda(self):\n        \"\"\"Returns whether using GPU for training.\"\"\"\n        return self._use_cuda\n\n    def to_device(self, objects):\n        if isinstance(objects, (list, tuple)):\n            return type(objects)([self.to_device(_object) for _object in objects])\n        else:\n            return objects.cuda() if self._use_cuda else objects\n\n    def apply_model(self, *inputs):\n        if hasattr(self, '_base_device_ordinal'):\n            # This is to not break old checkpoints\n            base_device_ordinal = self._base_device_ordinal\n        else:\n            base_device_ordinal = None\n        if self._devices is not None:\n            return data_parallel(self.model, inputs, list(self._devices),\n                                 output_device=base_device_ordinal)\n        else:\n            return self.model(*inputs)\n\n    def cast(self, objects):\n        if isinstance(objects, (list, tuple)):\n            return type(objects)([self.cast(_object) for _object in objects])\n        else:\n            # Cast only the float types, while leaving the ints alone\n            if objects.__class__.__name__ in ['HalfTensor', 'FloatTensor', 'DoubleTensor']:\n                cast_fn = getattr(objects, self._dtype, None)\n            else:\n                cast_fn = None\n\n            if cast_fn is not None:\n                return cast_fn()\n            else:\n                return objects\n\n    def set_precision(self, dtype):\n        \"\"\"\n        Set training precision.\n\n        Parameters\n        ----------\n        dtype : {'double', 'float', 'half'}\n            Training precision.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        assert dtype in ['double', 'float', 'half']\n        self._dtype = dtype\n        self._model = getattr(self._model, dtype)()\n        return self\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    @dtype.setter\n    def dtype(self, value):\n        self.set_precision(value)\n\n    def bind_loader(self, name, loader, num_inputs=None, num_targets=1):\n        \"\"\"\n        Bind a data loader to the trainer.\n\n        Parameters\n        ----------\n        name : {'train', 'validate', 'test'}\n            Name of the loader, i.e. what it should be used for.\n        loader : torch.utils.data.DataLoader\n            DataLoader object.\n        num_inputs : int\n            Number of input tensors from the `loader`.\n        num_targets : int\n            Number of target tensors from the `loader`.\n\n        Returns\n        -------\n        Trainer\n            self\n\n        Raises\n        ------\n        KeyError\n            if name is invalid.\n        TypeError\n            if loader is not a DataLoader instance.\n        \"\"\"\n        assert_(name in ['train', 'validate', 'test'],\n                \"`name` must be one of ['train', 'validate', 'test']. \"\n                \"Got {} instead.\".format(name),\n                KeyError)\n        assert_(isinstance(loader, DataLoader),\n                \"`loader` must be a DataLoader object. \"\n                \"Got {} instead.\".format(type(loader).__name__),\n                TypeError)\n        # Check to see if the loader is actually new. This should usually be True.\n        is_new_loader = loader is not self._loaders.get(name)\n        self._loaders.update({name: loader})\n        # We also need to account for the case when a loader is being replaced. When this happens,\n        # the old DataLoaderIter might still have processes running, which we need to kill.\n        if is_new_loader and name in self._loader_iters:\n            # This is when the previous loader already has a DataLoaderIter running.\n            # The DataLoaderIter implements a __del__ method, which shuts down workers.\n            del self._loader_iters[name]\n        # Trainers loaded from pickle files might not have '_loader_specs', therefore:\n        if not hasattr(self, '_loader_specs'):\n            setattr(self, '_loader_specs', {})\n        self._loader_specs.update({name: {'num_inputs': num_inputs,\n                                          'num_targets': num_targets}})\n        return self\n\n    def get_loader_specs(self, name):\n        assert name in self._loader_specs.keys(), \\\n            \"Could not find specs about loader '{}'. Valid loader names are: {}\" \\\n                .format(name, set(self._loader_specs.keys()))\n        return self._loader_specs.get(name)\n\n    def fetch_next_batch(self, from_loader='train', restart_exhausted_generators=True,\n                         update_batch_count=True, update_epoch_count_if_generator_exhausted=True):\n        # Check if the iterator is built\n        if from_loader not in self._loader_iters:\n            self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()})\n        # Try to fetch from iterator\n        try:\n            # Fetch\n            next_batch = next(self._loader_iters[from_loader])\n            # Verify\n            self.verify_batch(next_batch, from_loader)\n            if update_batch_count:\n                self._batch_count += 1\n            return next_batch\n        except StopIteration:\n            # This if clause prevents infinite recursion if the loader is empty\n            if restart_exhausted_generators:\n                self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()})\n                # Update epoch count\n                if update_epoch_count_if_generator_exhausted:\n                    self.next_epoch()\n                return self.fetch_next_batch(from_loader, restart_exhausted_generators=False,\n                                             update_batch_count=update_batch_count)\n            else:\n                raise\n\n    def verify_batch(self, batch, from_loader):\n        loader_specs = self.get_loader_specs(from_loader)\n        num_inputs = loader_specs.get('num_inputs')\n        num_targets = loader_specs.get('num_targets')\n        if None not in [num_inputs, num_targets]:\n            assert len(batch) == num_inputs + num_targets, \\\n                \"Was expecting a batch with {} (= num_inputs) + {} (= num_targets) tensors, \" \\\n                \"got one with {} tensors.\".format(num_inputs, num_targets, len(batch))\n        if num_inputs is not None:\n            assert len(batch) > num_inputs, \\\n                \"Expecting {} inputs, but the batch contains only {} tensors.\" \\\n                    .format(num_inputs, len(batch))\n        if num_targets is not None:\n            assert len(batch) > num_targets, \\\n                \"Expecting {} outputs, but the batch contains only {} tensors.\" \\\n                    .format(num_targets, len(batch))\n        return batch\n\n    def split_batch(self, batch, from_loader):\n        loader_specs = self.get_loader_specs(from_loader)\n        num_inputs = loader_specs.get('num_inputs')\n        num_targets = loader_specs.get('num_targets')\n        assert not (num_targets is None and num_inputs is None), \\\n            \"Can not split batch if both the number of inputs and targets is not known.\"\n        if num_inputs is None:\n            # Unknown number of inputs\n            num_inputs = len(batch) - num_targets    #to allow for num_targets == 0\n            inputs, targets = batch[:num_inputs], batch[num_inputs:]\n        elif num_targets is None:\n            # Unknown number of targets\n            inputs, targets = batch[:num_inputs], batch[num_inputs:]\n        else:\n            # Known number of inputs and targets\n            inputs, targets = batch[:num_inputs], batch[-num_targets:]\n        return inputs, pyu.from_iterable(targets)\n\n    def restart_generators(self, of_loader=None):\n        if of_loader is None:\n            of_loader = self._loaders.keys()\n        else:\n            assert of_loader in self._loaders.keys(), \\\n                \"Key {} not in loaders ({})\".format(of_loader, list(self._loaders))\n            of_loader = pyu.to_iterable(of_loader)\n\n        self._loader_iters.update({from_loader: self._loaders[from_loader].__iter__()\n                                   for from_loader in of_loader})\n        return self\n\n    def wrap_batch(self, batch, from_loader=None, requires_grad=False):\n        base_device_ordinal = \\\n            self._base_device_ordinal if hasattr(self, '_base_device_ordinal') else None\n        # First, send to the right device\n        if base_device_ordinal is None:\n            # Both inputs and labels are sent to the device\n            batch = self.to_device(batch)\n        elif base_device_ordinal == -1:\n            # Input batches go to device, while labels remain on the CPU.\n            # To start, we need the number of input batches, i.e. from_loader must not be None\n            assert_(from_loader is not None,\n                    \"`from_loader` needs to be specified if base_device_ordinal is -1 \"\n                    \"(i.e. base device for data-parallel training is CPU).\",\n                    ValueError)\n            loader_spec = self._loader_specs.get(from_loader)\n            assert_(loader_spec is not None,\n                    \"No `loader_spec` found for loader key '{}'.\".format(from_loader),\n                    RuntimeError)\n            num_inputs = loader_spec['num_inputs']\n            if num_inputs is None:\n                num_inputs = len(batch) - loader_spec['num_targets']\n            # Fetch input batches and send'em to device (leave the targets alone)\n            inputs = batch[:num_inputs]\n            inputs = self.to_device(inputs)\n            # Finally, build the batch\n            batch = inputs + batch[num_inputs:]\n        else:\n            raise ValueError(\"Internal Error: Invalid base_device_ordinal: {}.\"\n                             .format(base_device_ordinal))\n\n        # Cast to the right dtype and return\n        batch = self.cast(batch)\n        # Set gradients if required\n        variable_batch = []\n        for batch_num, _batch in enumerate(batch):\n            if thu.is_tensor(_batch):\n                variable_batch.append(_batch.requires_grad_() if requires_grad else _batch)\n            elif pyu.is_listlike(_batch):\n                variable_batch.append([__batch.requires_grad_() if requires_grad else __batch\n                                       for __batch in _batch])\n            else:\n                raise RuntimeError(f\"Was Expecting batch at index {batch_num} to be either a \"\n                                   f\"tensor or a list of tensors. Got {type(_batch)} instead.\")\n        batch = type(batch)(variable_batch)\n        return batch\n\n    def next_iteration(self):\n        self._iteration_count += 1\n\n    def next_epoch(self):\n        # Callback before the end of epoch\n        self.callbacks.call(self.callbacks.END_OF_EPOCH,\n                            epoch_count=self._epoch_count,\n                            batch_count=self._batch_count,\n                            iteration_count=self._iteration_count)\n        self._epoch_count += 1\n        self._batch_count = 0\n        # Callback after the start of epoch\n        self.callbacks.call(self.callbacks.BEGIN_OF_EPOCH,\n                            epoch_count=self._epoch_count,\n                            batch_count=self._batch_count,\n                            iteration_count=self._iteration_count)\n\n    def stop_fitting(self, max_num_iterations=None, max_num_epochs=None):\n        # First priority to iteration count\n        if max_num_iterations is not None or max_num_epochs is None:\n            max_num_iterations = \\\n                self._max_num_iterations if max_num_iterations is None else max_num_iterations\n            assert_(max_num_iterations is not None,\n                    \"Neither max_num_iterations nor max_num_epochs was set.\",\n                    RuntimeError)\n            return self._iteration_count >= max_num_iterations\n        else:\n            # max_num_epochs is specified. It could be 'auto', in which case we read from the\n            # class attribute\n            max_num_epochs = self._max_num_epochs \\\n                if isinstance(max_num_epochs, str) and max_num_epochs.lower() == 'auto' \\\n                else max_num_epochs\n            return self._epoch_count >= max_num_epochs\n\n    INF_STRINGS = {'inf', 'infinity', 'infty'}\n\n    def set_max_num_iterations(self, max_num_iterations):\n        \"\"\"\n        Set the maximum number of training iterations.\n\n        Parameters\n        ----------\n        max_num_iterations : int or float or str\n            Maximum number of training iterations. If float, it should equal numpy.inf.\n            If str, it should be one of {'inf', 'infinity', 'infty'}.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        max_num_iterations = \\\n            inf if max_num_iterations in self.INF_STRINGS else max_num_iterations\n        # Validate type\n        assert_(isinstance(max_num_iterations, int) or max_num_iterations == inf,\n                \"max_num_iterations must be an integer or numpy.inf, got {} instead.\"\n                .format(type(max_num_iterations).__name__),\n                TypeError)\n        self._max_num_iterations = max_num_iterations\n        return self\n\n    def set_max_num_epochs(self, max_num_epochs):\n        \"\"\"\n        Set the maximum number of training epochs.\n\n        Parameters\n        ----------\n        max_num_epochs : int or float or str\n            Maximum number of training epochs. If float, it should equal numpy.inf.\n            If str, it should be one of {'inf', 'infinity', 'infty'}.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs\n        assert_(isinstance(max_num_epochs, int) or max_num_epochs == inf,\n                \"max_num_epochs must be an integer or numpy.inf, got {} instead.\"\n                .format(type(max_num_epochs).__name__),\n                TypeError)\n        self._max_num_epochs = max_num_epochs\n        return self\n\n    def fit(self, max_num_iterations=None, max_num_epochs=None):\n        \"\"\"\n        Fit model.\n\n        Parameters\n        ----------\n        max_num_iterations : int or float or str\n            (Optional) Maximum number of training iterations. Overrides the value set by\n            `Trainer.set_max_num_iterations`. If float, it should equal numpy.inf.\n            If str, it should be one of {'inf', 'infinity', 'infty'}.\n        max_num_epochs : int or float or str\n            (Optional) Maximum number of training epochs. Overrides the value set by\n            `Trainer.set_max_num_epochs`. If float, it should equal numpy.inf.\n            If str, it should be one of {'inf', 'infinity', 'infty'}.\n\n        Returns\n        -------\n        Trainer\n            self\n\n        \"\"\"\n        # Takes care of:\n        #   - dispatching train\n        #   - validation\n        #   - learning rate scheduling\n        #   - saving\n\n        max_num_iterations = inf if max_num_iterations in self.INF_STRINGS else max_num_iterations\n        max_num_iterations = self._max_num_iterations if max_num_iterations is None \\\n            else max_num_iterations\n\n        max_num_epochs = inf if max_num_epochs in self.INF_STRINGS else max_num_epochs\n        max_num_epochs = self._max_num_epochs if max_num_epochs is None else max_num_epochs\n\n        self.callbacks.call(self.callbacks.BEGIN_OF_FIT,\n                            max_num_iterations=max_num_iterations,\n                            max_num_epochs=max_num_epochs)\n\n        # Local clock\n        run_num = 0\n        while True:\n            if self.stop_fitting(max_num_iterations, max_num_epochs):\n                self.console.info(\"Exceeded max number of iterations / epochs, breaking.\")\n                break\n            # Train\n            self.train_for(break_callback=lambda *args: self.stop_fitting(max_num_iterations,\n                                                                          max_num_epochs))\n            # Check if it's time to validate\n            if self.validate_now:\n                self.console.info(\"Validating.\")\n                self.validate_for()\n            # Check if it's time to save\n            if self.save_now:\n                self.console.info(\"Saving.\")\n                self.save()\n            run_num += 1\n\n        # Call callback\n        self.callbacks.call(self.callbacks.END_OF_FIT,\n                            max_num_iterations=max_num_iterations,\n                            max_num_epochs=max_num_epochs,\n                            num_runs=run_num)\n\n        return self\n\n    def apply_model_and_loss(self, inputs, target, backward=True, mode=None):\n        if mode is None:\n            mode = self._current_mode\n            assert_(mode in ['train', 'eval'],\n                    f\"`mode` must be one of ['train', 'eval'], got {mode} instead.\", ValueError)\n        # Compute prediction\n        prediction = self.apply_model(*inputs)\n        # Compute loss\n        kwargs = {}\n        if (isinstance(self.criterion, torch.nn.Module) and\n                'trainer' in signature(self.criterion.forward).parameters):\n            kwargs['trainer'] = self\n        if mode == 'train':\n            loss = self.criterion(prediction, target, **kwargs) \\\n                   if len(target) != 0 else self.criterion(prediction, **kwargs)\n        elif mode == 'eval':\n            loss = self.validation_criterion(prediction, target, **kwargs) \\\n                   if len(target) != 0 else self.validation_criterion(prediction, **kwargs)\n        else:\n            raise ValueError\n        if backward:\n            # Backprop if required\n            # retain_graph option is needed for some custom\n            # loss functions like malis, False per default\n            if self.mixed_precision:\n                with amp.scale_loss(loss, self.optimizer) as scaled_loss:\n                    scaled_loss.backward(retain_graph=self.retain_graph)\n            else:\n                loss.backward(retain_graph=self.retain_graph)\n        return prediction, loss\n\n    def train_for(self, num_iterations=None, break_callback=None):\n        # Switch model to train mode\n        self.train_mode()\n        # Call callback\n        self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_RUN,\n                            num_iterations=num_iterations)\n        # iteration_num is a local clock. There's the global self._iteration_count that keeps\n        # actual track of the number of iterations - this is updated by the call to\n        # self.next_iteration().\n        iteration_num = 0\n        while True:\n            if num_iterations is not None and iteration_num >= num_iterations:\n                self.console.info(\"Finished {} iterations. Breaking...\".format(num_iterations))\n                break\n            # Break if break callback asks us to\n            if break_callback is not None and break_callback(iteration_num):\n                self.console.info(\"Breaking on request from callback.\")\n                break\n            self.console.progress(\"Training iteration {} (batch {} of epoch {}).\"\n                                  .format(iteration_num, self._batch_count, self._epoch_count))\n            # Call callback\n            self.callbacks.call(self.callbacks.BEGIN_OF_TRAINING_ITERATION,\n                                iteration_num=iteration_num)\n            # No interrupts while computing - a SIGINT could shoot down the driver if\n            # done at the wrong time. Not sure if this has something to do with pinned memory\n            with pyu.delayed_keyboard_interrupt():\n                # Get batch\n                batch = self.fetch_next_batch('train')\n                # Send to device and wrap as variable\n                batch = self.wrap_batch(batch, from_loader='train')\n                # Separate inputs from targets\n                inputs, target = self.split_batch(batch, from_loader='train')\n                # Apply model, compute loss and backprop\n                prediction, loss = self.apply_model_and_loss(inputs, target, backward=True,\n                                                             mode='train')\n            self.callbacks.call(self.callbacks.AFTER_MODEL_AND_LOSS_IS_APPLIED,\n                                prediction=prediction, loss=loss, iteration_num=iteration_num)\n            # Compute metric\n            if self.metric_is_defined and self.evaluate_metric_now:\n                self._last_metric_evaluated_at_epoch = self._epoch_count\n                # TODO Make unwrap a method for folks to overload\n                error = self.metric(thu.unwrap(prediction, to_cpu=False),\n                                    thu.unwrap(target, to_cpu=False))\n                self.update_state('training_error', thu.unwrap(error))\n            else:\n                error = None\n            # Update state from computation\n            self.update_state('training_inputs', thu.unwrap(inputs))\n            self.update_state('training_target', thu.unwrap(target))\n            self.update_state('training_prediction', thu.unwrap(prediction))\n            self.update_state('training_loss', thu.unwrap(loss))\n            # Update state from model's state hooks\n            self.update_state_from_model_state_hooks()\n            if iteration_num % self.backprop_every == 0:\n               # Update parameters\n                self.optimizer.step()\n                # Zero out the grads\n                self.optimizer.zero_grad()\n            # Call callback\n            self.callbacks.call(self.callbacks.END_OF_TRAINING_ITERATION,\n                                iteration_num=iteration_num)\n            # Prepare for next iteration\n            self.next_iteration()\n            # Break if validating or saving. It's important that the next_iteration() method is\n            # called before checking validate_now and save_now - because otherwise, the iteration\n            # counter is never updated after the first save and validate, resulting in an infinite\n            # save + validate loop.\n            if self.validate_now:\n                self.console.info(\"Breaking to validate.\")\n                break\n            if self.save_now:\n                self.console.info(\"Breaking to save.\")\n                break\n            iteration_num += 1\n\n        self.callbacks.call(self.callbacks.END_OF_TRAINING_RUN, num_iterations=num_iterations)\n        return self\n\n    def validate_for(self, num_iterations=None, loader_name='validate'):\n        \"\"\"\n        Validate for a given number of validation (if `num_iterations is not None`)\n        or over the entire (validation) data set.\n\n        Parameters\n        ----------\n        num_iterations : int\n            Number of iterations to validate for. To validate on the entire dataset,\n            leave this as `None`.\n        loader_name : str\n            Name of the data loader to use for validation. 'validate' is the obvious default.\n\n        Returns\n        -------\n        Trainer\n            self.\n        \"\"\"\n        assert_(loader_name in ['validate', 'test', 'train'],\n                \"Invalid `loader_name`: {}\".format(loader_name),\n                ValueError)\n        # Average over errors\n        validation_error_meter = tu.AverageMeter()\n        validation_loss_meter = tu.AverageMeter()\n        iteration_num = 0\n        num_iterations = \\\n            self._num_validation_iterations if num_iterations is None else num_iterations\n\n        # Switch to eval mode (e.g. for batchnorm, etc.)\n        self.eval_mode()\n\n        if loader_name not in self._loader_iters:\n            self._loader_iters.update({loader_name: self._loaders[loader_name].__iter__()})\n\n        # If we don't know num_iterations, we're validating the entire dataset - so we might as\n        # well restart the loader now\n        if num_iterations is None:\n            self.restart_generators(loader_name)\n\n        # Record the epoch we're validating in\n        self._last_validated_at_epoch = self._epoch_count\n        self._last_validated_at_iteration = self._iteration_count\n        self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_RUN,\n                            num_iterations=num_iterations,\n                            num_iterations_in_generator=len(self._loader_iters[loader_name]),\n                            last_validated_at_epoch=self._last_validated_at_epoch)\n\n        while True:\n            if num_iterations is not None and iteration_num >= num_iterations:\n                break\n\n            self.callbacks.call(self.callbacks.BEGIN_OF_VALIDATION_ITERATION,\n                                iteration_num=iteration_num)\n\n            try:\n                batch = self.fetch_next_batch(loader_name,\n                                              restart_exhausted_generators=num_iterations is not None,\n                                              update_batch_count=False,\n                                              update_epoch_count_if_generator_exhausted=False)\n            except StopIteration:\n                self.console.info(\"{} generator exhausted, breaking.\".format(loader_name))\n                break\n\n            self.console.progress(\"Validating iteration {}.\".format(iteration_num))\n\n            # Delay SIGINTs till after computation\n            with pyu.delayed_keyboard_interrupt(), torch.no_grad():\n                # Wrap\n                batch = self.wrap_batch(batch, from_loader=loader_name)\n                # Separate\n                inputs, target = self.split_batch(batch, from_loader=loader_name)\n                # Apply model, compute loss\n                output, loss = self.apply_model_and_loss(inputs, target, backward=False,\n                                                         mode='eval')\n            if isinstance(target, (list, tuple)):\n                batch_size = target[0].size(self._target_batch_dim)\n            else:\n                batch_size = target.size(self._target_batch_dim)\n            validation_loss_meter.update(thu.unwrap(loss, extract_item=True), n=batch_size)\n\n            # Compute validation_error\n            if self.metric_is_defined:\n                validation_error = self.metric(thu.unwrap(output, to_cpu=False),\n                                               thu.unwrap(target, to_cpu=False))\n                if torch.is_tensor(validation_error):\n                    # Convert to float\n                    validation_error = thu.unwrap(validation_error, extract_item=True)\n                self.update_state('validation_error', thu.unwrap(validation_error))\n                validation_error_meter.update(validation_error, n=batch_size)\n\n            self.update_state('validation_inputs', thu.unwrap(inputs))\n            self.update_state('validation_target', thu.unwrap(target))\n            self.update_state('validation_prediction', thu.unwrap(output))\n            self.update_state('validation_loss', thu.unwrap(loss))\n            # This is here for legacy reasons and will eventually be deprecated.\n            self.update_state('validation_input', self.get_state('validation_inputs'))\n            # Update from model's state hooks\n            self.update_state_from_model_state_hooks()\n\n            self.callbacks.call(self.callbacks.END_OF_VALIDATION_ITERATION,\n                                iteration_num=iteration_num)\n\n            iteration_num += 1\n\n        self.console.info(\"Done validating. Logging results...\")\n\n        # Report\n        validation_results = {\n            'validation_loss': validation_loss_meter.avg,\n            'validation_error': (validation_error_meter.avg if self.metric_is_defined else None)\n        }\n        self.record_validation_results(**validation_results)\n\n        self.console.info(\"Validation loss: {validation_loss}; validation error: \"\n                          \"{validation_error}\".format(**validation_results))\n\n        self.callbacks.call(self.callbacks.END_OF_VALIDATION_RUN,\n                            validation_loss_meter=validation_loss_meter,\n                            validation_error_meter=validation_error_meter if\n                            self.metric_is_defined else None)\n        return self\n\n    def record_validation_results(self, validation_loss, validation_error):\n        # Update state\n        self.update_state('validation_loss_averaged', thu.unwrap(validation_loss))\n        if validation_error is not None:\n            self.update_state('validation_error_averaged', thu.unwrap(validation_error))\n        # Prefer the error metric (if provided). This should be handled with care -\n        # validation error should either always not be None, or otherwise.\n        validation_score = validation_loss if validation_error is None else validation_error\n        # Check if validation error is less than the best so far\n        if self._best_validation_score is None or validation_score < self._best_validation_score:\n            # Best score so far. The following flag will trigger a save\n            self._is_iteration_with_best_validation_score = True\n            self._best_validation_score = validation_score\n\n    def get_config(self, exclude_loader=True):\n        # Returns a config dictionary, like __getstate__. Except optionally without the\n        # data loaders (which might be yuuuuuge if it contains the data)\n        config_dict = dict(self.__dict__)\n        # Loader iterators can't be pickled\n        if '_loader_iters' in config_dict:\n            config_dict.update({'_loader_iters': {}})\n        if exclude_loader:\n            if '_loaders' in config_dict:\n                config_dict.update({'_loaders': {}})\n        return config_dict\n\n    def set_config(self, config_dict):\n        # TODO some sanity checks on config_dict (e.g. whether the model is actually a model, etc)\n        self.__dict__.update(config_dict)\n        # Rebind trainer to callback engine\n        self.callbacks.bind_trainer(self)\n        # Have callback engine rebind all callbacks to trainer\n        self.callbacks.rebind_trainer_to_all_callbacks()\n        return self\n\n    def save(self, exclude_loader=True, stash_best_checkpoint=True):\n        # Log the epoch for save_now\n        self._last_saved_at_epoch = self._epoch_count\n\n        self.callbacks.call(self.callbacks.BEGIN_OF_SAVE,\n                            save_to_directory=self._save_to_directory,\n                            epoch_count=self._epoch_count,\n                            batch_count=self._batch_count,\n                            iteration_count=self._iteration_count,\n                            is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score)\n\n        checkpoint_path = os.path.join(self._save_to_directory,\n                                       self._checkpoint_filename)\n        best_checkpoint_path = os.path.join(self._save_to_directory,\n                                            self._best_checkpoint_filename)\n\n        # Save the state dictionary\n        torch.save(self.get_config(exclude_loader=exclude_loader),\n                   checkpoint_path,\n                   pickle_module=self.pickle_module)\n\n        self.callbacks.call(self.callbacks.END_OF_SAVE,\n                            save_to_directory=self._save_to_directory,\n                            checkpoint_path=checkpoint_path,\n                            best_checkpoint_path=best_checkpoint_path,\n                            epoch_count=self._epoch_count,\n                            batch_count=self._batch_count,\n                            iteration_count=self._iteration_count,\n                            is_iteration_with_best_validation_score=self._is_iteration_with_best_validation_score)\n\n        if self._is_iteration_with_best_validation_score and stash_best_checkpoint:\n            # Do the stashin'\n            shutil.copyfile(checkpoint_path, best_checkpoint_path)\n\n        # This is required to prevent an infinite save loop?\n        self._is_iteration_with_best_validation_score = False\n        self.console.info(\"Saved to {}.\".format(self._save_to_directory))\n        return self\n\n    def save_model(self, to_directory=None):\n        to_directory = self._save_to_directory if to_directory is None else to_directory\n        # Save the state dictionary\n        torch.save(self.model,\n                   os.path.join(to_directory, 'model.pytorch'),\n                   pickle_module=self.pickle_module)\n        return self\n\n    def load(self, from_directory=None, best=False, filename=None, map_location=None):\n        \"\"\"\n        Load the trainer from checkpoint.\n\n        Parameters\n        ----------\n        from_directory : str\n            Path to the directory where the checkpoint is located. The filename should be\n            'checkpoint.pytorch' if best=False, or 'best_checkpoint.pytorch' if best=True.\n        best : bool\n            Whether to load the best checkpoint. The filename in `from_directory` should be\n            'best_checkpoint.pytorch'.\n        filename : str\n            Overrides the default filename.\n        map_location : function, torch.device, string or a dict\n            Specify how to remap storage locations.\n\n        Returns\n        -------\n        Trainer\n            self\n        \"\"\"\n        from_directory = self._save_to_directory if from_directory is None else from_directory\n        assert from_directory is not None, \"Nowhere to load from.\"\n        # Get file name\n        if filename is None:\n            filename = self._best_checkpoint_filename if best else self._checkpoint_filename\n        # Load the dictionary\n        config_dict = torch.load(os.path.join(from_directory, filename),\n                                 pickle_module=self.pickle_module, map_location=map_location)\n\n        # This is required to prevent an infinite save loop?\n        self._is_iteration_with_best_validation_score = False\n        # Set config\n        self.set_config(config_dict)\n        return self\n\n    def load_model(self, from_directory=None, filename=None):\n        from_directory = self._save_to_directory if from_directory is None else from_directory\n        filename = 'model.pytorch' if filename is None else filename\n        # Load the model\n        model = torch.load(os.path.join(from_directory, filename),\n                           pickle_module=self.pickle_module)\n        # Set model\n        self.model = model\n        return self\n\n    def load_(self, *args, **kwargs):\n        # Here for legacy reasons - use load instead.\n        return self.load(*args, **kwargs)\n\n    @pyu.deprecated(\"please use self.console.{info,progress,warning,debug} instead\")\n    def print(self, message):\n        print(\"[+][{}] {}\".format(str(datetime.now()), message))\n\n    @classmethod\n    def build(cls, model=None, **trainer_config):\n        \"\"\"Factory function to build the trainer.\"\"\"\n        # Check if trainer is to be loaded from file\n        if trainer_config.get('load_from_checkpoint'):\n            # Load checkpoint config\n            trainer = cls(model).save_every(**trainer_config.get('checkpoint_config'))\n            trainer.load_()\n        else:\n            trainer = cls(model)\n            if 'logger_config' in trainer_config:\n                trainer.build_logger(**trainer_config.get('logger_config'))\n            if 'criterion_config' in trainer_config:\n                trainer.build_criterion(**trainer_config.get('criterion_config'))\n            if 'optimizer_config' in trainer_config:\n                trainer.build_optimizer(**trainer_config.get('optimizer_config'))\n            if 'metric_config' in trainer_config:\n                trainer.build_metric(**trainer_config.get('metric_config'))\n            if 'checkpoint_config' in trainer_config:\n                trainer.save_every(**trainer_config.get('checkpoint_config'))\n            if 'validation_config' in trainer_config:\n                trainer.validate_every(**trainer_config.get('validation_config'))\n            if 'max_num_iterations' in trainer_config:\n                trainer.set_max_num_iterations(trainer_config.get('max_num_iterations'))\n            if 'max_num_epochs' in trainer_config:\n                trainer.set_max_num_epochs(trainer_config.get('max_num_epochs'))\n            if trainer_config.get('use_cuda'):\n                devices = trainer_config.get('use_cuda').get('devices') \\\n                    if isinstance(trainer_config.get('use_cuda'), dict) else None\n                trainer.cuda(devices=devices)\n            if 'training_precision' in trainer_config:\n                trainer.set_precision(trainer_config.get('training_precision'))\n        return trainer\n"
  },
  {
    "path": "inferno/trainers/callbacks/__init__.py",
    "content": "__all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients']\n\nfrom .base import CallbackEngine, Callback\nfrom .console import Console\nfrom . import essentials\nfrom . import scheduling\nfrom . import gradients\n\ntry:\n    from .tqdm import TQDMProgressBar\n    __all__.append('TQDMProgressBar')\nexcept ImportError:\n    from .tqdmstub import TQDMProgressBar\n"
  },
  {
    "path": "inferno/trainers/callbacks/base.py",
    "content": "from ...utils import python_utils as pyu\n\n\nclass CallbackEngine(object):\n    \"\"\"\n    Gathers and manages callbacks.\n\n    Callbacks are callables which are to be called by trainers when certain events ('triggers')\n    occur. They could be any callable object, but if endowed with a `bind_trainer` method,\n    it's called when the callback is registered. It is recommended that callbacks\n    (or their `__call__` methods) use the double-star syntax for keyword arguments.\n    \"\"\"\n    # Triggers\n    BEGIN_OF_FIT = 'begin_of_fit'\n    END_OF_FIT = 'end_of_fit'\n    BEGIN_OF_TRAINING_RUN = 'begin_of_training_run'\n    END_OF_TRAINING_RUN = 'end_of_training_run'\n    BEGIN_OF_EPOCH = 'begin_of_epoch'\n    END_OF_EPOCH = 'end_of_epoch'\n    BEGIN_OF_TRAINING_ITERATION = 'begin_of_training_iteration'\n    AFTER_MODEL_AND_LOSS_IS_APPLIED = 'after_model_and_loss_is_applied'\n    END_OF_TRAINING_ITERATION = 'end_of_training_iteration'\n    BEGIN_OF_VALIDATION_RUN = 'begin_of_validation_run'\n    END_OF_VALIDATION_RUN = 'end_of_validation_run'\n    BEGIN_OF_VALIDATION_ITERATION = 'begin_of_validation_iteration'\n    END_OF_VALIDATION_ITERATION = 'end_of_validation_iteration'\n    BEGIN_OF_SAVE = 'begin_of_save'\n    END_OF_SAVE = 'end_of_save'\n\n    TRIGGERS = {BEGIN_OF_FIT,\n                END_OF_FIT,\n                BEGIN_OF_TRAINING_RUN,\n                END_OF_TRAINING_RUN,\n                BEGIN_OF_EPOCH,\n                END_OF_EPOCH,\n                BEGIN_OF_TRAINING_ITERATION,\n                AFTER_MODEL_AND_LOSS_IS_APPLIED,\n                END_OF_TRAINING_ITERATION,\n                BEGIN_OF_VALIDATION_RUN,\n                END_OF_VALIDATION_RUN,\n                BEGIN_OF_VALIDATION_ITERATION,\n                END_OF_VALIDATION_ITERATION,\n                BEGIN_OF_SAVE,\n                END_OF_SAVE}\n\n    def __init__(self):\n        self._trainer = None\n        self._callback_registry = {trigger: set() for trigger in self.TRIGGERS}\n        self._last_known_epoch = None\n        self._last_known_iteration = None\n\n    def register_new_trigger(self, trigger_name):\n        self.TRIGGERS.add(trigger_name)\n        self._callback_registry.update({trigger_name: set()})\n\n    def bind_trainer(self, trainer):\n        self._trainer = trainer\n        return self\n\n    def unbind_trainer(self):\n        self._trainer = None\n        return self\n\n    @property\n    def trainer_is_bound(self):\n        return self._trainer is not None\n\n    def register_callback(self, callback, trigger='auto', bind_trainer=True):\n        assert callable(callback)\n        # Automatic callback registration based on their methods\n        if trigger == 'auto':\n            automatic_registration_successful = False\n            for trigger in self.TRIGGERS:\n                if pyu.has_callable_attr(callback, trigger):\n                    automatic_registration_successful = True\n                    self.register_callback(callback, trigger, bind_trainer)\n            assert automatic_registration_successful, \\\n                \"Callback could not be auto-registered: no triggers recognized.\"\n            return self\n        # Validate triggers\n        assert trigger in self.TRIGGERS\n        # Add to callback registry\n        self._callback_registry.get(trigger).add(callback)\n        # Register trainer with the callback if required\n        bind_trainer_to_callback = self.trainer_is_bound and \\\n                                   bind_trainer and \\\n                                   pyu.has_callable_attr(callback, 'bind_trainer')\n        if bind_trainer_to_callback:\n            callback.bind_trainer(self._trainer)\n        return self\n\n    def rebind_trainer_to_all_callbacks(self):\n        # FIXME This makes bind_trainer in register_callback reduntant,\n        # especially if used by the trainer class, so... deprecate bind_traner.\n        for callbacks_at_trigger in self._callback_registry.values():\n            for callback in callbacks_at_trigger:\n                # Register trainer with the callback if required\n                bind_trainer_to_callback = self.trainer_is_bound and \\\n                                           pyu.has_callable_attr(callback, 'bind_trainer')\n                if bind_trainer_to_callback:\n                    callback.bind_trainer(self._trainer)\n\n    def call(self, trigger, **kwargs):\n        assert trigger in self.TRIGGERS\n        kwargs.update({'trigger': trigger})\n        for callback in self._callback_registry.get(trigger):\n            callback(**kwargs)\n\n    def get_config(self):\n        # Pop trainer\n        config_dict = dict(self.__dict__)\n        config_dict.update({'_trainer': None})\n        return config_dict\n\n    def set_config(self, config_dict):\n        self.__dict__.update(config_dict)\n        return self\n\n    def __getstate__(self):\n        return self.get_config()\n\n    def __setstate__(self, state):\n        self.set_config(state)\n\n\nclass Callback(object):\n    \"\"\"Recommended (but not required) base class for callbacks.\"\"\"\n    def __init__(self):\n        self._trainer = None\n        self._debugging = False\n        self.register_instance(self)\n\n    @classmethod\n    def register_instance(cls, instance):\n        if hasattr(cls, '_instance_registry') and instance not in cls._instance_registry:\n            cls._instance_registry.append(instance)\n        else:\n            cls._instance_registry = [instance]\n\n    @classmethod\n    def get_instances(cls):\n        if hasattr(cls, '_instance_registry'):\n            return pyu.from_iterable(cls._instance_registry)\n        else:\n            return None\n\n    @property\n    def trainer(self):\n        return self._trainer\n\n    def bind_trainer(self, trainer):\n        self._trainer = trainer\n        return self\n\n    def unbind_trainer(self):\n        self._trainer = None\n        return self\n\n    def __call__(self, **kwargs):\n        if 'trigger' in kwargs:\n            if hasattr(self, kwargs.get('trigger')) and \\\n                    callable(getattr(self, kwargs.get('trigger'))):\n                getattr(self, kwargs.get('trigger'))(**kwargs)\n\n    def get_config(self):\n        config_dict = dict(self.__dict__)\n        config_dict.update({'_trainer': None})\n        return config_dict\n\n    def set_config(self, config_dict):\n        self.__dict__.update(config_dict)\n        return self\n\n    def __getstate__(self):\n        return self.get_config()\n\n    def __setstate__(self, state):\n        self.set_config(state)\n\n    def toggle_debug(self):\n        self._debugging = not self._debugging\n        return self\n\n    def debug_print(self, message):\n        if self._debugging:\n            self.trainer.console.debug(\"[{}] {}\".format(type(self).__name__, message))\n"
  },
  {
    "path": "inferno/trainers/callbacks/console.py",
    "content": "from datetime import datetime\nfrom .base import Callback\n\nclass StdoutPrinter(object):\n    def print(self, message):\n        print(\"[+][{}] {}\".format(str(datetime.now()), message))\n\n\nclass Console(object):\n    LEVEL_INFO = 1\n    LEVEL_PROGRESS = 2\n    LEVEL_WARNING = 3\n    LEVEL_DEBUG = 4\n\n    def __init__(self, printer=StdoutPrinter()):\n        self._printer = printer\n        self._enabled = {self.LEVEL_INFO, self.LEVEL_PROGRESS, self.LEVEL_WARNING}\n\n    def set_console(self, console):\n        self._printer = console\n\n    def _print(self, message, level):\n        if level not in self._enabled:\n            return\n\n        self._printer.print(message)\n\n    def info(self, message):\n        self._print(\"[INFO    ] \" + message, self.LEVEL_INFO)\n\n    def print(self, message):\n        self.info(message)\n\n    def progress(self, message):\n        self._print(\"[PROGRESS] \" + message, self.LEVEL_PROGRESS)\n\n    def warning(self, message):\n        self._print(\"[WARNING ] \" + message, self.LEVEL_WARNING)\n\n    def debug(self, message):\n        self._print(\"[DEBUG   ] \" + message, self.LEVEL_DEBUG)\n\n    def _toggle(self, level, state):\n        if state:\n            self._enabled.add(level)\n        else:\n            if level in self._enabled:\n                self._enabled.remove(level)\n\n    def toggle_info(self, state):\n        self._toggle(self.LEVEL_INFO, state)\n\n    def toggle_progress(self, state):\n        self._toggle(self.LEVEL_PROGRESS, state)\n\n    def toggle_warning(self, state):\n        self._toggle(self.LEVEL_WARNING, state)\n\n\n\nclass ShowMinimalConsoleInfo(Callback):\n    \"\"\"\n    Callback to show only minimum training info on console \n    viz. current epoch number, current learning rate,\n    training loss and training error if exists.\n    \"\"\"\n    def __init__(self, *args, **kwargs):\n        super(ShowMinimalConsoleInfo, self).__init__(*args, **kwargs)\n\n    def begin_of_fit(self,**_):\n        self.trainer.quiet()\n\n    def end_of_epoch(self, **_):\n        training_loss = self.trainer.get_state('training_loss')\n        training_error = self.trainer.get_state('training_error')\n        learning_rate = self.trainer.get_state('learning_rate')\n\n        self.trainer.console.info(\"--------------------------------\")\n        self.trainer.console.info(\"Epoch \"+str(self.trainer.epoch_count))\n        if training_loss is not None:\n            self.trainer.console.info(\"Train Loss \"+str(training_loss.item()))\n        if training_error is not None:\n            self.trainer.console.info(\"Train Error \"+str(training_error.item()))\n        self.trainer.console.info(\"Current LR \"+str(learning_rate))"
  },
  {
    "path": "inferno/trainers/callbacks/essentials.py",
    "content": "import numpy as np\nimport os\nimport h5py as h5\nfrom ...utils import torch_utils as tu\nfrom ...utils.train_utils import Frequency\nfrom ...utils.exceptions import assert_, FrequencyValueError, NotUnwrappableError\nfrom ...utils import python_utils as pyu\nfrom .base import Callback\nimport gc\n\n\nclass NaNDetector(Callback):\n    def end_of_training_iteration(self, **_):\n        training_loss = self.trainer.get_state('training_loss')\n        # Extract scalar\n        if tu.is_tensor(training_loss):\n            training_loss = tu.unwrap(training_loss, extract_item=True)\n        if not np.isfinite(training_loss):\n            raise RuntimeError(\"Loss is not finite (loss={})!\".format(training_loss))\n\n\nclass PersistentSave(Callback):\n    def __init__(self, template='checkpoint.pytorch.epoch{epoch_count}.iteration{iteration_count}'):\n        super(PersistentSave, self).__init__()\n        self.template = template\n\n    def begin_of_save(self, **kwargs):\n        self._orig_checkpoint_filename = self.trainer._checkpoint_filename\n        self.trainer._checkpoint_filename = self.template.format(**kwargs)\n\n    def end_of_save(self, save_to_directory, **_):\n        orig_checkpoint_path = os.path.join(save_to_directory, self._orig_checkpoint_filename)\n\n        if os.path.lexists(orig_checkpoint_path):\n            os.remove(orig_checkpoint_path)\n        os.symlink(self.trainer._checkpoint_filename, orig_checkpoint_path)\n\n        self.trainer._checkpoint_filename = self._orig_checkpoint_filename\n\n\nclass DumpHDF5Every(Callback):\n    \"\"\"Dumps intermediate training states to a HDF5 file.\"\"\"\n    def __init__(self, frequency, to_directory,\n                 filename_template='dump.{mode}.epoch{epoch_count}.iteration{iteration_count}.h5',\n                 force_dump=False, dump_after_every_validation_run=False):\n        super(DumpHDF5Every, self).__init__()\n        # Privates\n        self._dump_every = None\n        self._trainer_states_to_be_dumped_while_training = {'training_inputs',\n                                                            'training_target',\n                                                            'training_prediction'}\n        self._trainer_states_to_be_dumped_while_validating = {'validation_inputs',\n                                                              'validation_target',\n                                                              'validation_prediction'}\n        self._dump_cache = {}\n        # Publics\n        self.dump_every = frequency\n        self.dump_directory = to_directory\n        self.dump_filename_template = filename_template\n        self.force_dump = force_dump    # hihi\n        self.dump_after_every_validation_run = dump_after_every_validation_run\n\n    @property\n    def dump_every(self):\n        return self._dump_every\n\n    @dump_every.setter\n    def dump_every(self, value):\n        self._dump_every = Frequency.build_from(value)\n        assert_(self._dump_every.is_consistent,\n                \"Dump frequency is not consistent.\",\n                FrequencyValueError)\n\n    @property\n    def dump_now(self):\n        return self.dump_every.match(iteration_count=self.trainer.iteration_count,\n                                     epoch_count=self.trainer.epoch_count,\n                                     persistent=True, match_zero=True)\n\n    def add_to_dump_cache(self, key, value):\n        if pyu.is_listlike(value):\n            for value_num, _value in enumerate(value):\n                self.add_to_dump_cache(\"{}_{}\".format(key, value_num), _value)\n        else:\n            self._dump_cache.update({key: value})\n\n    def clear_dump_cache(self):\n        self._dump_cache.clear()\n\n    def dump_state(self, key, dump_while='training'):\n        # Validate arguments\n        keyword_mapping = {'train': 'training',\n                           'training': 'training',\n                           'validation': 'validating',\n                           'validating': 'validating'}\n        dump_while = keyword_mapping.get(dump_while)\n        assert_(dump_while is not None,\n                \"The keyword dump_while must be one of: {}.\"\n                .format(set(keyword_mapping.keys())),\n                ValueError)\n        assert_(isinstance(key, str),\n                \"State key must be a string, got {} instead.\".format(type(key).__name__),\n                TypeError)\n        # Add to set of observed states\n        if dump_while == 'training':\n            self._trainer_states_to_be_dumped_while_training.add(key)\n        elif dump_while == 'validating':\n            self._trainer_states_to_be_dumped_while_validating.add(key)\n        else:\n            raise NotImplementedError\n        return self\n\n    def dump_states(self, keys, dump_while='training'):\n        for key in keys:\n            self.dump_state(key, dump_while=dump_while)\n        return self\n\n    def get_file_path(self, mode):\n        # Make sure the dump directory exists\n        if not os.path.exists(self.dump_directory):\n            os.mkdir(self.dump_directory)\n        else:\n            assert_(os.path.isdir(self.dump_directory),\n                    \"Dump directory {} is a file.\".format(self.dump_directory),\n                    FileExistsError)\n        filename = self.dump_filename_template.format(epoch_count=self.trainer.epoch_count,\n                                                      iteration_count=self.trainer.iteration_count,\n                                                      mode=mode)\n        return os.path.join(self.dump_directory, filename)\n\n    def dump(self, mode):\n        with h5.File(name=self.get_file_path(mode), mode='w') as h5_file:\n            for key, to_dump in self._dump_cache.items():\n                if to_dump is None:\n                    continue\n                try:\n                    to_dump = tu.unwrap(to_dump, as_numpy=True)\n                except NotUnwrappableError:\n                    # Can't unwrap to_dump, but let's not throw a tantrum if we're not required to\n                    if not self.force_dump:\n                        continue\n                    else:\n                        raise\n                # Do the dumpin'\n                h5_file.create_dataset(name=key, data=to_dump)\n\n    def end_of_training_iteration(self, **_):\n        dump_now = self.dump_now\n        if dump_now:\n            # To be double sure\n            self.clear_dump_cache()\n            # Get object to dump\n            for state_name in self._trainer_states_to_be_dumped_while_training:\n                self.add_to_dump_cache(state_name, self.trainer.get_state(state_name))\n            # Dump\n            self.dump(mode='training')\n            # Clear cache\n            self.clear_dump_cache()\n\n    def end_of_validation_run(self, **_):\n        if self.dump_after_every_validation_run:\n            # To be double sure\n            self.clear_dump_cache()\n            # Get object to dump\n            for state_name in self._trainer_states_to_be_dumped_while_validating:\n                self.add_to_dump_cache(state_name, self.trainer.get_state(state_name))\n            # Dump\n            self.dump(mode='validation')\n            # Clear cache\n            self.clear_dump_cache()\n\n\nclass SaveAtBestValidationScore(Callback):\n    \"\"\"\n    Triggers a save at the best EMA (exponential moving average) validation score.\n    The basic `Trainer` has built in support for saving at the best validation score, but this\n    callback might eventually replace that functionality.\n    \"\"\"\n    def __init__(self, smoothness=0, verbose=False):\n        super(SaveAtBestValidationScore, self).__init__()\n        # Privates\n        self._ema_validation_score = None\n        self._best_ema_validation_score = None\n        # Publics\n        self.smoothness = smoothness\n        self.verbose = verbose\n\n    def end_of_validation_run(self, **_):\n        # Get score (i.e. validation error if available, else validation loss)\n        current_validation_score = self.trainer.get_state('validation_error_averaged')\n        current_validation_score = self.trainer.get_state('validation_loss_averaged') \\\n            if current_validation_score is None else current_validation_score\n        # Maintain ema\n        if self._ema_validation_score is None:\n            self._ema_validation_score = current_validation_score\n            self._best_ema_validation_score = current_validation_score\n            # If no previous score is known, assume this is the best score and save\n            self.trainer._is_iteration_with_best_validation_score = True\n        else:\n            self._ema_validation_score = self.smoothness * self._ema_validation_score + \\\n                                         (1 - self.smoothness) * current_validation_score\n            # This overrides the default behaviour, but reduces to it if smoothness = 0\n            self.trainer._is_iteration_with_best_validation_score = \\\n                self._ema_validation_score < self._best_ema_validation_score\n        # Trigger a save\n        if self.trainer._is_iteration_with_best_validation_score:\n            if self.verbose:\n                self.trainer.console.info(\"Current smoothed validation score {} is better \"\n                                          \"than the best smoothed validation score {}.\"\n                                          .format(self._ema_validation_score,\n                                                  self._best_ema_validation_score))\n            self._best_ema_validation_score = self._ema_validation_score\n            self.trainer.save_now = True\n        else:\n            if self.verbose:\n                self.trainer.console.info(\"Current smoothed validation score {} is not better \"\n                                          \"than the best smoothed validation score {}.\"\n                                          .format(self._ema_validation_score,\n                                                  self._best_ema_validation_score))\n        # Done\n\n\nclass ParameterEMA(Callback):\n    \"\"\"Maintain a moving average of network parameters.\"\"\"\n    def __init__(self, momentum):\n        \"\"\"\n        Parameters\n        ----------\n        momentum : float\n            Momentum for the moving average. The following holds:\n            `new_moving_average = momentum * old_moving_average + (1 - momentum) * value`\n        \"\"\"\n        super(ParameterEMA, self).__init__()\n        # Privates\n        self._parameters = None\n        # Publics\n        self.momentum = momentum\n\n    def maintain(self):\n        if self._parameters is None:\n            self._parameters = [p.data.new().zero_() for p in self.trainer.model.parameters()]\n        for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters):\n            p_ema.mul_(self.momentum).add_(p_model.data.mul(1. - self.momentum))\n\n    def apply(self):\n        assert_(self._parameters is not None,\n                \"Can't apply parameter EMA's: not available.\",\n                ValueError)\n        for p_model, p_ema in zip(self.trainer.model.parameters(), self._parameters):\n            p_model.data.copy_(p_ema)\n\n    def end_of_training_iteration(self, **_):\n        self.maintain()\n\n\nclass GradientClip(Callback):\n    def __init__(self, clip_value=None, clip_norm=None):\n        super(GradientClip, self).__init__()\n        assert_(not (clip_value is None and clip_norm is None),\n                \"Must provide either clip_value or clip_norm.\",\n                ValueError)\n        assert_(clip_value is None or clip_norm is None,\n                f\"Must provide only one, but not both: \"\n                f\"clip_value ({clip_value}) or clip_norm ({clip_norm}).\",\n                RuntimeError)\n        self._clip_value = clip_value\n        self._clip_norm = clip_norm\n\n    @property\n    def mode(self):\n        return 'value' if self._clip_value is not None else 'norm'\n\n    @property\n    def norm_or_value(self):\n        return self._clip_value if self._clip_value is not None else self._clip_norm\n\n    def after_model_and_loss_is_applied(self, **_):\n        tu.clip_gradients_(self.trainer.model.parameters(), self.mode, self.norm_or_value)\n\n\nclass GarbageCollection(Callback):\n    \"\"\"\n    Callback that triggers garbage collection at the end of every\n    training iteration in order to reduce the memory footprint of training\n    \"\"\"\n\n    def end_of_training_iteration(self, **_):\n        gc.collect()\n"
  },
  {
    "path": "inferno/trainers/callbacks/gradients.py",
    "content": "from ...utils.train_utils import Frequency\nfrom ...utils.exceptions import assert_, FrequencyValueError\nfrom .base import Callback\n\n\nclass LogOutputGradients(Callback):\n    \"\"\"Logs the gradient of the network output\"\"\"\n\n    def __init__(self, frequency):\n        super(LogOutputGradients, self).__init__()\n        self.log_every = frequency\n        self.registered = False\n        self.hook_handle = None\n\n    @property\n    def log_every(self):\n        return self._log_every\n\n    @log_every.setter\n    def log_every(self, value):\n        self._log_every = Frequency(value, 'iterations')\n        assert_(self.log_every.is_consistent,\n                \"Log frequency is not consistent.\",\n                FrequencyValueError)\n\n    def hook(self, module, grad_input, grad_output):\n\n        #remove hook if trainer does not exits\n        if self.trainer is None:\n            self.hook_handle.remove()\n            return\n\n        if self.log_every.match(iteration_count=self.trainer.iteration_count,\n                                epoch_count=self.trainer.epoch_count,\n                                persistent=True, match_zero=True):\n            self.trainer.update_state('output_gradient', grad_output[0].detach().float().clone().cpu())\n\n    def add_hook(self):\n        self.hook_handle = self.trainer.model.register_backward_hook(self.hook)\n\n    def begin_of_fit(self, **kwargs):\n        self._trainer.logger.observe_state(\"output_gradient\",\n                                           observe_while='training')\n        self.add_hook()\n\n    def begin_of_save(self, **_):\n        # remove hook from model, because you can't pickle it.\n        if self.hook_handle is not None:\n            self.hook_handle.remove()\n            self.hook_handle = None\n\n    def end_of_save(self, **_):\n        # add hook after model save\n        self.add_hook()\n\n"
  },
  {
    "path": "inferno/trainers/callbacks/logging/__init__.py",
    "content": "__all__ = ['get_logger']\ntry:\n    INFERNO_WITH_TENSORBOARD_LOGGER = True\n    from .tensorboard import TensorboardLogger\n    __all__.append('TensorboardLogger')\nexcept ImportError:\n    INFERNO_WITH_TENSORBOARD_LOGGER = False\n\n\ndef get_logger(name):\n    if name in globals():\n        return globals().get(name)\n    else:\n        raise NotImplementedError(\"Logger not found.\")\n"
  },
  {
    "path": "inferno/trainers/callbacks/logging/base.py",
    "content": "import os\nfrom ..base import Callback\n\n\nclass Logger(Callback):\n    \"\"\"\n    A special callback for logging.\n\n    Loggers are special because they're required to be serializable, whereas other\n    callbacks have no such guarantees. In this regard, they jointly handled by\n    trainers and the callback engine.\n    \"\"\"\n    def __init__(self, log_directory=None):\n        super(Logger, self).__init__()\n        self._log_directory = None\n        if log_directory is not None:\n            self.set_log_directory(log_directory)\n\n    @property\n    def log_directory(self):\n        if self._log_directory is not None:\n            return self._log_directory\n        elif self.trainer is not None and self.trainer._log_directory is not None:\n            return self.trainer._log_directory\n        else:\n            raise RuntimeError(\"No log directory found.\")\n\n    @log_directory.setter\n    def log_directory(self, value):\n        self.set_log_directory(value)\n\n    def set_log_directory(self, log_directory):\n        assert isinstance(log_directory, str)\n        if not os.path.isdir(log_directory):\n            assert not os.path.exists(log_directory)\n            os.makedirs(log_directory)\n        self._log_directory = log_directory\n        return self\n"
  },
  {
    "path": "inferno/trainers/callbacks/logging/tensorboard.py",
    "content": "import warnings\nimport numpy as np\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom .base import Logger\nfrom ....utils import torch_utils as tu\nfrom ....utils import python_utils as pyu\nfrom ....utils import train_utils as tru\nfrom ....utils.exceptions import assert_\n\n\nclass TaggedImage(object):\n    def __init__(self, array, tag):\n        self.array = array\n        self.tag = tag\n\n\nclass TensorboardLogger(Logger):\n    \"\"\"Class to enable logging of training progress to Tensorboard.\n\n    Currently supports logging scalars and images.\n    \"\"\"\n    # This is hard coded because tensorboardX doesn't have a __version__\n    _TENSORBOARDX_IMAGE_FORMAT = 'CHW'\n    _DEBUG = False\n\n    def __init__(self, log_directory=None,\n                 log_scalars_every=None, log_images_every=None, log_histograms_every=None,\n                 send_image_at_batch_indices='all', send_image_at_channel_indices='all',\n                 send_volume_at_z_indices='mid'):\n        \"\"\"\n        Parameters\n        ----------\n        log_directory : str\n            Path to the directory where the log files will be placed.\n        log_scalars_every : str or tuple or inferno.utils.train_utils.Frequency\n            How often scalars should be logged to Tensorboard. By default, once every iteration.\n        log_images_every : str or tuple or inferno.utils.train_utils.Frequency\n            How often images should be logged to Tensorboard. By default, once every epoch.\n        log_histograms_every : str or tuple or inferno.utils.train_utils.Frequency\n            How often histograms should be logged to Tensorboard. By default, never.\n        send_image_at_batch_indices : list or str\n            The indices of the batches to be logged. An `image_batch` usually has the shape\n            (num_samples, num_channels, num_rows, num_cols). By setting this argument to say\n            [0, 2], only images corresponding to `image_batch[0]` and `image_batch[2]` are\n            logged. When a str, it should be 'all', in which case, all samples are logged.\n        send_image_at_channel_indices : list or str\n            Similar to `send_image_at_batch_indices`, but applying to channels.\n        send_volume_at_z_indices : list or str\n            For 3D batches of shape (num_samples, num_channels, num_z_slices, num_rows, num_cols),\n            select the indices of the z slices to be logged. When a str, it could be 'all' or\n            'mid' (to log the central z slice).\n\n        Warnings\n        --------\n        Leaving log_images_every to the default (i.e. once every iteration) might generate a\n        large logfile and/or slow down the training.\n        \"\"\"\n        super(TensorboardLogger, self).__init__(log_directory=log_directory)\n        self._log_scalars_every = None\n        self._log_images_every = None\n        self._log_histograms_every = None\n        self._writer = None\n        self._config = {'image_batch_indices': send_image_at_batch_indices,\n                        'image_channel_indices': send_image_at_channel_indices,\n                        'volume_z_indices': send_volume_at_z_indices}\n        # We ought to know the trainer states we're observing (and plotting to tensorboard).\n        # These are the defaults.\n        self._trainer_states_being_observed_while_training = {'training_loss',\n                                                              'training_error',\n                                                              'training_prediction',\n                                                              'training_inputs',\n                                                              'training_target',\n                                                              'learning_rate'}\n        self._trainer_states_being_observed_while_validating = {'validation_error_averaged',\n                                                                'validation_loss_averaged'}\n        if log_scalars_every is not None:\n            self.log_scalars_every = log_scalars_every\n        if log_images_every is not None:\n            self.log_images_every = log_images_every\n        if log_histograms_every is not None:\n            self.log_histograms_every = log_histograms_every\n\n    @property\n    def writer(self):\n        if self._writer is None:\n            self._writer = SummaryWriter(self.log_directory)\n        return self._writer\n\n    @property\n    def log_scalars_every(self):\n        if self._log_scalars_every is None:\n            self._log_scalars_every = tru.Frequency(1, 'iterations')\n        return self._log_scalars_every\n\n    @log_scalars_every.setter\n    def log_scalars_every(self, value):\n        self._log_scalars_every = tru.Frequency.build_from(value)\n\n    @property\n    def log_scalars_now(self):\n        # Using persistent=True in a property getter is probably not a very good idea...\n        # We need to make sure that this getter is called only once per callback-call.\n        return self.log_scalars_every.match(iteration_count=self.trainer.iteration_count,\n                                            epoch_count=self.trainer.epoch_count,\n                                            persistent=True)\n\n    @property\n    def log_images_every(self):\n        if self._log_images_every is None:\n            self._log_images_every = tru.Frequency(1, 'epochs')\n        return self._log_images_every\n\n    @log_images_every.setter\n    def log_images_every(self, value):\n        self._log_images_every = tru.Frequency.build_from(value)\n\n    @property\n    def log_images_now(self):\n        # Using persistent=True in a property getter is probably not a very good idea...\n        # We need to make sure that this getter is called only once per callback-call.\n        return self.log_images_every.match(iteration_count=self.trainer.iteration_count,\n                                           epoch_count=self.trainer.epoch_count,\n                                           persistent=True)\n\n    @property\n    def log_histograms_every(self):\n        if self._log_histograms_every is None:\n            self._log_histograms_every = tru.Frequency('never')\n        return self._log_histograms_every\n\n    @log_histograms_every.setter\n    def log_histograms_every(self, value):\n        self._log_histograms_every = tru.Frequency.build_from(value)\n\n    @property\n    def log_histograms_now(self):\n        # Using persistent=True in a property getter is probably not a very good idea...\n        # We need to make sure that this getter is called only once per callback-call.\n        return self.log_histograms_every.match(iteration_count=self.trainer.iteration_count,\n                                               epoch_count=self.trainer.epoch_count,\n                                               persistent=True)\n\n    def observe_state(self, key, observe_while='training'):\n        # Validate arguments\n        keyword_mapping = {'train': 'training',\n                           'training': 'training',\n                           'validation': 'validating',\n                           'validating': 'validating'}\n        observe_while = keyword_mapping.get(observe_while)\n        assert_(observe_while is not None,\n                \"The keyword observe_while must be one of: {}.\"\n                .format(set(keyword_mapping.keys())),\n                ValueError)\n        assert_(isinstance(key, str),\n                \"State key must be a string, got {} instead.\".format(type(key).__name__),\n                TypeError)\n        # Add to set of observed states\n        if observe_while == 'training':\n            self._trainer_states_being_observed_while_training.add(key)\n        elif observe_while == 'validating':\n            self._trainer_states_being_observed_while_validating.add(key)\n        else:\n            raise NotImplementedError\n        return self\n\n    def unobserve_state(self, key, observe_while='training'):\n        if observe_while == 'training':\n            self._trainer_states_being_observed_while_training.remove(key)\n        elif observe_while == 'validating':\n            self._trainer_states_being_observed_while_validating.remove(key)\n        else:\n            raise NotImplementedError\n        return self\n\n    def unobserve_states(self, keys, observe_while='training'):\n        for key in keys:\n            self.unobserve_state(key, observe_while=observe_while)\n        return self\n\n    def observe_training_and_validation_state(self, key):\n        for mode in ['training', 'validation']:\n            self.observe_state('{}_{}'.format(mode, key), observe_while=mode)\n\n    def observe_states(self, keys, observe_while='training'):\n        for key in keys:\n            self.observe_state(key, observe_while=observe_while)\n        return self\n\n    def observe_training_and_validation_states(self, keys):\n        for key in keys:\n            self.observe_training_and_validation_state(key)\n        return self\n\n    def log_object(self, tag, object_,\n                   allow_scalar_logging=True,\n                   allow_image_logging=True,\n                   allow_histogram_logging=True):\n        assert isinstance(tag, str)\n        if isinstance(object_, (list, tuple)):\n            for object_num, _object in enumerate(object_):\n                self.log_object(\"{}_{}\".format(tag, object_num),\n                                _object,\n                                allow_scalar_logging,\n                                allow_image_logging,\n                                allow_histogram_logging)\n            return\n\n        # Check whether object is a scalar\n        if tu.is_scalar_tensor(object_) and allow_scalar_logging:\n            # Log scalar\n            value = tu.unwrap(object_.float(), extract_item=True)\n            self.log_scalar(tag, value, step=self.trainer.iteration_count)\n        elif isinstance(object_, (float, int)) and allow_scalar_logging:\n            value = float(object_)\n            self.log_scalar(tag, value, step=self.trainer.iteration_count)\n        elif tu.is_label_image_or_volume_tensor(object_) and allow_image_logging:\n            # Add a channel axis and log as images\n            self.log_image_or_volume_batch(tag, object_[:, None, ...],\n                                           self.trainer.iteration_count)\n        elif tu.is_image_or_volume_tensor(object_):\n            if allow_image_logging:\n                # Log images\n                self.log_image_or_volume_batch(tag, object_, self.trainer.iteration_count)\n        elif tu.is_vector_tensor(object_) and allow_histogram_logging:\n            # Log histograms\n            values = tu.unwrap(object_, as_numpy=True)\n            self.log_histogram(tag, values, self.trainer.iteration_count)\n        else:\n            # Object is neither a scalar nor an image nor a vector, there's nothing we can do\n            if tu.is_tensor(object_) and self._DEBUG:\n                # Throw a warning when in debug mode.\n                warnings.warn(\"Unsupported attempt to log tensor `{}` of shape `{}`\".format(tag, object_.size()))\n\n    def end_of_training_iteration(self, **_):\n        log_scalars_now = self.log_scalars_now\n        log_images_now = self.log_images_now\n        log_histograms_now = self.log_histograms_now\n        if not log_scalars_now and not log_images_now:\n            # Nothing to log, so we won't bother\n            return\n        # Read states\n        for state_key in self._trainer_states_being_observed_while_training:\n            state = self.trainer.get_state(state_key, default=None)\n            if state is None:\n                # State not found in trainer but don't throw a hissy fit\n                continue\n            self.log_object(state_key, state,\n                            allow_scalar_logging=log_scalars_now,\n                            allow_image_logging=log_images_now,\n                            allow_histogram_logging=log_histograms_now)\n\n    def end_of_validation_run(self, **_):\n        # Log everything\n        # Read states\n        for state_key in self._trainer_states_being_observed_while_validating:\n            state = self.trainer.get_state(state_key, default=None)\n            if state is None:\n                # State not found in trainer but don't throw a hissy fit\n                continue\n            self.log_object(state_key, state,\n                            allow_scalar_logging=True,\n                            allow_image_logging=True,\n                            allow_histogram_logging=False)\n\n    def _tag_image(self, image, base_tag, prefix=None, instance_num=None, channel_num=None,\n                   slice_num=None):\n        tag = base_tag\n        if prefix is not None:\n            tag = '{}/{}'.format(base_tag, prefix)\n        if instance_num is not None:\n            tag = '{}/instance_{}'.format(tag, instance_num)\n        if channel_num is not None:\n            tag = '{}/channel_{}'.format(tag, channel_num)\n        if slice_num is not None:\n            tag = '{}/slice_{}'.format(tag, slice_num)\n        return TaggedImage(image, tag)\n\n    def extract_images_from_batch(self, batch, base_tag=None, prefix=None):\n        if base_tag is None:\n            assert_(prefix is None,\n                    \"`base_tag` is not provided - `prefix` must be None in this case.\",\n                    ValueError)\n        # Special case when batch is a list or tuple of batches\n        if isinstance(batch, (list, tuple)):\n            image_list = []\n            for batch_num, _batch in batch:\n                image_list.extend(\n                    self.extract_images_from_batch(_batch, base_tag=base_tag,\n                                                   prefix='batch_{}'.format(batch_num)))\n            return image_list\n        # `batch` really is a tensor from now on.\n        batch_is_image_tensor = tu.is_image_tensor(batch)\n        batch_is_volume_tensor = tu.is_volume_tensor(batch)\n        assert batch_is_volume_tensor != batch_is_image_tensor, \\\n            \"Batch must either be a image or a volume tensor.\"\n        # Convert to numpy\n        batch = batch.float().numpy()\n        # Get the indices of the batches we want to send to tensorboard\n        batch_indices = self._config.get('image_batch_indices', 'all')\n        if batch_indices == 'all':\n            batch_indices = list(range(batch.shape[0]))\n        elif isinstance(batch_indices, (list, tuple)):\n            pass\n        elif isinstance(batch_indices, int):\n            batch_indices = [batch_indices]\n        else:\n            raise NotImplementedError\n        # Get the indices of the channels we want to send to tensorboard\n        channel_indices = self._config.get('image_channel_indices', 'all')\n        if channel_indices == 'all':\n            channel_indices = list(range(batch.shape[1]))\n        elif isinstance(channel_indices, (list, tuple)):\n            pass\n        elif isinstance(channel_indices, int):\n            channel_indices = [channel_indices]\n        else:\n            raise NotImplementedError\n        # Extract images from batch\n        if batch_is_image_tensor:\n            image_list = [(self._tag_image(image,\n                                           base_tag=base_tag, prefix=prefix,\n                                           instance_num=instance_num,\n                                           channel_num=channel_num)\n                           if base_tag is not None else image)\n                          for instance_num, instance in enumerate(batch)\n                          for channel_num, image in enumerate(instance)\n                          if instance_num in batch_indices and channel_num in channel_indices]\n        else:\n            assert batch_is_volume_tensor\n            # Trim away along the z axis\n            z_indices = self._config.get('volume_z_indices', 'mid')\n            if z_indices == 'all':\n                z_indices = list(range(batch.shape[2]))\n            elif z_indices == 'mid':\n                z_indices = [batch.shape[2] // 2]\n            elif isinstance(z_indices, (list, tuple)):\n                pass\n            elif isinstance(z_indices, int):\n                z_indices = [z_indices]\n            else:\n                raise NotImplementedError\n            # I'm going to hell for this.\n            image_list = [(self._tag_image(image,\n                                           base_tag=base_tag, prefix=prefix,\n                                           instance_num=instance_num,\n                                           channel_num=channel_num,\n                                           slice_num=slice_num)\n                           if base_tag is not None else image)\n                          for instance_num, instance in enumerate(batch)\n                          for channel_num, volume in enumerate(instance)\n                          for slice_num, image in enumerate(volume)\n                          if instance_num in batch_indices and\n                          channel_num in channel_indices and\n                          slice_num in z_indices]\n        # Done.\n        return image_list\n\n    def log_image_or_volume_batch(self, tag, batch, step=None):\n        assert pyu.is_maybe_list_of(tu.is_image_or_volume_tensor)(batch)\n        step = step or self.trainer.iteration_count\n        image_list = self.extract_images_from_batch(batch, base_tag=tag)\n        self.log_images(tag, image_list, step)\n\n    def log_scalar(self, tag, value, step):\n        \"\"\"\n        Parameter\n        ----------\n        tag : basestring\n            Name of the scalar\n        value\n        step : int\n            training iteration\n        \"\"\"\n        self.writer.add_scalar(tag=tag, scalar_value=value, global_step=step)\n\n    def log_images(self, tag, images, step, image_format='CHW'):\n        \"\"\"Logs a list of images.\"\"\"\n        assert_(image_format.upper() in ['CHW', 'HWC'],\n                \"Image format must be either 'CHW' or 'HWC'. Got {} instead.\".format(image_format),\n                ValueError)\n        for image_num, image in enumerate(images):\n            if isinstance(image, TaggedImage):\n                tag = image.tag\n                image = image.array\n            else:\n                tag = \"{}/{}\".format(tag, image_num)\n            # This will fail for the wrong tensorboard version.\n            image = self._order_image_axes(image, image_format, self._TENSORBOARDX_IMAGE_FORMAT)\n            # unfortunately tensorboardX does not have a __version__ attribute\n            # so I don't see how to check for the version and provide backwards\n            # compatability here\n            # tensorboardX borks if the number of image channels is not 3\n            # if image.shape[-1] == 1:\n            #     image = image[..., [0, 0, 0]]\n            image = self._normalize_image(image)\n            # print(image.dtype, image.shape)\n            self.writer.add_image(tag, img_tensor=image, global_step=step)\n\n    @staticmethod\n    def _order_image_axes(image, image_format='CHW', target_format='CHW'):\n        # image axis gymnastics\n        _not_implemented_message = \"target_format must be 'CHW' or 'HCW'.\"\n        if image.ndim == 2:\n            if target_format == 'CHW':\n                # image is 2D - tensorboardX 1.4+ needs a channel axis in the front\n                image = image[None, ...]\n            elif target_format == 'HWC':\n                # image is 2D - tensorboardX 1.3- needs a channel axis in the end\n                image = image[..., None]\n            else:\n                raise NotImplementedError(_not_implemented_message)\n        elif image.ndim == 3 and image_format.upper() == 'CHW':\n            if target_format == 'CHW':\n                # Nothing to do here\n                pass\n            elif target_format == 'HCW':\n                # We have a CHW image, but need HWC.\n                image = np.moveaxis(image, 0, 2)\n            else:\n                raise NotImplementedError(_not_implemented_message)\n        elif image.ndim == 3 and image_format.upper() == 'HWC':\n            if target_format == 'CHW':\n                # We have a HWC image, but need CHW\n                image = np.moveaxis(image, 2, 0)\n            elif target_format == 'HWC':\n                # Nothing to do here\n                pass\n            else:\n                raise NotImplementedError(_not_implemented_message)\n        else:\n            raise RuntimeError\n        return image\n\n    @staticmethod\n    def _normalize_image(image):\n        normalized_image = image - image.min()\n        maxval = normalized_image.max()\n        if maxval > 0:\n            normalized_image = normalized_image / maxval\n        return normalized_image\n\n    def log_histogram(self, tag, values, step, bins=1000):\n        \"\"\"Logs the histogram of a list/vector of values.\"\"\"\n        # TODO\n        raise NotImplementedError\n\n    def get_config(self):\n        # Apparently, some SwigPyObject objects cannot be pickled - so we need to build the\n        # writer on the fly.\n        config = super(TensorboardLogger, self).get_config()\n        config.update({'_writer': None})\n        return config\n"
  },
  {
    "path": "inferno/trainers/callbacks/scheduling.py",
    "content": "from ...utils.train_utils import Frequency, Duration, MovingAverage\nfrom ...utils import python_utils as pyu\nfrom ...utils.exceptions import assert_, NotSetError\nfrom .base import Callback\nfrom functools import reduce\n\n\nclass _Scheduler(Callback):\n    def __init__(self, monitor='auto', monitor_momentum=0., monitor_while='auto'):\n        super(_Scheduler, self).__init__()\n        # Privates\n        self._monitor_value_moving_average = MovingAverage(momentum=monitor_momentum)\n        self._monitor_while = 'auto'\n        self._monitor = 'auto'\n        # Publics\n        self.monitor = monitor\n        self.monitor_while = monitor_while\n\n    @property\n    def monitor(self):\n        assert_(self._monitor is not None, \"Monitor is not set yet.\", NotSetError)\n        return self._monitor\n\n    @monitor.setter\n    def monitor(self, value):\n        self._monitor = value\n\n    @property\n    def monitor_value(self):\n        return self.get_monitor_value()[0]\n\n    @property\n    def monitor_while(self):\n        if self._monitor_while == 'auto':\n            monitor_value, monitor = self.get_monitor_value()\n            if monitor.startswith('training_'):\n                self._monitor_while = 'training'\n            elif monitor.startswith('validation_'):\n                self._monitor_while = 'validation'\n            else:\n                raise RuntimeError(\"Could not parse `monitor_while`. \"\n                                   \"Please provide one manually.\")\n        return self._monitor_while\n\n    @monitor_while.setter\n    def monitor_while(self, value):\n        value_mapping = {'auto': 'auto',\n                         'training': 'training',\n                         'validation': 'validation',\n                         'validating': 'validation'}\n        value = value_mapping.get(value)\n        assert_(value is not None,\n                \"`monitor_while` must be one of {}, got {} instead.\"\n                .format(value_mapping.keys(), value),\n                ValueError)\n        self._monitor_while = value\n\n    def get_monitor_value(self):\n        if self._monitor == 'auto':\n            # Try to get validation error\n            monitor_value = self.trainer.get_state('validation_error_averaged')\n            if monitor_value is not None:\n                self._monitor = 'validation_error_averaged'\n                return monitor_value, self._monitor\n            monitor_value = self.trainer.get_state('validation_loss_averaged')\n            if monitor_value is not None:\n                self._monitor = 'validation_loss_averaged'\n                return monitor_value, self._monitor\n            monitor_value = self.trainer.get_state('training_error')\n            if monitor_value is not None:\n                self._monitor = 'training_error'\n                return monitor_value, self._monitor\n            monitor_value = self.trainer.get_state('training_loss')\n            if monitor_value is not None:\n                self._monitor = 'training_loss'\n                return monitor_value, self._monitor\n            else:\n                raise RuntimeError(\"Could not auto-fetch a monitor_value. \"\n                                   \"Please specify a monitor manually.\")\n        else:\n            monitor_value = self.trainer.get_state(self._monitor)\n            assert_(monitor_value is not None,\n                    \"Could not fetch the specified monitor ('{}') from trainer's state.\"\n                    .format(self._monitor),\n                    ValueError)\n            return monitor_value, self._monitor\n\n    def maintain_monitor_moving_average(self):\n        monitor_value = self.monitor_value\n        self._monitor_value_moving_average.update(monitor_value)\n        return monitor_value\n\n\nclass AutoLR(_Scheduler):\n    \"\"\"\n    Callback to decay or hike the learning rate automatically when a specified monitor\n    stops improving.\n\n    The monitor should be decreasing, i.e. lower value --> better performance.\n    \"\"\"\n    def __init__(self, factor, patience, required_minimum_relative_improvement=0,\n                 consider_improvement_with_respect_to='best',\n                 cooldown_duration=None, monitor='auto', monitor_momentum=0,\n                 monitor_while='auto', exclude_param_groups=None, verbose=False):\n        \"\"\"\n        Parameters\n        ----------\n        factor : float\n            Factor to multiply the learning rate with when out of patience\n            and not in cooldown. Setting `factor < 1` results in a LR decay,\n            whereas setting `factor > 1` results in a LR hike.\n        patience : str or tuple or inferno.utils.train_utils.Duration\n            Specifies how long to wait for an improvement before a LR decay is triggered.\n        required_minimum_relative_improvement : float\n            Specifies by how much (as a fraction of the current value) the monitor should\n            improve to consider the improvement significant. Leaving this to zero implies\n            the monitor will be considered improving even if it's only so slightly better.\n        consider_improvement_with_respect_to : {'best', 'previous'}\n            While determining if the monitor has improved, the improvement is considered with\n            respect to this value. Could be 'best' or 'previous'.\n        cooldown_duration: str or tuple or inferno.utils.train_utils.Duration\n            Wait for this duration to resume operation after having decayed LR.\n        monitor : str\n            Specifies the monitor. Monitor must be a trainer state, and decrease with\n            increasing performance. Examples: 'validation_error', 'training_loss'.\n            The monitor can be 'auto' in which case it's recommended that you specify\n            `monitor_while`.\n        monitor_momentum : float\n            A momentum to smooth the monitor history with. Usually recommended to smooth out\n            any fluctuations in the monitor value.\n        monitor_while : {'auto', 'training', 'validating'}\n            Whether to monitor while training or validating. If the monitor is specified\n            (i.e. is not 'auto'), this can be left to 'auto'.\n        exclude_param_groups : int or list\n            Parameter groups to __not__ apply the LR decay on.\n        verbose : bool\n            Specifies if a message be printed before decaying.\n        \"\"\"\n        super(AutoLR, self).__init__(monitor=monitor, monitor_momentum=monitor_momentum,\n                                     monitor_while=monitor_while)\n        # Validate\n        assert_(consider_improvement_with_respect_to in ['best', 'previous'],\n                \"`consider_improvement_with_respect_to` must be either 'best' or 'previous', \"\n                \"and not {}\".format(consider_improvement_with_respect_to),\n                ValueError)\n        # Privates\n        self._patience = None\n        self._cooldown = None\n        self._last_decayed_at = {'iteration_count': None, 'epoch_count': None}\n        self._last_improved_at = {'iteration_count': None, 'epoch_count': None}\n        self._best_monitor_value = None\n        # Publics\n        self.patience = patience\n        self.cooldown_duration = cooldown_duration\n        self.factor = factor\n        self.required_minimum_relative_improvement = required_minimum_relative_improvement\n        self.consider_improvement_with_respect_to = consider_improvement_with_respect_to\n        self.exclude_param_groups = pyu.to_iterable(exclude_param_groups) \\\n            if exclude_param_groups is not None else None\n        self.verbose = verbose\n\n    @property\n    def patience(self):\n        assert_(self._patience is not None, \"Patience is not set yet.\", NotSetError)\n        return self._patience\n\n    @patience.setter\n    def patience(self, value):\n        self._patience = Duration.build_from(value)\n\n    @property\n    def cooldown_duration(self):\n        return self._cooldown\n\n    @cooldown_duration.setter\n    def cooldown_duration(self, value):\n        if value is not None:\n            self._cooldown = Duration.build_from(value)\n\n    @property\n    def duration_since_last_decay(self):\n        since_last_decayed = {}\n        if self._last_decayed_at.get('iteration_count') is None:\n            since_last_decayed.update({'iteration_count': self.trainer.iteration_count})\n        else:\n            since_last_decayed.update(\n                {'iteration_count': (self.trainer.iteration_count -\n                                     self._last_decayed_at['iteration_count'])\n                 })\n\n        if self._last_decayed_at.get('epoch_count') is None:\n            since_last_decayed.update({'epoch_count': self.trainer.epoch_count})\n        else:\n            since_last_decayed.update(\n                {'epoch_count': (self.trainer.epoch_count -\n                                 self._last_decayed_at['epoch_count'])\n                 })\n        return since_last_decayed\n\n    @property\n    def duration_since_last_improvment(self):\n        since_last_improved = {}\n        if self._last_improved_at.get('iteration_count') is None:\n            since_last_improved.update({'iteration_count': self.trainer.iteration_count})\n        else:\n            since_last_improved.update(\n                {'iteration_count': (self.trainer.iteration_count -\n                                     self._last_improved_at['iteration_count'])\n                 })\n\n        if self._last_improved_at.get('epoch_count') is None:\n            since_last_improved.update({'epoch_count': self.trainer.epoch_count})\n        else:\n            since_last_improved.update(\n                {'epoch_count': (self.trainer.epoch_count -\n                                 self._last_improved_at['epoch_count'])\n                 })\n        return since_last_improved\n\n    @property\n    def out_of_patience(self):\n        return self.patience.match(**self.duration_since_last_improvment)\n\n    @property\n    def in_cooldown(self):\n        if self.cooldown_duration is not None:\n            return not self.cooldown_duration.match(**self.duration_since_last_decay)\n        else:\n            return False\n\n    def decay(self):\n        exclude_param_groups = \\\n            [] if self.exclude_param_groups is None else list(self.exclude_param_groups)\n        for param_group_num, param_group in enumerate(self.trainer.optimizer.param_groups):\n            if param_group_num not in exclude_param_groups:\n                param_group['lr'] *= self.factor\n                self.debug_print(\"Decayed LR of param_group {} from {} --> {}\"\n                                 .format(param_group_num,\n                                         param_group['lr'] / self.factor,\n                                         param_group['lr']))\n        self._last_decayed_at.update({'iteration_count': self.trainer.iteration_count,\n                                      'epoch_count': self.trainer.epoch_count})\n\n    def maintain_monitor_moving_average(self):\n        monitor_value = super(AutoLR, self).maintain_monitor_moving_average()\n        if self._best_monitor_value is None:\n            self._best_monitor_value = monitor_value\n\n    @property\n    def monitor_value_has_significantly_improved(self):\n        if self._monitor_value_moving_average.previous is None:\n            # There's nothing to compare with\n            return True\n        else:\n            improvement_baseline = \\\n                self._best_monitor_value \\\n                if self.consider_improvement_with_respect_to == 'best' else \\\n                self._monitor_value_moving_average.previous\n            monitor_value_has_significantly_improved = \\\n                self.is_significantly_less_than(self._monitor_value_moving_average.val,\n                                                improvement_baseline,\n                                                self.required_minimum_relative_improvement)\n            self.debug_print(\"Is {} significantly less than {} with min_relative_delta = {}? {}.\"\n                             .format(self._monitor_value_moving_average.val,\n                                     improvement_baseline,\n                                     self.required_minimum_relative_improvement,\n                                     monitor_value_has_significantly_improved))\n            # monitor_value_has_significantly_improved could be False, even if the current\n            # moving average is less than the best monitor value, if the improvement is not\n            # significant enough\n            self._best_monitor_value = min([self._best_monitor_value,\n                                            self._monitor_value_moving_average.val])\n            if monitor_value_has_significantly_improved:\n                self._last_improved_at.update({'iteration_count': self.trainer.iteration_count,\n                                               'epoch_count': self.trainer.epoch_count})\n            return monitor_value_has_significantly_improved\n\n    def end_of_training_iteration(self, **_):\n        # Decay if we're not in cooldown (and monitoring while training)\n        if self.monitor_while == 'training':\n            self.maintain_monitor_moving_average()\n            if not self.monitor_value_has_significantly_improved and \\\n                    self.out_of_patience and not self.in_cooldown:\n                if self.verbose:\n                    self.trainer.console.info(\"Monitor '{}' has not significantly improved, decaying LR.\"\n                                       .format(self.monitor))\n                self.decay()\n\n    def end_of_validation_run(self, **_):\n        if self.monitor_while == 'validation':\n            self.maintain_monitor_moving_average()\n            if not self.monitor_value_has_significantly_improved \\\n                    and self.out_of_patience and not self.in_cooldown:\n                if self.verbose:\n                    self.trainer.console.info(\"Monitor '{}' has not significantly improved \"\n                                       \"({} vs. {}), decaying LR.\"\n                                       .format(self.monitor,\n                                               self._monitor_value_moving_average.val,\n                                               self._best_monitor_value))\n                self.decay()\n\n    @staticmethod\n    def is_significantly_less_than(x, y, min_relative_delta):\n        eps = 1.e-6\n        if x > y:\n            return False\n        relative_delta = abs(y - x) / (abs(y) + eps)\n        return relative_delta > min_relative_delta\n\n\nclass AutoLRDecay(AutoLR):\n    \"\"\"\n    Callback to decay the learning rate automatically when a specified monitor\n    stops improving.\n\n    The monitor should be decreasing, i.e. lower value --> better performance.\n    \"\"\"\n    pass\n\n\nclass DecaySpec(object):\n    \"\"\"A class to specify when to decay (or hike) LR and by what factor.\"\"\"\n    def __init__(self, duration, factor):\n        # Privates\n        self._matched = False\n        # Publics\n        self.duration = Duration.build_from(duration)\n        self.factor = factor\n\n    def match(self, iteration_count=None, epoch_count=None, when_equal_return=True):\n        match_result = self.duration.match(iteration_count=iteration_count,\n                                           epoch_count=epoch_count,\n                                           when_equal_return=when_equal_return)\n        if match_result and not self._matched:\n            # First match\n            self._matched = True\n            return match_result\n        else:\n            # Already matched once (or more often)\n            return False\n\n    def new(self):\n        return type(self)(self.duration, self.factor)\n\n    @classmethod\n    def build_from(cls, args):\n        if isinstance(args, (list, tuple)):\n            return cls(*args)\n        elif isinstance(args, dict):\n            return cls(**args)\n        elif isinstance(args, cls):\n            return args\n        else:\n            raise NotImplementedError(\"Can't build DecaySpec from {}.\".format(type(args)))\n\n\nclass ManualLR(Callback):\n    def __init__(self, decay_specs, exclude_param_groups=None):\n        super(ManualLR, self).__init__()\n        self.decay_specs = [DecaySpec.build_from(decay_spec)\n                            for decay_spec in pyu.to_iterable(decay_specs)]\n        self.exclude_param_groups = pyu.to_iterable(exclude_param_groups) \\\n            if exclude_param_groups is not None else None\n\n    def match(self):\n        # Find the decayspec that matched\n        matched = [decay_spec\n                   for decay_spec in self.decay_specs\n                   if decay_spec.match(iteration_count=self.trainer.iteration_count,\n                                       epoch_count=self.trainer.epoch_count)]\n        if matched:\n            # Allow for more than one matches; in which case the factors are multiplied\n            global_factor = reduce(lambda x, y: x * y,\n                                   [matched_decay_spec.factor for matched_decay_spec in matched])\n            return True, global_factor\n        else:\n            return False, None\n\n    def decay(self, factor):\n        exclude_param_groups = \\\n            [] if self.exclude_param_groups is None else list(self.exclude_param_groups)\n        for param_group_num, param_group in enumerate(self.trainer.optimizer.param_groups):\n            if param_group_num not in exclude_param_groups:\n                param_group['lr'] *= factor\n                self.debug_print(\"Decayed LR of param_group {} from {} --> {}\"\n                                 .format(param_group_num,\n                                         param_group['lr'] / factor,\n                                         param_group['lr']))\n\n    def end_of_training_iteration(self, **_):\n        matched, global_factor = self.match()\n        if matched:\n            assert global_factor is not None\n            self.decay(global_factor)\n\n\nclass SaveModelRegularly(Callback):\n    \"\"\"saves the network weights in regular intervals\"\"\"\n\n    def __init__(self, frequency):\n        super().__init__()\n        self._save_every = Frequency.build_from(frequency)\n\n    @property\n    def save_now(self):\n        return self._save_every.match(iteration_count=self.trainer.iteration_count,\n                                      epoch_count=self.trainer.epoch_count,\n                                      persistent=True, match_zero=True)\n\n    def end_of_training_iteration(self, **_):\n        if self.save_now:\n            self.trainer.save_model()\n"
  },
  {
    "path": "inferno/trainers/callbacks/tqdm.py",
    "content": "from .base import Callback\nfrom tqdm import tqdm\nfrom datetime import datetime\nfrom .console import Console\n\n\nclass TQDMPrinter(object):\n    def __init__(self, progress):\n        self._progress = progress\n\n    def print(self, message):\n        if self._progress.outer_bar is not None:\n            self._progress.outer_bar.clear()\n        tqdm.write(message)\n        if self._progress.outer_bar is not None:\n            self._progress.outer_bar.refresh()\n\n\nclass TQDMConsole(Console):\n    def __init__(self):\n        super(TQDMConsole, self).__init__(printer=TQDMPrinter(TQDMProgressBar()))\n\n\nclass TQDMProgressBar(Callback):\n    def __init__(self, *args, **kwargs):\n        super(TQDMProgressBar, self).__init__(*args, **kwargs)\n        self.epoch_bar = None\n        self.outer_bar = None\n        self.is_training = False\n        self.is_validation = False\n\n    def bind_trainer(self, *args, **kwargs):\n        super(TQDMProgressBar, self).bind_trainer(*args, **kwargs)\n        self.trainer.console.toggle_progress(False)\n        self.trainer.console.set_console(TQDMPrinter(self))\n\n    def _init_epoch_bar_train(self):\n        n_batch = len(self.trainer._loader_iters['train'])\n        self.epoch_bar = tqdm(total=n_batch, position=1, dynamic_ncols=True)\n        self.epoch_bar.update(self.trainer._batch_count)\n        self.epoch_bar.set_description(\"Training epoch %d\" % self.trainer._epoch_count)\n\n    def print(self, message, **_):\n        if self.outer_bar is not None:\n            self.outer_bar.clear()\n        tqdm.write(\"[+][{}] {}\".format(str(datetime.now()), message))\n        if self.outer_bar is not None:\n            self.outer_bar.refresh()\n\n    def begin_of_fit(self, max_num_epochs, **_):\n        if isinstance(max_num_epochs, int):\n            self.outer_bar = tqdm(total=max_num_epochs, position=0, dynamic_ncols=True)\n        else:\n            self.outer_bar = tqdm(total=1000, position=0, dynamic_ncols=True)\n        self.outer_bar.set_description(\"Epochs\")\n\n    def end_of_fit(self, **_):\n        if self.outer_bar is not None:\n            self.outer_bar.close()\n            self.outer_bar = None\n\n    def begin_of_epoch(self, **_):\n        if self.epoch_bar is not None:\n            self.epoch_bar.close()\n\n    def end_of_epoch(self, **_):\n        if self.outer_bar is not None:\n            self.outer_bar.update(1)\n\n    def begin_of_training_iteration(self, **_):\n        if not self.epoch_bar and 'train' in self.trainer._loader_iters.keys():\n            self._init_epoch_bar_train()\n            return\n\n        if self.epoch_bar:\n            self.epoch_bar.update(1)\n\n    def begin_of_validation_iteration(self, **_):\n        if self.epoch_bar:\n            self.epoch_bar.update(1)\n\n    def begin_of_training_run(self, **_):\n        self.is_training = True\n\n    def end_of_training_run(self, **_):\n        self.is_training = False\n        if self.epoch_bar:\n            self.epoch_bar.close()\n            self.epoch_bar = None\n\n    def begin_of_validation_run(self, num_iterations, num_iterations_in_generator, last_validated_at_epoch, **_):\n        self.is_validation = True\n        nmax = num_iterations\n        if not nmax:\n            nmax = num_iterations_in_generator\n\n        self.epoch_bar = tqdm(total=nmax, position=1, dynamic_ncols=True)\n        self.epoch_bar.set_description(\"Validating epoch %d\" % (last_validated_at_epoch-1))\n\n    def end_of_validation_run(self, **_):\n        self.is_validation = False\n        if self.epoch_bar:\n            self.epoch_bar.close()\n            self.epoch_bar = None\n"
  },
  {
    "path": "inferno/trainers/callbacks/tqdmstub.py",
    "content": "from .base import Callback\n\nclass TQDMProgressBar(Callback):\n    def __init__(self, *args, **kwargs):\n        super(TQDMProgressBar, self).__init__(*args, **kwargs)\n\n    def bind_trainer(self, *args, **kwargs):\n        super(TQDMProgressBar, self).bind_trainer(*args, **kwargs)\n        self.trainer.console.warning(\"tqdm is not installed. will fall back to normal stdout console.\")\n\n    def begin_of_fit(self, **_):\n        pass\n"
  },
  {
    "path": "inferno/utils/__init__.py",
    "content": ""
  },
  {
    "path": "inferno/utils/exceptions.py",
    "content": "\"\"\"Exceptions and Error Handling\"\"\"\n\n\ndef assert_(condition, message='', exception_type=AssertionError):\n    \"\"\"Like assert, but with arbitrary exception types.\"\"\"\n    if not condition:\n        raise exception_type(message)\n\n\n# ------ VALUE ERRORS ------\n\n\nclass ShapeError(ValueError):\n    pass\n\n\nclass FrequencyValueError(ValueError):\n    pass\n\n\nclass DeviceError(ValueError):\n    pass\n\n\nclass NotSetError(ValueError):\n    pass\n\n\n# ------ TYPE ERRORS ------\n\n\nclass NotTorchModuleError(TypeError):\n    pass\n\n\nclass FrequencyTypeError(TypeError):\n    pass\n\n\nclass DTypeError(TypeError):\n    pass\n\n\n# ------ LOOKUP ERRORS ------\n\n\nclass ClassNotFoundError(LookupError):\n    pass\n\n\n# ------ NOT-IMPLEMENTED ERRORS ------\n\n\nclass NotUnwrappableError(NotImplementedError):\n    pass"
  },
  {
    "path": "inferno/utils/io_utils.py",
    "content": "import os\nimport h5py as h5\nimport numpy as np\nimport yaml\nfrom skimage.io import imsave\n\n\n# Function to load in a dataset from a h5file\ndef fromh5(path, datapath=None, dataslice=None, asnumpy=True, preptrain=None):\n    \"\"\"\n    Opens a hdf5 file at path, loads in the dataset at datapath, and returns dataset\n    as a numpy array.\n    \"\"\"\n    # Check if path exists (thanks Lukas!)\n    assert os.path.exists(path), \"Path {} does not exist.\".format(path)\n    with h5.File(path, 'r') as f:\n        # Init dataset\n        h5dataset = f[datapath] if datapath is not None else f.values()[0]\n        # Slice dataset\n        h5dataset = h5dataset[dataslice] if dataslice is not None else h5dataset\n        # Convert to numpy if required\n        h5dataset = np.asarray(h5dataset) if asnumpy else h5dataset\n        # Apply preptrain\n        h5dataset = preptrain(h5dataset) if preptrain is not None else h5dataset\n    return h5dataset\n\n\n# TODO we could also do **h5_kwargs instead\ndef toh5(data, path, datapath='data', compression=None, chunks=None):\n    \"\"\"Write `data` to a HDF5 volume.\"\"\"\n    with h5.File(path) as f:\n        f.create_dataset(datapath, data=data, compression=compression, chunks=chunks)\n\n\ndef fromz5(path, datapath, dataslice=None, n_threads=8):\n    # we import z5py only here because we don't want to assume that it's in the env\n    import z5py\n    assert os.path.exists(path), \"Path {} does not exist.\".format(path)\n    with z5py.File(path) as f:\n        ds = f[datapath]\n        ds.n_threads = n_threads\n        data = ds[:] if dataslice is None else ds[dataslice]\n    return data\n\n\n# Yaml to dict reader\ndef yaml2dict(path):\n    if isinstance(path, dict):\n        # Forgivable mistake that path is a dict already\n        return path\n    with open(path, 'r') as f:\n        readict = yaml.load(f, Loader=yaml.FullLoader)\n    return readict\n\n\ndef print_tensor(tensor, prefix, directory):\n    \"\"\"Prints a image or volume tensor to file as images.\"\"\"\n    def _print_image(image, prefix, batch, channel, z=None):\n        if z is None:\n            file_name = \"{}--B-{}--CH-{}.png\".format(prefix, batch, channel)\n        else:\n            file_name = \"{}--B-{}--CH-{}--Z-{}.png\".format(prefix, batch, channel, z)\n        full_file_name = os.path.join(directory, file_name)\n        imsave(arr=image, fname=full_file_name)\n\n    for batch in range(tensor.shape[0]):\n        for channel in range(tensor.shape[1]):\n            if tensor.ndim == 4:\n                _print_image(tensor[batch, channel, ...], prefix, batch, channel)\n            else:\n                for plane in range(tensor.shape[2]):\n                    _print_image(tensor[batch, channel, plane, ...], prefix, batch, channel, plane)\n"
  },
  {
    "path": "inferno/utils/math_utils.py",
    "content": "\n\ndef max_allowed_ds_steps(shape, factor):\n    \"\"\"How often can a shape be down-sampled by a given factor\n        such that non of the divisions will give non-integers.\n\n    Args:\n        shape (listlike): tensor shape\n        factor (integer): downsample factor\n\n    Returns:\n        int: maximum allowed downsample operations\n    \"\"\"\n    def max_allowed_ds_steps_impl(size, factor):\n\n        current_size = float(size)\n        allowed_steps = 0\n        while(True):\n\n            new_size = current_size / float(factor)\n            if(new_size >=1 and new_size.is_integer()):\n\n                current_size = new_size\n                allowed_steps += 1\n            else:\n                break\n        return allowed_steps\n\n    min_steps = float('inf')\n\n    for s in shape:\n        min_steps = int(min(min_steps, max_allowed_ds_steps_impl(s, factor)))\n\n    return min_steps\n"
  },
  {
    "path": "inferno/utils/model_utils.py",
    "content": "import torch\nfrom .exceptions import assert_, NotTorchModuleError, ShapeError\n\n\ndef is_model_cuda(model):\n    try:\n        return next(model.parameters()).is_cuda\n    except StopIteration:\n        # Assuming that if a network has no parameters, it doesn't use CUDA\n        return False\n\n\nclass ModelTester(object):\n    def __init__(self, input_shape, expected_output_shape):\n        self._is_cuda = False\n        self.input_shape = input_shape\n        self.expected_output_shape = expected_output_shape\n\n    def cuda(self):\n        self._is_cuda = True\n        return self\n\n    def get_input(self):\n        with torch.no_grad():\n            if self._is_cuda:\n                return torch.rand(*self.input_shape, requires_grad=False).cuda()\n            else:\n                return torch.rand(*self.input_shape, requires_grad=False)\n\n    def __call__(self, model):\n        # Make sure model is a model\n        assert_(isinstance(model, torch.nn.Module),\n                \"Model is not a torch module.\",\n                NotTorchModuleError)\n        # Transfer to cuda if required\n        if not is_model_cuda(model) and self._is_cuda:\n            model.cuda()\n        input_ = self.get_input()\n        output = model(input_)\n        assert_(list(output.size()) == list(self.expected_output_shape),\n                \"Expected output shape {} for input shape {}, \"\n                \"got output of shape {} instead.\".format(list(self.expected_output_shape),\n                                                         list(self.input_shape),\n                                                         list(output.size())),\n                ShapeError)\n        return model\n\n\nclass MultiscaleModelTester(ModelTester):\n    def __call__(self, model):\n        # Make sure model is a model\n        assert_(isinstance(model, torch.nn.Module),\n                \"Model is not a torch module.\",\n                NotTorchModuleError)\n        # Transfer to cuda if required\n        if not is_model_cuda(model) and self._is_cuda:\n            model.cuda()\n        input_ = self.get_input()\n        output = model(input_)\n        assert_(isinstance(output, tuple), \"Expect tuple output\")\n        for scale in range(len(output)):\n            assert_(list(output[scale].size()) == list(self.expected_output_shape[scale]),\n                    \"Expected output shape {} for input shape {}, \"\n                    \"got output of shape {} instead.\".format(list(self.expected_output_shape[scale]),\n                                                             list(self.input_shape),\n                                                             list(output[scale].size())),\n                    ShapeError)\n        return model\n"
  },
  {
    "path": "inferno/utils/partial_cls.py",
    "content": "import functools\nimport sys\nimport types\nimport inspect\n\n\n__all__ =  [\n    'partial_cls',\n    'register_partial_cls'\n]\n\n\ndef partial_cls(base_cls, name, module, fix=None, default=None):\n\n    # helper function\n    def insert_if_not_present(dict_a, dict_b):\n        for kw,val in dict_b.items():\n            if kw not in dict_a:\n                dict_a[kw] = val\n        return dict_a\n\n    # helper function\n    def insert_call_if_present(dict_a, dict_b, callback):\n        for kw,val in dict_b.items():\n            if kw not in dict_a:\n                dict_a[kw] = val\n            else:\n                callback(kw)\n        return dict_a\n\n    # helper class\n    class PartialCls(object):\n        def __init__(self, base_cls, name, module, fix=None, default=None):\n\n            self.base_cls = base_cls\n            self.name = name\n            self.module = module\n            self.fix = [fix, {}][fix is None]\n            self.default = [default, {}][default is None]\n\n            if self.fix.keys() & self.default.keys():\n                raise TypeError('fix and default share keys')\n\n            # remove binded kw\n            self._allowed_kw = self._get_allowed_kw()\n\n        def _get_allowed_kw(self):\n\n            \n            argspec = inspect.getfullargspec(base_cls.__init__)\n            args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations = argspec\n\n            if varargs is not None:\n                raise TypeError('partial_cls can only be used if __init__ has no varargs')\n\n            if varkw is not None:\n                raise TypeError('partial_cls can only be used if __init__ has no varkw')\n\n            if kwonlyargs is not None and kwonlyargs != []:\n                raise TypeError('partial_cls can only be used without kwonlyargs')\n\n            if args is None or len(args) < 1:\n                raise TypeError('seems like self is missing')\n            \n            \n            return [kw for kw in args[1:] if kw  not in self.fix]   \n         \n\n        def _build_kw(self, args, kwargs):\n            # handle *args\n            if len(args) > len(self._allowed_kw):\n                raise TypeError(\"to many arguments\")\n\n            all_args =  {}\n            for arg, akw in zip(args, self._allowed_kw):\n                all_args[akw] = arg\n\n            # handle **kwargs\n            intersection = self.fix.keys() & kwargs.keys()\n            if len(intersection) >= 1:\n                kw = intersection.pop()\n                raise TypeError(\"`{}.__init__` got unexpected keyword argument '{}'\".format(name, kw))\n\n            def raise_cb(kw):\n                raise TypeError(\"{}.__init__ got multiple values for argument '{}'\".format(name, kw))\n            all_args = insert_call_if_present(all_args, kwargs, raise_cb)\n\n            # handle fixed arguments\n            def raise_cb(kw):\n                raise TypeError()\n            all_args = insert_call_if_present(all_args, self.fix, raise_cb)\n\n            # handle defaults\n            all_args = insert_if_not_present(all_args, self.default)\n\n            # handle fixed \n            all_args.update(self.fix)\n\n            return all_args\n\n        def build_cls(self):\n\n            def new_init(self_of_new_cls, *args, **kwargs):\n                combined_args = self._build_kw(args=args, kwargs=kwargs)\n\n                #call base cls init\n                super(self_of_new_cls.__class__, self_of_new_cls).__init__(**combined_args)\n\n            return type(name, (self.base_cls,), {\n                '__module__': self.module,\n                '__init__' : new_init\n            })\n            return cls\n\n\n    return PartialCls(base_cls=base_cls, name=name, module=module,\n        fix=fix, default=default).build_cls()\n\n\ndef register_partial_cls(base_cls, name, module, fix=None, default=None):\n    module_dict = sys.modules[module].__dict__\n    generatedClass = partial_cls(base_cls=base_cls,name=name, module=module,\n        fix=fix, default=default)\n    module_dict[generatedClass.__name__] = generatedClass\n    del generatedClass\n\n\nif __name__ == \"__main__\":\n\n    class Conv(object):\n        def __init__(self, dim, activation, stride=1):\n            print(f\"dim {dim} act {activation} stride {stride}\")\n\n\n    Conv2D = partial_cls(Conv,'Conv2D',__name__, fix=dict(dim=2), default=dict(stride=2))\n\n\n    #obj =  Conv2D(activation='a')\n    #obj =  Conv2D('a',activation='a', stride=3)\n    obj =  Conv2D('fu','bar')    \n\n"
  },
  {
    "path": "inferno/utils/python_utils.py",
    "content": "\"\"\"Utility functions with no external dependencies.\"\"\"\nimport signal\nimport warnings\nimport functools\nimport inspect\nimport os\n\nfrom threading import current_thread, main_thread\n\n\ndef ensure_dir(directory):\n    \"\"\"ensure the existence of e directory at a given path\n\n        If the directory does not exist it is created\n\n    Args:\n        directory (str): path of the directory\n\n    Returns:\n        str: path of the directory\n    \"\"\"\n    if not os.path.exists(directory):\n        os.makedirs(directory)\n    return directory\n\n\ndef require_dict_kwargs(kwargs, msg=None):\n    \"\"\" Ensure arguments passed kwargs are either None or a dict.\n        If arguments are neither a dict nor None a RuntimeError\n        is thrown\n    Args:\n        kwargs (object): possible dict or None\n        msg (None, optional): Error msg\n\n    Returns:\n        dict: kwargs dict\n\n    Raises:\n        RuntimeError: if the passed value is neither a dict nor None\n            this error is raised\n    \"\"\"\n    if kwargs is None:\n        return dict()\n    elif isinstance(kwargs, dict):\n        return kwargs\n    else:\n        if msg is None:\n            raise RuntimeError(\"value passed as keyword argument dict is neither None nor a dict\")\n        else:\n            raise RuntimeError(\"%s\"%str(msg))\n\n\ndef is_listlike(x):\n    return isinstance(x, (list, tuple))\n\n\ndef to_iterable(x):\n    return [x] if not is_listlike(x) else x\n\n\ndef from_iterable(x):\n    return x[0] if (is_listlike(x) and len(x) == 1) else x\n\n\ndef robust_len(x):\n    return len(x) if is_listlike(x) else 1\n\n\ndef as_tuple_of_len(x, len_):\n    if is_listlike(x):\n        assert len(x) == len_, \\\n            \"Listlike object of len {} can't be returned \" \\\n            \"as a tuple of length {}.\".format(len(x), len_)\n        return tuple(x)\n    else:\n        return (x,) * len_\n\n\ndef has_callable_attr(object_, name):\n    return hasattr(object_, name) and callable(getattr(object_, name))\n\n\ndef is_maybe_list_of(check_function):\n    def decorated_function(object_, **kwargs):\n        if isinstance(object_, (list, tuple)):\n            return all([check_function(_object, **kwargs) for _object in object_])\n        else:\n            return check_function(object_, **kwargs)\n    return decorated_function\n\n\nclass delayed_keyboard_interrupt(object):\n    \"\"\"\n    Delays SIGINT over critical code.\n    Borrowed from:\n    https://stackoverflow.com/questions/842557/\n    how-to-prevent-a-block-of-code-from-being-interrupted-by-keyboardinterrupt-in-py\n    \"\"\"\n    # PEP8: Context manager class in lowercase\n    def __enter__(self):\n        if current_thread() is main_thread():\n            self.signal_received = False\n            self.old_handler = signal.getsignal(signal.SIGINT)\n            signal.signal(signal.SIGINT, self.handler)\n\n    def handler(self, sig, frame):\n        self.signal_received = (sig, frame)\n\n    def __exit__(self, type, value, traceback):\n        if current_thread() is main_thread():\n            signal.signal(signal.SIGINT, self.old_handler)\n            if self.signal_received:\n                self.old_handler(*self.signal_received)\n\n\ndef get_config_for_name(config, name):\n    config_for_name = {}\n    for key, val in config.items():\n        if isinstance(val, dict) and name in val:\n            # we leave the slicing_config validation to classes higher up in MRO\n            config_for_name.update({key: val.get(name)})\n        else:\n            config_for_name.update({key: val})\n    return config_for_name\n\nstring_types = (type(b''), type(u''))\n\n\ndef deprecated(reason):\n    \"\"\"\n    This is a decorator which can be used to mark functions\n    as deprecated. It will result in a warning being emitted\n    when the function is used.\n\n    Borrowed from\n    https://stackoverflow.com/questions/2536307/\n    decorators-in-the-python-standard-lib-deprecated-specifically\n    by Laurent LAPORTE\n    https://stackoverflow.com/users/1513933/laurent-laporte\n\n    \"\"\"\n\n    if isinstance(reason, string_types):\n\n        # The @deprecated is used with a 'reason'.\n        #\n        # .. code-block:: python\n        #\n        #    @deprecated(\"please, use another function\")\n        #    def old_function(x, y):\n        #      pass\n\n        def decorator(func1):\n\n            if inspect.isclass(func1):\n                fmt1 = \"Call to deprecated class {name} ({reason}).\"\n            else:\n                fmt1 = \"Call to deprecated function {name} ({reason}).\"\n\n            @functools.wraps(func1)\n            def new_func1(*args, **kwargs):\n                warnings.simplefilter('always', DeprecationWarning)\n                warnings.warn(\n                    fmt1.format(name=func1.__name__, reason=reason),\n                    category=DeprecationWarning,\n                    stacklevel=2\n                )\n                warnings.simplefilter('default', DeprecationWarning)\n                return func1(*args, **kwargs)\n\n            return new_func1\n\n        return decorator\n\n    elif inspect.isclass(reason) or inspect.isfunction(reason):\n\n        # The @deprecated is used without any 'reason'.\n        #\n        # .. code-block:: python\n        #\n        #    @deprecated\n        #    def old_function(x, y):\n        #      pass\n\n        func2 = reason\n\n        if inspect.isclass(func2):\n            fmt2 = \"Call to deprecated class {name}.\"\n        else:\n            fmt2 = \"Call to deprecated function {name}.\"\n\n        @functools.wraps(func2)\n        def new_func2(*args, **kwargs):\n            warnings.simplefilter('always', DeprecationWarning)\n            warnings.warn(\n                fmt2.format(name=func2.__name__),\n                category=DeprecationWarning,\n                stacklevel=2\n            )\n            warnings.simplefilter('default', DeprecationWarning)\n            return func2(*args, **kwargs)\n\n        return new_func2\n\n    else:\n        raise TypeError(repr(type(reason)))\n"
  },
  {
    "path": "inferno/utils/test_utils.py",
    "content": "import torch\nfrom torch.utils.data.dataset import TensorDataset\nfrom torch.utils.data.dataloader import DataLoader\nimport numpy as np\n\n\ndef generate_random_data(num_samples, shape, num_classes, hardness=0.3, dtype=None):\n    \"\"\"Generate a random dataset with a given hardness and number of classes.\"\"\"\n    dataset_input = np.zeros((num_samples,) + shape, dtype=dtype)\n    dataset_target = np.random.randint(num_classes, size=num_samples)\n    for sample_num in range(num_samples):\n        dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num],\n                                                     scale=(1 - hardness),\n                                                     size=shape)\n    return dataset_input, dataset_target\n\n\ndef generate_random_dataset(num_samples, shape, num_classes, hardness=0.3, dtype=None):\n    \"\"\"Generate a random dataset with a given hardness and number of classes.\"\"\"\n    # Generate numpy arrays\n    dataset_input, dataset_target = generate_random_data(num_samples, shape, num_classes,\n                                                         hardness=hardness, dtype=dtype)\n    # Convert to tensor and build dataset\n    dataset = TensorDataset(torch.from_numpy(dataset_input),\n                            torch.from_numpy(dataset_target))\n    return dataset\n\n\ndef generate_random_dataloader(num_samples, shape, num_classes, hardness=0.3, dtype=None,\n                               batch_size=1, shuffle=False, num_workers=0, pin_memory=False):\n    \"\"\"Generate a loader with a random dataset of given hardness and number of classes.\"\"\"\n    dataset = generate_random_dataset(num_samples, shape, num_classes, hardness=hardness,\n                                      dtype=dtype)\n    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,\n                            num_workers=num_workers, pin_memory=pin_memory)\n    return dataloader\n"
  },
  {
    "path": "inferno/utils/torch_utils.py",
    "content": "import numpy as np\nimport torch\n\nfrom .python_utils import delayed_keyboard_interrupt\nfrom .exceptions import assert_, ShapeError, NotUnwrappableError\n\n\ndef unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False):\n    if isinstance(input_, (list, tuple)):\n        return type(input_)([unwrap(_t, to_cpu=to_cpu, as_numpy=as_numpy)\n                             for _t in input_])\n    elif torch.is_tensor(input_):\n        tensor = input_\n    elif isinstance(input_, np.ndarray):\n        return input_\n    elif isinstance(input_, (float, int)):\n        return input_\n    else:\n        raise NotUnwrappableError(\"Cannot unwrap a '{}'.\"\n                                  .format(type(input_).__name__))\n    # Transfer to CPU if required\n    if to_cpu:\n        with delayed_keyboard_interrupt():\n            tensor = tensor.cpu().detach()\n    # Convert to numpy if required\n    if as_numpy:\n        return tensor.cpu().detach().numpy()\n    elif extract_item:\n        try:\n            return tensor.item()\n        except AttributeError:\n            return tensor[0]\n    else:\n        return tensor\n\n\ndef is_tensor(object_):\n    missed_tensor_classes = (torch.HalfTensor,)\n    return torch.is_tensor(object_) or isinstance(object_, missed_tensor_classes)\n\n\ndef is_label_tensor(object_):\n    return is_tensor(object_) and object_.type() in ['torch.LongTensor', 'torch.cuda.LongTensor']\n\n\ndef is_image_tensor(object_):\n    return is_tensor(object_) and object_.dim() == 4\n\n\ndef is_volume_tensor(object_):\n    return is_tensor(object_) and object_.dim() == 5\n\n\ndef is_image_or_volume_tensor(object_):\n    return is_image_tensor(object_) or is_volume_tensor(object_)\n\n\ndef is_label_image_tensor(object_):\n    return is_label_tensor(object_) and object_.dim() == 3\n\n\ndef is_label_volume_tensor(object_):\n    return is_label_tensor(object_) and object_.dim() == 4\n\n\ndef is_label_image_or_volume_tensor(object_):\n    return is_label_image_tensor(object_) or is_label_volume_tensor(object_)\n\n\ndef is_matrix_tensor(object_):\n    return is_tensor(object_) and object_.dim() == 2\n\n\ndef is_scalar_tensor(object_):\n    return is_tensor(object_) and object_.dim() <= 1 and object_.numel() == 1\n\n\ndef is_vector_tensor(object_):\n    return is_tensor(object_) and object_.dim() == 1 and object_.numel() > 1\n\n\ndef assert_same_size(tensor_1, tensor_2):\n    assert_(list(tensor_1.size()) == list(tensor_2.size()),\n            \"Tensor sizes {} and {} do not match.\".format(tensor_1.size(), tensor_2.size()),\n            ShapeError)\n\n\ndef where(condition, if_true, if_false):\n    \"\"\"\n    Torch equivalent of numpy.where.\n\n    Parameters\n    ----------\n    condition : torch.ByteTensor or torch.cuda.ByteTensor\n        Condition to check.\n    if_true : torch.Tensor or torch.cuda.Tensor\n        Output value if condition is true.\n    if_false: torch.Tensor or torch.cuda.Tensor\n        Output value if condition is false\n\n    Returns\n    -------\n    torch.Tensor\n\n    Raises\n    ------\n    AssertionError\n        if if_true and if_false don't have the same datatype.\n    \"\"\"\n    # noinspection PyArgumentList\n    assert if_true.type() == if_false.type(), \\\n        \"Type mismatch: {} and {}\".format(if_true.data.type(), if_false.data.type())\n    casted_condition = condition.type_as(if_true)\n    output = casted_condition * if_true + (1 - casted_condition) * if_false\n    return output\n\n\ndef flatten_samples(input_):\n    \"\"\"\n    Flattens a tensor or a variable such that the channel axis is first and the sample axis\n    is second. The shapes are transformed as follows:\n        (N, C, H, W) --> (C, N * H * W)\n        (N, C, D, H, W) --> (C, N * D * H * W)\n        (N, C) --> (C, N)\n    The input must be atleast 2d.\n    \"\"\"\n    assert_(input_.dim() >= 2,\n            \"Tensor or variable must be atleast 2D. Got one of dim {}.\"\n            .format(input_.dim()),\n            ShapeError)\n    # Get number of channels\n    num_channels = input_.size(1)\n    # Permute the channel axis to first\n    permute_axes = list(range(input_.dim()))\n    permute_axes[0], permute_axes[1] = permute_axes[1], permute_axes[0]\n    # For input shape (say) NCHW, this should have the shape CNHW\n    permuted = input_.permute(*permute_axes).contiguous()\n    # Now flatten out all but the first axis and return\n    flattened = permuted.view(num_channels, -1)\n    return flattened\n\n\ndef clip_gradients_(parameters, mode, norm_or_value):\n    assert_(mode in ['norm', 'value'],\n            f\"Mode must be 'norm' or 'value', got '{mode}' instead.\",\n            ValueError)\n    if mode == 'norm':\n        torch.nn.utils.clip_grad_norm_(parameters, norm_or_value)\n    elif mode == 'value':\n        torch.nn.utils.clip_grad_value_(parameters, norm_or_value)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "inferno/utils/train_utils.py",
    "content": "\"\"\"Utilities for training.\"\"\"\nimport numpy as np\nfrom .exceptions import assert_, FrequencyTypeError, FrequencyValueError\n\n\nclass AverageMeter(object):\n    \"\"\"\n    Computes and stores the average and current value.\n    Taken from https://github.com/pytorch/examples/blob/master/imagenet/main.py\n    \"\"\"\n    def __init__(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n\nclass MovingAverage(object):\n    \"\"\"Computes the moving average of a given float.\"\"\"\n    def __init__(self, momentum=0):\n        self.momentum = momentum\n        self.val = None\n        self.previous = None\n\n    def reset(self):\n        self.val = None\n\n    def update(self, val):\n        self.previous = self.val\n        if self.val is None:\n            self.val = val\n        else:\n            self.val = self.momentum * self.val + (1 - self.momentum) * val\n        return self.val\n\n    @property\n    def relative_change(self):\n        if None not in [self.val, self.previous]:\n            relative_change = (self.previous - self.val) / self.previous\n            return relative_change\n        else:\n            return None\n\n\nclass CLUI(object):\n    \"\"\"Command Line User Interface\"\"\"\n\n    def __call__(self, f):\n        def decorated(cls, *args, **kwargs):\n            try:\n                f(cls, *args, **kwargs)\n            except KeyboardInterrupt:\n                options_ = input(\"[!] Interrupted. Please select:\\n\"\n                                 \"[w] Save\\n\"\n                                 \"[d] Debug with PDB\\n\"\n                                 \"[q] Quit\\n\"\n                                 \"[c] Continue\\n\"\n                                 \"[?] >>> \")\n                save_now = 'w' in options_\n                quit_now = 'q' in options_\n                debug_now = 'd' in options_\n                continue_now = 'c' in options_ or not quit_now\n\n                if save_now:\n                    cls.save()\n\n                if debug_now:\n                    print(\"[*] Firing up PDB. The trainer instance might be accessible as 'cls'.\")\n                    import pdb\n                    pdb.set_trace()\n\n                if quit_now:\n                    cls.print(\"Exiting.\")\n                    raise SystemExit\n\n                if continue_now:\n                    return\n\n        return decorated\n\n\nclass Frequency(object):\n\n    def __init__(self, value=None, units=None):\n        # Private\n        self._last_match_value = None\n        self._value = None\n        self._units = None\n        # Public\n        self.value = value\n        self.units = units\n\n    @property\n    def value(self):\n        return self._value\n\n    @value.setter\n    def value(self, value):\n        # If value is not being set, make sure the frequency never matches muhahaha\n        if value is None or value == 'never':\n            value = np.inf\n        self.assert_value_consistent(value)\n        self._value = value\n\n    UNIT_PRIORITY = 'iterations'\n    VALID_UNIT_NAME_MAPPING = {'iterations': 'iterations',\n                               'iteration': 'iterations',\n                               'epochs': 'epochs',\n                               'epoch': 'epochs'}\n\n    @property\n    def units(self):\n        return self._units\n\n    @units.setter\n    def units(self, value):\n        if value is None:\n            value = self.UNIT_PRIORITY\n        self.assert_units_consistent(value)\n        self._units = self.VALID_UNIT_NAME_MAPPING.get(value)\n\n    def assert_value_consistent(self, value=None):\n        value = value or self.value\n        # Make sure that value is an integer or inf\n        assert_(isinstance(value, (int, float)),\n                \"Value must be an integer or np.inf, got {} instead.\"\n                .format(type(value).__name__),\n                FrequencyTypeError)\n        if isinstance(value, float):\n            assert_(value == np.inf,\n                    \"Provided value must be numpy.inf if a float, got {}.\".format(value),\n                    FrequencyValueError)\n\n    def assert_units_consistent(self, units=None):\n        units = units or self.units\n        # Map\n        units = self.VALID_UNIT_NAME_MAPPING.get(units)\n        assert_(units is not None, \"Unit '{}' not understood.\".format(units),\n                FrequencyValueError)\n\n    @property\n    def is_consistent(self):\n        try:\n            self.assert_value_consistent()\n            self.assert_units_consistent()\n            return True\n        except (FrequencyValueError, FrequencyTypeError):\n            return False\n\n    def epoch(self):\n        self.units = 'epochs'\n        return self\n\n    def iteration(self):\n        self.units = 'iterations'\n        return self\n\n    @property\n    def by_epoch(self):\n        return self.units == 'epochs'\n\n    @property\n    def by_iteration(self):\n        return self.units == 'iterations'\n\n    def every(self, value):\n        self.value = value\n        return self\n\n    def match(self, iteration_count=None, epoch_count=None, persistent=False, match_zero=True):\n        match_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units)\n        if not match_zero and match_value == 0:\n            match = False\n        else:\n            match = match_value is not None and \\\n                    self.value != np.inf and \\\n                    match_value % self.value == 0\n        if persistent and match and self._last_match_value == match_value:\n            # Last matched value is the current matched value, i.e. we've matched once already,\n            # and don't need to match again\n            match = False\n        if match:\n            # Record current match value as the last known match value to maintain persistency\n            self._last_match_value = match_value\n        return match\n\n    def __str__(self):\n        return \"{} {}\".format(self.value, self.units)\n\n    def __repr__(self):\n        return \"{}(value={}, units={})\".format(type(self).__name__, self.value, self.units)\n\n    @classmethod\n    def from_string(cls, string):\n        assert_(isinstance(string, str), \"`string` must be a string, got {} instead.\"\n                .format(type(string).__name__), TypeError)\n        if string == 'never':\n            return cls(np.inf, 'iterations')\n        else:\n            value_and_unit = string.split(' ')\n            assert_(len(value_and_unit) == 2,\n                    \"Was expecting a string 'value units' with one white-space \"\n                    \"between 'value' and 'units'.\", ValueError)\n            value, unit = value_and_unit\n            value = np.inf if value == 'inf' else int(value)\n            return cls(value, unit)\n\n    @classmethod\n    def build_from(cls, args, priority='iterations'):\n        if isinstance(args, int):\n            return cls(args, priority)\n        elif isinstance(args, (tuple, list)):\n            return cls(*args)\n        elif isinstance(args, Frequency):\n            return args\n        elif isinstance(args, str):\n            return cls.from_string(args)\n        else:\n            raise NotImplementedError\n\n\nclass Duration(Frequency):\n    \"\"\"Like frequency, but measures a duration.\"\"\"\n    def match(self, iteration_count=None, epoch_count=None, when_equal_return=False, **_):\n        match_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units)\n        assert_(match_value is not None,\n                \"Could not match duration because {} is not known.\".format(self.units),\n                ValueError)\n        if match_value == self.value:\n            return when_equal_return\n        return match_value > self.value\n\n    def compare(self, iteration_count=None, epoch_count=None):\n        compare_value = {'iterations': iteration_count, 'epochs': epoch_count}.get(self.units)\n        assert_(compare_value is not None,\n                \"Could not match duration because {} is not known.\".format(self.units),\n                ValueError)\n        compared = {'iterations': None, 'epochs': None}\n        compared.update({self.units: self.value - compare_value})\n        return compared\n\n    def __sub__(self, other):\n        assert_(isinstance(other, Duration),\n                \"Object of type {} cannot be subtracted from \"\n                \"a Duration object.\".format(type(other)),\n                TypeError)\n        assert_(other.units == self.units,\n                \"The Duration objects being subtracted must have the same units.\",\n                ValueError)\n        return Duration(value=(self.value - other.value), units=self.units)\n\n\nclass NoLogger(object):\n    def __init__(self, logdir=None):\n        self.logdir = logdir\n\n    def log_value(self, *kwargs):\n        pass\n\n\ndef set_state(module, key, value):\n    \"\"\"Writes `key`-`value` pair to `module`'s state hook.\"\"\"\n    if hasattr(module, '_state_hooks'):\n        state_hooks = getattr(module, '_state_hooks')\n        assert isinstance(state_hooks, dict), \\\n            \"State hook (i.e. module._state_hooks) is not a dictionary.\"\n        state_hooks.update({key: value})\n    else:\n        setattr(module, '_state_hooks', {key: value})\n    return module\n\n\ndef get_state(module, key, default=None):\n    \"\"\"Gets key from `module`'s state hooks.\"\"\"\n    return getattr(module, '_state_hooks', {}).get(key, default)\n"
  },
  {
    "path": "inferno/version.py",
    "content": "__version__ = '0.4.0'\n"
  },
  {
    "path": "readthedocs.yml",
    "content": "conda:\n    file: docs/environment.yml\npython:\n  version: 3.5\n  pip_install: false"
  },
  {
    "path": "requirements.txt",
    "content": "dill\npyyaml\nscipy>=0.13.0\nh5py\nnumpy>=1.8\nscikit-image"
  },
  {
    "path": "requirements_dev.txt",
    "content": "pip==8.1.2\nbumpversion==0.5.3\nwheel==0.29.0\nwatchdog==0.8.3\nflake8==2.6.0\ntox==2.3.1\ncoverage==4.1\nSphinx==1.4.8\ncryptography==1.7\nPyYAML==5.1\ndill\npyyaml\nscipy>=0.13.0\nh5py\nscikit-image\nsphinx-gallery\nsphinxcontrib-napoleon\nsphinxcontrib-inlinesyntaxhighlight\nsphinx_rtd_theme"
  },
  {
    "path": "setup.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"The setup script.\"\"\"\n\nfrom setuptools import setup, find_packages\nimport runpy\n__version__ = runpy.run_path('inferno/version.py')['__version__']\n\n\nwith open('README.rst') as readme_file:\n    readme = readme_file.read()\n\nwith open('HISTORY.rst') as history_file:\n    history = history_file.read()\n\nrequirements = [\n    # TODO: put package requirements here\n    \"pip>=8.1.2\",\n    \"torch>=0.1.12\",\n    \"dill\",\n    \"pyyaml\",\n    \"scipy>=0.13.0\",\n    \"h5py\",\n    \"numpy>=1.8\",\n    \"scikit-image\",\n    \"torchvision\",\n    \"tqdm\"\n]\n\n\nsetup_requirements = [\n    'pytest-runner'\n]\n\ntest_requirements = [\n    'pytest', 'unittest'\n]\n\ndependency_links = [\n    'http://download.pytorch.org/whl/cu75/torch-0.2.0.post1-cp35-cp35m-manylinux1_x86_64.whl#egg=torch-0.2.0'\n]\n\nsetup(\n    name='inferno-pytorch',\n    version=__version__,\n    description=\"Inferno is a little library providing utilities and convenience functions/classes around PyTorch.\",\n    long_description=readme + '\\n\\n' + history,\n    author=\"Nasim Rahaman\",\n    author_email='nasim.rahaman@iwr.uni-heidelberg.de',\n    url='https://github.com/inferno-pytorch/inferno',\n    packages=find_packages(where='.', exclude=[\"*.tests\", \"*.tests.*\",\n                                               \"tests.*\", \"tests\",\n                                               \"__pycache__\", \"*.pyc\"]),\n    dependency_links=dependency_links,\n    include_package_data=True,\n    install_requires=requirements,\n    license=\"Apache Software License 2.0\",\n    zip_safe=False,\n    keywords='inferno pytorch torch deep learning cnn deep-pyromania',\n    classifiers=[\n        # How mature is this project? Common values are\\\n        #   2 - Pre-Alpha',\n        #   3 - Alpha,\n        #   4 - Beta,\n        #   5 - Production/Stable\n        'Development Status :: 2 - Pre-Alpha',\n        # Indicate who your project is intended for\n        'Intended Audience :: Science/Research',\n        'License :: OSI Approved :: Apache Software License',\n        'Natural Language :: English',\n        'Programming Language :: Python :: 3.5',\n        'Programming Language :: Python :: 3.6'\n    ],\n    test_suite='test',\n    tests_require=test_requirements,\n    setup_requires=setup_requirements,\n)\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_extensions/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_extensions/test_containers/test_graph.py",
    "content": "import unittest\nfrom functools import reduce\nimport torch\n\n\nclass TestGraph(unittest.TestCase):\n    def setUp(self):\n        import torch.nn as nn\n        from inferno.utils.python_utils import from_iterable\n\n        class DummyNamedModule(nn.Module):\n            def __init__(self, name, history, num_inputs=1):\n                super(DummyNamedModule, self).__init__()\n                self.name = name\n                self.history = history\n                self.num_inputs = num_inputs\n\n            def forward(self, *inputs):\n                assert len(inputs) == self.num_inputs\n                self.history.append(self.name)\n                if self.num_inputs > 1:\n                    output = reduce(lambda x, y: x + y, inputs)\n                else:\n                    output = from_iterable(inputs)\n\n                return output\n\n        self.DummyNamedModule = DummyNamedModule\n\n    # @unittest.skip\n    def test_graph_dummy_basic(self):\n        import torch\n        from inferno.extensions.containers.graph import Graph\n\n        if not hasattr(self, 'DummyNamedModule'):\n            self.setUp()\n\n        DummyNamedModule = self.DummyNamedModule\n\n        history = []\n        # Build graph\n        model = Graph()\n        model.add_input_node('input_0')\n        model.add_input_node('input_1')\n        model.add_node('conv0_0', DummyNamedModule('conv0_0', history))\n        model.add_node('conv0_1', DummyNamedModule('conv0_1', history))\n        model.add_node('conv1', DummyNamedModule('conv1', history, 2))\n        model.add_node('conv2', DummyNamedModule('conv2', history))\n        model.add_output_node('output_0')\n        model.add_edge('input_0', 'conv0_0')\\\n            .add_edge('input_1', 'conv0_1')\\\n            .add_edge('conv0_0', 'conv1')\\\n            .add_edge('conv0_1', 'conv1')\\\n            .add_edge('conv1', 'conv2')\\\n            .add_edge('conv2', 'output_0')\n\n        input_0 = torch.rand(10, 10)\n        input_1 = torch.rand(10, 10)\n        model(input_0, input_1)\n        self.assertTrue(history == ['conv0_0', 'conv0_1', 'conv1', 'conv2'] or\n                        history == ['conv0_1', 'conv0_0', 'conv1', 'conv2'])\n\n    # @unittest.skip\n    def test_graph_dummy_inception(self):\n        import torch\n        from inferno.extensions.containers.graph import Graph\n\n        if not hasattr(self, 'DummyNamedModule'):\n            self.setUp()\n\n        DummyNamedModule = self.DummyNamedModule\n\n        history = []\n        # Build graph\n        model = Graph()\n        model.add_input_node('input_0')\n        model.add_node('conv0', DummyNamedModule('conv0', history), 'input_0')\n        model.add_node('conv1_0', DummyNamedModule('conv1_0', history), 'conv0')\n        model.add_node('conv1_1', DummyNamedModule('conv1_1', history), 'conv0')\n        model.add_node('conv2', DummyNamedModule('conv2', history, 2),\n                       ['conv1_0', 'conv1_1'])\n        model.add_output_node('output_0', 'conv2')\n        input_0 = torch.rand(10, 10)\n        model(input_0)\n        self.assertTrue(history == ['conv0', 'conv1_0', 'conv1_1', 'conv2'] or\n                        history == ['conv0', 'conv1_1', 'conv1_2', 'conv2'])\n\n    # @unittest.skip\n    def test_graph_basic(self):\n        from inferno.extensions.containers.graph import Graph\n        from inferno.extensions.layers.convolutional import ConvELU2D\n        from inferno.utils.model_utils import ModelTester\n        # Build graph\n        model = Graph()\n        model.add_input_node('input_0')\n        model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0')\n        model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0')\n        model.add_output_node('output_0', previous='conv1')\n        ModelTester((1, 1, 100, 100), (1, 1, 100, 100))(model)\n\n    @unittest.skipUnless(torch.cuda.is_available(), \"No cuda.\")\n    def test_graph_device_transfers(self):\n        from inferno.extensions.containers.graph import Graph\n        from inferno.extensions.layers.convolutional import ConvELU2D\n        import torch\n        # Build graph\n        model = Graph()\n        model.add_input_node('input_0')\n        model.add_node('conv0', ConvELU2D(1, 10, 3), previous='input_0')\n        model.add_node('conv1', ConvELU2D(10, 1, 3), previous='conv0')\n        model.add_output_node('output_0', previous='conv1')\n        # Transfer\n        model.to_device('conv0', 'cpu').to_device('conv1', 'cuda', 0)\n        x = torch.rand(1, 1, 100, 100)\n        y = model(x)\n        self.assertIsInstance(y.data, torch.cuda.FloatTensor)\n\n    @unittest.skip(\"Needs machine with 4 GPUs\")\n    def test_multi_gpu(self):\n        import torch\n        import torch.nn as nn\n        from torch.nn.parallel.data_parallel import data_parallel\n        from inferno.extensions.containers.graph import Graph\n\n        input_shape = [8, 1, 3, 128, 128]\n        model = Graph() \\\n            .add_input_node('input') \\\n            .add_node('conv0', nn.Conv3d(1, 10, 3, padding=1), previous='input') \\\n            .add_node('conv1', nn.Conv3d(10, 1, 3, padding=1), previous='conv0') \\\n            .add_output_node('output', previous='conv1')\n\n        model.cuda()\n        input = torch.rand(*input_shape).cuda()\n        data_parallel(model, input, device_ids=[0, 1, 2, 3])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_criteria/test_core.py",
    "content": "import unittest\nimport torch\nimport torch.nn as nn\n\n\nclass TestCore(unittest.TestCase):\n    def test_as_2d_criterion(self):\n        from inferno.extensions.criteria.core import As2DCriterion\n\n        prediction = torch.FloatTensor(2, 10, 100, 100).uniform_()\n        prediction = nn.Softmax2d()(prediction)\n        target = torch.LongTensor(2, 100, 100).fill_(0)\n        criterion = As2DCriterion(nn.CrossEntropyLoss())\n        criterion(prediction, target)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_criteria/test_elementwise_measures.py",
    "content": "import unittest\nimport inferno.extensions.criteria.elementwise_measures as em\nimport torch\n\n\nclass TestElementwiseMeasures(unittest.TestCase):\n    def test_weighted_mse_loss(self):\n        input = torch.zeros(10, 10)\n        target = torch.ones(10, 10)\n        loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)\n        self.assertAlmostEqual(loss.item(), 2., delta=1e-5)\n        target = torch.zeros(10, 10)\n        input = torch.ones(10, 10)\n        loss = em.WeightedMSELoss(positive_class_weight=2.)(input, target)\n        self.assertAlmostEqual(loss.item(), 1., delta=1e-5)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_criteria/test_set_similarity_measures.py",
    "content": "import unittest\nimport torch\n\n\nclass SetSimilarityTest(unittest.TestCase):\n    def get_dummy_variables(self):\n        x = torch.zeros(3, 2, 100, 100).uniform_()\n        y = torch.zeros(3, 2, 100, 100).uniform_()\n        return x, y\n\n    def get_dummy_variables_with_channels_and_classes(self):\n        # (batch_size, channels, classes, ...)\n        x = torch.zeros(3, 2, 5, 100, 100).uniform_()\n        y = torch.zeros(3, 2, 5, 100, 100).uniform_()\n        return x, y\n\n\nclass TestSorensenDice(SetSimilarityTest):\n    # noinspection PyCallingNonCallable\n    def test_channelwise(self):\n        from inferno.extensions.criteria.set_similarity_measures import SorensenDiceLoss\n        x, y = self.get_dummy_variables()\n        channelwise = SorensenDiceLoss(channelwise=True)\n        not_channelwise = SorensenDiceLoss(channelwise=False)\n        # Compute expected channelwise loss\n        expected_channelwise_loss = \\\n            not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \\\n            not_channelwise(x[:, 1, ...], y[:, 1, ...])\n        # Compute channelwise\n        channelwise_loss = channelwise(x, y)\n        # Compare\n        self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item())\n\n\nclass TestGeneralizedSorensenDice(SetSimilarityTest):\n    def test_channelwise(self):\n        from inferno.extensions.criteria.set_similarity_measures import GeneralizedDiceLoss\n        x, y = self.get_dummy_variables_with_channels_and_classes()\n        channelwise = GeneralizedDiceLoss(channelwise=True)\n        not_channelwise = GeneralizedDiceLoss(channelwise=False)\n        # Compute channelwise loss and expected one:\n        channelwise_loss = channelwise(x, y)\n        expected_channelwise_loss = \\\n            not_channelwise(x[:, 0, ...], y[:, 0, ...]) + \\\n            not_channelwise(x[:, 1, ...], y[:, 1, ...])\n        # Compare\n        self.assertAlmostEqual(expected_channelwise_loss.item(), channelwise_loss.item())\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_layers/deprecated/building_blocks.py",
    "content": "import unittest\nimport torch\nimport inferno.extensions.layers.building_blocks as bb\n\n\nclass ResBlockTest(unittest.TestCase):\n\n    def test_2D_simple_(self):\n\n        x = torch.rand(1, 3, 64, 15)\n        model = bb.ResBlock(in_channels=3, out_channels=3, dim=2)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,3, 64, 15])\n\n    def test_3D_simple_(self):\n\n        x = torch.rand(1,3,20, 64,15)\n        model = bb.ResBlock(in_channels=3, out_channels=3, dim=3)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,3, 20, 64, 15])\n\n    def test_2D_simple_2(self):\n\n        x = torch.rand(1,3,64,64)\n        model = bb.ResBlock(in_channels=3, out_channels=6, dim=2)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64])\n\n    def test_2D_simple_3(self):\n\n        x = torch.rand(1,3,64,64)\n        model = bb.ResBlock(in_channels=3, out_channels=6, dim=2, size=4)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64])\n\n    def test_2D_simple_4(self):\n\n        x = torch.rand(1,6,64,64)\n        model = bb.ResBlock(in_channels=6, out_channels=6, dim=2, size=4,\n            force_skip_op=True)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64])\n\n    def test_2D_simple_5(self):\n\n        x = torch.rand(1,6,64,64)\n        model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4,\n            force_skip_op=True)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64])\n\n    def test_2D_simple_6(self):\n\n        x = torch.rand(1,6,64,64)\n        model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=2, size=4,\n            force_skip_op=True, activated=False)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64])\n\n    def test_3D_simple_6(self):\n\n        x = torch.rand(1,6,64,64, 20)\n        model = bb.ResBlock(in_channels=6, batchnorm=False, out_channels=6, dim=3, size=4,\n            force_skip_op=True, activated=False)\n        xx = model(x)\n        out_size = xx.size()\n        self.assertEqual(list(out_size), [1,6, 64, 64, 20])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_layers/test_activations.py",
    "content": "import unittest\nimport torch\nimport inferno.extensions.layers.activations as activations\n\n\nclass ActivationTest(unittest.TestCase):\n    def test_selu(self):\n        x = torch.rand(100)\n        y = activations.SELU()(x)\n        self.assertEqual(list(x.size()), list(y.size()))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_layers/test_convolutional.py",
    "content": "import unittest\nimport torch\nfrom inferno.utils.model_utils import ModelTester\n\n\nclass TestConvolutional(unittest.TestCase):\n    @unittest.skipIf(not torch.cuda.is_available(), \"GPU not available.\")\n    def test_bn_relu_depthwise_conv2d_pyinn(self):\n        from inferno.extensions.layers.convolutional import BNReLUDepthwiseConv2D\n        model = BNReLUDepthwiseConv2D(10, 'auto', 3)\n        ModelTester((1, 10, 100, 100),\n                    (1, 10, 100, 100)).cuda()(model)\n        self.assertTrue(model.depthwise)\n        self.assertEqual(model.conv.groups, 10)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_layers/test_device.py",
    "content": "import unittest\nfrom inferno.extensions.layers.device import DeviceTransfer, OnDevice\nimport torch\n\n\nclass TransferTest(unittest.TestCase):\n    @unittest.skipIf(not torch.cuda.is_available(), \"GPU not available.\")\n    def test_device_transfer(self):\n        if not torch.cuda.is_available():\n            return\n        # Build transfer model\n        transfer = DeviceTransfer('cpu')\n        x = torch.rand(10, 10).cuda()\n        y = transfer(x)\n        loss = y.mean()\n        loss.backward()\n        self.assertFalse(y.data.is_cuda)\n        self.assertIsNotNone(x.grad)\n        self.assertTrue(x.grad.data.is_cuda)\n\n    @unittest.skipIf(not torch.cuda.is_available(), \"GPU not available.\")\n    def test_on_device(self):\n        if not torch.cuda.is_available():\n            return\n        # Build variable on the GPU\n        x = torch.rand(1, 10)\n        # Build model over multiple devices\n        multi_device_model = torch.nn.Sequential(OnDevice(torch.nn.Linear(10, 10), 'cuda'),\n                                                 OnDevice(torch.nn.Linear(10, 10), 'cpu'))\n        y = multi_device_model(x)\n        self.assertIsInstance(y.data, torch.FloatTensor)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_layers/test_reshape.py",
    "content": "import unittest\nimport torch\n\n\nclass TestReshape(unittest.TestCase):\n    def _get_input_variable(self, *shape):\n        return torch.rand(*shape)\n\n    def test_as_matrix(self):\n        from inferno.extensions.layers.reshape import AsMatrix\n\n        input = self._get_input_variable(10, 20, 1, 1)\n        as_matrix = AsMatrix()\n        output = as_matrix(input)\n        self.assertEqual(list(output.size()), [10, 20])\n\n    def test_flatten(self):\n        from inferno.extensions.layers.reshape import Flatten\n\n        input = self._get_input_variable(10, 20, 2, 2)\n        flatten = Flatten()\n        output = flatten(input)\n        self.assertEqual(list(output.size()), [10, 80])\n\n    def test_as_2d(self):\n        from inferno.extensions.layers.reshape import As2D\n\n        as_2d = As2D()\n\n        output_shape = as_2d(self._get_input_variable(10, 20, 3, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 60, 30, 30])\n\n        output_shape = as_2d(self._get_input_variable(10, 20, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 20, 30, 30])\n\n        output_shape = as_2d(self._get_input_variable(10, 20)).size()\n        self.assertEqual(list(output_shape), [10, 20, 1, 1])\n\n    def test_as_3d(self):\n        from inferno.extensions.layers.reshape import As3D\n        from inferno.utils.exceptions import ShapeError\n\n        as_3d = As3D()\n\n        output_shape = as_3d(self._get_input_variable(10, 20, 3, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 20, 3, 30, 30])\n\n        output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 20, 1, 30, 30])\n\n        output_shape = as_3d(self._get_input_variable(10, 20)).size()\n        self.assertEqual(list(output_shape), [10, 20, 1, 1, 1])\n\n        as_3d.channel_as_z = True\n        output_shape = as_3d(self._get_input_variable(10, 20, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 1, 20, 30, 30])\n\n        as_3d.num_channels_or_num_z_slices = 2\n        output_shape = as_3d(self._get_input_variable(10, 40, 30, 30)).size()\n        self.assertEqual(list(output_shape), [10, 2, 20, 30, 30])\n\n        with self.assertRaises(ShapeError):\n            output_shape = as_3d(self._get_input_variable(10, 41, 30, 30)).size()\n            self.assertEqual(list(output_shape), [10, 2, 20, 30, 30])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_metrics/categorical.py",
    "content": "import unittest\nimport torch\nfrom inferno.extensions.metrics import IOU\n\n\nclass TestCategorical(unittest.TestCase):\n    def test_iou_basic(self):\n        # from one hot\n        predicted_image = torch.zeros(*(2, 10, 10))\n        predicted_image[:, 0:4, 0:4] = 1\n        target_image = torch.zeros(*(2, 10, 10))\n        target_image[:, 0:3, 0:3] = 1\n        expected_iou = (3 * 3)/(4 * 4)\n        iou = IOU()(predicted_image[None, ...], target_image[None, ...])\n        self.assertAlmostEqual(iou, expected_iou, places=4)\n\n    def test_iou_with_ignore_class(self):\n        predicted_image = torch.zeros(*(2, 10, 10))\n        predicted_image[0, 0:4, 0:4] = 1\n        target_image = torch.zeros(*(2, 10, 10))\n        target_image[:, 0:3, 0:3] = 1\n        expected_iou = (3 * 3) / (4 * 4)\n        iou = IOU(ignore_class=1)(predicted_image[None, ...], target_image[None, ...])\n        self.assertAlmostEqual(iou, expected_iou, places=4)\n\n    def test_multiclass_iou(self):\n        predicted_image = torch.zeros(*(2, 10, 10))\n        predicted_image[0, 0:4, 0:4] = 1\n        target_image = torch.zeros(*(2, 10, 10))\n        target_image[:, 0:3, 0:3] = 1\n        iou_class_0 = (3 * 3) / (4 * 4)\n        iou_class_1 = 0\n        expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1)\n        iou = IOU()(predicted_image[None, ...], target_image[None, ...])\n        self.assertAlmostEqual(iou, expected_mean_iou, places=4)\n\n    def test_multiclass_iou_with_ignore_class(self):\n        predicted_image = torch.zeros(*(3, 10, 10))\n        predicted_image[0, 0:4, 0:4] = 1\n        # Have the third plane be crap\n        predicted_image[2, :, :] = 1\n        target_image = torch.zeros(*(3, 10, 10))\n        target_image[:, 0:3, 0:3] = 1\n        iou_class_0 = (3 * 3) / (4 * 4)\n        iou_class_1 = 0\n        expected_mean_iou = 0.5 * (iou_class_0 + iou_class_1)\n        iou = IOU(ignore_class=-1)(predicted_image[None, ...], target_image[None, ...])\n        self.assertAlmostEqual(iou, expected_mean_iou, places=4)\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_extensions/test_models/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_extensions/test_models/test_res_unet.py",
    "content": "import unittest\nimport torch\nimport torch.cuda as cuda\nfrom inferno.utils.model_utils import ModelTester\n\n\nclass ResUNetTest(unittest.TestCase):\n    def test_res_unet_2d(self):\n        from inferno.extensions.models import ResBlockUNet\n        tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))\n        if cuda.is_available():\n            tester.cuda()\n        tester(ResBlockUNet(in_channels=1, out_channels=1, dim=2))\n\n    def test_res_unet_3d(self):\n        from inferno.extensions.models import ResBlockUNet\n        tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))\n        if cuda.is_available():\n            tester.cuda()\n        # test default unet 3d\n        tester(ResBlockUNet(in_channels=1, out_channels=1, dim=3))\n\n    def test_2d_side_out_bot_up(self):\n        from inferno.extensions.models import ResBlockUNet\n        depth = 3\n        in_channels = 3\n\n        x = torch.rand(1, in_channels, 64, 32)\n        model = ResBlockUNet(in_channels=in_channels,\n                             out_channels=8, dim=2,\n                             side_out_parts=['bottom','up'],\n                             unet_kwargs=dict(depth=depth))\n\n        out_list = model(x)\n        self.assertEqual(len(out_list), depth + 1)\n\n        self.assertEqual(list(out_list[0].size()), [1, 24, 8, 4])\n        self.assertEqual(list(out_list[1].size()), [1, 12, 16, 8])\n        self.assertEqual(list(out_list[2].size()), [1, 6, 32, 16])\n        self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32])\n\n    def test_2d_side_out_up(self):\n        from inferno.extensions.models import ResBlockUNet\n        depth = 3\n        in_channels = 3\n\n        x = torch.rand(1, in_channels, 64, 32)\n        model = ResBlockUNet(in_channels=in_channels,\n                             out_channels=8, dim=2,\n                             side_out_parts=['up'],\n                             unet_kwargs=dict(depth=depth))\n\n        out_list = model(x)\n        self.assertEqual(len(out_list), depth)\n\n        self.assertEqual(list(out_list[0].size()), [1,12, 16, 8])\n        self.assertEqual(list(out_list[1].size()), [1, 6, 32, 16])\n        self.assertEqual(list(out_list[2].size()), [1, 8, 64, 32])\n\n    def test_2d_side_out_down(self):\n        from inferno.extensions.models import ResBlockUNet\n        depth = 3\n        in_channels = 3\n\n        x = torch.rand(1, in_channels, 64, 32)\n        model = ResBlockUNet(in_channels=in_channels,\n                             out_channels=8, dim=2,\n                             side_out_parts=['down'],\n                             unet_kwargs=dict(depth=depth))\n\n        out_list = model(x)\n        self.assertEqual(len(out_list), depth  + 1)\n\n        self.assertEqual(list(out_list[0].size()), [1, 6, 64, 32])\n        self.assertEqual(list(out_list[1].size()), [1, 12, 32, 16])\n        self.assertEqual(list(out_list[2].size()), [1, 24, 16, 8])\n\n        # the actual output\n        self.assertEqual(list(out_list[3].size()), [1, 8, 64, 32])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_extensions/test_models/test_unet.py",
    "content": "import unittest\nimport torch.cuda as cuda\nfrom inferno.utils.model_utils import ModelTester, MultiscaleModelTester\nfrom inferno.extensions.models import UNet\n\nclass _MultiscaleUNet(UNet):\n    def conv_op_factory(self, in_channels, out_channels, part, index):\n        return super(_MultiscaleUNet, self).conv_op_factory(in_channels, out_channels, part, index)[0], True\n\n    def forward(self, input):\n        x = self._initial_conv(input)\n        x = list(super(UNet, self).forward(x))\n        x[-1] = self._output(x[-1])\n        return tuple(x)\n\n\nclass UNetTest(unittest.TestCase):\n    def test_unet_2d(self):\n        tester = ModelTester((1, 1, 256, 256), (1, 1, 256, 256))\n        if cuda.is_available():\n            tester.cuda()\n        tester(UNet(1, 1, dim=2, initial_features=32))\n\n    def test_unet_3d(self):\n        tester = ModelTester((1, 1, 16, 64, 64), (1, 1, 16, 64, 64))\n        if cuda.is_available():\n            tester.cuda()\n        # test default unet 3d\n        tester(UNet(1, 1, dim=3, initial_features=8))\n\n    def test_monochannel_unet_3d(self):\n        nc = 2\n        class _UNetMonochannel(_MultiscaleUNet):\n            def _get_num_channels(self, depth):\n                return nc\n\n        shapes = [(1, nc, 16, 64, 64), (1, nc, 8, 32, 32), (1, nc, 4, 16, 16), (1, nc, 2, 8, 8), (1, nc, 1, 4, 4),\n                  (1, nc, 2, 8, 8), (1, nc, 4, 16, 16), (1, nc, 8, 32, 32), (1, 1, 16, 64, 64)]\n        tester = MultiscaleModelTester((1, 1, 16, 64, 64), shapes)\n        if cuda.is_available():\n            tester.cuda()\n        tester(_UNetMonochannel(1, 1, dim=3, initial_features=8))\n\n    def test_inverse_pyramid_unet_2d(self):\n        class _UNetInversePyramid(_MultiscaleUNet):\n            def _get_num_channels(self, depth):\n                return [13, 12, 11][depth - 1]\n\n        shapes = [(1, 13, 16, 64), (1, 12, 8, 32), (1, 11, 4, 16), (1, 11, 2, 8),\n                  (1, 12, 4, 16), (1, 13, 8, 32), (1, 1, 16, 64)]\n        tester = MultiscaleModelTester((1, 1, 16, 64), shapes)\n        if cuda.is_available():\n            tester.cuda()\n        tester(_UNetInversePyramid(1, 1, dim=2, depth=3, initial_features=8))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_inferno.py",
    "content": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"Tests for `inferno` package.\"\"\"\n\n\nimport unittest\nimport numpy as np\nimport torch\nimport os\nimport shutil\nfrom os.path import dirname, join\nfrom torch.utils.data.dataset import TensorDataset\nfrom torch.utils.data.dataloader import DataLoader\nfrom inferno.extensions.layers import Conv2D, BNReLUConv2D\nfrom inferno.extensions.layers import AsMatrix\nfrom inferno.extensions.containers import Graph\nfrom inferno.trainers.basic import Trainer\nfrom inferno.trainers.callbacks.essentials import NaNDetector\nfrom inferno.trainers.callbacks.base import Callback\nfrom torch import nn\n\n\nclass TestInferno(unittest.TestCase):\n    \"\"\"Tests for `inferno` package.\"\"\"\n\n    NUM_SAMPLES = 100\n    NUM_TRAINING_SAMPLES = 70\n    NUM_CLASSES = 10\n    WORKING_DIRECTORY = dirname(__file__)\n\n    def read_environment_variables(self):\n        self.NUM_SAMPLES = int(os.getenv('INFERNO_TEST_NUM_SAMPLES', str(self.NUM_SAMPLES)))\n        self.NUM_TRAINING_SAMPLES = int(os.getenv('INFERNO_TEST_NUM_SAMPLES',\n                                                  str(self.NUM_TRAINING_SAMPLES)))\n        self.NUM_CLASSES = int(os.getenv('INFERNO_TEST_NUM_CLASSES', str(self.NUM_CLASSES)))\n        self.WORKING_DIRECTORY = os.getenv('INFERNO_TEST_WORKING_DIRECTORY',\n                                           self.WORKING_DIRECTORY)\n\n    def setUp(self):\n        \"\"\"Set up test fixtures, if any.\"\"\"\n        self.setUpDatasets()\n\n    def setUpDatasets(self):\n        # Build training dataset\n        inputs, targets = self.generate_random_data(self.NUM_SAMPLES, (3, 32, 32),\n                                                    num_classes=self.NUM_CLASSES,\n                                                    dtype='float32')\n        # Split to train and split\n        train_inputs, train_targets = inputs[:self.NUM_TRAINING_SAMPLES], \\\n                                      targets[:self.NUM_TRAINING_SAMPLES]\n        validate_inputs, validate_targets = inputs[self.NUM_TRAINING_SAMPLES:], \\\n                                            targets[self.NUM_TRAINING_SAMPLES:]\n        # Convert to tensor and build dataset\n        train_dataset = TensorDataset(torch.from_numpy(train_inputs),\n                                      torch.from_numpy(train_targets))\n        validate_dataset = TensorDataset(torch.from_numpy(validate_inputs),\n                                         torch.from_numpy(validate_targets))\n        # Build dataloaders from dataset\n        self.train_loader = DataLoader(train_dataset, batch_size=16,\n                                       shuffle=True, num_workers=0, pin_memory=False)\n        self.validate_loader = DataLoader(validate_dataset, batch_size=16,\n                                          shuffle=True, num_workers=0, pin_memory=False)\n\n    def setUpCallbacks(self):\n\n        class RecordSaveInfo(Callback):\n            def __init__(self):\n                super(RecordSaveInfo, self).__init__()\n                self.best_saves_at_iteration_epoch = []\n                self.saves_at_iteration_epoch = []\n\n            def begin_of_save(self, epoch_count, iteration_count,\n                              is_iteration_with_best_validation_score, **_):\n                if is_iteration_with_best_validation_score:\n                    self.best_saves_at_iteration_epoch.append((iteration_count, epoch_count))\n                else:\n                    self.saves_at_iteration_epoch.append((iteration_count, epoch_count))\n\n        self.RecordSaveInfo = RecordSaveInfo\n\n    def generate_random_data(self, num_samples, shape, num_classes,\n                             hardness=0.3, dtype=None):\n        dataset_input = np.zeros((num_samples,) + shape, dtype=dtype)\n        dataset_target = np.random.randint(num_classes, size=num_samples)\n        for sample_num in range(num_samples):\n            dataset_input[sample_num] = np.random.normal(loc=dataset_target[sample_num],\n                                                         scale=(1 - hardness),\n                                                         size=shape)\n        return dataset_input, dataset_target\n\n    def tearDown(self):\n        \"\"\"Tear down test fixtures, if any.\"\"\"\n        if os.path.exists(join(self.WORKING_DIRECTORY, 'Weights')):\n            shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights'))\n\n    def build_graph_model(self):\n        model = Graph()\n        model\\\n            .add_input_node('input')\\\n            .add_node('conv1', Conv2D(3, 8, 3), 'input')\\\n            .add_node('conv2', BNReLUConv2D(8, 8, 3), 'conv1')\\\n            .add_node('pool1', nn.MaxPool2d(kernel_size=2, stride=2), 'conv2')\\\n            .add_node('conv3', BNReLUConv2D(8, 8, 3), 'pool1')\\\n            .add_node('pool2', nn.MaxPool2d(kernel_size=2, stride=2), 'conv3')\\\n            .add_node('conv4', BNReLUConv2D(8, 8, 3), 'pool2')\\\n            .add_node('pool3', nn.AdaptiveAvgPool2d(output_size=(1, 1)), 'conv4')\\\n            .add_node('matrix', AsMatrix(), 'pool3')\\\n            .add_node('linear', nn.Linear(8, self.NUM_CLASSES), 'matrix')\\\n            .add_output_node('output', 'linear')\n        return model\n\n    def test_training_cpu(self):\n        \"\"\"Test Trainer.\"\"\"\n        # Build model\n        model = self.build_graph_model()\n\n        # Build callbacks\n        # save_info_recorder = RecordSaveInfo()\n        # Build trainer\n        trainer = Trainer(model)\\\n            .save_every((2, 'epochs'), to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\\\n            .validate_every((100, 'iterations'), for_num_iterations=10)\\\n            .set_max_num_epochs(4)\\\n            .save_at_best_validation_score()\\\n            .build_optimizer('RMSprop')\\\n            .build_criterion('CrossEntropyLoss')\\\n            .build_metric('CategoricalError')\\\n            .register_callback(NaNDetector)\n        # Bind datasets\n        trainer\\\n            .bind_loader('train', self.train_loader)\\\n            .bind_loader('validate', self.validate_loader)\n        # Go\n        trainer.pickle_module = 'dill'\n        trainer.fit()\n\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_io/test_box/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_io/test_box/test_camvid.py",
    "content": "import os\nfrom os.path import join, dirname, exists, isdir\nimport unittest\nimport numpy as np\n\n\n_CAMVID_ROOT = None\n\n\ndef _camvid_available():\n    return _CAMVID_ROOT is not None or os.environ.get('CAMVID_ROOT') is not None\n\n\nclass TestCamvid(unittest.TestCase):\n    CAMVID_ROOT = _CAMVID_ROOT\n    PLOT_DIRECTORY = join(dirname(__file__), 'plots')\n\n    def get_camvid_root(self):\n        if self.CAMVID_ROOT is None:\n            root = os.environ.get('CAMVID_ROOT')\n            assert root is not None, \"Camvid Root not found.\"\n        else:\n            return self.CAMVID_ROOT\n\n    @unittest.skipUnless(_camvid_available(), \"No root available.\")\n    def test_camvid_dataset_without_transforms(self):\n        from inferno.io.box.camvid import CamVid\n        camvid = CamVid(self.get_camvid_root())\n        image, label = camvid[0]\n        image = np.asarray(image)\n        label = np.asarray(label)\n        self.assertSequenceEqual(image.shape, (360, 480, 3))\n        self.assertSequenceEqual(label.shape, (360, 480))\n        self.assertLessEqual(label.max(), 11)\n\n    @unittest.skipUnless(_camvid_available(), \"No root available.\")\n    def _test_camvid_dataset_with_transforms(self):\n        from inferno.io.box.camvid import CamVid\n        from inferno.io.transform.base import Compose\n        from inferno.io.transform.image import PILImage2NumPyArray, RandomSizedCrop, Scale\n        from inferno.utils.io_utils import print_tensor\n\n        camvid = CamVid(self.get_camvid_root(),\n                        image_transform=Compose(),\n                        label_transform=Compose(),\n                        joint_transform=Compose())\n        camvid.image_transform.add(PILImage2NumPyArray())\n        camvid.label_transform.add(PILImage2NumPyArray())\n        image, label = camvid[0]\n        self.assertSequenceEqual(image.shape, (3, 360, 480))\n        self.assertSequenceEqual(label.shape, (360, 480))\n        # Add crop trafo\n        camvid.joint_transform.add(RandomSizedCrop(ratio_between=(0.7, 1.0),\n                                                   preserve_aspect_ratio=True))\n        # We need 2 scale transforms, one with order 3 (image) and the other with order 0 (label)\n        camvid.joint_transform.add(Scale(output_image_shape=(360, 480),\n                                         interpolation_order=3, apply_to=[0]))\n        camvid.joint_transform.add(Scale(output_image_shape=(360, 480),\n                                         interpolation_order=0, apply_to=[1]))\n        image, label = camvid[0]\n        self.assertSequenceEqual(image.shape, (3, 360, 480))\n        self.assertSequenceEqual(label.shape, (360, 480))\n        self.assertLessEqual(len(np.unique(label)), 12)\n        # Print tensors to make sure they look legit\n        if not exists(self.PLOT_DIRECTORY):\n            os.mkdir(self.PLOT_DIRECTORY)\n        else:\n            assert isdir(self.PLOT_DIRECTORY)\n        print_tensor(image[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)\n        print_tensor(label[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)\n        print(\"[+] Inspect images at {}\".format(self.PLOT_DIRECTORY))\n\n    @unittest.skipUnless(_camvid_available(), \"No root available.\")\n    def test_camvid_dataset_with_transforms(self):\n        from inferno.io.box.camvid import get_camvid_loaders\n        from inferno.utils.io_utils import print_tensor\n\n        train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root())\n        train_dataset = train_loader.dataset\n        image, label = train_dataset[0]\n        # Make sure the shapes checkout\n        self.assertSequenceEqual(image.size(), (3, 360, 480))\n        self.assertSequenceEqual(label.size(), (360, 480))\n        self.assertEqual(image.type(), 'torch.FloatTensor')\n        self.assertEqual(label.type(), 'torch.LongTensor')\n        # Print tensors to make sure they look legit\n        if not exists(self.PLOT_DIRECTORY):\n            os.mkdir(self.PLOT_DIRECTORY)\n        else:\n            assert isdir(self.PLOT_DIRECTORY)\n        print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)\n        print_tensor(label.numpy()[None, None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)\n        print(\"[+] Inspect images at {}\".format(self.PLOT_DIRECTORY))\n\n    @unittest.skipUnless(_camvid_available(), \"No root available.\")\n    def test_camvid_dataset_with_transforms_onehot(self):\n        from inferno.io.box.camvid import get_camvid_loaders\n        from inferno.utils.io_utils import print_tensor\n\n        train_loader, validate_loader, test_loader = get_camvid_loaders(self.get_camvid_root(),\n                                                                        labels_as_onehot=True)\n        train_dataset = train_loader.dataset\n        image, label = train_dataset[0]\n        # Make sure the shapes checkout\n        self.assertSequenceEqual(image.size(), (3, 360, 480))\n        self.assertSequenceEqual(label.size(), (12, 360, 480))\n        self.assertEqual(image.type(), 'torch.FloatTensor')\n        self.assertEqual(label.type(), 'torch.FloatTensor')\n        # Print tensors to make sure they look legit\n        if not exists(self.PLOT_DIRECTORY):\n            os.mkdir(self.PLOT_DIRECTORY)\n        else:\n            assert isdir(self.PLOT_DIRECTORY)\n        print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)\n        print_tensor(label.numpy()[None, ...], prefix='LAB--', directory=self.PLOT_DIRECTORY)\n        print(\"[+] Inspect images at {}\".format(self.PLOT_DIRECTORY))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/test_box/test_cityscapes.py",
    "content": "import os\nfrom os.path import join, dirname, exists, isdir\nimport unittest\nimport numpy as np\nimport time\n\n_CITYSCAPES_ROOT = None\n\n\ndef _cityscapes_available():\n    return _CITYSCAPES_ROOT is not None or os.environ.get('CITYSCAPES_ROOT') is not None\n\n\nclass TestCityscapes(unittest.TestCase):\n    CITYSCAPES_ROOT = _CITYSCAPES_ROOT\n    PLOT_DIRECTORY = join(dirname(__file__), 'plots')\n    INCLUDE_COARSE = False\n\n    def get_cityscapes_root(self):\n        if self.CITYSCAPES_ROOT is None:\n            root = os.environ.get('CITYSCAPES_ROOT')\n            assert root is not None, \"Cityscapes Root not found.\"\n        else:\n            return self.CITYSCAPES_ROOT\n\n    @unittest.skipUnless(_cityscapes_available(), \"No cityscapes available.\")\n    def test_cityscapes_dataset_without_transforms(self):\n        from inferno.io.box.cityscapes import Cityscapes\n        cityscapes = Cityscapes(self.get_cityscapes_root())\n        image, label = cityscapes[0]\n        image = np.asarray(image)\n        label = np.asarray(label)\n        self.assertSequenceEqual(image.shape, (1024, 2048, 3))\n        self.assertSequenceEqual(label.shape, (1024, 2048))\n        self.assertLessEqual(label.max(), 33)\n\n    @unittest.skipUnless(_cityscapes_available(), \"No cityscapes available.\")\n    def test_cityscapes_dataset_without_transforms_unzipped(self):\n        from inferno.io.box.cityscapes import Cityscapes\n        cityscapes = Cityscapes(join(self.get_cityscapes_root(), 'extracted'),\n                                read_from_zip_archive=False)\n        image, label = cityscapes[0]\n        image = np.asarray(image)\n        label = np.asarray(label)\n        self.assertSequenceEqual(image.shape, (1024, 2048, 3))\n        self.assertSequenceEqual(label.shape, (1024, 2048))\n        self.assertLessEqual(label.max(), 33)\n\n    @unittest.skipUnless(_cityscapes_available(), \"No cityscapes available.\")\n    def test_cityscapes_dataset_with_transforms(self):\n        from inferno.io.box.cityscapes import get_cityscapes_loaders\n        from inferno.utils.io_utils import print_tensor\n\n        train_loader, validate_loader = get_cityscapes_loaders(self.get_cityscapes_root(),\n                                                               include_coarse_dataset=self.INCLUDE_COARSE)\n        train_dataset = train_loader.dataset\n        tic = time.time()\n        image, label = train_dataset[0]\n        toc = time.time()\n        print(\"[+] Loaded sample in {} seconds.\".format(toc - tic))\n        # Make sure the shapes checkout\n        self.assertSequenceEqual(image.size(), (3, 1024, 2048))\n        self.assertSequenceEqual(label.size(), (1024, 2048))\n        self.assertEqual(image.type(), 'torch.FloatTensor')\n        self.assertEqual(label.type(), 'torch.LongTensor')\n        # Print tensors to make sure they look legit\n        if not exists(self.PLOT_DIRECTORY):\n            os.mkdir(self.PLOT_DIRECTORY)\n        else:\n            assert isdir(self.PLOT_DIRECTORY)\n        print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)\n        for class_id in np.unique(label.numpy()):\n            print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'),\n                         prefix='LAB-{}--'.format(class_id),\n                         directory=self.PLOT_DIRECTORY)\n        print_tensor(label.numpy()[None, None, ...],\n                     prefix='LAB--',\n                     directory=self.PLOT_DIRECTORY)\n        print(\"[+] Inspect images at {}\".format(self.PLOT_DIRECTORY))\n\n    @unittest.skipUnless(_cityscapes_available(), \"No cityscapes available.\")\n    def test_cityscapes_dataset_with_transforms_unzipped(self):\n        from inferno.io.box.cityscapes import get_cityscapes_loaders\n        from inferno.utils.io_utils import print_tensor\n\n        train_loader, validate_loader = get_cityscapes_loaders(join(self.get_cityscapes_root(),\n                                                                    'extracted'),\n                                                               include_coarse_dataset=self.INCLUDE_COARSE,\n                                                               read_from_zip_archive=False)\n        train_dataset = train_loader.dataset\n        tic = time.time()\n        image, label = train_dataset[0]\n        toc = time.time()\n        print(\"[+] Loaded sample in {} seconds.\".format(toc - tic))\n        # Make sure the shapes checkout\n        self.assertSequenceEqual(image.size(), (3, 1024, 2048))\n        self.assertSequenceEqual(label.size(), (1024, 2048))\n        self.assertEqual(image.type(), 'torch.FloatTensor')\n        self.assertEqual(label.type(), 'torch.LongTensor')\n        # Print tensors to make sure they look legit\n        if not exists(self.PLOT_DIRECTORY):\n            os.mkdir(self.PLOT_DIRECTORY)\n        else:\n            assert isdir(self.PLOT_DIRECTORY)\n        print_tensor(image.numpy()[None, ...], prefix='IMG--', directory=self.PLOT_DIRECTORY)\n        for class_id in np.unique(label.numpy()):\n            print_tensor((label.numpy()[None, None, ...] == class_id).astype('float32'),\n                         prefix='LAB-{}--'.format(class_id),\n                         directory=self.PLOT_DIRECTORY)\n        print_tensor(label.numpy()[None, None, ...],\n                     prefix='LAB--',\n                     directory=self.PLOT_DIRECTORY)\n        print(\"[+] Inspect images at {}\".format(self.PLOT_DIRECTORY))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/test_core/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_io/test_core/test_concatenate.py",
    "content": "import unittest\n\n\nclass ConcatenateTest(unittest.TestCase):\n    def test_concatenate(self):\n        from inferno.io.core import Concatenate\n        from torch.utils.data.dataset import Dataset\n\n        with self.assertRaises(AssertionError):\n            cated = Concatenate([1, 2, 3], [4, 5, 6, 7])\n\n        class ListDataset(list, Dataset):\n            pass\n\n        dataset_1 = ListDataset([1, 2, 3, 4])\n        dataset_2 = ListDataset([5, 6, 7])\n        dataset_3 = ListDataset([8, 9, 10, 11, 12])\n\n        cated = Concatenate(dataset_1, dataset_2, dataset_3)\n        self.assertEqual(len(cated), 12)\n\n        # Try to fetch\n        self.assertEqual(cated[2], 3)\n        self.assertEqual(cated[4], 5)\n        self.assertEqual(cated[6], 7)\n        self.assertEqual(cated[10], 11)\n        self.assertEqual(cated[11], 12)\n\n        with self.assertRaises(AssertionError):\n            _ = cated[12]\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/test_core/test_zip.py",
    "content": "import unittest\n\n\nclass ZipTest(unittest.TestCase):\n    def test_zip_minimal(self):\n        \"\"\"Minimal test with python lists as iterators.\"\"\"\n        from inferno.io.core import Zip\n        from torch.utils.data.dataset import Dataset\n\n        with self.assertRaises(TypeError):\n            zipped = Zip([1, 2, 3], [4, 5, 6, 7])\n\n        # This is required because Zip checks if its inputs are actually torch datasets\n        class ListDataset(list, Dataset):\n            pass\n\n        dataset_1 = ListDataset([1, 2, 3, 4])\n        dataset_2 = ListDataset([5, 6, 7, 8, 9])\n        zipped = Zip(dataset_1, dataset_2)\n        self.assertEqual(len(zipped), 4)\n\n        fetched = zipped[1]\n        self.assertEqual(fetched, [2, 6])\n\n        with self.assertRaises(IndexError):\n            fetched = zipped[4]\n\n    def test_zip_sync(self):\n        \"\"\"Test synchronization mechanics.\"\"\"\n        # TODO\n\n    def test_zip_reject(self):\n        from inferno.io.core import ZipReject\n        from torch.utils.data.dataset import Dataset\n\n        # This is required because Zip checks if its inputs are actually torch datasets\n        class ListDataset(list, Dataset):\n            pass\n\n        def rejection_criterion(sample_1, sample_2):\n            return sample_1 < sample_2\n\n        dataset_1 = ListDataset([1, 2, 3, 4])\n        dataset_2 = ListDataset([2, 1, 3, 4])\n        dataset_3 = ListDataset([0, 1, 2, 3])\n\n        zipped = ZipReject(dataset_1, dataset_2, dataset_3,\n                           rejection_criterion=rejection_criterion,\n                           random_jump_after_reject=False,\n                           rejection_dataset_indices=[0, 1])\n        fetched = zipped[0]\n        self.assertSequenceEqual(fetched, [2, 1, 1])\n\n        zipped = ZipReject(dataset_1, dataset_2, dataset_3,\n                           rejection_criterion=rejection_criterion,\n                           rejection_dataset_indices=[1, 0])\n        fetched = zipped[0]\n        self.assertSequenceEqual(fetched, [1, 2, 0])\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/test_volumetric/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_io/test_volumetric/test_lazy_volume_loader.py",
    "content": "import unittest\nimport os\nimport numpy as np\n\n# try to load io libraries (h5py and z5py)\ntry:\n    import h5py\n    WITH_H5PY = True\nexcept ImportError:\n    WITH_H5PY = False\n\n# try:\n#     import z5py\n#     WITH_Z5PY = True\n# except ImportError:\n#     WITH_Z5PY = False\n\n\nclass TestLazyVolumeLoader(unittest.TestCase):\n\n    def tearDown(self):\n        try:\n            os.remove('tmp.h5')\n        except OSError:\n            pass\n\n    @unittest.skipUnless(WITH_H5PY, \"Need h5py\")\n    def test_h5_loader(self):\n        from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader\n        shape = (100, 100)\n\n        # test default data loader\n        data = np.arange(np.product(shape)).reshape(shape)\n        with h5py.File('tmp.h5') as f:\n            f.create_dataset('data', data=data)\n\n        loader = LazyHDF5VolumeLoader('tmp.h5', 'data',\n                                      window_size=[10, 10], stride=[10, 10],\n                                      return_index_spec=True)\n        self.assertEqual(loader.shape, shape)\n        for batch, index in loader:\n            expected = data[index.base_sequence_at_index]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n    @unittest.skipUnless(WITH_H5PY, \"Need h5py\")\n    def test_h5_loader_data_slice(self):\n        from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader\n        shape = (100, 100, 100)\n        data_slice = np.s_[:, 20:80, 10:30]\n\n        # test default data loader\n        data = np.arange(np.product(shape)).reshape(shape)\n        with h5py.File('tmp.h5') as f:\n            f.create_dataset('data', data=data)\n        data = data[data_slice]\n\n        loader = LazyHDF5VolumeLoader('tmp.h5', 'data',\n                                      window_size=[10, 10, 10], stride=[10, 10, 10],\n                                      return_index_spec=True, data_slice=data_slice)\n        self.assertEqual(loader.shape, data.shape)\n        for batch, index in loader:\n            slice_ = index.base_sequence_at_index\n            expected = data[slice_]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n    @unittest.skipUnless(WITH_H5PY, \"Need h5py\")\n    def test_h5_loader_pad(self):\n        from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader\n        shape = (100, 100, 100)\n        pad = [[0, 10], [0, 0], [5, 15]]\n\n        # test default data loader\n        data = np.arange(np.product(shape)).reshape(shape)\n        with h5py.File('tmp.h5') as f:\n            f.create_dataset('data', data=data)\n        data = np.pad(data, pad_width=pad, mode='constant')\n\n        loader = LazyHDF5VolumeLoader('tmp.h5', 'data',\n                                      window_size=[20, 20, 20], stride=[20, 20, 20],\n                                      return_index_spec=True, padding=pad, padding_mode='constant')\n        self.assertEqual(loader.shape, data.shape)\n        for batch, index in loader:\n            slice_ = index.base_sequence_at_index\n            expected = data[slice_]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n    @unittest.skipUnless(WITH_H5PY, \"Need h5py\")\n    def test_h5_loader_data_slice_pad(self):\n        from inferno.io.volumetric.lazy_volume_loader import LazyHDF5VolumeLoader\n        shape = (100, 100, 100)\n        data_slice = np.s_[:, 20:80, 10:90]\n        pad = [[0, 10], [5, 5], [5, 15]]\n\n        # test default data loader\n        data = np.arange(np.product(shape)).reshape(shape)\n        with h5py.File('tmp.h5') as f:\n            f.create_dataset('data', data=data)\n        data = data[data_slice]\n        data = np.pad(data, pad_width=pad, mode='constant')\n\n        loader = LazyHDF5VolumeLoader('tmp.h5', 'data',\n                                      window_size=[20, 20, 20], stride=[20, 20, 20],\n                                      return_index_spec=True, padding=pad, padding_mode='constant',\n                                      data_slice=data_slice)\n        self.assertEqual(loader.shape, data.shape)\n        for batch, index in loader:\n            slice_ = index.base_sequence_at_index\n            expected = data[slice_]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_io/test_volumetric/test_volume_loader.py",
    "content": "import unittest\nimport os\nfrom shutil import rmtree\n\nimport numpy as np\nimport h5py\n\n\nclass TestVolumeLoader(unittest.TestCase):\n    shape = (100, 100, 100)\n    def setUp(self):\n        self.data = np.random.rand(*self.shape)\n\n    def test_loader(self):\n        from inferno.io.volumetric import VolumeLoader\n        loader = VolumeLoader(self.data,\n                              window_size=(10, 10, 10),\n                              stride=(10, 10, 10), return_index_spec=True)\n        for batch, idx in loader:\n            slice_ = loader.base_sequence[int(idx)]\n            expected = self.data[slice_]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n\nclass TestHDF5VolumeLoader(unittest.TestCase):\n    shape = (100, 100, 100)\n    def setUp(self):\n        try:\n            os.mkdir('./tmp')\n        except OSError:\n            pass\n        self.data = np.random.rand(*self.shape)\n        with h5py.File('./tmp/data.h5') as f:\n            f.create_dataset('data', data=self.data)\n\n    def tearDown(self):\n        try:\n            rmtree('./tmp')\n        except OSError:\n            pass\n\n    def test_hdf5_loader(self):\n        from inferno.io.volumetric import HDF5VolumeLoader\n        loader = HDF5VolumeLoader('./tmp/data.h5', 'data',\n                                  window_size=(10, 10, 10),\n                                  stride=(10, 10, 10), return_index_spec=True)\n        for batch, idx in loader:\n            slice_ = loader.base_sequence[int(idx)]\n            expected = self.data[slice_]\n            self.assertEqual(batch.shape, expected.shape)\n            self.assertTrue(np.allclose(batch, expected))\n\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_training/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_training/test_basic.py",
    "content": "from unittest import TestCase, skipUnless\nimport torch\nfrom unittest import main\nimport time\nfrom os.path import join, dirname\n\n\nclass TestTrainer(TestCase):\n    # Parameters\n    ROOT_DIR = dirname(__file__)\n    CUDA = False\n    HALF_PRECISION = False\n    DOWNLOAD_CIFAR = True\n\n    @staticmethod\n    def _make_test_model():\n        import torch.nn as nn\n        from inferno.extensions.layers.reshape import AsMatrix\n\n        toy_net = nn.Sequential(nn.Conv2d(3, 8, 3, 1, 1),\n                                nn.ELU(),\n                                nn.MaxPool2d(2),\n                                nn.Conv2d(8, 8, 3, 1, 1),\n                                nn.ELU(),\n                                nn.MaxPool2d(2),\n                                nn.Conv2d(8, 16, 3, 1, 1),\n                                nn.ELU(),\n                                nn.AdaptiveAvgPool2d((1, 1)),\n                                AsMatrix(),\n                                nn.Linear(16, 10))\n        return toy_net\n\n    def test_cifar(self):\n        from inferno.trainers.basic import Trainer\n        from inferno.io.box.cifar import get_cifar10_loaders\n        # Build cifar10 loaders\n        trainloader, testloader = get_cifar10_loaders(root_directory=join(self.ROOT_DIR, 'data'),\n                                                      download=self.DOWNLOAD_CIFAR)\n        # Make model\n        net = self._make_test_model()\n        tic = time.time()\n        # Make trainer\n        trainer = Trainer(model=net)\\\n            .build_optimizer('Adam')\\\n            .build_criterion('CrossEntropyLoss')\\\n            .build_metric('CategoricalError')\\\n            .validate_every((1, 'epochs'))\\\n            .save_every((1, 'epochs'), to_directory=join(self.ROOT_DIR, 'saves'))\\\n            .save_at_best_validation_score()\\\n            .set_max_num_epochs(2)\n        # Bind trainer to datasets\n        trainer.bind_loader('train', trainloader).bind_loader('validate', testloader)\n        # Check device and fit\n        if self.CUDA:\n            if self.HALF_PRECISION:\n                trainer.cuda().set_precision('half').fit()\n            else:\n                trainer.cuda().fit()\n        else:\n            trainer.fit()\n        toc = time.time()\n        print(\"[*] Elapsed time: {} seconds.\".format(toc - tic))\n\n    def test_multi_io(self):\n        from torch.utils.data.dataset import Dataset\n        from torch.utils.data.dataloader import DataLoader\n        from inferno.trainers.basic import Trainer\n\n        class DummyDataset(Dataset):\n            def __len__(self):\n                return 42\n\n            def __getitem__(self, item):\n                # 2 inputs and 3 targets (say)\n                return torch.rand(3, 32, 32), \\\n                       torch.rand(3, 32, 32), \\\n                       torch.rand(1).uniform_(), \\\n                       torch.rand(1).uniform_(), \\\n                       torch.rand(1).uniform_()\n\n        class DummyNetwork(torch.nn.Module):\n            def __init__(self):\n                super(DummyNetwork, self).__init__()\n                self.conv = torch.nn.Conv2d(3, 1, 3, padding=1)\n\n            def forward(self, *inputs):\n                assert len(inputs) == 2\n                out = self.conv(inputs[0])\n                return out.view(inputs[0].size(0), -1).mean(1), \\\n                       out.view(inputs[0].size(0), -1).mean(1), \\\n                       out.view(inputs[0].size(0), -1).mean(1)\n\n        class DummyCriterion(torch.nn.Module):\n            def forward(self, predictions, targets):\n                assert len(predictions) == len(targets) == 3\n                return predictions[0].mean()\n\n        loader = DataLoader(DummyDataset())\n        net = DummyNetwork()\n\n        trainer = Trainer(net)\\\n            .build_criterion(DummyCriterion)\\\n            .build_optimizer('Adam')\\\n            .set_max_num_iterations(50)\\\n            .bind_loader('train', loader, num_inputs=2, num_targets=3)\n\n        trainer.fit()\n\n    def test_serialization(self):\n        from inferno.trainers.basic import Trainer\n        import os\n\n        # Make model\n        net = self._make_test_model()\n        # Make trainer\n        trainer = Trainer(model=net) \\\n            .build_optimizer('Adam') \\\n            .build_criterion('CrossEntropyLoss') \\\n            .build_metric('CategoricalError') \\\n            .validate_every((1, 'epochs')) \\\n            .save_every((1, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves')) \\\n            .save_at_best_validation_score() \\\n            .set_max_num_epochs(2)\n\n        # Try to serialize\n        trainer.save()\n\n        # Try to unserialize\n        trainer = Trainer(net).save_to_directory(os.path.join(self.ROOT_DIR, 'saves')).load()\n\n    @skipUnless(torch.cuda.device_count() >= 4, \"Not enough cuda devices for test_multi_gpu.\")\n    def test_multi_gpu(self):\n        if not torch.cuda.is_available():\n            return\n\n        from inferno.trainers.basic import Trainer\n        from inferno.io.box.cifar import get_cifar10_loaders\n        import os\n\n        # Make model\n        net = self._make_test_model()\n        # Make trainer\n        trainer = Trainer(model=net) \\\n            .build_optimizer('Adam') \\\n            .build_criterion('CrossEntropyLoss') \\\n            .build_metric('CategoricalError') \\\n            .validate_every((1, 'epochs')) \\\n            .save_every((1, 'epochs'), to_directory=os.path.join(self.ROOT_DIR, 'saves')) \\\n            .save_at_best_validation_score() \\\n            .set_max_num_epochs(2)\\\n            .cuda(devices=[0, 1, 2, 3], base_device='cpu')\n\n        train_loader, validate_loader = get_cifar10_loaders(root_directory=self.ROOT_DIR,\n                                                            download=True)\n        trainer.bind_loader('train', train_loader)\n        trainer.bind_loader('validate', validate_loader)\n\n        trainer.fit()\n\n    def test_save(self):\n        from inferno.trainers.basic import Trainer\n        trainer = Trainer().save_to_directory(to_directory=self.ROOT_DIR,\n                                              checkpoint_filename='dummy.pytorch')\n        trainer.save()\n        # Instantiate new trainer and load\n        trainer = Trainer().load(from_directory=self.ROOT_DIR, filename='dummy.pytorch')\n\n    @skipUnless(torch.cuda.device_count() >= 2, \"Not enough cuda devices for test_multi_gpu_setup.\")\n    def test_multi_gpu_setup(self):\n        from torch.nn import CrossEntropyLoss\n        from inferno.trainers.basic import Trainer\n        # Test base_device = 'cpu'\n        # Build model\n        net = self._make_test_model()\n        # Make dummy criterion\n        criterion = CrossEntropyLoss(weight=torch.rand(10))\n        # Make trainer\n        trainer = Trainer(net).build_criterion(criterion).cuda([0, 1], base_device='cpu')\n        self.assertIsInstance(trainer.criterion.weight, torch.FloatTensor)\n        # Test base_device = 'cpu'\n        # Build model\n        net = self._make_test_model()\n        criterion = CrossEntropyLoss(weight=torch.rand(10))\n        # Make trainer\n        trainer = Trainer(net).build_criterion(criterion).cuda([0, 1], base_device='cuda')\n        self.assertIsInstance(trainer.criterion.weight, torch.cuda.FloatTensor)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "tests/test_training/test_callbacks/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_training/test_callbacks/test_base.py",
    "content": "import unittest\nimport torch\nfrom inferno.trainers.callbacks.base import Callback, CallbackEngine\nfrom inferno.trainers.basic import Trainer\nfrom os.path import join, dirname, exists\nfrom os import makedirs\nfrom shutil import rmtree\n\n\nclass DummyCallback(Callback):\n    def end_of_training_iteration(self, **_):\n        assert self.trainer is not None\n\n\nclass WrongDummyCallback(Callback):\n    def end_of_iteration(self):\n        pass\n\n\nclass CallbackMechTest(unittest.TestCase):\n    ROOT_DIR = join(dirname(__file__), 'root')\n\n    def setUp(self):\n        makedirs(self.ROOT_DIR, exist_ok=True)\n\n    def tearDown(self):\n        if exists(self.ROOT_DIR):\n            rmtree(self.ROOT_DIR)\n\n    def test_serialization(self):\n        # Build engine and trainer\n        callback_engine = CallbackEngine().bind_trainer(Trainer())\n        callback_engine.register_callback(DummyCallback())\n        # Serialize\n        torch.save(callback_engine, join(self.ROOT_DIR, 'callback_engine.pkl'))\n        # Unserialize\n        callback_engine = torch.load(join(self.ROOT_DIR, 'callback_engine.pkl'))\n        # Make sure the trainer is detached\n        self.assertIsNone(callback_engine._trainer)\n        self.assertIsInstance(next(iter(callback_engine\n                                        ._callback_registry\n                                        .get('end_of_training_iteration'))),\n                              DummyCallback)\n\n    def test_auto_registry(self):\n        callback_engine = CallbackEngine().bind_trainer(Trainer())\n        callback_engine.register_callback(DummyCallback())\n        self.assertIsInstance(next(iter(callback_engine\n                                        ._callback_registry\n                                        .get('end_of_training_iteration'))),\n                              DummyCallback)\n        with self.assertRaises(AssertionError):\n            callback_engine.register_callback(WrongDummyCallback())\n\n    def test_instance_registry(self):\n        class Foo(Callback):\n            pass\n\n        class Bar(Callback):\n            pass\n\n        foo = Foo()\n        bar = Bar()\n        self.assertIs(foo.get_instances(), foo)\n        self.assertIs(bar.get_instances(), bar)\n        foo2 = Foo()\n        self.assertSequenceEqual(foo2.get_instances(), [foo, foo2])\n        self.assertIs(bar.get_instances(), bar)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_training/test_callbacks/test_essentials.py",
    "content": "import unittest\nimport shutil\nimport h5py as h5\nfrom os.path import dirname, join\nfrom os import listdir\nfrom inferno.trainers.basic import Trainer\nfrom inferno.trainers.callbacks.essentials import DumpHDF5Every\nfrom inferno.utils.test_utils import generate_random_dataloader\nfrom inferno.extensions.layers import Conv2D, AsMatrix\nfrom torch.nn import Sequential, MaxPool2d, AdaptiveAvgPool2d, Linear, Softmax\n\n\nclass TestEssentials(unittest.TestCase):\n    WORKING_DIRECTORY = dirname(__file__)\n\n    def setUp(self):\n        # Build a simple ass model\n        model = Sequential(Conv2D(3, 8, 3, activation='ReLU'),\n                           MaxPool2d(2, 2),\n                           Conv2D(8, 8, 3, activation='ReLU'),\n                           MaxPool2d(2, 2),\n                           Conv2D(8, 8, 3, activation='ReLU'),\n                           MaxPool2d(2, 2),\n                           Conv2D(8, 8, 3, activation='ReLU'),\n                           AdaptiveAvgPool2d((1, 1)),\n                           AsMatrix(),\n                           Linear(8, 10))\n\n        train_dataloader = generate_random_dataloader(512, (3, 32, 32), 10, batch_size=16,\n                                                      dtype='float32')\n        validate_dataloader = generate_random_dataloader(32, (3, 32, 32), 10, batch_size=16,\n                                                         dtype='float32')\n        # Build trainer\n        trainer = Trainer(model)\\\n            .bind_loader('train', train_dataloader)\\\n            .bind_loader('validate', validate_dataloader)\\\n            .save_to_directory(to_directory=join(self.WORKING_DIRECTORY, 'Weights'))\\\n            .build_criterion('CrossEntropyLoss').build_optimizer('RMSprop')\n        self.trainer = trainer\n\n    def test_dump_hdf5_every(self):\n        # Configure callback\n        dumper = DumpHDF5Every((1, 'epoch'),\n                               to_directory=join(self.WORKING_DIRECTORY, 'Weights'),\n                               dump_after_every_validation_run=True)\n        self.trainer\\\n            .set_max_num_epochs(4)\\\n            .register_callback(dumper)\\\n            .validate_every((16, 'iterations'))\n\n        self.trainer.fit()\n        all_files = listdir(join(self.WORKING_DIRECTORY, 'Weights'))\n        for epoch in range(5):\n            self.assertIn('dump.training.epoch{}.iteration{}.h5'.format(epoch, epoch * 32),\n                          all_files)\n            # We don't validate at last epoch\n            if epoch != 4:\n                self.assertIn('dump.validation.epoch{}.iteration{}.h5'\n                              .format(epoch, (epoch * 32) + 16),\n                              all_files)\n                self.assertIn('dump.validation.epoch{}.iteration{}.h5'\n                              .format(epoch, (epoch * 32) + 32),\n                              all_files)\n\n        # Check if the keys are right in a training dump\n        sample_file_path = join(self.WORKING_DIRECTORY, 'Weights',\n                                'dump.training.epoch0.iteration0.h5')\n        with h5.File(sample_file_path, 'r') as sample_file:\n            all_dataset_names = list(sample_file.keys())\n        self.assertSequenceEqual(all_dataset_names,\n                                 ['training_inputs_0', 'training_prediction', 'training_target'])\n        # Check if the keys are right in a validation dump\n        sample_file_path = join(self.WORKING_DIRECTORY, 'Weights',\n                                'dump.validation.epoch0.iteration16.h5')\n        with h5.File(sample_file_path, 'r') as sample_file:\n            all_dataset_names = list(sample_file.keys())\n        self.assertSequenceEqual(all_dataset_names,\n                                 ['validation_inputs_0', 'validation_prediction',\n                                  'validation_target'])\n\n    def tearDown(self):\n        shutil.rmtree(join(self.WORKING_DIRECTORY, 'Weights'))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_training/test_callbacks/test_logging/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_training/test_callbacks/test_logging/test_base.py",
    "content": "import unittest\nfrom inferno.trainers.callbacks.logging.base import Logger\nfrom inferno.trainers.basic import Trainer\nfrom os.path import join, dirname\n\n\nclass DummyLogger(Logger):\n    def end_of_training_iteration(self, **_):\n        pass\n\n\nclass TestLogger(unittest.TestCase):\n    ROOT = dirname(__file__)\n\n    def test_serialization(self):\n        trainer = Trainer()\\\n            .build_logger(logger=DummyLogger())\\\n            .save_to_directory(join(self.ROOT, 'saves'))\n        trainer.save()\n        # Unserialize\n        trainer = Trainer().load(from_directory=join(self.ROOT, 'saves'))\n        # Check if the loggers are consistent\n        logger_from_trainer = trainer._logger\n        logger_from_callback_engine = \\\n            next(iter(trainer.callbacks._callback_registry['end_of_training_iteration']))\n        self.assertIs(logger_from_trainer, logger_from_callback_engine)\n        self.assertIs(logger_from_callback_engine.trainer, trainer)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  },
  {
    "path": "tests/test_training/test_callbacks/test_logging/test_tensorboard.py",
    "content": "import unittest\n\nimport os\nfrom shutil import rmtree\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom inferno.trainers.basic import Trainer\nfrom torch.utils.data.dataset import TensorDataset\nfrom torch.utils.data.dataloader import DataLoader\nfrom inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger\nfrom inferno.extensions.layers.reshape import AsMatrix\n\n\nclass TestTensorboard(unittest.TestCase):\n    ROOT_DIR = os.path.dirname(__file__)\n    PRECISION = 'float'\n    SAVE_DIRECTORY = os.path.join(ROOT_DIR, 'saves')\n    LOG_DIRECTORY = os.path.join(ROOT_DIR, 'logs')\n\n    @staticmethod\n    def _make_test_model(input_channels):\n        toy_net = nn.Sequential(nn.Conv2d(input_channels, 8, 3, 1, 1),\n                                nn.ELU(),\n                                nn.MaxPool2d(2),\n                                nn.Conv2d(8, 8, 3, 1, 1),\n                                nn.ELU(),\n                                nn.MaxPool2d(2),\n                                nn.Conv2d(8, 16, 3, 1, 1),\n                                nn.ELU(),\n                                nn.AdaptiveMaxPool2d((1, 1)),\n                                AsMatrix(),\n                                nn.Linear(16, 10))\n        return toy_net\n\n    def tearDown(self):\n        for d in [self.SAVE_DIRECTORY, self.LOG_DIRECTORY]:\n            try:\n                rmtree(d)\n            except OSError:\n                pass\n\n    def get_random_dataloaders(self, input_channels=3):\n        # Convert build random tensor dataset\n        data_shape = (1, input_channels, 64, 64)\n        target_shape = (1)\n        random_array = torch.from_numpy(np.random.rand(*data_shape)).float()\n        target_array = torch.from_numpy(np.random.randint(0, 9, size=target_shape))\n        train_dataset = TensorDataset(random_array, target_array)\n        test_dataset = TensorDataset(random_array, target_array)\n\n        # Build dataloaders from dataset\n        train_loader = DataLoader(train_dataset, batch_size=1,\n                                  shuffle=True, num_workers=0, pin_memory=False)\n        test_loader = DataLoader(test_dataset, batch_size=1,\n                                 shuffle=True, num_workers=0, pin_memory=False)\n        return train_loader, test_loader\n\n    def get_trainer(self, input_channels):\n        # Build model\n        net = self._make_test_model(input_channels)\n        # Build trainer\n        trainer = Trainer(net)\\\n            .build_logger(TensorboardLogger(send_image_at_batch_indices=0,\n                                            send_image_at_channel_indices='all',\n                                            log_images_every=(20, 'iterations')),\n                          log_directory=self.LOG_DIRECTORY)\\\n            .build_criterion('CrossEntropyLoss')\\\n            .build_metric('CategoricalError')\\\n            .build_optimizer('Adam')\\\n            .validate_every((1, 'epochs'))\\\n            .save_every((2, 'epochs'), to_directory=self.SAVE_DIRECTORY)\\\n            .save_at_best_validation_score()\\\n            .set_max_num_epochs(2)\\\n            .set_precision(self.PRECISION)\n        # Bind loaders\n        train_loader, test_loader = self.get_random_dataloaders(input_channels=input_channels)\n        trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)\n        return trainer\n\n    def test_tensorboard(self):\n        trainer = self.get_trainer(3)\n        trainer.fit()\n\n    def test_tensorboard_grayscale(self):\n        trainer = self.get_trainer(1)\n        trainer.fit()\n\n    def test_serialization(self):\n        trainer = self.get_trainer(3)\n        # Serialize\n        trainer.save()\n        # Unserialize\n        trainer = Trainer().load(os.path.join(self.ROOT_DIR, 'saves'))\n        train_loader, test_loader = self.get_random_dataloaders(input_channels=3)\n        trainer.bind_loader('train', train_loader).bind_loader('validate', test_loader)\n        trainer.fit()\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_training/test_callbacks/test_scheduling.py",
    "content": "import unittest\nfrom inferno.trainers.callbacks.scheduling import ManualLR\nfrom torch import nn\nfrom torch.optim import Adam\n\n\nclass TestSchedulers(unittest.TestCase):\n\n    def test_manual_lr(self):\n        class DummyTrainer(object):\n            def __init__(self):\n                self.iteration_count = 0\n                self.epoch_count = 0\n                self.optimizer = Adam(nn.Linear(10, 10).parameters(), lr=1.)\n\n        manual_lr = ManualLR([((100, 'iterations'), 0.5),\n                              ((200, 'iterations'), 0.5),\n                              ((200, 'iterations'), 0.1)])\n        trainer = DummyTrainer()\n        manual_lr._trainer = trainer\n\n        manual_lr.end_of_training_iteration()\n        self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 1.)\n        trainer.iteration_count = 100\n        manual_lr.end_of_training_iteration()\n        self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.5)\n        trainer.iteration_count = 200\n        manual_lr.end_of_training_iteration()\n        self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025)\n        trainer.iteration_count = 300\n        self.assertEqual(trainer.optimizer.param_groups[0]['lr'], 0.025)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_utils/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
  },
  {
    "path": "tests/test_utils/test_model_utils.py",
    "content": "import unittest\nimport inferno.utils.model_utils as mu\nfrom inferno.utils.exceptions import ShapeError\nimport torch\nimport torch.nn as nn\n\n\nclass ModelUtilTester(unittest.TestCase):\n    def test_model_tester(self):\n        model = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32))(nn.Conv2d(10, 20, 3, padding=1))\n        with self.assertRaises(ShapeError):\n            mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32))(model)\n\n    @unittest.skipUnless(torch.cuda.is_available(), \"need cuda\")\n    def test_model_tester_cuda(self):\n        tester = mu.ModelTester((1, 10, 32, 32), (1, 20, 32, 32)).cuda()\n        model = tester(nn.Conv2d(10, 20, 3, padding=1).cuda())\n        with self.assertRaises(ShapeError):\n            mu.ModelTester((1, 10, 32, 32), (1, 30, 32, 32)).cuda()(model)\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_utils/test_partial_cls.py",
    "content": "import unittest\nimport inferno.utils.model_utils as mu\nfrom inferno.utils.partial_cls import register_partial_cls\nimport torch\nimport torch.nn as nn\n\n\nclass TestCls(object):\n    def __init__(self, a, b, c=1, d=2):\n        self.a = a\n        self.b = b\n        self.c = c\n        self.d = d\n\nclass PartialClsTester(unittest.TestCase):\n\n    def test_partial_cls(self):\n        register_partial_cls(TestCls, 'TestA', \n            fix=dict(a='a'),\n            default=dict(b='b'),\n            module=__name__\n        )\n        assert 'TestA' in globals()\n\n        inst = TestA()\n        assert inst.a == 'a'\n        assert inst.b == 'b'\n        assert inst.c == 1\n        assert inst.d == 2\n\n        inst = TestA('fu','bar','fubar')\n        assert inst.a == 'a'\n        assert inst.b == 'fu'\n        assert inst.c == 'bar'\n        assert inst.d == 'fubar'\n\n        with self.assertRaises(TypeError):\n            inst = TestA(a=2)\n\n    def test_update_existing_default_cls(self):\n        register_partial_cls(TestCls, 'TestA', \n            fix=dict(a='a'),\n            default=dict(d=3),\n            module=__name__\n        )\n        assert 'TestA' in globals()\n\n        inst = TestA(42)\n        assert inst.a == 'a'\n        assert inst.b == 42\n        assert inst.c == 1\n        assert inst.d == 3\n\n        with self.assertRaises(TypeError):\n            inst = TestA()\n\n    def test_fix_nothing(self):\n        register_partial_cls(TestCls, 'TestA',\n            module=__name__\n        )\n        assert 'TestA' in globals()\n\n        inst = TestA(1,2,3,4)\n        assert inst.a == 1\n        assert inst.b == 2\n        assert inst.c == 3\n        assert inst.d == 4\n\n        with self.assertRaises(TypeError):\n            inst = TestA()\n\n    def test_fix_all(self):\n        register_partial_cls(TestCls, 'TestA',\n            module=__name__,\n            fix=dict(a=4, b=3, c=2, d=1)\n        )\n        assert 'TestA' in globals()\n\n        inst = TestA()\n        assert inst.a == 4\n        assert inst.b == 3\n        assert inst.c == 2\n        assert inst.d == 1\n\n        with self.assertRaises(TypeError):\n            inst = TestA('a')\n\n        with self.assertRaises(TypeError):\n            inst = TestA(a=1)\n        with self.assertRaises(TypeError):\n            inst = TestA(b=1)\n        with self.assertRaises(TypeError):\n            inst = TestA(c=1)\n        with self.assertRaises(TypeError):\n            inst = TestA(d=1)\n\n\n    def test_default_all(self):\n        register_partial_cls(TestCls, 'TestA',\n            module=__name__,\n            default=dict(a=4, b=3, c=2, d=1)\n        )\n        assert 'TestA' in globals()\n\n        inst = TestA()\n        assert inst.a == 4\n        assert inst.b == 3\n        assert inst.c == 2\n        assert inst.d == 1\n\n\n        inst = TestA(2)\n        assert inst.a == 2\n        assert inst.b == 3\n        assert inst.c == 2\n        assert inst.d == 1\n\n        inst = TestA(2,3,4,5)\n        assert inst.a == 2\n        assert inst.b == 3\n        assert inst.c == 4\n        assert inst.d == 5\n\n        with self.assertRaises(TypeError):\n            inst = TestA(3,4,5,a=2)\n            \n        inst = TestA(3,4,5,d=2)\n        assert inst.a == 3\n        assert inst.b == 4\n        assert inst.c == 5\n        assert inst.d == 2\n\n\n      \n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/test_utils/test_train_utils.py",
    "content": "import unittest\nimport inferno.utils.train_utils as tu\nimport numpy as np\n\n\nclass FrequencyTest(unittest.TestCase):\n    def test_from_string(self):\n        frequency = tu.Frequency.from_string('10 epochs')\n        self.assertFalse(frequency.match(epoch_count=9))\n        self.assertTrue(frequency.match(epoch_count=10))\n        frequency = tu.Frequency.from_string('1 iteration')\n        self.assertEqual(frequency.units, 'iterations')\n        self.assertTrue(frequency.match(iteration_count=10))\n        frequency = tu.Frequency.from_string('never')\n        self.assertFalse(frequency.match(epoch_count=9))\n        frequency = tu.Frequency.from_string('inf epochs')\n        self.assertFalse(frequency.match(epoch_count=9))\n\n    def test_from_tuple(self):\n        frequency = tu.Frequency.build_from((np.inf, 'epoch'))\n        self.assertFalse(frequency.match(epoch_count=9))\n        self.assertFalse(frequency.match(epoch_count=10))\n\n    def test_is_consistent(self):\n        frequency = tu.Frequency.build_from('10 epochs')\n        frequency._units = 'banana'\n        self.assertFalse(frequency.is_consistent)\n\n    def test_init(self):\n        frequency = tu.Frequency()\n        self.assertEqual(frequency.value, np.inf)\n        self.assertEqual(frequency.units, frequency.UNIT_PRIORITY)\n\n    def test_duration(self):\n        duration = tu.Duration.build_from((3, 'iterations'))\n        self.assertFalse(duration.match(iteration_count=2))\n        self.assertFalse(duration.match(iteration_count=3))\n        self.assertTrue(duration.match(iteration_count=3, when_equal_return=True))\n        self.assertTrue(duration.match(iteration_count=4))\n        self.assertEqual(duration.compare(iteration_count=1, epoch_count=3).get('iterations'),\n                         2)\n        with self.assertRaises(ValueError):\n            duration.match(epoch_count=2)\n\n\nif __name__ == '__main__':\n    unittest.main()"
  }
]