Showing preview only (658K chars total). Download the full file or copy to clipboard to get everything.
Repository: inferno-pytorch/inferno
Branch: master
Commit: 789c6d00b34c
Files: 179
Total size: 610.4 KB
Directory structure:
gitextract_cy4q7gy0/
├── .editorconfig
├── .github/
│ └── ISSUE_TEMPLATE.md
├── .gitignore
├── .travis.yml
├── AUTHORS.rst
├── CONTRIBUTING.rst
├── HISTORY.rst
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.rst
├── add2path.sh
├── build_docs.sh
├── conda-recipe/
│ ├── build.sh
│ └── meta.yaml
├── docs/
│ ├── .gitignore
│ ├── Makefile
│ ├── _templates/
│ │ ├── layout.html
│ │ └── template_module.rst
│ ├── authors.rst
│ ├── conf.py
│ ├── contributing.rst
│ ├── environment.yml
│ ├── examples.rst
│ ├── history.rst
│ ├── index.rst
│ ├── inferno-apidoc/
│ │ ├── inferno.extensions.containers.rst
│ │ ├── inferno.extensions.criteria.rst
│ │ ├── inferno.extensions.initializers.rst
│ │ ├── inferno.extensions.layers.rst
│ │ ├── inferno.extensions.metrics.rst
│ │ ├── inferno.extensions.optimizers.rst
│ │ ├── inferno.extensions.rst
│ │ ├── inferno.io.box.rst
│ │ ├── inferno.io.core.rst
│ │ ├── inferno.io.rst
│ │ ├── inferno.io.transform.rst
│ │ ├── inferno.io.volumetric.rst
│ │ ├── inferno.rst
│ │ ├── inferno.trainers.callbacks.logging.rst
│ │ ├── inferno.trainers.callbacks.rst
│ │ ├── inferno.trainers.rst
│ │ ├── inferno.utils.rst
│ │ └── modules.rst
│ ├── installation.rst
│ ├── make.bat
│ ├── readme.rst
│ ├── refs.bib
│ ├── usage.rst
│ └── zbibliography.rst
├── examples/
│ ├── README.txt
│ ├── plot_cheap_unet.py
│ ├── plot_train_side_loss_unet.py
│ ├── plot_unet_tutorial.py
│ ├── regularized_mnist.py
│ └── trainer.py
├── inferno/
│ ├── __init__.py
│ ├── extensions/
│ │ ├── __init__.py
│ │ ├── containers/
│ │ │ ├── __init__.py
│ │ │ ├── graph.py
│ │ │ └── sequential.py
│ │ ├── criteria/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── elementwise_measures.py
│ │ │ ├── regularized.py
│ │ │ └── set_similarity_measures.py
│ │ ├── initializers/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── presets.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── convolutional.py
│ │ │ ├── convolutional_blocks.py
│ │ │ ├── device.py
│ │ │ ├── identity.py
│ │ │ ├── normalization.py
│ │ │ ├── reshape.py
│ │ │ └── sampling.py
│ │ ├── metrics/
│ │ │ ├── __init__.py
│ │ │ ├── arand.py
│ │ │ ├── base.py
│ │ │ ├── categorical.py
│ │ │ ├── cremi_score.py
│ │ │ └── voi.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── res_unet.py
│ │ │ └── unet.py
│ │ └── optimizers/
│ │ ├── __init__.py
│ │ ├── adam.py
│ │ ├── annealed_adam.py
│ │ └── ranger.py
│ ├── inferno.py
│ ├── io/
│ │ ├── __init__.py
│ │ ├── box/
│ │ │ ├── __init__.py
│ │ │ ├── binary_blobs.py
│ │ │ ├── camvid.py
│ │ │ ├── cifar.py
│ │ │ └── cityscapes.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── concatenate.py
│ │ │ ├── data_utils.py
│ │ │ └── zip.py
│ │ ├── transform/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── generic.py
│ │ │ ├── image.py
│ │ │ └── volume.py
│ │ └── volumetric/
│ │ ├── __init__.py
│ │ ├── lazy_volume_loader.py
│ │ ├── volume.py
│ │ └── volumetric_utils.py
│ ├── trainers/
│ │ ├── __init__.py
│ │ ├── basic.py
│ │ └── callbacks/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── console.py
│ │ ├── essentials.py
│ │ ├── gradients.py
│ │ ├── logging/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── tensorboard.py
│ │ ├── scheduling.py
│ │ ├── tqdm.py
│ │ └── tqdmstub.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── exceptions.py
│ │ ├── io_utils.py
│ │ ├── math_utils.py
│ │ ├── model_utils.py
│ │ ├── partial_cls.py
│ │ ├── python_utils.py
│ │ ├── test_utils.py
│ │ ├── torch_utils.py
│ │ └── train_utils.py
│ └── version.py
├── readthedocs.yml
├── requirements.txt
├── requirements_dev.txt
├── setup.py
└── tests/
├── __init__.py
├── test_extensions/
│ ├── __init__.py
│ ├── test_containers/
│ │ └── test_graph.py
│ ├── test_criteria/
│ │ ├── test_core.py
│ │ ├── test_elementwise_measures.py
│ │ └── test_set_similarity_measures.py
│ ├── test_layers/
│ │ ├── deprecated/
│ │ │ └── building_blocks.py
│ │ ├── test_activations.py
│ │ ├── test_convolutional.py
│ │ ├── test_device.py
│ │ └── test_reshape.py
│ ├── test_metrics/
│ │ └── categorical.py
│ └── test_models/
│ ├── __init__.py
│ ├── test_res_unet.py
│ └── test_unet.py
├── test_inferno.py
├── test_io/
│ ├── __init__.py
│ ├── test_box/
│ │ ├── __init__.py
│ │ ├── test_camvid.py
│ │ └── test_cityscapes.py
│ ├── test_core/
│ │ ├── __init__.py
│ │ ├── test_concatenate.py
│ │ └── test_zip.py
│ └── test_volumetric/
│ ├── __init__.py
│ ├── test_lazy_volume_loader.py
│ └── test_volume_loader.py
├── test_training/
│ ├── __init__.py
│ ├── test_basic.py
│ └── test_callbacks/
│ ├── __init__.py
│ ├── test_base.py
│ ├── test_essentials.py
│ ├── test_logging/
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ └── test_tensorboard.py
│ └── test_scheduling.py
└── test_utils/
├── __init__.py
├── test_model_utils.py
├── test_partial_cls.py
└── test_train_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .editorconfig
================================================
# http://editorconfig.org
root = true
[*]
indent_style = space
indent_size = 4
trim_trailing_whitespace = true
insert_final_newline = true
charset = utf-8
end_of_line = lf
[*.bat]
indent_style = tab
end_of_line = crlf
[LICENSE]
insert_final_newline = false
[Makefile]
indent_style = tab
================================================
FILE: .github/ISSUE_TEMPLATE.md
================================================
* inferno version:
* Python version:
* Operating System:
### Description
Describe what you were trying to get done.
Tell us what happened, what went wrong, and what you expected to happen.
### What I Did
```
Paste the command(s) you ran and the output.
If there was a crash, please include the traceback here.
```
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
*.egg-info/
.installed.cfg
*.egg
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/
# Translations
*.mo
*.pot
# Django stuff:
*.log
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# pyenv python configuration file
.python-version
================================================
FILE: .travis.yml
================================================
language: python
dist: xenial
python:
- 3.7
env:
- PYTORCH_CONDA="pytorch" TORCHVISION_CONDA="torchvision" TORCHVISION_CHANNEL=pytorch
install:
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- source activate test-environment
- conda install -c conda-forge networkx h5py scikit-image pyyaml dill tensorboardx
- conda install -c pytorch $PYTORCH_CONDA
- conda install -c $TORCHVISION_CHANNEL $TORCHVISION_CONDA
deploy:
provider: pypi
user: nasimrahaman
password:
secure: !!binary |
bWwzZitLcEpibHBaUWNhUVA4UUlGa2JsZDQxVkx3eFlkY1FiYWJqYkFvWm5pdDErRzlKRXZFM0hR
ZE15V0tIWm5JQlJRSGlveXdYNjAzQVc1UFV3ZjNBOG0zc21vK3RaZjVSYnM5aE5ySE93ajBXc1N4
akNHNGhOSnF6UnBDY2kwakxPeWhxaEwxQkR0empSaFdJbWVlOE81RDVPY2pSdGw1TDQ3QjhwVGor
TVREdlpSYTVFd2xNNXdadTJYWFVXL3ZQY0VLZE9xckFoVk5PSHpkTTh5MGM1S1lHaS9nNThVK2JO
OVp5RkFROVpuOEY3YmxPdzBQZnAvL202ZUkxamlKSmxhaE13UU4zV2tJRWRpNklVSTE0RUp1ck5s
Q28xL2kzNER0dGVkZzI0eVhULzcxRFl5Y0pZQWMrcWtoa1VVVUo4NEZKV3JjUjNqTnF5bVI3Ykty
cFJrR3JydjV0dUpGUnBhc2NIdEdKVUswMkdJWEJUc3JJWGg4bS9oRGtMaVJaMExBeitJQWR4b2tF
MzB0OWppZ0x5VXFSMmxnVmNvZERzRWZMRnJEMTBHeTJVS2FueVhlYmpsck9qK3V5S1dtZm5UTXg4
bGNzN09HWEZiUmo2K0ZuYTg5a00xN3poSXhzc3pSMnRGSVJwamV4a0gzZUpyZlpYY1daTFZ3QnV0
clUwZW10VEsxeGFmOGFjNTd3Wll1R3JXNEZJT1h2bmxoeS9pV0FMVlE4YnVFZFFjQnJ5YWFiRjUy
RkZvZk1SUnp3aDFhZ3Q3cUxVa0FIbXVuZ1NYQWZxMUlOTkVNYXRTcFVJUURJM3huWmNPeTNhSWFP
YkVpSlFHY1lrWlhXZ1Z2cVdvcktPOW53a29Hem5BSm1HRVZHYU11dDYwaGg2SGU1MVJPTll3WHc9
on:
all_branches: false
tags: true
script:
- source activate test-environment
- python setup.py install
- python -m unittest discover -s tests -v
================================================
FILE: AUTHORS.rst
================================================
=======
Credits
=======
Development Lead
----------------
* `Nasim Rahaman <https://github.com/nasimrahaman>`_ @ `Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ , `Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
Contributors
------------
In no particular order,
* `Steffen Wolf <https://github.com/Steffen-Wolf>`_ @
`Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,
`Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
* `Maurice Weiler <https://github.com/mauriceweiler>`_ @
`Amsterdam Machine Learning Lab <http://amlab.science.uva.nl/>`_ ,
`University of Amsterdam <http://www.uva.nl/en/home>`_ ,
* `Constantin Pape <https://github.com/constantinpape>`_ @
`Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,
`Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
* `Sven Peter <https://github.com/svenpeter42>`_ @
`Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,
`Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
* `Manuel Haussmann <https://github.com/manuelhaussmann>`_ @
`Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,
`Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
* `Thorsten Beier <https://github.com/DerThorsten>`_ @
`Image Analysis and Learning Lab <https://hci.iwr.uni-heidelberg.de/mip>`_ ,
`Heidelberg Collaboratory for Image Processing <https://hci.iwr.uni-heidelberg.de/>`_ ,
* `Benjamin Striner <https://github.com/bstriner>`_ @
`Machine Learning Department <https://www.ml.cmu.edu/>`_ ,
`Carnegie Mellon University <https://www.cmu.edu/>`_ ,
================================================
FILE: CONTRIBUTING.rst
================================================
.. highlight:: shell
============
Contributing
============
Contributions are welcome, and they are greatly appreciated! Every
little bit helps, and credit will always be given.
You can contribute in many ways:
Types of Contributions
----------------------
Report Bugs
~~~~~~~~~~~
Report bugs at https://github.com/nasimrahaman/inferno/issues.
If you are reporting a bug, please include:
* Your operating system name and version.
* Any details about your local setup that might be helpful in troubleshooting.
* Detailed steps to reproduce the bug.
Fix Bugs
~~~~~~~~
Look through the GitHub issues for bugs. Anything tagged with "bug"
and "help wanted" is open to whoever wants to implement it.
Implement Features
~~~~~~~~~~~~~~~~~~
Look through the GitHub issues for features. Anything tagged with "enhancement"
and "help wanted" is open to whoever wants to implement it.
Write Documentation
~~~~~~~~~~~~~~~~~~~
inferno could always use more documentation, whether as part of the
official inferno docs, in docstrings, or even on the web in blog posts,
articles, and such.
Submit Feedback
~~~~~~~~~~~~~~~
The best way to send feedback is to file an issue at https://github.com/nasimrahaman/inferno/issues.
If you are proposing a feature:
* Explain in detail how it would work.
* Keep the scope as narrow as possible, to make it easier to implement.
* Remember that this is a volunteer-driven project, and that contributions
are welcome :)
Get Started!
------------
Ready to contribute? Here's how to set up `inferno` for local development.
1. Fork the `inferno` repo on GitHub.
2. Clone your fork locally::
$ git clone git@github.com:your_name_here/inferno.git
3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::
$ mkvirtualenv inferno
$ cd inferno/
$ python setup.py develop
4. Create a branch for local development::
$ git checkout -b name-of-your-bugfix-or-feature
Now you can make your changes locally.
5. When you're done making changes, check that your changes pass flake8 and the tests, including testing other Python versions with tox::
$ flake8 inferno tests
$ python setup.py test or py.test
$ tox
To get flake8 and tox, just pip install them into your virtualenv.
6. Commit your changes and push your branch to GitHub::
$ git add .
$ git commit -m "Your detailed description of your changes."
$ git push origin name-of-your-bugfix-or-feature
7. Submit a pull request through the GitHub website.
Pull Request Guidelines
-----------------------
Before you submit a pull request, check that it meets these guidelines:
1. The pull request should include tests.
2. If the pull request adds functionality, the docs should be updated. Put
your new functionality into a function with a docstring, and add the
feature to the list in README.rst.
3. The pull request should work for Python 3.5 and 3.6. Check
https://travis-ci.org/nasimrahaman/inferno/pull_requests
and make sure that the tests pass for all supported Python versions.
Tips
----
To run a subset of tests::
$ python -m unittest tests.test_inferno
Sphinx Apidoc
--------------
before building the documentation
one needs to generate the auto-generated
sphinxs api documentation.
These files need to be in the github repository.
.. code:: bash
cd docs
sphinx-apidoc -o inferno-apidoc ../inferno
.. warning::
Do not make changes to `inferno/docs/inferno-apidoc` This folder is auto-generated
by the above mentioned command.
The following combines all the commands necessary to build the html documentation:
.. code:: bash
./build_docs.sh
================================================
FILE: HISTORY.rst
================================================
=======
History
=======
0.1.0 (2017-08-24)
------------------
* First early release on PyPI
0.1.1 (2017-08-24)
------------------
* Version Increment
0.1.2 (2017-08-24)
------------------
* Version Increment
0.1.3 (2017-08-24)
------------------
* Updated Documentation
0.1.4 (2017-08-24)
------------------
* travis auto-deployment on pypi
0.1.5 (2017-08-24)
------------------
* travis changes to run unittest
0.1.6 (2017-08-24)
------------------
* travis missing packages for unittesting
* fixed inconsistent version numbers
0.1.7 (2017-08-25)
------------------
* setup.py critical bugix in install procedure
CURRENT CHANGES
-----------------
* Flexible Unet
================================================
FILE: LICENSE
================================================
Apache Software License 2.0
Copyright (c) 2017, Inferno Developers
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: MANIFEST.in
================================================
include AUTHORS.rst
include CONTRIBUTING.rst
include HISTORY.rst
include LICENSE
include README.rst
recursive-include tests *
recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif
================================================
FILE: Makefile
================================================
.PHONY: clean clean-test clean-pyc clean-build docs help
.DEFAULT_GOAL := help
define BROWSER_PYSCRIPT
import os, webbrowser, sys
try:
from urllib import pathname2url
except:
from urllib.request import pathname2url
webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1])))
endef
export BROWSER_PYSCRIPT
define PRINT_HELP_PYSCRIPT
import re, sys
for line in sys.stdin:
match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line)
if match:
target, help = match.groups()
print("%-20s %s" % (target, help))
endef
export PRINT_HELP_PYSCRIPT
BROWSER := python -c "$$BROWSER_PYSCRIPT"
help:
@python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST)
clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts
clean-build: ## remove build artifacts
rm -fr build/
rm -fr dist/
rm -fr .eggs/
find . -name '*.egg-info' -exec rm -fr {} +
find . -name '*.egg' -exec rm -f {} +
clean-pyc: ## remove Python file artifacts
find . -name '*.pyc' -exec rm -f {} +
find . -name '*.pyo' -exec rm -f {} +
find . -name '*~' -exec rm -f {} +
find . -name '__pycache__' -exec rm -fr {} +
clean-test: ## remove test and coverage artifacts
rm -fr .tox/
rm -f .coverage
rm -fr htmlcov/
lint: ## check style with flake8
flake8 inferno tests
test: ## run tests quickly with the default Python
python setup.py test
test-all: ## run tests on every Python version with tox
tox
coverage: ## check code coverage quickly with the default Python
coverage run --source inferno setup.py test
coverage report -m
coverage html
$(BROWSER) htmlcov/index.html
docs: ## generate Sphinx HTML documentation, including API docs
rm -f docs/inferno.rst
rm -f docs/modules.rst
sphinx-apidoc -o docs/ inferno
$(MAKE) -C docs clean
$(MAKE) -C docs html
$(BROWSER) docs/_build/html/index.html
servedocs: docs ## compile the docs watching for changes
watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D .
release: clean ## package and upload a release
python setup.py sdist upload
python setup.py bdist_wheel upload
dist: clean ## builds source and wheel package
python setup.py sdist
python setup.py bdist_wheel
ls -l dist
install: clean ## install the package to the active Python's site-packages
python setup.py install
================================================
FILE: README.rst
================================================
=======
Inferno
=======
.. image:: https://anaconda.org/conda-forge/inferno/badges/version.svg
:target: https://anaconda.org/conda-forge/inferno
.. image:: https://travis-ci.org/inferno-pytorch/inferno.svg?branch=master
:target: https://travis-ci.org/inferno-pytorch/inferno
..
TODO new docs shield goes here, see https://github.com/inferno-pytorch/inferno/issues/139
.. image:: https://readthedocs.org/projects/inferno-pytorch/badge/?version=latest
:target: http://inferno-pytorch.readthedocs.io/en/latest/?badge=latest
:alt: Documentation Status
.. image:: http://svgshare.com/i/2j7.svg
Inferno is a little library providing utilities and convenience functions/classes around
`PyTorch <https://github.com/pytorch/pytorch>`_.
It's a work-in-progress, but the releases from v0.4 on should be fairly stable!
* Free software: Apache Software License 2.0
* Documentation: http://inferno-pytorch.readthedocs.io (Work in Progress).
Features
--------
Current features include:
* a basic
`Trainer class <https://github.com/nasimrahaman/inferno/tree/master/docs#preparing-the-trainer>`_
to encapsulate the training boilerplate (iteration/epoch loops, validation and checkpoint creation),
* a `graph API <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/containers/graph.py>`_ for building models with complex architectures, powered by `networkx <https://github.com/networkx/networkx>`_.
* `easy data-parallelism <https://github.com/nasimrahaman/inferno/tree/master/docs#using-gpus>`_ over multiple GPUs,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/extensions/initializers>`_ for `torch.nn.Module`-level parameter initialization,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/transform>`_ for data preprocessing / transforms,
* `support <https://github.com/nasimrahaman/inferno/tree/master/docs#using-tensorboard>`_ for `Tensorboard <https://www.tensorflow.org/get_started/summaries_and_tensorboard>`_ (best with atleast `tensorflow-cpu <https://github.com/tensorflow/tensorflow>`_ installed)
* `a callback API <https://github.com/nasimrahaman/inferno/tree/master/docs#setting-up-callbacks>`_ to enable flexible interaction with the trainer,
* `various utility layers <https://github.com/nasimrahaman/inferno/tree/master/inferno/extensions/layers>`_ with more underway,
* `a submodule <https://github.com/nasimrahaman/inferno/blob/master/inferno/io/volumetric>`_ for volumetric datasets, and more!
.. code:: python
import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten
# Fill these in:
LOG_DIRECTORY = '...'
SAVE_DIRECTORY = '...'
DATASET_DIRECTORY = '...'
DOWNLOAD_CIFAR = True
USE_CUDA = True
# Build torch model
model = nn.Sequential(
ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(in_features=(256 * 4 * 4), out_features=10),
nn.LogSoftmax(dim=1)
)
# Load loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
download=DOWNLOAD_CIFAR)
# Build trainer
trainer = Trainer(model) \
.build_criterion('NLLLoss') \
.build_metric('CategoricalError') \
.build_optimizer('Adam') \
.validate_every((2, 'epochs')) \
.save_every((5, 'epochs')) \
.save_to_directory(SAVE_DIRECTORY) \
.set_max_num_epochs(10) \
.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every='never'),
log_directory=LOG_DIRECTORY)
# Bind loaders
trainer \
.bind_loader('train', train_loader) \
.bind_loader('validate', validate_loader)
if USE_CUDA:
trainer.cuda()
# Go!
trainer.fit()
To visualize the training progress, navigate to `LOG_DIRECTORY` and fire up tensorboard with
.. code:: bash
$ tensorboard --logdir=${PWD} --port=6007
and navigate to `localhost:6007` with your browser.
Installation
------------------------
Conda packages for python >= 3.6 for all distributions are availaible on conda-forge:
.. code:: bash
$ conda install -c pytorch -c conda-forge inferno
Future Features:
------------------------
Planned features include:
* a class to encapsulate Hogwild! training over multiple GPUs,
* minimal shape inference with a dry-run,
* proper packaging and documentation,
* cutting-edge fresh-off-the-press implementations of what the future has in store. :)
Credits
---------
All contributors are listed here_.
.. _here: https://inferno-pytorch.github.io/inferno/html/authors.html
This package was partially generated with Cookiecutter_ and the `audreyr/cookiecutter-pypackage`_ project template + lots of work by Thorsten.
.. _Cookiecutter: https://github.com/audreyr/cookiecutter
.. _`audreyr/cookiecutter-pypackage`: https://github.com/audreyr/cookiecutter-pypackage
================================================
FILE: add2path.sh
================================================
#!/usr/bin/env bash
# Run this script from within the directory.
export PYTHONPATH=${PYTHONPATH}:${PWD}
================================================
FILE: build_docs.sh
================================================
#!/bin/bash
cd docs
rm -r -f inferno-apidoc
sphinx-apidoc -o inferno-apidoc ../inferno
make html
cd ..
================================================
FILE: conda-recipe/build.sh
================================================
PY_VER=$(python -c "import sys; print('{}.{}'.format(*sys.version_info[:2]))")
# Install python modules
mkdir -p ${PREFIX}/inferno
cp -r inferno/* ${PREFIX}/inferno
echo "${PREFIX}" > ${PREFIX}/lib/python${PY_VER}/site-packages/inferno.pth
python -m compileall ${PREFIX}/inferno
================================================
FILE: conda-recipe/meta.yaml
================================================
package:
name: inferno
{% set tagged_version = GIT_DESCRIBE_TAG|replace("v","")|replace("-", ".") %}
# If we're using a non-tagged revision, append '.postN' to the version
{% if GIT_DESCRIBE_NUMBER|int != 0 %}
{% set tagged_version = tagged_version + '.post' + GIT_DESCRIBE_NUMBER %}
{% endif %}
version: {{tagged_version}}
source:
path: ..
build:
number: 1
string: py_{{PKG_BUILDNUM}}_g{{GIT_FULL_HASH[:7]}}
requirements:
build:
- python {{PY_VER}}*
run:
- python {{PY_VER}}*
- pytorch
- torchvision
- pyyaml
- scipy
- scikit-image
- scikit-learn
- h5py
- dill
- networkx 1.11
- tensorboardx
- sphinx_rtd_theme
test:
imports:
- inferno
about:
license: Apache License 2.0
summary: A utility library around PyTorch
================================================
FILE: docs/.gitignore
================================================
/inferno.rst
/inferno.*.rst
/modules.rst
================================================
FILE: docs/Makefile
================================================
# Makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
PAPER =
BUILDDIR = _build
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
$(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/)
endif
# Internal variables.
PAPEROPT_a4 = -D latex_paper_size=a4
PAPEROPT_letter = -D latex_paper_size=letter
ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
# the i18n builder cannot share the environment and doctrees with the others
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
help:
@echo "Please use \`make <target>' where <target> is one of"
@echo " html to make standalone HTML files"
@echo " dirhtml to make HTML files named index.html in directories"
@echo " singlehtml to make a single large HTML file"
@echo " pickle to make pickle files"
@echo " json to make JSON files"
@echo " htmlhelp to make HTML files and a HTML help project"
@echo " qthelp to make HTML files and a qthelp project"
@echo " devhelp to make HTML files and a Devhelp project"
@echo " epub to make an epub"
@echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter"
@echo " latexpdf to make LaTeX files and run them through pdflatex"
@echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx"
@echo " text to make text files"
@echo " man to make manual pages"
@echo " texinfo to make Texinfo files"
@echo " info to make Texinfo files and run them through makeinfo"
@echo " gettext to make PO message catalogs"
@echo " changes to make an overview of all changed/added/deprecated items"
@echo " xml to make Docutils-native XML files"
@echo " pseudoxml to make pseudoxml-XML files for display purposes"
@echo " linkcheck to check all external links for integrity"
@echo " doctest to run all doctests embedded in the documentation (if enabled)"
clean:
rm -rf $(BUILDDIR)/*
html:
$(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
dirhtml:
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml."
singlehtml:
$(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml
@echo
@echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml."
pickle:
$(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle
@echo
@echo "Build finished; now you can process the pickle files."
json:
$(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json
@echo
@echo "Build finished; now you can process the JSON files."
htmlhelp:
$(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp
@echo
@echo "Build finished; now you can run HTML Help Workshop with the" \
".hhp project file in $(BUILDDIR)/htmlhelp."
qthelp:
$(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp
@echo
@echo "Build finished; now you can run "qcollectiongenerator" with the" \
".qhcp project file in $(BUILDDIR)/qthelp, like this:"
@echo "# qcollectiongenerator $(BUILDDIR)/qthelp/inferno.qhcp"
@echo "To view the help file:"
@echo "# assistant -collectionFile $(BUILDDIR)/qthelp/inferno.qhc"
devhelp:
$(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp
@echo
@echo "Build finished."
@echo "To view the help file:"
@echo "# mkdir -p $$HOME/.local/share/devhelp/inferno"
@echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/inferno"
@echo "# devhelp"
epub:
$(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub
@echo
@echo "Build finished. The epub file is in $(BUILDDIR)/epub."
latex:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo
@echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex."
@echo "Run \`make' in that directory to run these through (pdf)latex" \
"(use \`make latexpdf' here to do that automatically)."
latexpdf:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through pdflatex..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
latexpdfja:
$(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex
@echo "Running LaTeX files through platex and dvipdfmx..."
$(MAKE) -C $(BUILDDIR)/latex all-pdf-ja
@echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex."
text:
$(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text
@echo
@echo "Build finished. The text files are in $(BUILDDIR)/text."
man:
$(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man
@echo
@echo "Build finished. The manual pages are in $(BUILDDIR)/man."
texinfo:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo
@echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo."
@echo "Run \`make' in that directory to run these through makeinfo" \
"(use \`make info' here to do that automatically)."
info:
$(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo
@echo "Running Texinfo files through makeinfo..."
make -C $(BUILDDIR)/texinfo info
@echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo."
gettext:
$(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale
@echo
@echo "Build finished. The message catalogs are in $(BUILDDIR)/locale."
changes:
$(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes
@echo
@echo "The overview file is in $(BUILDDIR)/changes."
linkcheck:
$(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck
@echo
@echo "Link check complete; look for any errors in the above output " \
"or in $(BUILDDIR)/linkcheck/output.txt."
doctest:
$(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest
@echo "Testing of doctests in the sources finished, look at the " \
"results in $(BUILDDIR)/doctest/output.txt."
xml:
$(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml
@echo
@echo "Build finished. The XML files are in $(BUILDDIR)/xml."
pseudoxml:
$(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml
@echo
@echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml."
================================================
FILE: docs/_templates/layout.html
================================================
{# layout.html #}
{# Import the theme's layout. #}
{% extends "!layout.html" %}
{% set css_files = css_files + ['_static/pygments.css'] %}
================================================
FILE: docs/_templates/template_module.rst
================================================
{{ fullname }}
{{ underline }}
.. automodule:: {{ fullname }}
{% block functions %}
{% if functions %}
Functions
==================
{% for item in functions %}
.. autofunction:: {{ item }}
.. include:: backreferences/{{fullname}}.{{item}}.examples
.. raw:: html
<div style='clear:both'></div>
{%- endfor %}
{% endif %}
{% endblock %}
{% block classes %}
{% if classes %}
Classes
-------
{% for item in classes %}
.. autoclass:: {{ item }}
:members:
.. include:: backreferences/{{fullname}}.{{item}}.examples
.. raw:: html
<div style='clear:both'></div>
{%- endfor %}
{% endif %}
{% endblock %}
{% block exceptions %}
{% if exceptions %}
Exceptions
----------
.. autosummary::
{% for item in exceptions %}
{{ item }}
{%- endfor %}
{% endif %}
{% endblock %}
================================================
FILE: docs/authors.rst
================================================
.. include:: ../AUTHORS.rst
================================================
FILE: docs/conf.py
================================================
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# inferno documentation build configuration file, created by
# sphinx-quickstart on Tue Jul 9 22:26:36 2013.
#
# This file is execfile()d with the current directory set to its
# containing dir.
#
# Note that not all possible configuration values are present in this
# autogenerated file.
#
# All configuration values have a default; values that are commented out
# serve to show the default.
import matplotlib
matplotlib.use('Agg')
import sphinx_gallery
import sys
from unittest.mock import MagicMock
class Mock(MagicMock):
@classmethod
def __getattr__(cls, name):
return MagicMock()
# MOCK_MODULES = ['pygtk',
# 'hdf5',
# 'skimage',
# 'argparse',
# 'pandas',
# 'torch',
# 'torch.nn', 'torch.nn.init', 'torch.nn.functional',
# 'torch.nn.parallel', 'torch.nn.parallel.data_parallel',
# 'torch.multiprocessing', 'torch.autograd',
# 'torch.utils', 'torch.utils.data',
# 'torch.optim', 'torch.sparse', 'torch.cuda']
# sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES)
import os
# If extensions (or modules to document with autodoc) are in another
# directory, add these directories to sys.path here. If the directory is
# relative to the documentation root, use os.path.abspath to make it
# absolute, like shown here.
#sys.path.insert(0, os.path.abspath('.'))
# Get the project root dir, which is the parent dir of this
cwd = os.getcwd()
project_root = os.path.dirname(cwd)
# Insert the project root dir as the first element in the PYTHONPATH.
# This lets us ensure that the source package is imported, and that its
# version is used.
sys.path.insert(0, project_root)
import inferno
import inferno.extensions
import inferno.extensions.layers
from inferno.extensions.layers import *
from inferno.extensions.layers.reshape import *
# -- General configuration ---------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.todo',
'sphinx.ext.ifconfig',
'sphinx.ext.mathjax',
'sphinx.ext.graphviz',
'sphinx_gallery.gen_gallery',
'sphinxcontrib.bibtex',
'sphinx.ext.napoleon',
'sphinxcontrib.inlinesyntaxhighlight'
]
sphinx_gallery_conf = {
# path to your examples scripts
'examples_dirs' :
'../examples',
# path where to save gallery generated examples
'gallery_dirs' :
'auto_examples',
'backreferences_dir' :
'gen_modules/backreferences',
'scan_used_functions':
True,
'doc_module' :
('inferno','inferno.extensions','inferno.extensions.layers','inferno.extensions.layers.convolutional'),
'docs_resolv': True,
'parallel_read_safe': True,
'reference_url': {
# The module you locally document uses a None
'inferno': None,
# External python modules use their documentation websites
#'matplotlib': 'http://matplotlib.org',
'numpy': 'http://docs.scipy.org/doc/numpy-1.13.0'}
}
# Napoleon settings
napoleon_google_docstring = True
napoleon_numpy_docstring = True
napoleon_include_init_with_doc = False
napoleon_include_private_with_doc = False
napoleon_include_special_with_doc = True
napoleon_use_admonition_for_examples = False
napoleon_use_admonition_for_notes = False
napoleon_use_admonition_for_references = False
napoleon_use_ivar = False
napoleon_use_param = True
napoleon_use_rtype = True
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# generate autosummary even if no references
autosummary_generate = True
# The suffix of source filenames.
source_suffix = '.rst'
# The encoding of source files.
#source_encoding = 'utf-8-sig'
# The master toctree document.
master_doc = 'index'
# General information about the project.
project = u'inferno'
copyright = u"2018, f"
# The version info for the project you're documenting, acts as replacement
# for |version| and |release|, also used in various other places throughout
# the built documents.
#
# The short X.Y version.
version = inferno.__version__
# The full version, including alpha/beta/rc tags.
release = inferno.__version__
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#language = None
# There are two options for replacing |today|: either, you set today to
# some non-false value, then it is used:
#today = ''
# Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
exclude_patterns = ['_build']
# The reST default role (used for this markup: `text`) to use for all
# documents.
#default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True
# If true, the current module name will be prepended to all description
# unit titles (such as .. function::).
#add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
#show_authors = False
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# A list of ignored prefixes for module index sorting.
#modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built
# documents.
#keep_warnings = False
# -- Options for HTML output -------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a
# theme further. For a list of options available for each theme, see the
# documentation.
#html_theme_options = {}
# Add any paths that contain custom themes here, relative to this directory.
#html_theme_path = []
# The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation".
#html_title = None
# A shorter title for the navigation bar. Default is the same as
# html_title.
#html_short_title = None
# The name of an image file (relative to this directory) to place at the
# top of the sidebar.
#html_logo = None
# The name of an image file (within the static path) to use as favicon
# of the docs. This file should be a Windows icon file (.ico) being
# 16x16 or 32x32 pixels large.
#html_favicon = None
# Add any paths that contain custom static files (such as style sheets)
# here, relative to this directory. They are copied after the builtin
# static files, so a file named "default.css" will overwrite the builtin
# "default.css".
html_static_path = ['_static']
# If not '', a 'Last updated on:' timestamp is inserted at every page
# bottom, using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities.
#html_use_smartypants = True
# Custom sidebar templates, maps document names to template names.
#html_sidebars = {}
# Additional templates that should be rendered to pages, maps page names
# to template names.
#html_additional_pages = {}
# If false, no module index is generated.
#html_domain_indices = True
# If false, no index is generated.
#html_use_index = True
# If true, the index is split into individual pages for each letter.
#html_split_index = False
# If true, links to the reST sources are added to the pages.
#html_show_sourcelink = True
# If true, "Created using Sphinx" is shown in the HTML footer.
# Default is True.
#html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer.
# Default is True.
#html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages
# will contain a <link> tag referring to it. The value of this option
# must be the base URL from which the finished HTML is served.
#html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None
# Output file base name for HTML help builder.
htmlhelp_basename = 'infernodoc'
# -- Options for LaTeX output ------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass
# [howto/manual]).
latex_documents = [
('index', 'inferno.tex',
u'inferno Documentation',
u'Inferno Team', 'manual'),
]
# The name of an image file (relative to this directory) to place at
# the top of the title page.
#latex_logo = None
# For "manual" documents, if this is true, then toplevel headings
# are parts, not chapters.
#latex_use_parts = False
# If true, show page references after internal links.
#latex_show_pagerefs = False
# If true, show URL addresses after external links.
#latex_show_urls = False
# Documents to append as an appendix to all manuals.
#latex_appendices = []
# If false, no module index is generated.
#latex_domain_indices = True
# -- Options for manual page output ------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'inferno',
u'inferno Documentation',
[u'Inferno Team'], 1)
]
# If true, show URL addresses after external links.
#man_show_urls = False
# -- Options for Texinfo output ----------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'inferno',
u'inferno Documentation',
u'Inferno Team',
'inferno',
'One line description of project.',
'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.
#texinfo_appendices = []
# If false, no module index is generated.
#texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
================================================
FILE: docs/contributing.rst
================================================
.. include:: ../CONTRIBUTING.rst
================================================
FILE: docs/environment.yml
================================================
name: inferno_docs
channels:
- soumith
- anaconda
dependencies:
- python==3.5
- pytorch>=0.1.12
- torchvision
- scikit-image
- pip:
- scipy>=0.13.0
- h5py
- scikit-image
- pyyaml
- dill
- sphinx-gallery
- sphinxcontrib-napoleon
- sphinxcontrib-bibtex
- sphinxcontrib-inlinesyntaxhighlight
================================================
FILE: docs/examples.rst
================================================
.. _inferno_examples_gallery:
Inferno Examples Gallery
============================
.. toctree::
:maxdepth: 5
../auto_examples/index
================================================
FILE: docs/history.rst
================================================
.. include:: ../HISTORY.rst
================================================
FILE: docs/index.rst
================================================
Welcome to inferno's documentation!
======================================
Contents:
.. toctree::
:maxdepth: 1
readme
installation
usage
examples
contributing
inferno-apidoc/modules
authors
history
zbibliography
.. automodule:: inferno
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
================================================
FILE: docs/inferno-apidoc/inferno.extensions.containers.rst
================================================
inferno.extensions.containers package
=====================================
Submodules
----------
inferno.extensions.containers.graph module
------------------------------------------
.. automodule:: inferno.extensions.containers.graph
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.containers.sequential module
-----------------------------------------------
.. automodule:: inferno.extensions.containers.sequential
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.containers
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.criteria.rst
================================================
inferno.extensions.criteria package
===================================
Submodules
----------
inferno.extensions.criteria.core module
---------------------------------------
.. automodule:: inferno.extensions.criteria.core
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.criteria.elementwise\_measures module
--------------------------------------------------------
.. automodule:: inferno.extensions.criteria.elementwise_measures
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.criteria.regularized module
----------------------------------------------
.. automodule:: inferno.extensions.criteria.regularized
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.criteria.set\_similarity\_measures module
------------------------------------------------------------
.. automodule:: inferno.extensions.criteria.set_similarity_measures
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.criteria
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.initializers.rst
================================================
inferno.extensions.initializers package
=======================================
Submodules
----------
inferno.extensions.initializers.base module
-------------------------------------------
.. automodule:: inferno.extensions.initializers.base
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.initializers.presets module
----------------------------------------------
.. automodule:: inferno.extensions.initializers.presets
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.initializers
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.layers.rst
================================================
inferno.extensions.layers package
=================================
Submodules
----------
inferno.extensions.layers.activations module
--------------------------------------------
.. automodule:: inferno.extensions.layers.activations
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.building\_blocks module
-------------------------------------------------
.. automodule:: inferno.extensions.layers.building_blocks
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.convolutional module
----------------------------------------------
.. automodule:: inferno.extensions.layers.convolutional
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.device module
---------------------------------------
.. automodule:: inferno.extensions.layers.device
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.identity module
-----------------------------------------
.. automodule:: inferno.extensions.layers.identity
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.prefab module
---------------------------------------
.. automodule:: inferno.extensions.layers.prefab
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.res\_unet module
------------------------------------------
.. automodule:: inferno.extensions.layers.res_unet
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.reshape module
----------------------------------------
.. automodule:: inferno.extensions.layers.reshape
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.sampling module
-----------------------------------------
.. automodule:: inferno.extensions.layers.sampling
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.layers.unet\_base module
-------------------------------------------
.. automodule:: inferno.extensions.layers.unet_base
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.layers
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.metrics.rst
================================================
inferno.extensions.metrics package
==================================
Submodules
----------
inferno.extensions.metrics.arand module
---------------------------------------
.. automodule:: inferno.extensions.metrics.arand
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.metrics.base module
--------------------------------------
.. automodule:: inferno.extensions.metrics.base
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.metrics.categorical module
---------------------------------------------
.. automodule:: inferno.extensions.metrics.categorical
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.metrics.cremi\_score module
----------------------------------------------
.. automodule:: inferno.extensions.metrics.cremi_score
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.metrics.voi module
-------------------------------------
.. automodule:: inferno.extensions.metrics.voi
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.metrics
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.optimizers.rst
================================================
inferno.extensions.optimizers package
=====================================
Submodules
----------
inferno.extensions.optimizers.adam module
-----------------------------------------
.. automodule:: inferno.extensions.optimizers.adam
:members:
:undoc-members:
:show-inheritance:
inferno.extensions.optimizers.annealed\_adam module
---------------------------------------------------
.. automodule:: inferno.extensions.optimizers.annealed_adam
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.extensions.optimizers
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.extensions.rst
================================================
inferno.extensions package
==========================
Subpackages
-----------
.. toctree::
inferno.extensions.containers
inferno.extensions.criteria
inferno.extensions.initializers
inferno.extensions.layers
inferno.extensions.metrics
inferno.extensions.optimizers
Module contents
---------------
.. automodule:: inferno.extensions
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.io.box.rst
================================================
inferno.io.box package
======================
Submodules
----------
inferno.io.box.binary\_blobs module
-----------------------------------
.. automodule:: inferno.io.box.binary_blobs
:members:
:undoc-members:
:show-inheritance:
inferno.io.box.camvid module
----------------------------
.. automodule:: inferno.io.box.camvid
:members:
:undoc-members:
:show-inheritance:
inferno.io.box.cifar module
---------------------------
.. automodule:: inferno.io.box.cifar
:members:
:undoc-members:
:show-inheritance:
inferno.io.box.cityscapes module
--------------------------------
.. automodule:: inferno.io.box.cityscapes
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.io.box
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.io.core.rst
================================================
inferno.io.core package
=======================
Submodules
----------
inferno.io.core.base module
---------------------------
.. automodule:: inferno.io.core.base
:members:
:undoc-members:
:show-inheritance:
inferno.io.core.concatenate module
----------------------------------
.. automodule:: inferno.io.core.concatenate
:members:
:undoc-members:
:show-inheritance:
inferno.io.core.data\_utils module
----------------------------------
.. automodule:: inferno.io.core.data_utils
:members:
:undoc-members:
:show-inheritance:
inferno.io.core.zip module
--------------------------
.. automodule:: inferno.io.core.zip
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.io.core
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.io.rst
================================================
inferno.io package
==================
Subpackages
-----------
.. toctree::
inferno.io.box
inferno.io.core
inferno.io.transform
inferno.io.volumetric
Module contents
---------------
.. automodule:: inferno.io
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.io.transform.rst
================================================
inferno.io.transform package
============================
Submodules
----------
inferno.io.transform.base module
--------------------------------
.. automodule:: inferno.io.transform.base
:members:
:undoc-members:
:show-inheritance:
inferno.io.transform.generic module
-----------------------------------
.. automodule:: inferno.io.transform.generic
:members:
:undoc-members:
:show-inheritance:
inferno.io.transform.image module
---------------------------------
.. automodule:: inferno.io.transform.image
:members:
:undoc-members:
:show-inheritance:
inferno.io.transform.volume module
----------------------------------
.. automodule:: inferno.io.transform.volume
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.io.transform
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.io.volumetric.rst
================================================
inferno.io.volumetric package
=============================
Submodules
----------
inferno.io.volumetric.lazy\_volume\_loader module
-------------------------------------------------
.. automodule:: inferno.io.volumetric.lazy_volume_loader
:members:
:undoc-members:
:show-inheritance:
inferno.io.volumetric.volume module
-----------------------------------
.. automodule:: inferno.io.volumetric.volume
:members:
:undoc-members:
:show-inheritance:
inferno.io.volumetric.volumetric\_utils module
----------------------------------------------
.. automodule:: inferno.io.volumetric.volumetric_utils
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.io.volumetric
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.rst
================================================
inferno package
===============
Subpackages
-----------
.. toctree::
inferno.extensions
inferno.io
inferno.trainers
inferno.utils
Submodules
----------
inferno.inferno module
----------------------
.. automodule:: inferno.inferno
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst
================================================
inferno.trainers.callbacks.logging package
==========================================
Submodules
----------
inferno.trainers.callbacks.logging.base module
----------------------------------------------
.. automodule:: inferno.trainers.callbacks.logging.base
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.logging.tensorboard module
-----------------------------------------------------
.. automodule:: inferno.trainers.callbacks.logging.tensorboard
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.trainers.callbacks.logging
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.trainers.callbacks.rst
================================================
inferno.trainers.callbacks package
==================================
Subpackages
-----------
.. toctree::
inferno.trainers.callbacks.logging
Submodules
----------
inferno.trainers.callbacks.base module
--------------------------------------
.. automodule:: inferno.trainers.callbacks.base
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.console module
-----------------------------------------
.. automodule:: inferno.trainers.callbacks.console
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.essentials module
--------------------------------------------
.. automodule:: inferno.trainers.callbacks.essentials
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.scheduling module
--------------------------------------------
.. automodule:: inferno.trainers.callbacks.scheduling
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.tqdm module
--------------------------------------
.. automodule:: inferno.trainers.callbacks.tqdm
:members:
:undoc-members:
:show-inheritance:
inferno.trainers.callbacks.tqdmstub module
------------------------------------------
.. automodule:: inferno.trainers.callbacks.tqdmstub
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.trainers.callbacks
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.trainers.rst
================================================
inferno.trainers package
========================
Subpackages
-----------
.. toctree::
inferno.trainers.callbacks
Submodules
----------
inferno.trainers.basic module
-----------------------------
.. automodule:: inferno.trainers.basic
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.trainers
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/inferno.utils.rst
================================================
inferno.utils package
=====================
Submodules
----------
inferno.utils.exceptions module
-------------------------------
.. automodule:: inferno.utils.exceptions
:members:
:undoc-members:
:show-inheritance:
inferno.utils.io\_utils module
------------------------------
.. automodule:: inferno.utils.io_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.math\_utils module
--------------------------------
.. automodule:: inferno.utils.math_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.model\_utils module
---------------------------------
.. automodule:: inferno.utils.model_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.python\_utils module
----------------------------------
.. automodule:: inferno.utils.python_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.test\_utils module
--------------------------------
.. automodule:: inferno.utils.test_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.torch\_utils module
---------------------------------
.. automodule:: inferno.utils.torch_utils
:members:
:undoc-members:
:show-inheritance:
inferno.utils.train\_utils module
---------------------------------
.. automodule:: inferno.utils.train_utils
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: inferno.utils
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/inferno-apidoc/modules.rst
================================================
inferno
=======
.. toctree::
:maxdepth: 4
inferno
================================================
FILE: docs/installation.rst
================================================
.. highlight:: shell
==================================
Installation
==================================
Install on Linux and OSX
------------------------
Developers
~~~~~~~~~~~~~~~~~~~~~~
First, make sure `you have Pytorch installed <http://pytorch.org/>`_.
Then, clone this repository with:
.. code:: python
$ git clone https://github.com/nasimrahaman/inferno.git
Next, install the dependencies.
.. code:: python
$ cd inferno
$ pip install -r requirements.txt
If you use python from the shell:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Finally, add *inferno* to your `PYTHONPATH` with:
.. code:: python
source add2path.sh
If you use PyCharm:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Refer to this `QA <https://askubuntu.com/questions/684550/importing-a-python-module-works-from-command-line-but-not-from-pycharm>`_ about setting up paths with Pycharm.
======================================================
Installation via PyPi / pip / setup.py(Experimental)
======================================================
You need to install pytorch via pip before installing
inferno. Follow the `pytorch installation guide`_.
Stable release
--------------
To install inferno, run this command in your terminal:
.. code-block:: console
$ pip install inferno-pytorch
This is the preferred method to install inferno, as it will always install the most recent stable release.
If you don't have `pip`_ installed, this `Python installation guide`_ can guide
you through the process.
.. _pip: https://pip.pypa.io
.. _Python installation guide: http://docs.python-guide.org/en/latest/starting/installation/
.. _pytorch installation guide: http://pytorch.org/
From sources
------------------------
First, make sure `you have Pytorch installed <http://pytorch.org/>`_.
The sources for inferno can be downloaded from the `Github repo`_.
You can either clone the public repository:
.. code-block:: console
$ git clone git://github.com/nasimrahaman/inferno
Or download the `tarball`_:
.. code-block:: console
$ curl -OL https://github.com/nasimrahaman/inferno/tarball/master
Once you have a copy of the source, you can install it with:
.. code-block:: console
$ python setup.py install
.. _Github repo: https://github.com/nasimrahaman/inferno
.. _tarball: https://github.com/nasimrahaman/inferno/tarball/master
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set BUILDDIR=_build
set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
set I18NSPHINXOPTS=%SPHINXOPTS% .
if NOT "%PAPER%" == "" (
set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
)
if "%1" == "" goto help
if "%1" == "help" (
:help
echo.Please use `make ^<target^>` where ^<target^> is one of
echo. html to make standalone HTML files
echo. dirhtml to make HTML files named index.html in directories
echo. singlehtml to make a single large HTML file
echo. pickle to make pickle files
echo. json to make JSON files
echo. htmlhelp to make HTML files and a HTML help project
echo. qthelp to make HTML files and a qthelp project
echo. devhelp to make HTML files and a Devhelp project
echo. epub to make an epub
echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
echo. text to make text files
echo. man to make manual pages
echo. texinfo to make Texinfo files
echo. gettext to make PO message catalogs
echo. changes to make an overview over all changed/added/deprecated items
echo. xml to make Docutils-native XML files
echo. pseudoxml to make pseudoxml-XML files for display purposes
echo. linkcheck to check all external links for integrity
echo. doctest to run all doctests embedded in the documentation if enabled
goto end
)
if "%1" == "clean" (
for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
del /q /s %BUILDDIR%\*
goto end
)
%SPHINXBUILD% 2> nul
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
if "%1" == "html" (
%SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/html.
goto end
)
if "%1" == "dirhtml" (
%SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
goto end
)
if "%1" == "singlehtml" (
%SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
goto end
)
if "%1" == "pickle" (
%SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the pickle files.
goto end
)
if "%1" == "json" (
%SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can process the JSON files.
goto end
)
if "%1" == "htmlhelp" (
%SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can run HTML Help Workshop with the ^
.hhp project file in %BUILDDIR%/htmlhelp.
goto end
)
if "%1" == "qthelp" (
%SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished; now you can run "qcollectiongenerator" with the ^
.qhcp project file in %BUILDDIR%/qthelp, like this:
echo.^> qcollectiongenerator %BUILDDIR%\qthelp\inferno.qhcp
echo.To view the help file:
echo.^> assistant -collectionFile %BUILDDIR%\qthelp\inferno.ghc
goto end
)
if "%1" == "devhelp" (
%SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
if errorlevel 1 exit /b 1
echo.
echo.Build finished.
goto end
)
if "%1" == "epub" (
%SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The epub file is in %BUILDDIR%/epub.
goto end
)
if "%1" == "latex" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
if errorlevel 1 exit /b 1
echo.
echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
goto end
)
if "%1" == "latexpdf" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
cd %BUILDDIR%/latex
make all-pdf
cd %BUILDDIR%/..
echo.
echo.Build finished; the PDF files are in %BUILDDIR%/latex.
goto end
)
if "%1" == "latexpdfja" (
%SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
cd %BUILDDIR%/latex
make all-pdf-ja
cd %BUILDDIR%/..
echo.
echo.Build finished; the PDF files are in %BUILDDIR%/latex.
goto end
)
if "%1" == "text" (
%SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The text files are in %BUILDDIR%/text.
goto end
)
if "%1" == "man" (
%SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The manual pages are in %BUILDDIR%/man.
goto end
)
if "%1" == "texinfo" (
%SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
goto end
)
if "%1" == "gettext" (
%SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
goto end
)
if "%1" == "changes" (
%SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
if errorlevel 1 exit /b 1
echo.
echo.The overview file is in %BUILDDIR%/changes.
goto end
)
if "%1" == "linkcheck" (
%SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
if errorlevel 1 exit /b 1
echo.
echo.Link check complete; look for any errors in the above output ^
or in %BUILDDIR%/linkcheck/output.txt.
goto end
)
if "%1" == "doctest" (
%SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
if errorlevel 1 exit /b 1
echo.
echo.Testing of doctests in the sources finished, look at the ^
results in %BUILDDIR%/doctest/output.txt.
goto end
)
if "%1" == "xml" (
%SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The XML files are in %BUILDDIR%/xml.
goto end
)
if "%1" == "pseudoxml" (
%SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
if errorlevel 1 exit /b 1
echo.
echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
goto end
)
:end
================================================
FILE: docs/readme.rst
================================================
.. include:: ../README.rst
================================================
FILE: docs/refs.bib
================================================
@inproceedings{alush_2013_simbad,
title={Break and Conquer: Efficient Correlation Clustering for Image Segmentation},
author={Alush, Amir and Goldberger, Jacob},
booktitle={2nd International Workshop on Similarity-Based Pattern Analysis and Recognition},
year={2013}
}
================================================
FILE: docs/usage.rst
================================================
=====
Usage
=====
Inferno is a utility library built around [PyTorch](http://pytorch.org/), designed to help you train and even build complex pytorch models. And in this tutorial, we'll see how! If you're new to PyTorch, I highly recommended you work through the [Pytorch tutorials](http://pytorch.org/tutorials/) first.
Building a PyTorch Model
~~~~~~~~~~~~~~~~~~~~~~~~~~
Inferno's training machinery works with just about any valid [Pytorch module](http://pytorch.org/docs/master/nn.html#torch.nn.Module). However, to make things even easier, we also provide pre-configured layers that work out-of-the-box. Let's use them to build a convolutional neural network for Cifar-10.
.. code:: python
import torch.nn as nn
from inferno.extensions.layers.convolutional import ConvELU2D
from inferno.extensions.layers.reshape import Flatten
`ConvELU2D` is a 2-dimensional convolutional layer with orthogonal weight initialization and [ELU](http://pytorch.org/docs/master/nn.html#torch.nn.ELU) activation. `Flatten` reshapes the 4 dimensional activation tensor to a matrix. Let's use the Sequential container to chain together a bunch of convolutional and pooling layers, followed by a linear and softmax layer.
.. code:: python
model = nn.Sequential(
ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(in_features=(256 * 4 * 4), out_features=10),
nn.Softmax()
)
Models this size don't win competitions anymore, but it'll do for our purpose.
Data Logistics
**************************
With our model built, it's time to worry about the data generators. Or is it?
.. code:: python
from inferno.io.box.cifar import get_cifar10_loaders
train_loader, validate_loader = get_cifar10_loaders('path/to/cifar10',
download=True,
train_batch_size=128,
test_batch_size=100)
CIFAR-10 works out-of-the-`box` (pun very much intended) with all the fancy data-augmentation and normalization. Of course, it's perfectly fine if you have your own [`DataLoader`](http://pytorch.org/docs/master/data.html#torch.utils.data.DataLoader).
Preparing the Trainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
With our model and data loaders good to go, it's finally time to build the trainer. To start, let's initialize one.
.. code:: python
from inferno.trainers.basic import Trainer
trainer = Trainer(model)
# Tell trainer about the data loaders
trainer.bind_loader('train', train_loader).bind_loader('validate', validate_loader)
Now to the things we could do with it.
Setting up Checkpointing
***************************************
When training a model for days, it's usually a good idea to store the current training state to disk every once in a while. To set this up, we tell `trainer` where to store these *checkpoints* and how often.
.. code:: python
trainer.save_to_directory('path/to/save/directory').save_every((25, 'epochs'))
So we're saving once every 25 epochs. But what if an epoch takes forever, and you don't wish to wait that long?
.. code:: python
trainer.save_every((1000, 'iterations'))
In this setting, you're saving once every 1000 iterations (= batches). But we might also want to create a checkpoint when the validation score is the best. Easy as 1, 2,
.. code:: python
trainer.save_at_best_validation_score()
Remember that a checkpoint contains the entire training state, and not just the model. Everything is included in the checkpoint file, including optimizer, criterion, and callbacks but __not the data loaders__.
Setting up Validation
**************************
Let's say you wish to validate once every 2 epochs.
.. code:: python
trainer.validate_every((2, 'epochs'))
To be able to validate, you'll need to specify a validation metric.
.. code:: python
trainer.build_metric('CategoricalError')
Inferno looks for a metric `'CategoricalError'` in `inferno.extensions.metrics`. To specify your own metric, subclass `inferno.extensions.metrics.base.Metric` and implement the `forward` method. With that done, you could:
.. code:: python
trainer.build_metric(MyMetric)
or
.. code:: python
trainer.build_metric(MyMetric, **my_metric_kwargs)
A metric might be way too expensive to evaluate every training iteration without slowing down the training. If this is the case and you'd like to evaluate the metric every (say) 10 *training* iterations:
.. code:: python
trainer.evaluate_metric_every((10, 'iterations'))
However, while validating, the metric is evaluated once every iteration.
Setting up the Criterion and Optimizer
***************************************
With that out of the way, let's set up a training criterion and an optimizer.
.. code:: python
# set up the criterion
trainer.build_criterion('CrossEntropyLoss')
The `trainer` looks for a `'CrossEntropyLoss'` in `torch.nn`, which it finds. But any of the following would have worked:
.. code:: python
trainer.build_criterion(nn.CrossEntropyLoss)
or
.. code:: python
trainer.build_criterion(nn.CrossEntropyLoss())
What this means is that if you have your own loss criterion that has the same API as any of the criteria found in `torch.nn`, you should be fine by just plugging it in.
The same holds for the optimizer:
.. code:: python
trainer.build_optimizer('Adam', weight_decay=0.0005)
Like for criteria, the `trainer` looks for a `'Adam'` in `torch.optim` (among other places), and initializes it with `model`'s parameters. Any keywords you might use for `torch.optim.Adam`, you could pass them to the `build_optimizer` method.
Or alternatively, you could use:
.. code:: python
from torch.optim import Adam
trainer.build_optimizer(Adam, weight_decay=0.0005)
If you implemented your own optimizer (by subclassing `torch.optim.Optimizer`), you should be able to use it instead of `Adam`. Alternatively, if you already have an optimizer *instance*, you could do:
.. code:: python
optimizer = MyOptimizer(model.parameters(), **optimizer_kwargs)
trainer.build_optimizer(optimizer)
Setting up Training Duration
********************************
You probably don't want to train forever, in which case you must specify:
.. code:: python
trainer.set_max_num_epochs(100)
or
.. code:: python
trainer.set_max_num_iterations(10000)
If you like to train indefinitely (or until you're happy with the results), use:
.. code:: python
trainer.set_max_num_iterations('inf')
In this case, you'll need to interrupt the training manually with a `KeyboardInterrupt`.
Setting up Callbacks
*********************
Callbacks are pretty handy when it comes to interacting with the `Trainer`. More precisely: `Trainer` defines a number of events as 'triggers' for callbacks. Currently, these are:
.. code:: python
BEGIN_OF_FIT,
END_OF_FIT,
BEGIN_OF_TRAINING_RUN,
END_OF_TRAINING_RUN,
BEGIN_OF_EPOCH,
END_OF_EPOCH,
BEGIN_OF_TRAINING_ITERATION,
END_OF_TRAINING_ITERATION,
BEGIN_OF_VALIDATION_RUN,
END_OF_VALIDATION_RUN,
BEGIN_OF_VALIDATION_ITERATION,
END_OF_VALIDATION_ITERATION,
BEGIN_OF_SAVE,
END_OF_SAVE
As an example, let's build a simple callback to interrupt the training on NaNs. We check at the end of every training iteration whether the training loss is NaN, and accordingly raise a `RuntimeError`.
.. code:: python
import numpy as np
from inferno.trainers.callbacks.base import Callback
class NaNDetector(Callback):
def end_of_training_iteration(self, **_):
# The callback object has the trainer as an attribute.
# The trainer populates its 'states' with torch tensors (NOT VARIABLES!)
training_loss = self.trainer.get_state('training_loss')
# Extract float from torch tensor
training_loss = training_loss[0]
if np.isnan(training_loss):
raise RuntimeError("NaNs detected!")
With the callback defined, all we need to do is register it with the trainer:
.. code:: python
trainer.register_callback(NaNDetector())
So the next time you get `RuntimeError: "NaNs detected!`, you know the drill.
Using Tensorboard
**************************
Inferno supports logging scalars and images to Tensorboard out-of-the-box, though this requires you have at least [tensorflow-cpu](https://github.com/tensorflow/tensorflow) installed. Let's say you want to log scalars every iteration and images every 20 iterations:
.. code:: python
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every=(20, 'iterations')),
log_directory='/path/to/log/directory')
After you've started training, use a bash shell to fire up tensorboard with:
.. code:: bash
$ tensorboard --logdir=/path/to/log/directory --port=6007
and navigate to `localhost:6007` with your favorite browser.
Fine print: missing the `log_images_every` keyword argument to `TensorboardLogger` will result in images being logged every iteration. If you don't have a fast hard drive, this might actually slow down the training. To not log images, just use `log_images_every='never'`.
Using GPUs
*************
To use just one GPU:
.. code:: python
trainer.cuda()
For multi-GPU data-parallel training, simply pass `trainer.cuda` a list of devices:
.. code:: python
trainer.cuda(devices=[0, 1, 2, 3])
__Pro-tip__: Say you only want to use GPUs 0, 3, 5 and 7 (your colleagues might love you for this). Before running your training script, simply:
.. code:: bash
$ export CUDA_VISIBLE_DEVICES=0,3,5,7
$ python train.py
This maps device 0 to 0, 3 to 1, 5 to 2 and 7 to 3.
One more thing
**************************
Once you have everything configured, use
.. code:: python
trainer.fit()
to commence training! This last step is kinda important. :wink:
Cherries:
~~~~~~~~~~~~~~~~~~~~~~
Building Complex Models with the Graph API
****************************************************
Work in Progress:
Parameter Initialization
**************************
Work in Progress:
Support
*************
Work in Progress:
================================================
FILE: docs/zbibliography.rst
================================================
.. _inferno_bibliography:
Bibliography
============================
The bibliography:
.. bibliography:: refs.bib
:style: alpha
================================================
FILE: examples/README.txt
================================================
.. _examples-index:
Gallery of Examples
===================
================================================
FILE: examples/plot_cheap_unet.py
================================================
"""
UNet Tutorial
================================
A unet example which can be run without a gpu
"""
##############################################################################
# Preface
# --------------
# We start with some unspectacular multi purpose imports needed for this example
import matplotlib.pyplot as plt
import torch
from torch import nn
import numpy
##############################################################################
# determine whether we have a gpu
# and should use cuda
USE_CUDA = torch.cuda.is_available()
##############################################################################
# Dataset
# --------------
# For simplicity we will use a toy dataset where we need to perform
# a binary segmentation task.
from inferno.io.box.binary_blobs import get_binary_blob_loaders
# convert labels from long to float as needed by
# binary cross entropy loss
def label_transform(x):
return torch.from_numpy(x).float()
#label_transform = lambda x : torch.from_numpy(x).float()
train_loader, test_loader, validate_loader = get_binary_blob_loaders(
size=8, # how many images per {train,test,validate}
train_batch_size=2,
length=256, # <= size of the images
gaussian_noise_sigma=1.4, # <= how noise are the images
train_label_transform = label_transform,
validate_label_transform = label_transform
)
image_channels = 1 # <-- number of channels of the image
pred_channels = 1 # <-- number of channels needed for the prediction
if False:
##############################################################################
# Visualize Dataset
# ~~~~~~~~~~~~~~~~~~~~~~
fig = plt.figure()
for i,(image, target) in enumerate(train_loader):
ax = fig.add_subplot(1, 2, 1)
ax.imshow(image[0,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(1, 2, 2)
ax.imshow(target[0,...])
ax.set_title('ground truth')
break
fig.tight_layout()
plt.show()
##############################################################################
# Training
# ----------------------------
# To train the unet, we use the infernos Trainer class of inferno.
# Since we train many models later on in this example we encapsulate
# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for
# an example dedicated to the trainer itself).
from inferno.trainers import Trainer
from inferno.utils.python_utils import ensure_dir
def train_model(model, loaders, **kwargs):
trainer = Trainer(model)
trainer.build_criterion('BCEWithLogitsLoss')
trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))
#trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
#trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
#trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 20))
# bind the loaders
trainer.bind_loader('train', loaders[0])
trainer.bind_loader('validate', loaders[1])
if USE_CUDA:
trainer.cuda()
# do the training
trainer.fit()
return trainer
##############################################################################
# Prediction
# ----------------------------
# The trainer contains the trained model and we can do predictions.
# We use :code:`unwrap` to convert the results to numpy arrays.
# Since we want to do many prediction we encapsulate the
# the prediction in a function
from inferno.utils.torch_utils import unwrap
def predict(trainer, test_loader, save_dir=None):
trainer.eval_mode()
for image, target in test_loader:
# transfer image to gpu
image = image.cuda() if USE_CUDA else image
# get batch size from image
batch_size = image.size()[0]
for b in range(batch_size):
prediction = trainer.apply_model(image)
prediction = torch.nn.functional.sigmoid(prediction)
image = unwrap(image, as_numpy=True, to_cpu=True)
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
target = unwrap(target, as_numpy=True, to_cpu=True)
fig = plt.figure()
ax = fig.add_subplot(2, 2, 1)
ax.imshow(image[b,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(2, 2, 2)
ax.imshow(target[b,...])
ax.set_title('ground truth')
ax = fig.add_subplot(2, 2, 4)
ax.imshow(prediction[b,...])
ax.set_title('prediction')
fig.tight_layout()
plt.show()
##############################################################################
# Custom UNet
# ----------------------------
# Often one needs to have a UNet with custom layers.
# Here we show how to implement such a customized UNet.
# To this end we derive from :code:`UNetBase`.
# For the sake of this example we will create
# a Unet which uses depthwise convolutions and might be trained on a CPU
from inferno.extensions.models import UNetBase
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D,ConvActivation
class CheapConv(nn.Module):
def __init__(self, in_channels, out_channels, activated):
super(CheapConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
if activated:
self.convs = torch.nn.Sequential(
ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),
ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
)
else:
self.convs = torch.nn.Sequential(
ConvActivation(in_channels=in_channels, out_channels=in_channels, depthwise=True, kernel_size=(3, 3), activation='ReLU', dim=2),
Conv2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
)
def forward(self, x):
assert x.shape[1] == self.in_channels,"input has wrong number of channels"
x = self.convs(x)
assert x.shape[1] == self.out_channels,"output has wrong number of channels"
return x
class CheapConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, activated):
super(CheapConvBlock, self).__init__()
self.activated = activated
self.in_channels = in_channels
self.out_channels = out_channels
if(in_channels != out_channels):
self.start = ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1))
else:
self.start = None
self.conv_a = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=True)
self.conv_b = CheapConv(in_channels=out_channels, out_channels=out_channels, activated=False)
self.activation = torch.nn.ReLU()
def forward(self, x):
x_input = x
if self.start is not None:
x_input = self.start(x_input)
x = self.conv_a(x_input)
x = self.conv_b(x)
x = x + x_input
if self.activated:
x = self.activation(x)
return x
class MySimple2DCpUnet(UNetBase):
def __init__(self, in_channels, out_channels, depth=3, residual=False, **kwargs):
super(MySimple2DCpUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,
dim=2, depth=depth, **kwargs)
def conv_op_factory(self, in_channels, out_channels, part, index):
# last?
last = part == 'up' and index==0
return CheapConvBlock(in_channels=in_channels, out_channels=out_channels, activated=not last),False
from inferno.extensions.layers import RemoveSingletonDimension
model_b = torch.nn.Sequential(
CheapConv(in_channels=image_channels, out_channels=4, activated=True),
MySimple2DCpUnet(in_channels=4, out_channels=pred_channels) ,
RemoveSingletonDimension(dim=1)
)
###################################################
# do the training (with the same functions as before)
trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)
###################################################
# do the training (with the same functions as before)1
predict(trainer=trainer, test_loader=test_loader)
================================================
FILE: examples/plot_train_side_loss_unet.py
================================================
"""
Train Side Loss UNet Example
================================
In this example a UNet with side supervision
and auxiliary loss implemented
"""
##############################################################################
# Imports needed for this example
import torch
import torch.nn as nn
from inferno.io.box.binary_blobs import get_binary_blob_loaders
from inferno.trainers.basic import Trainer
from inferno.extensions.layers.convolutional import Conv2D
from inferno.extensions.models.res_unet import _ResBlock as ResBlock
from inferno.extensions.models import ResBlockUNet
from inferno.utils.torch_utils import unwrap
from inferno.utils.python_utils import ensure_dir
import pylab
##############################################################################
# To create a UNet with side loss we create a new nn.Module class
# which has a ResBlockUNet as member.
# The ResBlockUNet is configured such that the results of the
# bottom convolution and all the results of the up-stream
# convolutions are returned as (side)-output.
# a 1x1 convolutions is used to give the side outputs
# the right number of out_channels and UpSampling is
# used to resize all side-outputs to the full resolution
# of the input. These side `side-predictions` are
# returned by our MySideLossUNet.
# Furthermore, all `side-predictions` are concatenated
# and feed trough another two residual blocks to make
# the final prediction.
class MySideLossUNet(nn.Module):
def __init__(self, in_channels, out_channels, depth=3):
super(MySideLossUNet, self).__init__()
self.depth = depth
self.unet = ResBlockUNet(in_channels=in_channels, out_channels=in_channels*2,
dim=2, unet_kwargs=dict(depth=depth),
side_out_parts=['bottom', 'up'])
# number of out channels
self.n_channels_per_output = self.unet.n_channels_per_output
# 1x1 conv to give the side outs of the unet
# the right number of channels
# and a Upsampling to give the right shape
upscale_factor = 2**self.depth
conv_and_scale = []
for n_channels in self.n_channels_per_output:
# conv blocks
conv = Conv2D(in_channels=n_channels, out_channels=out_channels, kernel_size=1)
if upscale_factor > 1:
upsample = nn.Upsample(scale_factor=upscale_factor)
conv_and_scale.append(nn.Sequential(conv, upsample))
else:
conv_and_scale.append(conv)
upscale_factor //= 2
self.conv_and_scale = nn.ModuleList(conv_and_scale)
# combined number of channels after concat
# concat side output predictions with main output of unet
self.n_channels_combined = (self.depth + 1)* out_channels + in_channels*2
self.final_block = nn.Sequential(
ResBlock(dim=2,in_channels=self.n_channels_combined, out_channels=self.n_channels_combined),
ResBlock(in_channels=self.n_channels_combined, out_channels=out_channels,
dim=2, activated=False),
)
def forward(self, input):
outs = self.unet(input)
assert len(outs) == len(self.n_channels_per_output)
# convert the unet output into the right number of
preds = [None] * len(outs)
for i,out in enumerate(outs):
preds[i] = self.conv_and_scale[i](out)
# this is the side output
preds = tuple(preds)
# concat side output predictions with main output of unet
combined = torch.cat(preds + (outs[-1],), 1)
final_res = self.final_block(combined)
# return everything
return preds + (final_res,)
##############################################################################
# We use a custom loss functions which applied CrossEntropyLoss
# to all side outputs.
# The side outputs are weighted in a quadratic fashion and added up
# into a single value
class MySideLoss(nn.Module):
"""Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.
"""
def __init__(self):
super(MySideLoss, self).__init__()
self.criterion = nn.CrossEntropyLoss(reduce=True)
w = 1.0
l = None
def forward(self, predictions, target):
w = 1.0
l = None
for p in predictions:
ll = self.criterion(p, target)*w
if l is None:
l = ll
else:
l += ll
w *= 2
return l
##############################################################################
# Training boilerplate (see :ref:`sphx_glr_auto_examples_trainer.py`)
LOG_DIRECTORY = ensure_dir('log')
SAVE_DIRECTORY = ensure_dir('save')
DATASET_DIRECTORY = ensure_dir('dataset')
USE_CUDA = torch.cuda.is_available()
# Build a residual unet where the last layer is not activated
sl_unet = MySideLossUNet(in_channels=5, out_channels=2)
model = nn.Sequential(
ResBlock(dim=2, in_channels=1, out_channels=5),
sl_unet
)
train_loader, test_loader, validate_loader = get_binary_blob_loaders(
train_batch_size=3,
length=512, # <= size of the images
gaussian_noise_sigma=1.5 # <= how noise are the images
)
# Build trainer
trainer = Trainer(model)
trainer.build_criterion(MySideLoss())
trainer.build_optimizer('Adam')
trainer.validate_every((10, 'epochs'))
#trainer.save_every((10, 'epochs'))
#trainer.save_to_directory(SAVE_DIRECTORY)
trainer.set_max_num_epochs(40)
# Bind loaders
trainer \
.bind_loader('train', train_loader)\
.bind_loader('validate', validate_loader)
if USE_CUDA:
trainer.cuda()
# Go!
trainer.fit()
##############################################################################
# Predict with the trained network
# and visualize the results
# predict:
#trainer.load(best=True)
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
trainer.eval_mode()
if USE_CUDA:
trainer.cuda()
# look at an example
for img,target in test_loader:
if USE_CUDA:
img = img.cuda()
# softmax on each of the prediction
preds = trainer.apply_model(img)
preds = [nn.functional.softmax(pred,dim=1) for pred in preds]
preds = [unwrap(pred, as_numpy=True, to_cpu=True) for pred in preds]
img = unwrap(img, as_numpy=True, to_cpu=True)
target = unwrap(target, as_numpy=True, to_cpu=True)
n_plots = len(preds) + 2
batch_size = preds[0].shape[0]
for b in range(batch_size):
fig = pylab.figure()
ax1 = fig.add_subplot(2,4,1)
ax1.set_title('image')
ax1.imshow(img[b,0,...])
ax2 = fig.add_subplot(2,4,2)
ax2.set_title('ground truth')
ax2.imshow(target[b,...])
for i,pred in enumerate(preds):
axn = fig.add_subplot(2,4, 3+i)
axn.imshow(pred[b,1,...])
if i + 1 < len(preds):
axn.set_title('side prediction %d'%i)
else:
axn.set_title('combined prediction')
pylab.show()
break
================================================
FILE: examples/plot_unet_tutorial.py
================================================
"""
UNet Tutorial
================================
A tentative tutorial on the usage
of the unet framework in inferno
"""
##############################################################################
# Preface
# --------------
# We start with some unspectacular multi purpose imports needed for this example
import matplotlib.pyplot as plt
import torch
import numpy
##############################################################################
# determine whether we have a gpu
# and should use cuda
USE_CUDA = torch.cuda.is_available()
##############################################################################
# Dataset
# --------------
# For simplicity we will use a toy dataset where we need to perform
# a binary segmentation task.
from inferno.io.box.binary_blobs import get_binary_blob_loaders
# convert labels from long to float as needed by
# binary cross entropy loss
def label_transform(x):
return torch.from_numpy(x).float()
#label_transform = lambda x : torch.from_numpy(x).float()
train_loader, test_loader, validate_loader = get_binary_blob_loaders(
size=8, # how many images per {train,test,validate}
train_batch_size=2,
length=256, # <= size of the images
gaussian_noise_sigma=1.4, # <= how noise are the images
train_label_transform = label_transform,
validate_label_transform = label_transform
)
image_channels = 1 # <-- number of channels of the image
pred_channels = 1 # <-- number of channels needed for the prediction
##############################################################################
# Visualize Dataset
# ~~~~~~~~~~~~~~~~~~~~~~
fig = plt.figure()
for i,(image, target) in enumerate(train_loader):
ax = fig.add_subplot(1, 2, 1)
ax.imshow(image[0,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(1, 2, 2)
ax.imshow(target[0,...])
ax.set_title('ground truth')
break
fig.tight_layout()
plt.show()
##############################################################################
# Simple UNet
# ----------------------------
# We start with a very simple predefined
# res block UNet. By default, this UNet uses ReLUs (in conjunction with batchnorm) as nonlinearities
# With :code:`activated=False` we make sure that the last layer
# is not activated since we chain the UNet with a sigmoid
# activation function.
from inferno.extensions.models import ResBlockUNet
from inferno.extensions.layers import RemoveSingletonDimension
model = torch.nn.Sequential(
ResBlockUNet(dim=2, in_channels=image_channels, out_channels=pred_channels, activated=False),
RemoveSingletonDimension(dim=1),
torch.nn.Sigmoid()
)
##############################################################################
# while the model above will work in principal, it has some drawbacks.
# Within the UNet, the number of features is increased by a multiplicative
# factor while going down, the so-called gain. The default value for the gain is 2.
# Since we start with only a single channel we could either increase the gain,
# or use a some convolutions to increase the number of channels
# before the the UNet.
from inferno.extensions.layers import ConvReLU2D
model_a = torch.nn.Sequential(
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
ResBlockUNet(dim=2, in_channels=5, out_channels=pred_channels, activated=False,
res_block_kwargs=dict(batchnorm=True,size=2)) ,
RemoveSingletonDimension(dim=1)
# torch.nn.Sigmoid()
)
##############################################################################
# Training
# ----------------------------
# To train the unet, we use the infernos Trainer class of inferno.
# Since we train many models later on in this example we encapsulate
# the training in a function (see :ref:`sphx_glr_auto_examples_trainer.py` for
# an example dedicated to the trainer itself).
from inferno.trainers import Trainer
from inferno.utils.python_utils import ensure_dir
def train_model(model, loaders, **kwargs):
trainer = Trainer(model)
trainer.build_criterion('BCEWithLogitsLoss')
trainer.build_optimizer('Adam', lr=kwargs.get('lr', 0.0001))
#trainer.validate_every((kwargs.get('validate_every', 10), 'epochs'))
#trainer.save_every((kwargs.get('save_every', 10), 'epochs'))
#trainer.save_to_directory(ensure_dir(kwargs.get('save_dir', 'save_dor')))
trainer.set_max_num_epochs(kwargs.get('max_num_epochs', 200))
# bind the loaders
trainer.bind_loader('train', loaders[0])
trainer.bind_loader('validate', loaders[1])
if USE_CUDA:
trainer.cuda()
# do the training
trainer.fit()
return trainer
trainer = train_model(model=model_a, loaders=[train_loader, validate_loader], save_dir='model_a', lr=0.01)
##############################################################################
# Prediction
# ----------------------------
# The trainer contains the trained model and we can do predictions.
# We use :code:`unwrap` to convert the results to numpy arrays.
# Since we want to do many prediction we encapsulate the
# the prediction in a function
from inferno.utils.torch_utils import unwrap
def predict(trainer, test_loader, save_dir=None):
trainer.eval_mode()
for image, target in test_loader:
# transfer image to gpu
image = image.cuda() if USE_CUDA else image
# get batch size from image
batch_size = image.size()[0]
for b in range(batch_size):
prediction = trainer.apply_model(image)
prediction = torch.nn.functional.sigmoid(prediction)
image = unwrap(image, as_numpy=True, to_cpu=True)
prediction = unwrap(prediction, as_numpy=True, to_cpu=True)
target = unwrap(target, as_numpy=True, to_cpu=True)
fig = plt.figure()
ax = fig.add_subplot(2, 2, 1)
ax.imshow(image[b,0,...])
ax.set_title('raw data')
ax = fig.add_subplot(2, 2, 2)
ax.imshow(target[b,...])
ax.set_title('ground truth')
ax = fig.add_subplot(2, 2, 4)
ax.imshow(prediction[b,...])
ax.set_title('prediction')
fig.tight_layout()
plt.show()
###################################################
# do the prediction
predict(trainer=trainer, test_loader=test_loader)
##############################################################################
# Custom UNet
# ----------------------------
# Often one needs to have a UNet with custom layers.
# Here we show how to implement such a customized UNet.
# To this end we derive from :code:`UNetBase`.
# For the sake of this example we will create
# a rather exotic UNet which uses different types
# of convolutions/non-linearities in the different branches
# of the unet
from inferno.extensions.models import UNetBase
from inferno.extensions.layers import ConvSELU2D, ConvReLU2D, ConvELU2D, ConvSigmoid2D,Conv2D
from inferno.extensions.layers.sampling import Upsample
class MySimple2DUnet(UNetBase):
def __init__(self, in_channels, out_channels, depth=3, **kwargs):
super(MySimple2DUnet, self).__init__(in_channels=in_channels, out_channels=out_channels,
dim=2, depth=depth, **kwargs)
def conv_op_factory(self, in_channels, out_channels, part, index):
if part == 'down':
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvELU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
), False
elif part == 'bottom':
return torch.nn.Sequential(
ConvReLU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3),
), False
elif part == 'up':
# are we in the very last block?
if index == 0:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
Conv2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
), False
else:
return torch.nn.Sequential(
ConvELU2D(in_channels=in_channels, out_channels=out_channels, kernel_size=3),
ConvReLU2D(in_channels=out_channels, out_channels=out_channels, kernel_size=3)
), False
else:
raise RuntimeError("something is wrong")
# this function CAN be implemented, if not, MaxPooling is used by default
def downsample_op_factory(self, index):
return torch.nn.MaxPool2d(kernel_size=2, stride=2)
# this function CAN be implemented, if not, Upsampling is used by default
def upsample_op_factory(self, index):
return Upsample(mode='bilinear', align_corners=False,scale_factor=2)
model_b = torch.nn.Sequential(
ConvReLU2D(in_channels=image_channels, out_channels=5, kernel_size=3),
MySimple2DUnet(in_channels=5, out_channels=pred_channels) ,
RemoveSingletonDimension(dim=1)
)
###################################################
# do the training (with the same functions as before)
trainer = train_model(model=model_b, loaders=[train_loader, validate_loader], save_dir='model_b', lr=0.001)
###################################################
# do the training (with the same functions as before)
predict(trainer=trainer, test_loader=test_loader)
================================================
FILE: examples/regularized_mnist.py
================================================
"""
Regularized MNIST Example
================================
This example demonstrates adding and logging arbitrary regularization losses, in this case,
L2 activity regularization and L1 weight regularization.
- Add a `_losses` dictionary to any module containing loss names and values
- Use a criterion from `inferno.extensions.criteria.regularized` that will collect and add those losses
- Call `Trainer.observe_training_and_validation_states` to log the losses as well
"""
import argparse
import sys
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from inferno.extensions.layers.reshape import Flatten
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
class RegularizedLinear(nn.Linear):
def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs):
super(RegularizedLinear, self).__init__(*args, **kwargs)
self.ar_weight = ar_weight
self.l1_weight = l1_weight
self._losses = {}
def forward(self, input):
output = super(RegularizedLinear, self).forward(input)
self._losses['activity_regularization'] = (output * output).sum() * self.ar_weight
self._losses['l1_weight_regularization'] = torch.abs(self.weight).sum() * self.l1_weight
return output
def model_fn():
return nn.Sequential(
Flatten(),
RegularizedLinear(in_features=784, out_features=256),
nn.LeakyReLU(),
RegularizedLinear(in_features=256, out_features=128),
nn.LeakyReLU(),
RegularizedLinear(in_features=128, out_features=10)
)
def mnist_data_loaders(args):
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
return train_loader, test_loader
def train_model(args):
model = model_fn()
train_loader, validate_loader = mnist_data_loaders(args)
# Build trainer
trainer = Trainer(model) \
.build_criterion('RegularizedCrossEntropyLoss') \
.build_metric('CategoricalError') \
.build_optimizer('Adam') \
.validate_every((1, 'epochs')) \
.save_every((1, 'epochs')) \
.save_to_directory(args.save_directory) \
.set_max_num_epochs(args.epochs) \
.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every='never'),
log_directory=args.save_directory)
# Record regularization losses
trainer.logger.observe_training_and_validation_states([
'main_loss',
'total_regularization_loss',
'activity_regularization',
'l1_weight_regularization'
])
# Bind loaders
trainer \
.bind_loader('train', train_loader) \
.bind_loader('validate', validate_loader)
if args.cuda:
trainer.cuda()
# Go!
trainer.fit()
def main(argv):
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--save-directory', type=str, default='output/mnist/v1',
help='output directory')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=20, metavar='N',
help='number of epochs to train (default: 20)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
args = parser.parse_args(argv)
args.cuda = not args.no_cuda and torch.cuda.is_available()
train_model(args)
if __name__ == '__main__':
main(sys.argv[1:])
================================================
FILE: examples/trainer.py
================================================
"""
Trainer Example
================================
This example should illustrate how to use the trainer class.
"""
import torch.nn as nn
from inferno.io.box.cifar import get_cifar10_loaders
from inferno.trainers.basic import Trainer
from inferno.trainers.callbacks.logging.tensorboard import TensorboardLogger
from inferno.extensions.layers import ConvELU2D
from inferno.extensions.layers import Flatten
from inferno.utils.python_utils import ensure_dir
from inferno.extensions.layers import SELU
##################################################
# change directories to your needs
LOG_DIRECTORY = ensure_dir('log')
SAVE_DIRECTORY = ensure_dir('save')
DATASET_DIRECTORY = ensure_dir('dataset')
##################################################
# shall models be downloaded
DOWNLOAD_CIFAR = True
USE_CUDA = True
##################################################
# Build torch model
model = nn.Sequential(
ConvELU2D(in_channels=3, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
ConvELU2D(in_channels=256, out_channels=256, kernel_size=3),
nn.MaxPool2d(kernel_size=2, stride=2),
Flatten(),
nn.Linear(in_features=(256 * 4 * 4), out_features=10),
nn.Softmax()
)
##################################################
# data loaders
train_loader, validate_loader = get_cifar10_loaders(DATASET_DIRECTORY,
download=DOWNLOAD_CIFAR)
##################################################
# Build trainer
trainer = Trainer(model)
trainer.build_criterion('CrossEntropyLoss')
trainer.build_metric('CategoricalError')
trainer.build_optimizer('Adam')
trainer.validate_every((2, 'epochs'))
trainer.save_every((5, 'epochs'))
trainer.save_to_directory(SAVE_DIRECTORY)
trainer.set_max_num_epochs(10)
trainer.build_logger(TensorboardLogger(log_scalars_every=(1, 'iteration'),
log_images_every='never'),
log_directory=LOG_DIRECTORY)
##################################################
# Bind loaders
trainer.bind_loader('train', train_loader)
trainer.bind_loader('validate', validate_loader)
##################################################
# activate cuda
if USE_CUDA:
trainer.cuda()
##################################################
# fit
trainer.fit()
================================================
FILE: inferno/__init__.py
================================================
# -*- coding: utf-8 -*-
"""Top-level package for inferno."""
from . import extensions
from . import io
from . import trainers
from . import utils
from .version import __version__
__all__ = ['extensions', 'io', 'trainers', 'utils']
__author__ = """Nasim Rahaman"""
__email__ = 'nasim.rahaman@iwr.uni-heidelberg.de'
================================================
FILE: inferno/extensions/__init__.py
================================================
from . import containers
from . import criteria
from . import initializers
from . import layers
from . import metrics
from . import optimizers
from . import models
# Backward support
from . import models as model
__all__ = ['containers', 'criteria', 'initializers', 'layers', 'metrics', 'optimizers',
'models', 'model']
================================================
FILE: inferno/extensions/containers/__init__.py
================================================
from .graph import *
from .sequential import *
================================================
FILE: inferno/extensions/containers/graph.py
================================================
from collections import OrderedDict
import sys
import threading
import multiprocessing as mp
import copy
import gc
import networkx as nx
from networkx import is_directed_acyclic_graph, topological_sort
from torch import nn as nn
from ...utils import python_utils as pyu
from ...utils.exceptions import assert_
from ..layers.device import OnDevice
from ..layers.identity import Identity
__all__ = ['NNGraph', 'Graph']
class NNGraph(nx.DiGraph):
"""A NetworkX DiGraph, except that node and edge ordering matters."""
# We don't copy torch tensors, only to have them deleted.
ATTRIBUTES_TO_NOT_COPY = {'payload'}
node_dict_factory = OrderedDict
adjlist_dict_factory = OrderedDict
def copy(self, **init_kwargs):
new = type(self)(**init_kwargs)
# Remove all attributes and copy only the graph structure
for source, target in self.edges_iter():
# Add new nodes
new.add_node(source)
new.add_node(target)
# Copy attributes
new.node[source].update(copy.deepcopy({key: value
for key, value in self.node[source].items()
if key not in self.ATTRIBUTES_TO_NOT_COPY}))
new.node[target].update(copy.deepcopy({key: value
for key, value in self.node[target].items()
if key not in self.ATTRIBUTES_TO_NOT_COPY}))
# Add new edge
new.add_edge(copy.deepcopy(source), copy.deepcopy(target))
old_edge_attributes = self[source][target]
new_edge_attributes = {key: value for key, value in old_edge_attributes.items()
if key not in self.ATTRIBUTES_TO_NOT_COPY}
new_edge_attributes = copy.deepcopy(new_edge_attributes)
new[source][target].update(new_edge_attributes)
return new
class Graph(nn.Module):
"""
A graph structure to build networks with complex architectures. The resulting graph model
can be used like any other `torch.nn.Module`. The graph structure used behind the scenes
is a `networkx.DiGraph`. This internal graph is exposed by the `apply_on_graph` method,
which can be used with any NetworkX function (e.g. for plotting with matplotlib or GraphViz).
Examples
--------
The naive inception module (without the max-pooling for simplicity) with ELU-layers of 64 units
can be built as following, (assuming 64 input channels):
>>> from inferno.extensions.layers.reshape import Concatenate
>>> from inferno.extensions.layers.convolutional import ConvELU2D
>>> import torch
>>> # Build the model
>>> inception_module = Graph()
>>> inception_module.add_input_node('input')
>>> inception_module.add_node('conv1x1', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('conv3x3', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('conv5x5', ConvELU2D(64, 64, 3), previous='input')
>>> inception_module.add_node('cat', Concatenate(),
>>> previous=['conv1x1', 'conv3x3', 'conv5x5'])
>>> inception_module.add_output_node('output', 'cat')
>>> # Build dummy variable
>>> input = torch.rand(1, 64, 100, 100)
>>> # Get output
>>> output = inception_module(input)
"""
def __init__(self, graph=None):
"""
Construct the graph object.
Parameters
----------
graph : networkx.DiGraph or NNGraph
Graph to build the object from (optional).
"""
super(Graph, self).__init__()
# Privates
self._thread_to_graph_mapping = {}
self._creator_thread = threading.get_ident()
self._creator_pid = mp.current_process().pid
# Publics
if graph is not None:
self.graph = graph
else:
self.graph = NNGraph()
@property
def graph(self):
# `graph` needs to be different for every thread, because torch.nn.parallel.replicate does
# not make a copy.
graph = self._thread_to_graph_mapping.get(threading.get_ident())
if graph is None:
creator_thread_graph = self._thread_to_graph_mapping.get(self._creator_thread)
assert creator_thread_graph is not None
graph = creator_thread_graph.copy()
# We don't need to clear payloads because the copy method of NNGraph copies only the
# graph structure and not the attributes
self._thread_to_graph_mapping.update({threading.get_ident(): graph})
return graph
@graph.setter
def graph(self, value):
assert_(isinstance(value, NNGraph), exception_type=TypeError)
self._thread_to_graph_mapping.update({threading.get_ident(): value})
def is_node_in_graph(self, name):
"""
Checks whether a node is in the graph.
Parameters
----------
name : str
Name of the node.
Returns
-------
bool
"""
return name in self.graph.nodes
def is_source_node(self, name):
"""
Checks whether a given node (by name) is a source node.
A source node has no incoming edges.
Parameters
----------
name : str
Name of the node.
Returns
-------
bool
Raises
------
AssertionError
if node is not found in the graph.
"""
assert self.is_node_in_graph(name)
return self.graph.in_degree(name) == 0
def is_sink_node(self, name):
"""
Checks whether a given node (by name) is a sink node.
A sink node has no outgoing edges.
Parameters
----------
name : str
Name of the node.
Returns
-------
bool
Raises
------
AssertionError
if node is not found in the graph.
"""
assert self.is_node_in_graph(name)
return self.graph.out_degree(name) == 0
@property
def output_nodes(self):
"""
Gets a list of output nodes. The order is relevant and is the same as that
in which the forward method returns its outputs.
Returns
-------
list
A list of names (str) of the output nodes.
"""
return [name for name, node_attributes in self.graph.nodes.items()
if node_attributes.get('is_output_node', False)]
@property
def input_nodes(self):
"""
Gets a list of input nodes. The order is relevant and is the same as that
in which the forward method accepts its inputs.
Returns
-------
list
A list of names (str) of the input nodes.
"""
return [name for name, node_attributes in self.graph.nodes.items()
if node_attributes.get('is_input_node', False)]
@property
def graph_is_valid(self):
"""Checks if the graph is valid."""
# Check if the graph is a DAG
is_dag = is_directed_acyclic_graph(self.graph)
# Check if output nodes are sinks
output_nodes_are_sinks = all([self.is_sink_node(name) for name in self.output_nodes])
# Check inf input nodes are sources
input_nodes_are_sources = all([self.is_source_node(name) for name in self.input_nodes])
# TODO Check whether only input nodes are sources and only output nodes are sinks
# Conclude
is_valid = is_dag and output_nodes_are_sinks and input_nodes_are_sources
return is_valid
def assert_graph_is_valid(self):
"""Asserts that the graph is valid."""
assert is_directed_acyclic_graph(self.graph), "Graph is not a DAG."
for name in self.output_nodes:
assert self.is_sink_node(name), "Output node {} is not a sink.".format(name)
assert not self.is_source_node(name), "Output node {} is a source node. " \
"Make sure it's connected.".format(name)
for name in self.input_nodes:
assert self.is_source_node(name), "Input node {} is not a source.".format(name)
assert not self.is_sink_node(name), "Input node {} is a sink node. " \
"Make sure it's connected.".format(name)
def add_node(self, name, module, previous=None):
"""
Add a node to the graph.
Parameters
----------
name : str
Name of the node. Nodes are identified by their names.
module : torch.nn.Module
Torch module for this node.
previous : str or list of str
(List of) name(s) of the previous node(s).
Returns
-------
Graph
self
"""
assert isinstance(module, nn.Module)
self.add_module(name, module)
self.graph.add_node(name)
if previous is not None:
for _previous in pyu.to_iterable(previous):
self.add_edge(_previous, name)
return self
def add_input_node(self, name):
"""
Add an input to the graph. The order in which input nodes are added is the
order in which the forward method accepts its inputs.
Parameters
----------
name : str
Name of the input node.
Returns
-------
Graph
self
"""
self.add_module(name, Identity())
self.graph.add_node(name, is_input_node=True)
return self
def add_output_node(self, name, previous=None):
"""
Add an output to the graph. The order in which output nodes are added is the
order in which the forward method returns its outputs.
Parameters
----------
name : str
Name of the output node.
Returns
-------
Graph
self
"""
self.graph.add_node(name, is_output_node=True)
if previous is not None:
for _previous in pyu.to_iterable(previous):
self.add_edge(_previous, name)
return self
def add_edge(self, from_node, to_node):
"""
Add an edge between two nodes.
Parameters
----------
from_node : str
Name of the source node.
to_node : str
Name of the target node.
Returns
-------
Graph
self
Raises
------
AssertionError
if either of the two nodes is not in the graph,
or if the edge is not 'legal'.
"""
assert self.is_node_in_graph(from_node)
assert self.is_node_in_graph(to_node)
self.graph.add_edge(from_node, to_node)
assert self.graph_is_valid
return self
def apply_on_graph(self, function, *args, **kwargs):
"""Applies a `function` on the internal graph."""
return function(self, *args, **kwargs)
def get_module_for_nodes(self, names):
"""
Gets the `torch.nn.Module` object for nodes corresponding to `names`.
Parameters
----------
names : str or list of str
Names of the nodes to fetch the modules of.
Returns
-------
list or torch.nn.Module
Module or a list of modules corresponding to `names`.
"""
names = pyu.to_iterable(names)
modules = []
for name in names:
assert self.is_node_in_graph(name), "Node '{}' is not in graph.".format(name)
module = getattr(self, name, None)
assert module is not None, "Node '{}' is in the graph but could not find a module " \
"corresponding to it.".format(name)
modules.append(module)
return pyu.from_iterable(modules)
def to_device(self, names, target_device, device_ordinal=None, asynchronous=False):
"""Transfer nodes in the network to a specified device."""
names = pyu.to_iterable(names)
for name in names:
assert self.is_node_in_graph(name), "Node '{}' is not in graph.".format(name)
module = getattr(self, name, None)
assert module is not None, "Node '{}' is in the graph but could not find a module " \
"corresponding to it.".format(name)
# Transfer
module_on_device = OnDevice(module, target_device,
device_ordinal=device_ordinal,
asynchronous=asynchronous)
setattr(self, name, module_on_device)
return self
def get_parameters_for_nodes(self, names, named=False):
"""Get parameters of all nodes listed in `names`."""
if not named:
parameters = (parameter
for module in pyu.to_iterable(self.get_module_for_nodes(names))
for parameter in module.parameters())
else:
parameters = ((name, parameter)
for module in pyu.to_iterable(self.get_module_for_nodes(names))
for name, parameter in module.named_parameters())
return parameters
def clear_payloads(self, graph=None):
graph = self.graph if graph is None else graph
for edge in list(graph.edges(data=True)):
source, target, _ = edge
if 'payload' in graph[source][target]:
del graph[source][target]['payload']
def forward_through_node(self, name, input=None):
# If input is a tuple/list, it will NOT be unpacked.
# Make sure the node is in the graph
if input is None:
# Make sure the node is not a source node
assert not self.is_source_node(name), \
"Node '{}' did not get an input but is a source node.".format(name)
# Get input from payload
incoming_edges = self.graph.in_edges(name)
input = []
for incoming, this in incoming_edges:
# Append to input
input.append(self.graph[incoming][this]['payload'])
# Clear reference for the garbage collector to do its thing
del self.graph[incoming][this]['payload']
else:
assert self.is_node_in_graph(name)
# Convert input to list
input = [input]
# Get outputs
try:
outputs = pyu.to_iterable(getattr(self, name)(*input))
except Exception as e:
input_spec_string = "\n".join(["--[{}]-{}-->[{}]".format(incoming,
tuple(_input.size()),
this)
for (incoming, this), _input in
zip(self.graph.in_edges(name), input)])
message = "In node '{}': {}\n" \
"Inputs to this node were:\n{}"\
.format(name, str(e), input_spec_string)
raise type(e)(message).with_traceback(sys.exc_info()[2])
# Distribute outputs to outgoing payloads if required
if not self.is_sink_node(name):
outgoing_edges = self.graph.out_edges(name)
if len(outputs) == 1:
# Support for replication
outputs *= len(outgoing_edges)
# Make sure the number of outputs check out
assert len(outputs) == len(outgoing_edges), \
"Number of outputs from the model ({}) does not match the number " \
"of out-edges ({}) in the graph for this node ('{}').".format(len(outputs),
len(outgoing_edges),
name)
for (this, outgoing), output in zip(outgoing_edges, outputs):
self.graph[this][outgoing].update({'payload': output})
# Collect garbage to free some GPU memory?
del input
gc.collect()
# Return outputs
return pyu.from_iterable(outputs)
def forward(self, *inputs):
self.assert_graph_is_valid()
input_nodes = self.input_nodes
output_nodes = self.output_nodes
assert len(inputs) == len(input_nodes), "Was expecting {} " \
"arguments for as many input nodes, got {}."\
.format(len(input_nodes), len(inputs))
# Unpack inputs to input nodes
for input, input_node in zip(inputs, input_nodes):
self.forward_through_node(input_node, input=input)
# Toposort the graph
toposorted = topological_sort(self.graph)
# Remove all input and output nodes
toposorted = [name for name in toposorted
if name not in input_nodes and name not in output_nodes]
# Since we'll be clearing payloads anyway, it makes no sense whatsoever
# to evaluate sink nodes
toposorted = [name for name in toposorted if not self.is_sink_node(name)]
# Forward
for node in toposorted:
self.forward_through_node(node)
# Read outputs from output nodes
outputs = []
for output_node in output_nodes:
# Get all incoming edges to output node
outputs_from_node = [self.graph[incoming][this]['payload']
for incoming, this in self.graph.in_edges(output_node)]
outputs.append(pyu.from_iterable(outputs_from_node))
# Clear payloads for next pass
self.clear_payloads()
# Done.
return pyu.from_iterable(outputs)
================================================
FILE: inferno/extensions/containers/sequential.py
================================================
import torch.nn as nn
from ...utils import python_utils as pyu
__all__ = ['Sequential1', 'Sequential2']
class Sequential1(nn.Sequential):
"""Like torch.nn.Sequential, but with a few extra methods."""
def __len__(self):
return len(self._modules.values())
class Sequential2(Sequential1):
"""Another sequential container.
Identitcal to torch.nn.Sequential, except that modules may return multiple outputs and
accept multiple inputs.
"""
def forward(self, *input):
for module in self._modules.values():
input = pyu.to_iterable(module(*pyu.to_iterable(input)))
return pyu.from_iterable(input)
================================================
FILE: inferno/extensions/criteria/__init__.py
================================================
from .set_similarity_measures import *
from .elementwise_measures import *
from .core import *
from .regularized import *
__all__ = ['set_similarity_measures', 'elementwise_measures','core','regularized']
================================================
FILE: inferno/extensions/criteria/core.py
================================================
import torch.nn as nn
from functools import reduce
from ...utils.exceptions import assert_, ShapeError, NotTorchModuleError
__all__ = ['Criteria', 'As2DCriterion']
class Criteria(nn.Module):
"""Aggregate multiple criteria to one."""
def __init__(self, *criteria):
super(Criteria, self).__init__()
if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)):
criteria = list(criteria[0])
else:
criteria = list(criteria)
# Validate criteria
assert all([isinstance(criterion, nn.Module) for criterion in criteria]), \
"Criterion must be a torch module."
self.criteria = criteria
def forward(self, prediction, target):
assert isinstance(prediction, (list, tuple)), \
"`prediction` must be a list or a tuple, got {} instead."\
.format(type(prediction).__name__)
assert isinstance(target, (list, tuple)), \
"`prediction` must be a list or a tuple, got {} instead." \
.format(type(target).__name__)
assert len(prediction) == len(target), \
"Number of predictions must equal the number of targets. " \
"Got {} predictions but {} targets.".format(len(prediction), len(target))
# Compute losses
losses = [criterion(prediction, target)
for _prediction, _target, criterion in zip(prediction, target, self.criteria)]
# Aggegate losses
loss = reduce(lambda x, y: x + y, losses)
# Done
return loss
class As2DCriterion(nn.Module):
"""
Makes a given criterion applicable on (N, C, H, W) prediction and (N, H, W) target tensors,
if they're applicable to (N, C) prediction and (N,) target tensors .
"""
def __init__(self, criterion):
super(As2DCriterion, self).__init__()
assert_(isinstance(criterion, nn.Module),
"Criterion must be a module, got a {} instead."
.format(type(criterion).__name__),
NotTorchModuleError)
self.criterion = criterion
def forward(self, prediction, target):
# Validate input
assert_(prediction.dim() == 4, "`prediction` is expected to be a 4D tensor of shape "
"(N, C, H, W), got a {}D "
"tensor instead.".format(prediction.dim()),
ShapeError)
assert_(target.dim() == 3, "`target` is expected to be a 3D tensor of shape "
"(N, H, W), got a {}D "
"tensor instead.".format(target.dim()),
ShapeError)
# prediction is assumed to be NCHW, and target NHW.
# this makes target (NHW,)
target = target.contiguous().view(-1)
# This makes prediction (N, H, W, C) --> (NHW, C)
num_channels = prediction.size(1)
prediction = prediction.permute(0, 2, 3, 1).contiguous().view(-1, num_channels)
# Now, the criterion should be applicable as is
loss = self.criterion(prediction, target)
return loss
================================================
FILE: inferno/extensions/criteria/elementwise_measures.py
================================================
import torch.nn as nn
from ...utils.exceptions import assert_
class WeightedMSELoss(nn.Module):
NEGATIVE_CLASS_WEIGHT = 1.
def __init__(self, positive_class_weight=1., positive_class_value=1., size_average=True):
super(WeightedMSELoss, self).__init__()
assert_(positive_class_weight >= 0,
"Positive class weight can't be less than zero, got {}."
.format(positive_class_weight),
ValueError)
self.mse = nn.MSELoss(size_average=size_average)
self.positive_class_weight = positive_class_weight
self.positive_class_value = positive_class_value
def forward(self, input, target):
# Get a mask
positive_class_mask = target.data.eq(self.positive_class_value).type_as(target.data)
# Get differential weights (positive_weight - negative_weight,
# i.e. subtract 1, assuming the negative weight is gauged at 1)
weight_differential = (positive_class_mask
.mul_(self.positive_class_weight - self.NEGATIVE_CLASS_WEIGHT))
# Get final weight by adding weight differential to a tensor with negative weights
weights = weight_differential.add_(self.NEGATIVE_CLASS_WEIGHT)
# `weights` should be positive if NEGATIVE_CLASS_WEIGHT is not messed with.
sqrt_weights = weights.sqrt_()
return self.mse(input * sqrt_weights, target * sqrt_weights)
================================================
FILE: inferno/extensions/criteria/regularized.py
================================================
import warnings
import torch
from torch import nn
from . import set_similarity_measures, core
__all__ = [
'RegularizedLoss',
'RegularizedCrossEntropyLoss',
'RegularizedBCEWithLogitsLoss',
'RegularizedBCELoss',
'RegularizedMSELoss',
'RegularizedNLLLoss'
]
def collect_losses(module):
"""Collect `_losses` dictionaries from module and children
:param module: a Module to be searched for losses
:return: dictionary of loss names to values
"""
losses = {}
def _collect(m):
if hasattr(m, '_losses'):
for k, v in m._losses.items():
if k in losses:
losses[k] = losses[k] + v
else:
losses[k] = v
module.apply(_collect)
return losses
def build_criterion(criterion, *args, **kwargs):
"""Build a criterion
:param criterion: criterion class, name of criterion class, or instance of criterion
:param args: args for constructor
:param kwargs: kwargs for constructor
:return: instance of criterion
"""
if isinstance(criterion, str):
for module in [nn, core, set_similarity_measures]:
criterion_class = getattr(module, criterion, None)
if criterion_class is not None:
break
assert criterion_class is not None, "Criterion {} not found.".format(criterion)
elif callable(criterion) and isinstance(criterion, type):
criterion_class = criterion
elif isinstance(criterion, torch.nn.Module):
return criterion
else:
raise NotImplementedError
return criterion_class(*args, **kwargs)
class RegularizedLoss(nn.Module):
"""Wrap a criterion. Collect regularization losses from model and combine with wrapped criterion.
"""
def __init__(self, criterion, *args, **kwargs):
super(RegularizedLoss, self).__init__()
self.criterion = build_criterion(criterion, *args, **kwargs)
def forward(self, *args, trainer=None, model=None, **kwargs):
# calculate wrapped loss
main_loss = self.criterion(*args, **kwargs)
# If no trainer, we cannot record states
if trainer is None:
warnings.warn('No trainer parameter provided. Not logging regularization losses.')
elif model is None:
model = trainer.model
# If no model or trainer, we cannot record states or collect losses
if model is None:
warnings.warn('No model or trainer parameter provided. Not calculating regularization losses.')
regularization_losses = {}
total_regularization_loss = None
total_loss = main_loss
else:
regularization_losses = collect_losses(model)
total_regularization_loss = sum(regularization_losses.values())
total_loss = main_loss + total_regularization_loss
# Record losses if trainer provided
if trainer is not None:
# prefix depending on mode
if self.training:
prefix = 'training'
else:
prefix = 'validation'
# main loss
updates = {'{}_main_loss'.format(prefix): main_loss}
# total regulariztion loss
if total_regularization_loss is not None:
updates['{}_total_regularization_loss'.format(prefix)] = total_regularization_loss
# detailed regularization losses
for k, v in regularization_losses.items():
updates['{}_{}'.format(prefix, k)] = v
# record state
trainer.update_state_from_dictionary(updates)
return total_loss
# Convenience wrappers for common losses
class RegularizedCrossEntropyLoss(RegularizedLoss):
def __init__(self, *args, **kwargs):
super(RegularizedCrossEntropyLoss, self).__init__(nn.CrossEntropyLoss, *args, **kwargs)
class RegularizedBCEWithLogitsLoss(RegularizedLoss):
def __init__(self, *args, **kwargs):
super(RegularizedBCEWithLogitsLoss, self).__init__(nn.BCEWithLogitsLoss, *args, **kwargs)
class RegularizedBCELoss(RegularizedLoss):
def __init__(self, *args, **kwargs):
super(RegularizedBCELoss, self).__init__(nn.BCELoss, *args, **kwargs)
class RegularizedMSELoss(RegularizedLoss):
def __init__(self, *args, **kwargs):
super(RegularizedMSELoss, self).__init__(nn.MSELoss, *args, **kwargs)
class RegularizedNLLLoss(RegularizedLoss):
def __init__(self, *args, **kwargs):
super(RegularizedNLLLoss, self).__init__(nn.NLLLoss, *args, **kwargs)
================================================
FILE: inferno/extensions/criteria/set_similarity_measures.py
================================================
import torch.nn as nn
from ...utils.torch_utils import flatten_samples
__all__ = ['SorensenDiceLoss', 'GeneralizedDiceLoss']
class SorensenDiceLoss(nn.Module):
"""
Computes a loss scalar, which when minimized maximizes the Sorensen-Dice similarity
between the input and the target. For both inputs and targets it must be the case that
`input_or_target.size(1) = num_channels`.
"""
def __init__(self, weight=None, channelwise=True, eps=1e-6):
"""
Parameters
----------
weight : torch.FloatTensor or torch.cuda.FloatTensor
Class weights. Applies only if `channelwise = True`.
channelwise : bool
Whether to apply the loss channelwise and sum the results (True)
or to apply it on all channels jointly (False).
"""
super(SorensenDiceLoss, self).__init__()
self.register_buffer('weight', weight)
self.channelwise = channelwise
self.eps = eps
def forward(self, input, target):
"""
input: torch.FloatTensor or torch.cuda.FloatTensor
target: torch.FloatTensor or torch.cuda.FloatTensor
Expected shape of the inputs: (batch_size, nb_channels, ...)
"""
assert input.size() == target.size()
if not self.channelwise:
numerator = (input * target).sum()
denominator = (input * input).sum() + (target * target).sum()
loss = -2. * (numerator / denominator.clamp(min=self.eps))
else:
# TODO This should be compatible with Pytorch 0.2, but check
# Flatten input and target to have the shape (C, N),
# where N is the number of samples
input = flatten_samples(input)
target = flatten_samples(target)
# Compute numerator and denominator (by summing over samples and
# leaving the channels intact)
numerator = (input * target).sum(-1)
denominator = (input * input).sum(-1) + (target * target).sum(-1)
channelwise_loss = -2 * (numerator / denominator.clamp(min=self.eps))
if self.weight is not None:
# With pytorch < 0.2, channelwise_loss.size = (C, 1).
if channelwise_loss.dim() == 2:
channelwise_loss = channelwise_loss.squeeze(1)
assert self.weight.size() == channelwise_loss.size()
# Apply weight
channelwise_loss = self.weight * channelwise_loss
# Sum over the channels to compute the total loss
loss = channelwise_loss.sum()
return loss
class GeneralizedDiceLoss(nn.Module):
"""
Computes the scalar Generalized Dice Loss defined in https://arxiv.org/abs/1707.03237
This version works for multiple classes and expects predictions for every class (e.g. softmax output) and
one-hot targets for every class.
"""
def __init__(self, weight=None, channelwise=False, eps=1e-6):
super(GeneralizedDiceLoss, self).__init__()
self.register_buffer('weight', weight)
self.channelwise = channelwise
self.eps = eps
def forward(self, input, target):
"""
input: torch.FloatTensor or torch.cuda.FloatTensor
target: torch.FloatTensor or torch.cuda.FloatTensor
Expected shape of the inputs:
- if not channelwise: (batch_size, nb_classes, ...)
- if channelwise: (batch_size, nb_channels, nb_classes, ...)
"""
assert input.size() == target.size()
if not self.channelwise:
# Flatten input and target to have the shape (nb_classes, N),
# where N is the number of samples
input = flatten_samples(input)
target = flatten_samples(target)
# Find classes weights:
sum_targets = target.sum(-1)
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)
# Compute generalized Dice loss:
numer = ((input * target).sum(-1) * class_weigths).sum()
denom = ((input + target).sum(-1) * class_weigths).sum()
loss = 1. - 2. * numer / denom.clamp(min=self.eps)
else:
def flatten_and_preserve_channels(tensor):
tensor_dim = tensor.dim()
assert tensor_dim >= 3
num_channels = tensor.size(1)
num_classes = tensor.size(2)
# Permute the channel axis to first
permute_axes = list(range(tensor_dim))
permute_axes[0], permute_axes[1], permute_axes[2] = permute_axes[1], permute_axes[2], permute_axes[0]
permuted = tensor.permute(*permute_axes).contiguous()
flattened = permuted.view(num_channels, num_classes, -1)
return flattened
# Flatten input and target to have the shape (nb_channels, nb_classes, N)
input = flatten_and_preserve_channels(input)
target = flatten_and_preserve_channels(target)
# Find classes weights:
sum_targets = target.sum(-1)
class_weigths = 1. / (sum_targets * sum_targets).clamp(min=self.eps)
# Compute generalized Dice loss:
numer = ((input * target).sum(-1) * class_weigths).sum(-1)
denom = ((input + target).sum(-1) * class_weigths).sum(-1)
channelwise_loss = 1. - 2. * numer / denom.clamp(min=self.eps)
if self.weight is not None:
if channelwise_loss.dim() == 2:
channelwise_loss = channelwise_loss.squeeze(1)
assert self.weight.size() == channelwise_loss.size(),\
"""`weight` should have shape (nb_channels, ),
`target` should have shape (batch_size, nb_channels, nb_classes, ...)"""
# Apply channel weights:
channelwise_loss = self.weight * channelwise_loss
loss = channelwise_loss.sum()
return loss
================================================
FILE: inferno/extensions/initializers/__init__.py
================================================
from .base import *
from .presets import *
================================================
FILE: inferno/extensions/initializers/base.py
================================================
import torch.nn.init as init
__all__ = ['Initializer',
'Initialization',
'WeightInitFunction',
'BiasInitFunction',
'TensorInitFunction']
class Initializer(object):
"""
Base class for all initializers.
"""
# TODO Support LSTMs and GRUs
VALID_LAYERS = {'Conv1d', 'Conv2d', 'Conv3d',
'ConvTranspose1d', 'ConvTranspose2d', 'ConvTranspose3d',
'Linear', 'Bilinear',
'Embedding'}
def __call__(self, module):
module_class_name = module.__class__.__name__
if module_class_name in self.VALID_LAYERS:
# Apply to weight and bias
try:
if hasattr(module, 'weight'):
self.call_on_weight(module.weight.data)
except NotImplementedError:
# Don't cry if it's not implemented
pass
try:
if hasattr(module, 'bias'):
self.call_on_bias(module.bias.data)
except NotImplementedError:
pass
return module
def call_on_bias(self, tensor):
return self.call_on_tensor(tensor)
def call_on_weight(self, tensor):
return self.call_on_tensor(tensor)
def call_on_tensor(self, tensor):
raise NotImplementedError
@classmethod
def initializes_weight(cls):
return 'call_on_tensor' in cls.__dict__ or 'call_on_weight' in cls.__dict__
@classmethod
def initializes_bias(cls):
return 'call_on_tensor' in cls.__dict__ or 'call_on_bias' in cls.__dict__
class Initialization(Initializer):
def __init__(self, weight_initializer=None, bias_initializer=None):
if weight_initializer is None:
self.weight_initializer = Initializer()
else:
if isinstance(weight_initializer, Initializer):
assert weight_initializer.initializes_weight()
self.weight_initializer = weight_initializer
elif isinstance(weight_initializer, str):
init_function = getattr(init, weight_initializer, None)
assert init_function is not None
self.weight_initializer = WeightInitFunction(init_function=init_function)
else:
# Provison for weight_initializer to be a function
assert callable(weight_initializer)
self.weight_initializer = WeightInitFunction(init_function=weight_initializer)
if bias_initializer is None:
self.bias_initializer = Initializer()
else:
if isinstance(bias_initializer, Initializer):
assert bias_initializer.initializes_bias
self.bias_initializer = bias_initializer
elif isinstance(bias_initializer, str):
init_function = getattr(init, bias_initializer, None)
assert init_function is not None
self.bias_initializer = BiasInitFunction(init_function=init_function)
else:
assert callable(bias_initializer)
self.bias_initializer = BiasInitFunction(init_function=bias_initializer)
def call_on_weight(self, tensor):
return self.weight_initializer.call_on_weight(tensor)
def call_on_bias(self, tensor):
return self.bias_initializer.call_on_bias(tensor)
class WeightInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(WeightInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_weight(self, tensor):
return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
class BiasInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(BiasInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_bias(self, tensor):
return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
class TensorInitFunction(Initializer):
def __init__(self, init_function, *init_function_args, **init_function_kwargs):
super(TensorInitFunction, self).__init__()
assert callable(init_function)
self.init_function = init_function
self.init_function_args = init_function_args
self.init_function_kwargs = init_function_kwargs
def call_on_tensor(self, tensor):
return self.init_function(tensor, *self.init_function_args, **self.init_function_kwargs)
================================================
FILE: inferno/extensions/initializers/presets.py
================================================
import numpy as np
import torch.nn.init as init
from functools import partial
from .base import Initialization, Initializer
__all__ = ['Constant', 'NormalWeights',
'SELUWeightsZeroBias',
'ELUWeightsZeroBias',
'OrthogonalWeightsZeroBias',
'KaimingNormalWeightsZeroBias']
class Constant(Initializer):
"""Initialize with a constant."""
def __init__(self, constant):
self.constant = constant
def call_on_tensor(self, tensor):
tensor.fill_(self.constant)
return tensor
class NormalWeights(Initializer):
"""
Initialize weights with random numbers drawn from the normal distribution at
`mean` and `stddev`.
"""
def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None):
self.mean = mean
self.stddev = stddev
self.sqrt_gain_over_fan_in = sqrt_gain_over_fan_in
def compute_fan_in(self, tensor):
if tensor.dim() == 2:
return tensor.size(1)
else:
return np.prod(list(tensor.size())[1:])
def call_on_weight(self, tensor):
# Compute stddev if required
if self.sqrt_gain_over_fan_in is not None:
stddev = self.stddev * \
np.sqrt(self.sqrt_gain_over_fan_in / self.compute_fan_in(tensor))
else:
stddev = self.stddev
# Init
tensor.normal_(self.mean, stddev)
class OrthogonalWeightsZeroBias(Initialization):
def __init__(self, orthogonal_gain=1.):
# This prevents a deprecated warning in Pytorch 0.4+
orthogonal = getattr(init, 'orthogonal_', init.orthogonal)
super(OrthogonalWeightsZeroBias, self)\
.__init__(weight_initializer=partial(orthogonal, gain=orthogonal_gain),
bias_initializer=Constant(0.))
class KaimingNormalWeightsZeroBias(Initialization):
def __init__(self, relu_leakage=0):
# This prevents a deprecated warning in Pytorch 0.4+
kaiming_normal = getattr(init, 'kaiming_normal_', init.kaiming_normal)
super(KaimingNormalWeightsZeroBias, self)\
.__init__(weight_initializer=partial(kaiming_normal, a=relu_leakage),
bias_initializer=Constant(0.))
class SELUWeightsZeroBias(Initialization):
def __init__(self):
super(SELUWeightsZeroBias, self)\
.__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.),
bias_initializer=Constant(0.))
class ELUWeightsZeroBias(Initialization):
def __init__(self):
super(ELUWeightsZeroBias, self)\
.__init__(weight_initializer=NormalWeights(sqrt_gain_over_fan_in=1.5505188080679277),
bias_initializer=Constant(0.))
================================================
FILE: inferno/extensions/layers/__init__.py
================================================
__all__ = []
from .activations import *
from .convolutional import *
from .device import *
from .reshape import *
from .convolutional_blocks import *
#######################################################
# the following is to make the sphinx example
# gallery makes proper cross-references
from .activations import _all as _activations_all
from .convolutional import _all as _convolutional_all
from .device import _all as _device_all
from .reshape import _all as _reshape_all
from .convolutional_blocks import _all as _convolutional_blocks_all
from .identity import _all as _identity_all
__all__.extend(_activations_all)
__all__.extend(_convolutional_all)
__all__.extend(_device_all)
__all__.extend(_reshape_all)
__all__.extend(_convolutional_blocks_all)
__all__.extend(_identity_all)
_all = __all__
================================================
FILE: inferno/extensions/layers/activations.py
================================================
import torch.nn.functional as F
import torch.nn as nn
from ...utils.torch_utils import where
__all__ = ['SELU']
_all = __all__
class SELU(nn.Module):
def forward(self, input):
return self.selu(input)
@staticmethod
def selu(x):
alpha = 1.6732632423543772848170429916717
scale = 1.0507009873554804934193349852946
# noinspection PyTypeChecker
return scale * where(x >= 0, x, alpha * F.elu(x))
================================================
FILE: inferno/extensions/layers/convolutional.py
================================================
import torch.nn as nn
import sys
import functools
from ..initializers import (
OrthogonalWeightsZeroBias,
KaimingNormalWeightsZeroBias,
SELUWeightsZeroBias,
)
from ..initializers import Initializer
from .normalization import BatchNormND
from .activations import SELU
from ...utils.exceptions import assert_, ShapeError
from ...utils.partial_cls import register_partial_cls
# we append to this later on
__all__ = [
"GlobalConv2D",
]
_all = __all__
register_partial_cls_here = functools.partial(register_partial_cls, module=__name__)
class ConvActivation(nn.Module):
"""Convolutional layer with 'SAME' padding by default followed by an activation."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
dim,
activation,
stride=1,
dilation=1,
groups=None,
depthwise=False,
bias=True,
deconv=False,
initialization=None,
valid_conv=False,
):
super(ConvActivation, self).__init__()
# Validate dim
assert_(
dim in [1, 2, 3],
"`dim` must be one of [1, 2, 3], got {}.".format(dim),
ShapeError,
)
self.dim = dim
# Check if depthwise
if depthwise:
# We know that in_channels == out_channels, but we also want a consistent API.
# As a compromise, we allow that out_channels be None or 'auto'.
out_channels = in_channels if out_channels in [None, "auto"] else out_channel
assert_(
in_channels == out_channels,
"For depthwise convolutions, number of input channels (given: {}) "
"must equal the number of output channels (given {}).".format(
in_channels, out_channels
),
ValueError,
)
assert_(
groups is None or groups == in_channels,
"For depthwise convolutions, groups (given: {}) must "
"equal the number of channels (given: {}).".format(groups, in_channels),
)
groups = in_channels
else:
groups = 1 if groups is None else groups
self.depthwise = depthwise
if valid_conv:
self.conv = getattr(nn, "Conv{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
elif not deconv:
# Get padding
padding = self.get_padding(kernel_size, dilation)
self.conv = getattr(nn, "Conv{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
else:
self.conv = getattr(nn, "ConvTranspose{}d".format(self.dim))(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
groups=groups,
bias=bias,
)
if initialization is None:
pass
elif isinstance(initialization, Initializer):
self.conv.apply(initialization)
else:
raise NotImplementedError
if isinstance(activation, str):
self.activation = getattr(nn, activation)()
elif isinstance(activation, nn.Module):
self.activation = activation
elif activation is None:
self.activation = None
else:
raise NotImplementedError
def forward(self, input):
conved = self.conv(input)
if self.activation is not None:
activated = self.activation(conved)
else:
# No activation
activated = conved
return activated
def _pair_or_triplet(self, object_):
if isinstance(object_, (list, tuple)):
assert len(object_) == self.dim
return object_
else:
object_ = [object_] * self.dim
return object_
def _get_padding(self, _kernel_size, _dilation):
assert isinstance(_kernel_size, int)
assert isinstance(_dilation, int)
assert _kernel_size % 2 == 1
return ((_kernel_size - 1) // 2) * _dilation
def get_padding(self, kernel_size, dilation):
kernel_size = self._pair_or_triplet(kernel_size)
dilation = self._pair_or_triplet(dilation)
padding = [
self._get_padding(_kernel_size, _dilation)
for _kernel_size, _dilation in zip(kernel_size, dilation)
]
return tuple(padding)
# for consistency
ConvActivationND = ConvActivation
# noinspection PyUnresolvedReferences
class _BNReLUSomeConv(object):
def forward(self, input):
normed = self.batchnorm(input)
activated = self.activation(normed)
conved = self.conv(activated)
return conved
class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation):
def __init__(self, in_channels, out_channels, kernel_size, dim, stride=1, dilation=1, deconv=False):
super(BNReLUConvBaseND, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
dim=dim,
stride=stride,
activation=nn.ReLU(inplace=True),
dilation=dilation,
deconv=deconv,
initialization=KaimingNormalWeightsZeroBias(0),
)
self.batchnorm = BatchNormND(dim, in_channels)
def _register_conv_cls(conv_name, fix=None, default=None):
if fix is None:
fix = {}
if default is None:
default = {}
# simple conv activation
activations = ["ReLU", "ELU", "Sigmoid", "SELU", ""]
init_map = {
"ReLU": KaimingNormalWeightsZeroBias,
"SELU": SELUWeightsZeroBias
}
for activation_str in activations:
cls_name = cls_name = "{}{}ND".format(conv_name,activation_str)
__all__.append(cls_name)
initialization_cls = init_map.get(activation_str, OrthogonalWeightsZeroBias)
if activation_str == "":
activation = None
_fix = {**fix}
_default = {'activation':None}
elif activation_str == "SELU":
activation = nn.SELU(inplace=True)
_fix={**fix, 'activation':activation}
_default = {**default}
else:
activation = activation_str
_fix={**fix, 'activation':activation}
_default = {**default}
register_partial_cls_here(ConvActivation, cls_name,
fix=_fix,
default={**_default, 'initialization':initialization_cls()}
)
for dim in [1, 2, 3]:
cls_name = "{}{}{}D".format(conv_name,activation_str, dim)
__all__.append(cls_name)
register_partial_cls_here(ConvActivation, cls_name,
fix={**_fix, 'dim':dim},
default={**_default, 'initialization':initialization_cls()}
)
def _register_bnr_conv_cls(conv_name, fix=None, default=None):
if fix is None:
fix = {}
if default is None:
default = {}
for dim in [1, 2, 3]:
cls_name = "BNReLU{}ND".format(conv_name)
__all__.append(cls_name)
register_partial_cls_here(BNReLUConvBaseND, cls_name,fix=fix,default=default)
for dim in [1, 2, 3]:
cls_name = "BNReLU{}{}D".format(conv_name, dim)
__all__.append(cls_name)
register_partial_cls_here(BNReLUConvBaseND, cls_name,
fix={**fix, 'dim':dim},
default=default)
# conv classes
_register_conv_cls("Conv")
_register_conv_cls("ValidConv", fix=dict(valid_conv=True))
_register_conv_cls("Deconv", fix=dict(deconv=True), default=dict(kernel_size=2, stride=2))
_register_conv_cls("StridedConv", default=dict(stride=2))
_register_conv_cls("DilatedConv", fix=dict(dilation=2))
_register_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto'))
# BatchNormRelu classes
_register_bnr_conv_cls("Conv", fix=dict(deconv=False))
_register_bnr_conv_cls("Deconv", fix=dict(deconv=True))
_register_bnr_conv_cls("StridedConv", default=dict(stride=2))
_register_bnr_conv_cls("DilatedConv", default=dict(dilation=2))
_register_bnr_conv_cls("DepthwiseConv", fix=dict(deconv=False, depthwise=True), default=dict(out_channels='auto'))
del _register_conv_cls
del _register_bnr_conv_cls
class GlobalConv2D(nn.Module):
"""From https://arxiv.org/pdf/1703.02719.pdf
Main idea: we can have a bigger kernel size computationally acceptable
if we separate 2D-conv in 2 1D-convs """
def __init__(
self,
in_channels,
out_channels,
kernel_size,
local_conv_type,
activation=None,
use_BN=False,
**kwargs
):
super(GlobalConv2D, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
assert isinstance(kernel_size, (int, list, tuple))
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 2
self.kwargs = kwargs
self.conv1a = local_conv_type(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(kernel_size[0], 1),
**kwargs
)
self.conv1b = local_conv_type(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=(1, kernel_size[1]),
**kwargs
)
self.conv2a = local_conv_type(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=(1, kernel_size[1]),
**kwargs
)
self.conv2b = local_conv_type(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=(kernel_size[0], 1),
**kwargs
)
if use_BN:
self.batchnorm = nn.BatchNorm2d(self.out_channels)
else:
self.batchnorm = None
self.activation = activation
def forward(self, input_):
out1 = self.conv1a(input_)
out1 = self.conv1b(out1)
out2 = self.conv2a(input_)
out2 = self.conv2b(out2)
out = out1.add(1, out2)
if self.activation is not None:
out = self.activation(out)
if self.batchnorm is not None:
out = self.batchnorm(out)
return out
================================================
FILE: inferno/extensions/layers/convolutional_blocks.py
================================================
import torch.nn as nn
from .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D
from ...utils import python_utils as pyu
from ...utils.exceptions import assert_
__all__ = ['ResidualBlock', 'PreActSimpleResidualBlock']
_all = __all__
class ResidualBlock(nn.Module):
def __init__(self, layers, resample=None):
super(ResidualBlock, self).__init__()
assert pyu.is_listlike(layers)
self.layers = nn.Sequential(*layers)
self.resample = resample
def forward(self, input):
preaddition = self.layers(input)
if self.resample is not None:
skip = self.resample(input)
else:
skip = input
output = preaddition + skip
return output
class PreActSimpleResidualBlock(ResidualBlock):
def __init__(self, in_channels, num_hidden_channels, upsample=False, downsample=False):
layers = []
if downsample:
assert_(not upsample, "Both downsample and upsample is set to true.", ValueError)
layers.append(BNReLUConv2D(in_channels=in_channels,
out_channels=num_hidden_channels,
kernel_size=3,
stride=2))
resample = nn.Sequential(Conv2D(in_channels=in_channels,
out_channels=in_channels,
kernel_size=1, stride=2),
nn.BatchNorm2d(in_channels))
elif upsample:
layers.append(BNReLUDeconv2D(in_channels=in_channels,
out_channels=num_hidden_channels,
kernel_size=2,
stride=2))
resample = nn.Sequential(Deconv2D(in_channels=in_channels,
out_channels=in_channels,
kernel_size=2, stride=2),
nn.BatchNorm2d(in_channels))
else:
layers.append(BNReLUConv2D(in_channels=in_channels,
out_channels=num_hidden_channels,
kernel_size=3))
resample = None
layers.append(BNReLUConv2D(in_channels=num_hidden_channels,
out_channels=in_channels,
kernel_size=3))
super(PreActSimpleResidualBlock, self).__init__(layers, resample)
# TODO PreActBottleneckResidualBlock
================================================
FILE: inferno/extensions/layers/device.py
================================================
import torch.nn as nn
from ...utils.python_utils import from_iterable, to_iterable
from ...utils.exceptions import assert_, DeviceError
__all__ = ['DeviceTransfer', 'OnDevice']
_all = __all__
class DeviceTransfer(nn.Module):
"""Layer to transfer variables to a specified device."""
def __init__(self, target_device, device_ordinal=None, asynchronous=False):
"""
Parameters
----------
target_device : {'cpu', 'cuda'}
Device to transfer to.
device_ordinal : int
Device ordinal if target_device == 'cuda'.
asynchronous : bool
Whether to use asynchronous transfers.
"""
super(DeviceTransfer, self).__init__()
# Validate arguments
assert_(target_device in ['cpu', 'cuda'],
"Target device must either be 'cpu' or 'cuda'.",
DeviceError)
if target_device == 'cpu':
assert_(device_ordinal is None,
"'device_ordinal' must be None if target_device is 'cpu'.",
DeviceError)
self.target_device = target_device
self.device_ordinal = device_ordinal
def forward(self, *inputs):
if self.target_device == 'cuda':
transferred = tuple(input_.cuda(device=self.device_ordinal,
non_blocking=self.asynchronous)
for input_ in inputs)
elif self.target_device == 'cpu':
transferred = tuple(input_.cpu() for input_ in inputs)
else:
raise NotImplementedError
return from_iterable(transferred)
class OnDevice(nn.Module):
"""
Moves a module to a device. The advantage of using this over `torch.nn.Module.cuda` is
that the inputs are transferred to the same device as the module, enabling easy model
parallelism.
"""
def __init__(self, module, target_device, device_ordinal=None, asynchronous=False):
"""
Parameters
----------
module : torch.nn.Module
Module to transfer to device.
target_device : {'cuda', 'cpu'}
The device to move `module` to. Must be either 'cuda' or 'cpu'.
device_ordinal : int
Ordinal of the GPU device if `target_device = 'cuda'`.
asynchronous : bool
Whether to use asynchronous transfers.
"""
super(OnDevice, self).__init__()
# Validate arguments
assert_(target_device in ['cpu', 'cuda'],
"Target device must either be 'cpu' or 'cuda'.",
DeviceError)
if target_device == 'cpu':
assert_(device_ordinal is None,
"'device_ordinal' must be None if target_device is 'cpu'.",
DeviceError)
self.target_device = target_device
self.device_ordinal = device_ordinal
self.asynchronous = asynchronous
# This is a no-op if module is already in the right device
self.device_transfer = DeviceTransfer(self.target_device,
device_ordinal=self.device_ordinal,
asynchronous=self.asynchronous)
self.module = self.transfer_module(module)
def transfer_module(self, module):
if self.target_device == 'cuda':
return module.cuda(device_id=self.device_ordinal)
elif self.target_device == 'cpu':
return module.cpu()
else:
raise NotImplementedError
def forward(self, *inputs):
# Transfer inputs (no-op if they're already on the right device)
transferred = to_iterable(self.device_transfer(*inputs))
output = self.module(*transferred)
return output
================================================
FILE: inferno/extensions/layers/identity.py
================================================
import torch.nn as nn
__all__ = ['identity']
_all = __all__
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
================================================
FILE: inferno/extensions/layers/normalization.py
================================================
import torch.nn as nn
class BatchNormND(nn.Module):
def __init__(self, dim, num_features,
eps=1e-5, momentum=0.1,
affine=True,track_running_stats=True):
super(BatchNormND, self).__init__()
assert dim in [1, 2, 3]
self.bn = getattr(nn, 'BatchNorm{}d'.format(dim))(num_features=num_features,
eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
def forward(self, x):
return self.bn(x)
================================================
FILE: inferno/extensions/layers/reshape.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...utils.exceptions import assert_, ShapeError
from ...utils import python_utils as pyu
__all__ = ['View', 'AsMatrix', 'Flatten',
'As3D', 'As2D',
'Concatenate', 'Cat',
'ResizeAndConcatenate', 'PoolCat',
'GlobalMeanPooling', 'GlobalMaxPooling',
'Sum', 'SplitChannels','Squeeze', 'RemoveSingletonDimension']
_all = __all__
class View(nn.Module):
def __init__(self, as_shape):
super(View, self).__init__()
self.as_shape = self.validate_as_shape(as_shape)
def validate_as_shape(self, as_shape):
assert all([isinstance(_s, int) or _s == 'x' for _s in as_shape])
all_int_indices = [_n for _n, _s in enumerate(as_shape) if isinstance(_s, int)]
if all_int_indices:
first_int_at_index = all_int_indices[0]
assert all([isinstance(_s, int) for _s in as_shape[first_int_at_index:]])
return as_shape
def forward(self, input):
input_shape = list(input.size())
reshaped_shape = [_s if isinstance(_s, int) else input_shape[_n]
for _n, _s in enumerate(self.as_shape)]
output = input.view(*reshaped_shape)
return output
class AsMatrix(View):
def __init__(self):
super(AsMatrix, self).__init__(as_shape=['x', 'x'])
class Flatten(View):
def __init__(self):
super(Flatten, self).__init__(as_shape=['x', -1])
class As3D(nn.Module):
def __init__(self, channel_as_z=False, num_channels_or_num_z_slices=1):
super(As3D, self).__init__()
self.channel_as_z = channel_as_z
self.num_channels_or_num_z_slices = num_channels_or_num_z_slices
def forward(self, input):
if input.dim() == 5:
# If input is a batch of 3D volumes - return as is
return input
elif input.dim() == 4:
# If input is a batch of 2D images, reshape
b, c, _0, _1 = list(input.size())
assert_(c % self.num_channels_or_num_z_slices == 0,
"Number of channels of the 4D image tensor (= {}) must be "
"divisible by the set number of channels or number of z slices "
"of the 5D volume tensor (= {})."
.format(c, self.num_channels_or_num_z_slices),
ShapeError)
c //= self.num_channels_or_num_z_slices
if self.channel_as_z:
# Move channel axis to z
return input.view(b, self.num_channels_or_num_z_slices, c, _0, _1)
else:
# Keep channel axis where it is, but add a singleton dimension for z
return input.view(b, c, self.num_channels_or_num_z_slices, _0, _1)
elif input.dim() == 2:
# We have a matrix which we wish to turn to a 3D batch
b, c = list(input.size())
return input.view(b, c, 1, 1, 1)
else:
raise NotImplementedError
class As2D(nn.Module):
def __init__(self, z_as_channel=True):
super(As2D, self).__init__()
self.z_as_channel = z_as_channel
def forward(self, input):
if input.dim() == 5:
b, c, _0, _1, _2 = list(input.size())
if not self.z_as_channel:
assert _0 == 1
# Reshape
return input.view(b, c * _0, _1, _2)
elif input.dim() == 4:
# Nothing to do here - input is already 2D
return input
elif input.dim() == 2:
# We make singleton dimensions
b, c = list(input.size())
return input.view(b, c, 1, 1)
class Concatenate(nn.Module):
"""Concatenate input tensors along a specified dimension."""
def __init__(self, dim=1):
super(Concatenate, self).__init__()
self.dim = dim
def forward(self, *inputs):
return torch.cat(inputs, dim=self.dim)
class ResizeAndConcatenate(nn.Module):
"""
Resize input tensors spatially (to a specified target size) before concatenating
them along the a given dim (channel, i.e. 1 by default). The down-sampling mode can
be specified ('average' or 'max'), but the up-sampling is always 'nearest'.
"""
POOL_MODE_MAPPING = {'avg': 'avg',
'average': 'avg',
'mean': 'avg',
'max': 'max'}
def __init__(self, target_size, pool_mode='average', dim=1):
super(ResizeAndConcatenate, self).__init__()
self.target_size = target_size
assert_(pool_mode in self.POOL_MODE_MAPPING.keys(),
"`pool_mode` must be one of {}, got {} instead."
.format(self.POOL_MODE_MAPPING.keys(), pool_mode),
ValueError)
self.pool_mode = self.POOL_MODE_MAPPING.get(pool_mode)
self.dim = dim
def forward(self, *inputs):
dim = inputs[0].dim()
assert_(dim in [4, 5],
'Input tensors must either be 4 or 5 '
'dimensional, but inputs[0] is {}D.'.format(dim),
ShapeError)
# Get resize function
spatial_dim = {4: 2, 5: 3}[dim]
resize_function = getattr(F, 'adaptive_{}_pool{}d'.format(self.pool_mode,
spatial_dim))
target_size = pyu.as_tuple_of_len(self.target_size, spatial_dim)
# Do the resizing
resized_inputs = []
for input_num, input in enumerate(inputs):
# Make sure the dim checks out
assert_(input.dim() == dim,
"Expected inputs[{}] to be a {}D tensor, got a {}D "
"tensor instead.".format(input_num, dim, input.dim()),
ShapeError)
resized_inputs.append(resize_function(input, target_size))
# Concatenate along the channel axis
if len(resized_inputs) > 1:
concatenated = torch.cat(tuple(resized_inputs), self.dim)
else:
concatenated = resized_inputs[0]
# Done
return concatenated
class Cat(Concatenate):
"""An alias for `Concatenate`. Hey, everyone knows who Cat is."""
pass
class PoolCat(ResizeAndConcatenate):
"""Alias for `ResizeAndConcatenate`, just to annoy snarky web developers."""
pass
class GlobalMeanPooling(ResizeAndConcatenate):
"""Global mean pooling layer."""
def __init__(self):
super(GlobalMeanPooling, self).__init__((1, 1), 'average')
class GlobalMaxPooling(ResizeAndConcatenate):
"""Global max pooling layer."""
def __init__(self):
super(GlobalMaxPooling, self).__init__((1, 1), 'max')
class Sum(nn.Module):
"""Sum all inputs."""
def forward(self, *inputs):
return torch.stack(inputs, dim=0).sum(0)
class SplitChannels(nn.Module):
"""Split input at a given index along the channel axis."""
def __init__(self, channel_index):
super(SplitChannels, self).__init__()
self.channel_index = channel_index
def forward(self, input):
if isinstance(self.channel_index, int):
split_location = self.channel_index
elif self.channel_index == 'half':
split_location = input.size(1) // 2
else:
raise NotImplementedError
assert split_location < input.size(1)
split_0 = input[:, 0:split_location, ...]
split_1 = input[:, split_location:, ...]
return split_0, split_1
class Squeeze(nn.Module):
def __init__(self):
super(Squeeze, self).__init__()
def forward(self, x):
return x.squeeze()
class RemoveSingletonDimension(nn.Module):
def __init__(self, dim=1):
super(RemoveSingletonDimension, self).__init__()
self.dim = 1
def forward(self, x):
size = list(x.size())
if size[self.dim] != 1:
raise RuntimeError("RemoveSingletonDimension expects a single channel at dim %d, shape=%s"%(self.dim,str(size)))
slicing = []
for s in size:
slicing.append(slice(0, s))
slicing[self.dim] = 0
return x[slicing]
================================================
FILE: inferno/extensions/layers/sampling.py
================================================
import torch.nn as nn
__all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'AnisotropicPool2D']
# torch is deprecating nn.Upsample in favor of nn.functional.interpolate
# we wrap interpolate here to still use Upsample as class
class Upsample(nn.Module):
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
super(Upsample, self).__init__()
# interpolate was only introduced in torch 0.4.1 for backward compatibility
# we check if we have the attribute here and fall back to Upsample otherwise
if hasattr(nn.functional, 'interpolate'):
self.have_interpolate = True
else:
self.have_interpolate = False
self.sampler = nn.Upsample(size=size, scale_factor=scale_factor,
mode=mode, align_corners=align_corners)
def forward(self, input):
if self.have_interpolate:
return nn.functional.interpolate(input, self.size, self.scale_factor,
self.mode, self.align_corners)
else:
return self.sampler(input)
class AnisotropicUpsample(nn.Module):
def __init__(self, scale_factor):
super(AnisotropicUpsample, self).__init__()
self.upsampler = Upsample(scale_factor=scale_factor)
def forward(self, input):
# input is 3D of shape NCDHW
N, C, D, H, W = input.size()
# Fold C and D axes in one
folded = input.view(N, C * D, H, W)
# Upsample
upsampled = self.upsampler(folded)
# Unfold out the C and D axes
unfolded = upsampled.view(N, C, D,
self.upsampler.scale_factor * H,
self.upsampler.scale_factor * W)
# Done
return unfolded
class AnisotropicPool(nn.MaxPool3d):
def __init__(self, downscale_factor):
ds = downscale_factor
super(AnisotropicPool, self).__init__(kernel_size=(1, ds + 1, ds + 1),
stride=(1, ds, ds),
padding=(0, 1, 1))
class AnisotropicUpsample2D(nn.Module):
def __init__(self, scale_factor):
super(AnisotropicUpsample2D, self).__init__()
self.upsampler = nn.Upsample(scale_factor=scale_factor)
def forward(self, input):
# input is 2D of shape NCDW (or NCDH, egal)
N, C, D, W = input.size()
# Fold C and D axes in one
folded = input.view(N, C * D, W)
# Upsample
upsampled = self.upsampler(folded)
# Unfold out the C and D axes
unfolded = upsampled.view(N, C, D,
self.upsampler.scale_factor * W)
# Done
return unfolded
class AnisotropicPool2D(nn.MaxPool2d):
def __init__(self, downscale_factor):
ds = downscale_factor
super(AnisotropicPool2D, self).__init__(kernel_size=(1, ds + 1),
stride=(1, ds),
padding=(0, 1))
================================================
FILE: inferno/extensions/metrics/__init__.py
================================================
from .categorical import *
from .arand import *
================================================
FILE: inferno/extensions/metrics/arand.py
================================================
from .base import Metric
import numpy as np
import scipy.sparse as sparse
import logging
class ArandScore(Metric):
"""Arand Score, as defined in [1].
References
----------
[1]: http://journal.frontiersin.org/article/10.3389/fnana.2015.00142/full#h3
"""
def __init__(self, average_slices=True):
self.average_slices = average_slices
# compute the arand score for a prediction target pair
def _arand_for_tensor(self, prediction, target):
# check if we need to average over slices
average_slices = self.average_slices and prediction.ndim == 3
score_is_invalid = False
# average the rand score over 3d slices
if average_slices:
# average the arand values over the 3d slices
evaluation_values = [adapted_rand(pred, targ) for pred, targ in zip(prediction, target)]
# check if the score is invalid
if all(ev_val is None for ev_val in evaluation_values):
score_is_invalid = True
score = 0
else:
score = np.mean([eval_val[0] for eval_val in evaluation_values if eval_val is not None])
# compute rand score on whole image / volume
else:
score = adapted_rand(prediction, target)
# check if the score is invalid
if score is None:
score_is_invalid = True
score = 0
else:
score = score[0]
if score_is_invalid:
logger = logging.getLogger(__name__)
logger.warning("All slices were invalid, returning worst possible score")
return score
def forward(self, prediction, target):
assert(prediction.shape == target.shape), "%s, %s" % (str(prediction.shape),
str(target.shape))
assert prediction.shape[1] == 1, "Expect singleton channel axis"
prediction = prediction.cpu().numpy()
target = target.cpu().numpy()
ndim = prediction.ndim
assert ndim in (4, 5), "Expect 2 or 3d input with additional batch and channel axis"
# return the average arand error over the batches
return np.mean([self._arand_for_tensor(pred[0], targ[0])
for pred, targ in zip(prediction, target)])
class ArandError(ArandScore):
"""Arand Error = 1 - <arand score>"""
def __init__(self, **super_kwargs):
super(ArandError, self).__init__(**super_kwargs)
def forward(self, prediction, target):
return 1. - super(ArandError, self).forward(prediction, target)
# Evaluation code courtesy of Juan Nunez-Iglesias, taken from
# https://github.com/janelia-flyem/gala/blob/master/gala/evaluate.py
def adapted_rand(seg, gt):
"""Compute Adapted Rand error as defined by the SNEMI3D contest [1]
Formula is given as 1 - the maximal F-score of the Rand index
(excluding the zero component of the original labels). Adapted
from the SNEMI3D MATLAB script, hence the strange style.
Parameters
----------
seg : np.ndarray
the segmentation to score, where each value is the label at that point
gt : np.ndarray, same shape as seg
the groundtruth to score against, where each value is a label
Returns
-------
are : float
The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$,
where $p$ and $r$ are the precision and recall described below.
prec : float, optional
The adapted Rand precision.
rec : float, optional
The adapted Rand recall.
References
----------
[1]: http://brainiac2.mit.edu/SNEMI3D/evaluation
"""
assert seg.shape == gt.shape, "%s, %s" % (str(seg.shape), str(gt.shape))
logger = logging.getLogger(__name__)
if np.any(seg == 0):
logger.debug("Zeros in segmentation, treating as background.")
if np.any(gt == 0):
logger.debug("Zeros in ground truth, 0's will be ignored.")
seg_zeros = np.all(seg == 0)
gt_zeros = np.all(gt == 0)
# return None if either gt or segmentation are all zeros
logger.debug("Either segmentation or groundtruth are all zeros, returning None.")
if seg_zeros or gt_zeros:
return None
# segA is truth, segB is query
segA = np.ravel(gt)
segB = np.ravel(seg)
# mask to foreground in A
mask = (segA > 0)
segA = segA[mask]
segB = segB[mask]
# number of nonzero pixels in original segA
n = segA.size
n_labels_A = int(np.amax(segA)) + 1
n_labels_B = int(np.amax(segB)) + 1
ones_data = np.ones(n)
p_ij = sparse.csr_matrix((ones_data, (segA.ravel(), segB.ravel())),
shape=(n_labels_A, n_labels_B),
dtype=np.uint64)
# In the paper where adapted rand is proposed, they treat each background
# pixel in segB as a different value (i.e., unique label for each pixel).
# To do this, we sum them differently than others
# ind (label_gt, label_seg), so ignore 0 seg labels
B_nonzero = p_ij[:, 1:]
B_zero = p_ij[:, 0]
# this is a count
num_B_zero = B_zero.sum()
# sum of the joint distribution
# separate sum of B>0 and B=0 parts
sum_p_ij = (B_nonzero).power(2).sum() + num_B_zero
# these are marginal probabilities
# sum over all seg labels overlapping one gt label (except 0 labels)
a_i = p_ij.sum(1)
b_i = B_nonzero.sum(0)
sum_a = np.power(a_i, 2).sum()
sum_b = np.power(b_i, 2).sum() + num_B_zero
precision = float(sum_p_ij) / sum_b
recall = float(sum_p_ij) / sum_a
f_score = 2.0 * precision * recall / (precision + recall)
return f_score, precision, recall
================================================
FILE: inferno/extensions/metrics/base.py
================================================
class Metric(object):
def forward(self, *args, **kwargs):
raise NotImplementedError
def __call__(self, prediction, target, **kwargs):
# We might have listlike predictions (e.g. multi-scale)
# If so, we evaluate the metric on the first prediction,
# which should be at the original scale
if isinstance(prediction, (list, tuple)):
prediction = prediction[0]
# same is true for the target
if isinstance(target, (list, tuple)):
target = target[0]
# Make sure prediction and target live on the same device.
# If they don't, move target to the right device.
if not prediction.is_cuda:
# Move to CPU
target = target.cpu()
else:
# Find device to move to
device_ordinal = prediction.get_device()
target = target.cuda(device_ordinal)
return self.forward(prediction, target, **kwargs)
================================================
FILE: inferno/extensions/metrics/categorical.py
================================================
import torch
from .base import Metric
from ...utils.torch_utils import flatten_samples, is_label_tensor
from ...utils.exceptions import assert_, DTypeError, ShapeError
class CategoricalError(Metric):
"""Categorical error."""
def __init__(self, aggregation_mode='mean'):
assert aggregation_mode in ['mean', 'sum']
self.aggregation_mode = aggregation_mode
def forward(self, prediction, target):
# Check if prediction is binary or not
is_binary = len(prediction.size()) == 1 or prediction.size(1) == 1
if len(target.size()) > 1:
target = target.squeeze(1)
assert len(target.size()) == 1
if is_binary:
# Binary classification
prediction = prediction > 0.5
incorrect = prediction.type_as(target).ne(target).float()
if self.aggregation_mode == 'mean':
return incorrect.mean()
else:
return incorrect.sum()
else:
# Multiclass classificiation
_, predicted_class = torch.max(prediction, 1)
if predicted_class.dim() == prediction.dim():
# Support for Pytorch 0.1.12
predicted_class = predicted_class.squeeze(1)
incorrect = predicted_class.type_as(target).ne(target).float()
if self.aggregation_mode == 'mean':
return incorrect.mean()
else:
return incorrect.sum()
class IOU(Metric):
"""Intersection over Union. """
def __init__(self, ignore_class=None, sharpen_prediction=False, eps=1e-6):
super(IOU, self).__init__()
self.eps = eps
self.ignore_class = ignore_class
self.sharpen_prediction = sharpen_prediction
def forward(s
gitextract_cy4q7gy0/
├── .editorconfig
├── .github/
│ └── ISSUE_TEMPLATE.md
├── .gitignore
├── .travis.yml
├── AUTHORS.rst
├── CONTRIBUTING.rst
├── HISTORY.rst
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.rst
├── add2path.sh
├── build_docs.sh
├── conda-recipe/
│ ├── build.sh
│ └── meta.yaml
├── docs/
│ ├── .gitignore
│ ├── Makefile
│ ├── _templates/
│ │ ├── layout.html
│ │ └── template_module.rst
│ ├── authors.rst
│ ├── conf.py
│ ├── contributing.rst
│ ├── environment.yml
│ ├── examples.rst
│ ├── history.rst
│ ├── index.rst
│ ├── inferno-apidoc/
│ │ ├── inferno.extensions.containers.rst
│ │ ├── inferno.extensions.criteria.rst
│ │ ├── inferno.extensions.initializers.rst
│ │ ├── inferno.extensions.layers.rst
│ │ ├── inferno.extensions.metrics.rst
│ │ ├── inferno.extensions.optimizers.rst
│ │ ├── inferno.extensions.rst
│ │ ├── inferno.io.box.rst
│ │ ├── inferno.io.core.rst
│ │ ├── inferno.io.rst
│ │ ├── inferno.io.transform.rst
│ │ ├── inferno.io.volumetric.rst
│ │ ├── inferno.rst
│ │ ├── inferno.trainers.callbacks.logging.rst
│ │ ├── inferno.trainers.callbacks.rst
│ │ ├── inferno.trainers.rst
│ │ ├── inferno.utils.rst
│ │ └── modules.rst
│ ├── installation.rst
│ ├── make.bat
│ ├── readme.rst
│ ├── refs.bib
│ ├── usage.rst
│ └── zbibliography.rst
├── examples/
│ ├── README.txt
│ ├── plot_cheap_unet.py
│ ├── plot_train_side_loss_unet.py
│ ├── plot_unet_tutorial.py
│ ├── regularized_mnist.py
│ └── trainer.py
├── inferno/
│ ├── __init__.py
│ ├── extensions/
│ │ ├── __init__.py
│ │ ├── containers/
│ │ │ ├── __init__.py
│ │ │ ├── graph.py
│ │ │ └── sequential.py
│ │ ├── criteria/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── elementwise_measures.py
│ │ │ ├── regularized.py
│ │ │ └── set_similarity_measures.py
│ │ ├── initializers/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── presets.py
│ │ ├── layers/
│ │ │ ├── __init__.py
│ │ │ ├── activations.py
│ │ │ ├── convolutional.py
│ │ │ ├── convolutional_blocks.py
│ │ │ ├── device.py
│ │ │ ├── identity.py
│ │ │ ├── normalization.py
│ │ │ ├── reshape.py
│ │ │ └── sampling.py
│ │ ├── metrics/
│ │ │ ├── __init__.py
│ │ │ ├── arand.py
│ │ │ ├── base.py
│ │ │ ├── categorical.py
│ │ │ ├── cremi_score.py
│ │ │ └── voi.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── res_unet.py
│ │ │ └── unet.py
│ │ └── optimizers/
│ │ ├── __init__.py
│ │ ├── adam.py
│ │ ├── annealed_adam.py
│ │ └── ranger.py
│ ├── inferno.py
│ ├── io/
│ │ ├── __init__.py
│ │ ├── box/
│ │ │ ├── __init__.py
│ │ │ ├── binary_blobs.py
│ │ │ ├── camvid.py
│ │ │ ├── cifar.py
│ │ │ └── cityscapes.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── concatenate.py
│ │ │ ├── data_utils.py
│ │ │ └── zip.py
│ │ ├── transform/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── generic.py
│ │ │ ├── image.py
│ │ │ └── volume.py
│ │ └── volumetric/
│ │ ├── __init__.py
│ │ ├── lazy_volume_loader.py
│ │ ├── volume.py
│ │ └── volumetric_utils.py
│ ├── trainers/
│ │ ├── __init__.py
│ │ ├── basic.py
│ │ └── callbacks/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── console.py
│ │ ├── essentials.py
│ │ ├── gradients.py
│ │ ├── logging/
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ └── tensorboard.py
│ │ ├── scheduling.py
│ │ ├── tqdm.py
│ │ └── tqdmstub.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── exceptions.py
│ │ ├── io_utils.py
│ │ ├── math_utils.py
│ │ ├── model_utils.py
│ │ ├── partial_cls.py
│ │ ├── python_utils.py
│ │ ├── test_utils.py
│ │ ├── torch_utils.py
│ │ └── train_utils.py
│ └── version.py
├── readthedocs.yml
├── requirements.txt
├── requirements_dev.txt
├── setup.py
└── tests/
├── __init__.py
├── test_extensions/
│ ├── __init__.py
│ ├── test_containers/
│ │ └── test_graph.py
│ ├── test_criteria/
│ │ ├── test_core.py
│ │ ├── test_elementwise_measures.py
│ │ └── test_set_similarity_measures.py
│ ├── test_layers/
│ │ ├── deprecated/
│ │ │ └── building_blocks.py
│ │ ├── test_activations.py
│ │ ├── test_convolutional.py
│ │ ├── test_device.py
│ │ └── test_reshape.py
│ ├── test_metrics/
│ │ └── categorical.py
│ └── test_models/
│ ├── __init__.py
│ ├── test_res_unet.py
│ └── test_unet.py
├── test_inferno.py
├── test_io/
│ ├── __init__.py
│ ├── test_box/
│ │ ├── __init__.py
│ │ ├── test_camvid.py
│ │ └── test_cityscapes.py
│ ├── test_core/
│ │ ├── __init__.py
│ │ ├── test_concatenate.py
│ │ └── test_zip.py
│ └── test_volumetric/
│ ├── __init__.py
│ ├── test_lazy_volume_loader.py
│ └── test_volume_loader.py
├── test_training/
│ ├── __init__.py
│ ├── test_basic.py
│ └── test_callbacks/
│ ├── __init__.py
│ ├── test_base.py
│ ├── test_essentials.py
│ ├── test_logging/
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ └── test_tensorboard.py
│ └── test_scheduling.py
└── test_utils/
├── __init__.py
├── test_model_utils.py
├── test_partial_cls.py
└── test_train_utils.py
SYMBOL INDEX (1067 symbols across 93 files)
FILE: docs/conf.py
class Mock (line 25) | class Mock(MagicMock):
method __getattr__ (line 27) | def __getattr__(cls, name):
FILE: examples/plot_cheap_unet.py
function label_transform (line 33) | def label_transform(x):
function train_model (line 79) | def train_model(model, loaders, **kwargs):
function predict (line 113) | def predict(trainer, test_loader, save_dir=None):
class CheapConv (line 164) | class CheapConv(nn.Module):
method __init__ (line 165) | def __init__(self, in_channels, out_channels, activated):
method forward (line 179) | def forward(self, x):
class CheapConvBlock (line 186) | class CheapConvBlock(nn.Module):
method __init__ (line 187) | def __init__(self, in_channels, out_channels, activated):
method forward (line 199) | def forward(self, x):
class MySimple2DCpUnet (line 213) | class MySimple2DCpUnet(UNetBase):
method __init__ (line 214) | def __init__(self, in_channels, out_channels, depth=3, residual=False,...
method conv_op_factory (line 218) | def conv_op_factory(self, in_channels, out_channels, part, index):
FILE: examples/plot_train_side_loss_unet.py
class MySideLossUNet (line 39) | class MySideLossUNet(nn.Module):
method __init__ (line 40) | def __init__(self, in_channels, out_channels, depth=3):
method forward (line 80) | def forward(self, input):
class MySideLoss (line 105) | class MySideLoss(nn.Module):
method __init__ (line 109) | def __init__(self):
method forward (line 116) | def forward(self, predictions, target):
FILE: examples/plot_unet_tutorial.py
function label_transform (line 32) | def label_transform(x):
function train_model (line 112) | def train_model(model, loaders, **kwargs):
function predict (line 148) | def predict(trainer, test_loader, save_dir=None):
class MySimple2DUnet (line 206) | class MySimple2DUnet(UNetBase):
method __init__ (line 207) | def __init__(self, in_channels, out_channels, depth=3, **kwargs):
method conv_op_factory (line 211) | def conv_op_factory(self, in_channels, out_channels, part, index):
method downsample_op_factory (line 242) | def downsample_op_factory(self, index):
method upsample_op_factory (line 246) | def upsample_op_factory(self, index):
FILE: examples/regularized_mnist.py
class RegularizedLinear (line 25) | class RegularizedLinear(nn.Linear):
method __init__ (line 26) | def __init__(self, *args, ar_weight=1e-3, l1_weight=1e-3, **kwargs):
method forward (line 32) | def forward(self, input):
function model_fn (line 39) | def model_fn():
function mnist_data_loaders (line 50) | def mnist_data_loaders(args):
function train_model (line 68) | def train_model(args):
function main (line 105) | def main(argv):
FILE: inferno/extensions/containers/graph.py
class NNGraph (line 20) | class NNGraph(nx.DiGraph):
method copy (line 27) | def copy(self, **init_kwargs):
class Graph (line 51) | class Graph(nn.Module):
method __init__ (line 81) | def __init__(self, graph=None):
method graph (line 102) | def graph(self):
method graph (line 116) | def graph(self, value):
method is_node_in_graph (line 120) | def is_node_in_graph(self, name):
method is_source_node (line 135) | def is_source_node(self, name):
method is_sink_node (line 157) | def is_sink_node(self, name):
method output_nodes (line 180) | def output_nodes(self):
method input_nodes (line 194) | def input_nodes(self):
method graph_is_valid (line 208) | def graph_is_valid(self):
method assert_graph_is_valid (line 221) | def assert_graph_is_valid(self):
method add_node (line 233) | def add_node(self, name, module, previous=None):
method add_input_node (line 261) | def add_input_node(self, name):
method add_output_node (line 280) | def add_output_node(self, name, previous=None):
method add_edge (line 301) | def add_edge(self, from_node, to_node):
method apply_on_graph (line 329) | def apply_on_graph(self, function, *args, **kwargs):
method get_module_for_nodes (line 333) | def get_module_for_nodes(self, names):
method to_device (line 358) | def to_device(self, names, target_device, device_ordinal=None, asynchr...
method get_parameters_for_nodes (line 373) | def get_parameters_for_nodes(self, names, named=False):
method clear_payloads (line 385) | def clear_payloads(self, graph=None):
method forward_through_node (line 392) | def forward_through_node(self, name, input=None):
method forward (line 445) | def forward(self, *inputs):
FILE: inferno/extensions/containers/sequential.py
class Sequential1 (line 8) | class Sequential1(nn.Sequential):
method __len__ (line 10) | def __len__(self):
class Sequential2 (line 14) | class Sequential2(Sequential1):
method forward (line 19) | def forward(self, *input):
FILE: inferno/extensions/criteria/core.py
class Criteria (line 9) | class Criteria(nn.Module):
method __init__ (line 11) | def __init__(self, *criteria):
method forward (line 22) | def forward(self, prediction, target):
class As2DCriterion (line 41) | class As2DCriterion(nn.Module):
method __init__ (line 46) | def __init__(self, criterion):
method forward (line 54) | def forward(self, prediction, target):
FILE: inferno/extensions/criteria/elementwise_measures.py
class WeightedMSELoss (line 5) | class WeightedMSELoss(nn.Module):
method __init__ (line 8) | def __init__(self, positive_class_weight=1., positive_class_value=1., ...
method forward (line 18) | def forward(self, input, target):
FILE: inferno/extensions/criteria/regularized.py
function collect_losses (line 18) | def collect_losses(module):
function build_criterion (line 38) | def build_criterion(criterion, *args, **kwargs):
class RegularizedLoss (line 61) | class RegularizedLoss(nn.Module):
method __init__ (line 65) | def __init__(self, criterion, *args, **kwargs):
method forward (line 69) | def forward(self, *args, trainer=None, model=None, **kwargs):
class RegularizedCrossEntropyLoss (line 112) | class RegularizedCrossEntropyLoss(RegularizedLoss):
method __init__ (line 113) | def __init__(self, *args, **kwargs):
class RegularizedBCEWithLogitsLoss (line 117) | class RegularizedBCEWithLogitsLoss(RegularizedLoss):
method __init__ (line 118) | def __init__(self, *args, **kwargs):
class RegularizedBCELoss (line 122) | class RegularizedBCELoss(RegularizedLoss):
method __init__ (line 123) | def __init__(self, *args, **kwargs):
class RegularizedMSELoss (line 127) | class RegularizedMSELoss(RegularizedLoss):
method __init__ (line 128) | def __init__(self, *args, **kwargs):
class RegularizedNLLLoss (line 132) | class RegularizedNLLLoss(RegularizedLoss):
method __init__ (line 133) | def __init__(self, *args, **kwargs):
FILE: inferno/extensions/criteria/set_similarity_measures.py
class SorensenDiceLoss (line 7) | class SorensenDiceLoss(nn.Module):
method __init__ (line 13) | def __init__(self, weight=None, channelwise=True, eps=1e-6):
method forward (line 28) | def forward(self, input, target):
class GeneralizedDiceLoss (line 63) | class GeneralizedDiceLoss(nn.Module):
method __init__ (line 70) | def __init__(self, weight=None, channelwise=False, eps=1e-6):
method forward (line 76) | def forward(self, input, target):
FILE: inferno/extensions/initializers/base.py
class Initializer (line 11) | class Initializer(object):
method __call__ (line 22) | def __call__(self, module):
method call_on_bias (line 41) | def call_on_bias(self, tensor):
method call_on_weight (line 44) | def call_on_weight(self, tensor):
method call_on_tensor (line 47) | def call_on_tensor(self, tensor):
method initializes_weight (line 51) | def initializes_weight(cls):
method initializes_bias (line 55) | def initializes_bias(cls):
class Initialization (line 59) | class Initialization(Initializer):
method __init__ (line 60) | def __init__(self, weight_initializer=None, bias_initializer=None):
method call_on_weight (line 90) | def call_on_weight(self, tensor):
method call_on_bias (line 93) | def call_on_bias(self, tensor):
class WeightInitFunction (line 97) | class WeightInitFunction(Initializer):
method __init__ (line 98) | def __init__(self, init_function, *init_function_args, **init_function...
method call_on_weight (line 105) | def call_on_weight(self, tensor):
class BiasInitFunction (line 109) | class BiasInitFunction(Initializer):
method __init__ (line 110) | def __init__(self, init_function, *init_function_args, **init_function...
method call_on_bias (line 117) | def call_on_bias(self, tensor):
class TensorInitFunction (line 121) | class TensorInitFunction(Initializer):
method __init__ (line 122) | def __init__(self, init_function, *init_function_args, **init_function...
method call_on_tensor (line 129) | def call_on_tensor(self, tensor):
FILE: inferno/extensions/initializers/presets.py
class Constant (line 15) | class Constant(Initializer):
method __init__ (line 17) | def __init__(self, constant):
method call_on_tensor (line 20) | def call_on_tensor(self, tensor):
class NormalWeights (line 25) | class NormalWeights(Initializer):
method __init__ (line 30) | def __init__(self, mean=0., stddev=1., sqrt_gain_over_fan_in=None):
method compute_fan_in (line 35) | def compute_fan_in(self, tensor):
method call_on_weight (line 41) | def call_on_weight(self, tensor):
class OrthogonalWeightsZeroBias (line 52) | class OrthogonalWeightsZeroBias(Initialization):
method __init__ (line 53) | def __init__(self, orthogonal_gain=1.):
class KaimingNormalWeightsZeroBias (line 61) | class KaimingNormalWeightsZeroBias(Initialization):
method __init__ (line 62) | def __init__(self, relu_leakage=0):
class SELUWeightsZeroBias (line 70) | class SELUWeightsZeroBias(Initialization):
method __init__ (line 71) | def __init__(self):
class ELUWeightsZeroBias (line 77) | class ELUWeightsZeroBias(Initialization):
method __init__ (line 78) | def __init__(self):
FILE: inferno/extensions/layers/activations.py
class SELU (line 8) | class SELU(nn.Module):
method forward (line 9) | def forward(self, input):
method selu (line 13) | def selu(x):
FILE: inferno/extensions/layers/convolutional.py
class ConvActivation (line 24) | class ConvActivation(nn.Module):
method __init__ (line 27) | def __init__(
method forward (line 123) | def forward(self, input):
method _pair_or_triplet (line 132) | def _pair_or_triplet(self, object_):
method _get_padding (line 140) | def _get_padding(self, _kernel_size, _dilation):
method get_padding (line 146) | def get_padding(self, kernel_size, dilation):
class _BNReLUSomeConv (line 160) | class _BNReLUSomeConv(object):
method forward (line 161) | def forward(self, input):
class BNReLUConvBaseND (line 167) | class BNReLUConvBaseND(_BNReLUSomeConv, ConvActivation):
method __init__ (line 168) | def __init__(self, in_channels, out_channels, kernel_size, dim, stride...
function _register_conv_cls (line 184) | def _register_conv_cls(conv_name, fix=None, default=None):
function _register_bnr_conv_cls (line 225) | def _register_bnr_conv_cls(conv_name, fix=None, default=None):
class GlobalConv2D (line 264) | class GlobalConv2D(nn.Module):
method __init__ (line 269) | def __init__(
method forward (line 317) | def forward(self, input_):
FILE: inferno/extensions/layers/convolutional_blocks.py
class ResidualBlock (line 10) | class ResidualBlock(nn.Module):
method __init__ (line 11) | def __init__(self, layers, resample=None):
method forward (line 17) | def forward(self, input):
class PreActSimpleResidualBlock (line 27) | class PreActSimpleResidualBlock(ResidualBlock):
method __init__ (line 28) | def __init__(self, in_channels, num_hidden_channels, upsample=False, d...
FILE: inferno/extensions/layers/device.py
class DeviceTransfer (line 9) | class DeviceTransfer(nn.Module):
method __init__ (line 11) | def __init__(self, target_device, device_ordinal=None, asynchronous=Fa...
method forward (line 34) | def forward(self, *inputs):
class OnDevice (line 46) | class OnDevice(nn.Module):
method __init__ (line 52) | def __init__(self, module, target_device, device_ordinal=None, asynchr...
method transfer_module (line 84) | def transfer_module(self, module):
method forward (line 92) | def forward(self, *inputs):
FILE: inferno/extensions/layers/identity.py
class Identity (line 5) | class Identity(nn.Module):
method __init__ (line 6) | def __init__(self):
method forward (line 9) | def forward(self, x):
FILE: inferno/extensions/layers/normalization.py
class BatchNormND (line 4) | class BatchNormND(nn.Module):
method __init__ (line 5) | def __init__(self, dim, num_features,
method forward (line 13) | def forward(self, x):
FILE: inferno/extensions/layers/reshape.py
class View (line 16) | class View(nn.Module):
method __init__ (line 17) | def __init__(self, as_shape):
method validate_as_shape (line 21) | def validate_as_shape(self, as_shape):
method forward (line 30) | def forward(self, input):
class AsMatrix (line 38) | class AsMatrix(View):
method __init__ (line 39) | def __init__(self):
class Flatten (line 43) | class Flatten(View):
method __init__ (line 44) | def __init__(self):
class As3D (line 48) | class As3D(nn.Module):
method __init__ (line 49) | def __init__(self, channel_as_z=False, num_channels_or_num_z_slices=1):
method forward (line 54) | def forward(self, input):
class As2D (line 82) | class As2D(nn.Module):
method __init__ (line 83) | def __init__(self, z_as_channel=True):
method forward (line 87) | def forward(self, input):
class Concatenate (line 103) | class Concatenate(nn.Module):
method __init__ (line 105) | def __init__(self, dim=1):
method forward (line 109) | def forward(self, *inputs):
class ResizeAndConcatenate (line 113) | class ResizeAndConcatenate(nn.Module):
method __init__ (line 125) | def __init__(self, target_size, pool_mode='average', dim=1):
method forward (line 135) | def forward(self, *inputs):
class Cat (line 164) | class Cat(Concatenate):
class PoolCat (line 169) | class PoolCat(ResizeAndConcatenate):
class GlobalMeanPooling (line 174) | class GlobalMeanPooling(ResizeAndConcatenate):
method __init__ (line 176) | def __init__(self):
class GlobalMaxPooling (line 180) | class GlobalMaxPooling(ResizeAndConcatenate):
method __init__ (line 182) | def __init__(self):
class Sum (line 186) | class Sum(nn.Module):
method forward (line 188) | def forward(self, *inputs):
class SplitChannels (line 192) | class SplitChannels(nn.Module):
method __init__ (line 194) | def __init__(self, channel_index):
method forward (line 198) | def forward(self, input):
class Squeeze (line 212) | class Squeeze(nn.Module):
method __init__ (line 213) | def __init__(self):
method forward (line 215) | def forward(self, x):
class RemoveSingletonDimension (line 218) | class RemoveSingletonDimension(nn.Module):
method __init__ (line 219) | def __init__(self, dim=1):
method forward (line 222) | def forward(self, x):
FILE: inferno/extensions/layers/sampling.py
class Upsample (line 8) | class Upsample(nn.Module):
method __init__ (line 9) | def __init__(self, size=None, scale_factor=None, mode='nearest', align...
method forward (line 24) | def forward(self, input):
class AnisotropicUpsample (line 32) | class AnisotropicUpsample(nn.Module):
method __init__ (line 33) | def __init__(self, scale_factor):
method forward (line 37) | def forward(self, input):
class AnisotropicPool (line 52) | class AnisotropicPool(nn.MaxPool3d):
method __init__ (line 53) | def __init__(self, downscale_factor):
class AnisotropicUpsample2D (line 59) | class AnisotropicUpsample2D(nn.Module):
method __init__ (line 60) | def __init__(self, scale_factor):
method forward (line 64) | def forward(self, input):
class AnisotropicPool2D (line 78) | class AnisotropicPool2D(nn.MaxPool2d):
method __init__ (line 79) | def __init__(self, downscale_factor):
FILE: inferno/extensions/metrics/arand.py
class ArandScore (line 7) | class ArandScore(Metric):
method __init__ (line 14) | def __init__(self, average_slices=True):
method _arand_for_tensor (line 18) | def _arand_for_tensor(self, prediction, target):
method forward (line 49) | def forward(self, prediction, target):
class ArandError (line 64) | class ArandError(ArandScore):
method __init__ (line 66) | def __init__(self, **super_kwargs):
method forward (line 69) | def forward(self, prediction, target):
function adapted_rand (line 75) | def adapted_rand(seg, gt):
FILE: inferno/extensions/metrics/base.py
class Metric (line 3) | class Metric(object):
method forward (line 5) | def forward(self, *args, **kwargs):
method __call__ (line 8) | def __call__(self, prediction, target, **kwargs):
FILE: inferno/extensions/metrics/categorical.py
class CategoricalError (line 7) | class CategoricalError(Metric):
method __init__ (line 9) | def __init__(self, aggregation_mode='mean'):
method forward (line 13) | def forward(self, prediction, target):
class IOU (line 42) | class IOU(Metric):
method __init__ (line 44) | def __init__(self, ignore_class=None, sharpen_prediction=False, eps=1e...
method forward (line 50) | def forward(self, prediction, target):
class NegativeIOU (line 133) | class NegativeIOU(IOU):
method forward (line 134) | def forward(self, prediction, target):
FILE: inferno/extensions/metrics/cremi_score.py
function cremi_metrics (line 9) | def cremi_metrics(seg, gt, no_seg_ignore=True):
FILE: inferno/extensions/metrics/voi.py
class VoiScore (line 7) | class VoiScore(Metric):
method forward (line 15) | def forward(self, prediction, target):
function voi (line 29) | def voi(seg, gt, ignore_reconstruction=[], ignore_groundtruth=[0]):
function split_vi (line 65) | def split_vi(x, y=None, ignore_x=[0], ignore_y=[0]):
function vi_tables (line 102) | def vi_tables(x, y=None, ignore_x=[0], ignore_y=[0]):
function contingency_table (line 156) | def contingency_table(seg, gt, ignore_seg=[0], ignore_gt=[0], norm=True):
function divide_columns (line 196) | def divide_columns(matrix, row, in_place=False):
function divide_rows (line 235) | def divide_rows(matrix, column, in_place=False):
function xlogx (line 274) | def xlogx(x, out=None, in_place=False):
FILE: inferno/extensions/models/res_unet.py
class _ResBlockBase (line 15) | class _ResBlockBase(nn.Module):
method __init__ (line 16) | def __init__(self, in_channels, out_channels, dim,
method activated_skip_op_factory (line 50) | def activated_skip_op_factory(self, in_channels, out_channels):
method nonactivated_conv_op_factory (line 53) | def nonactivated_conv_op_factory(self, in_channels, out_channels, index):
method activation_op_factory (line 56) | def activation_op_factory(self, index):
method forward (line 59) | def forward(self, input):
class _ResBlock (line 90) | class _ResBlock(_ResBlockBase):
method __init__ (line 91) | def __init__(self, in_channels, out_channels, dim, size=2, activated=T...
method activated_skip_op_factory (line 130) | def activated_skip_op_factory(self, in_channels, out_channels):
method nonactivated_conv_op_factory (line 139) | def nonactivated_conv_op_factory(self, in_channels, out_channels, index):
method activation_op_factory (line 148) | def activation_op_factory(self, index):
method batchnorm_op_factory (line 151) | def batchnorm_op_factory(self, in_channels):
class ResBlockUNet (line 159) | class ResBlockUNet(UNetBase):
method __init__ (line 171) | def __init__(self, in_channels, dim, out_channels, unet_kwargs=None,
method conv_op_factory (line 192) | def conv_op_factory(self, in_channels, out_channels, part, index):
FILE: inferno/extensions/models/unet.py
class UNetBase (line 13) | class UNetBase(nn.Module):
method __init__ (line 38) | def __init__(self, in_channels, dim, out_channels=None, depth=3,
method _get_num_channels (line 122) | def _get_num_channels(self, depth):
method _init__downstream (line 126) | def _init__downstream(self):
method _init__bottom (line 152) | def _init__bottom(self):
method _init__upstream (line 166) | def _init__upstream(self):
method _make_upsample_kwargs (line 201) | def _make_upsample_kwargs(self, upsample_mode):
method _forward_sanity_check (line 220) | def _forward_sanity_check(self, input):
method _check_scaling (line 235) | def _check_scaling(self, input):
method forward (line 242) | def forward(self, input):
method downsample_op_factory (line 308) | def downsample_op_factory(self, index):
method upsample_op_factory (line 312) | def upsample_op_factory(self, index):\
method conv_op_factory (line 316) | def conv_op_factory(self, in_channels, out_channels, part, index):
method _dropout (line 319) | def _dropout(self, x):
class UNet (line 327) | class UNet(UNetBase):
method __init__ (line 332) | def __init__(self, in_channels, out_channels, dim,
method forward (line 364) | def forward(self, input):
method conv_op_factory (line 370) | def conv_op_factory(self, in_channels, out_channels, part, index):
FILE: inferno/extensions/optimizers/adam.py
class Adam (line 5) | class Adam(Optimizer):
method __init__ (line 24) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
method step (line 31) | def step(self, closure=None):
FILE: inferno/extensions/optimizers/annealed_adam.py
class AnnealedAdam (line 4) | class AnnealedAdam(Adam):
method __init__ (line 26) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
method step (line 33) | def step(self, closure=None):
FILE: inferno/io/box/binary_blobs.py
class BinaryBlobs (line 7) | class BinaryBlobs(data.Dataset):
method __init__ (line 10) | def __init__(self, size=20, length=512, blob_size_fraction=0.1,
method __getitem__ (line 44) | def __getitem__(self, index):
method __len__ (line 108) | def __len__(self):
function get_binary_blob_loaders (line 112) | def get_binary_blob_loaders(train_batch_size=1, test_batch_size=1,
FILE: inferno/io/box/camvid.py
function is_image_file (line 22) | def is_image_file(filename):
function make_dataset (line 74) | def make_dataset(dir):
function label_to_long_tensor (line 85) | def label_to_long_tensor(pic):
function label_to_pil_image (line 92) | def label_to_pil_image(label):
class CamVid (line 109) | class CamVid(data.Dataset):
method __init__ (line 123) | def __init__(self, root, split='train',
method __getitem__ (line 145) | def __getitem__(self, index):
method __len__ (line 158) | def __len__(self):
method download (line 161) | def download(self):
function get_camvid_loaders (line 168) | def get_camvid_loaders(root_directory, image_shape=(360, 480), labels_as...
FILE: inferno/io/box/cifar.py
function get_cifar10_loaders (line 8) | def get_cifar10_loaders(root_directory, train_batch_size=128, test_batch...
function get_cifar100_loaders (line 64) | def get_cifar100_loaders(root_directory, train_batch_size=128, test_batc...
FILE: inferno/io/box/cityscapes.py
function get_matching_labelimage_file (line 173) | def get_matching_labelimage_file(f, groundtruth):
function get_filelist (line 180) | def get_filelist(path):
function make_dataset (line 190) | def make_dataset(path, split):
function extract_image (line 211) | def extract_image(path, image_path):
class Cityscapes (line 219) | class Cityscapes(data.Dataset):
method __init__ (line 237) | def __init__(self, root_folder, split='train', read_from_zip_archive=T...
method __getitem__ (line 263) | def __getitem__(self, index):
method __len__ (line 284) | def __len__(self):
method download (line 287) | def download(self):
method get_image_and_label_roots (line 292) | def get_image_and_label_roots(self):
function make_transforms (line 312) | def make_transforms(image_shape, labels_as_onehot):
function get_cityscapes_loaders (line 351) | def get_cityscapes_loaders(root_directory, image_shape=(1024, 2048), lab...
FILE: inferno/io/core/base.py
class SyncableDataset (line 4) | class SyncableDataset(Dataset):
method __init__ (line 5) | def __init__(self, base_sequence=None):
method sync_with (line 8) | def sync_with(self, dataset):
method __len__ (line 13) | def __len__(self):
class IndexSpec (line 22) | class IndexSpec(object):
method __init__ (line 28) | def __init__(self, index=None, base_sequence_at_index=None):
method __int__ (line 32) | def __int__(self):
FILE: inferno/io/core/concatenate.py
class Concatenate (line 6) | class Concatenate(Dataset):
method __init__ (line 11) | def __init__(self, *datasets, transforms=None):
method map_index (line 18) | def map_index(self, index):
method __getitem__ (line 40) | def __getitem__(self, index):
method __len__ (line 51) | def __len__(self):
method __repr__ (line 54) | def __repr__(self):
FILE: inferno/io/core/data_utils.py
function implements_sync_primitives (line 2) | def implements_sync_primitives(dataset):
function defines_base_sequence (line 6) | def defines_base_sequence(dataset):
FILE: inferno/io/core/zip.py
class Zip (line 11) | class Zip(SyncableDataset):
method __init__ (line 17) | def __init__(self, *datasets, sync=False, transforms=None):
method sync_datasets (line 40) | def sync_datasets(self):
method sync_with (line 46) | def sync_with(self, dataset):
method __getitem__ (line 53) | def __getitem__(self, index):
method __len__ (line 63) | def __len__(self):
method __repr__ (line 69) | def __repr__(self):
class ZipReject (line 79) | class ZipReject(Zip):
method __init__ (line 85) | def __init__(self, *datasets, sync=False, transforms=None,
method remove_rejected (line 131) | def remove_rejected(self):
method __len__ (line 140) | def __len__(self):
method next_index_to_try (line 147) | def next_index_to_try(self, index):
method fetch_from_rejection_datasets (line 153) | def fetch_from_rejection_datasets(self, index):
method __getitem__ (line 158) | def __getitem__(self, index):
FILE: inferno/io/transform/base.py
class Transform (line 5) | class Transform(object):
method __init__ (line 20) | def __init__(self, apply_to=None):
method build_random_variables (line 31) | def build_random_variables(self, **kwargs):
method clear_random_variables (line 34) | def clear_random_variables(self):
method get_random_variable (line 37) | def get_random_variable(self, key, default=None, build=True,
method set_random_variable (line 48) | def set_random_variable(self, key, value):
method __call__ (line 51) | def __call__(self, *tensors, **transform_function_kwargs):
method _apply_tensor_function (line 81) | def _apply_tensor_function(self, tensor, **transform_function_kwargs):
method _apply_image_function (line 87) | def _apply_image_function(self, tensor, **transform_function_kwargs):
method _apply_volume_function (line 115) | def _apply_volume_function(self, tensor, **transform_function_kwargs):
class Compose (line 142) | class Compose(object):
method __init__ (line 144) | def __init__(self, *transforms):
method add (line 154) | def add(self, transform):
method remove (line 159) | def remove(self, name):
method __call__ (line 169) | def __call__(self, *tensors):
class DTypeMapping (line 176) | class DTypeMapping(object):
FILE: inferno/io/transform/generic.py
class Normalize (line 7) | class Normalize(Transform):
method __init__ (line 9) | def __init__(self, eps=1e-4, mean=None, std=None, ignore_value=None, *...
method tensor_function (line 28) | def tensor_function(self, tensor):
class NormalizeRange (line 50) | class NormalizeRange(Transform):
method __init__ (line 52) | def __init__(self, normalize_by=255., **super_kwargs):
method tensor_function (line 64) | def tensor_function(self, tensor):
class Project (line 68) | class Project(Transform):
method __init__ (line 74) | def __init__(self, projection, **super_kwargs):
method tensor_function (line 86) | def tensor_function(self, tensor):
class Label2OneHot (line 93) | class Label2OneHot(Transform, DTypeMapping):
method __init__ (line 95) | def __init__(self, num_classes, dtype='float', **super_kwargs):
method tensor_function (line 110) | def tensor_function(self, tensor):
class Cast (line 121) | class Cast(Transform, DTypeMapping):
method __init__ (line 123) | def __init__(self, dtype='float', **super_kwargs):
method tensor_function (line 136) | def tensor_function(self, tensor):
class AsTorchBatch (line 140) | class AsTorchBatch(Transform):
method __init__ (line 148) | def __init__(self, dimensionality, add_channel_axis_if_necessary=True,...
method _to_batch (line 167) | def _to_batch(self, tensor):
method tensor_function (line 196) | def tensor_function(self, tensor):
FILE: inferno/io/transform/image.py
class PILImage2NumPyArray (line 13) | class PILImage2NumPyArray(Transform):
method tensor_function (line 19) | def tensor_function(self, tensor):
class Scale (line 33) | class Scale(Transform):
method __init__ (line 44) | def __init__(self, output_image_shape, interpolation_order=3, zoom_kwa...
method image_function (line 69) | def image_function(self, image):
class RandomCrop (line 87) | class RandomCrop(Transform):
method __init__ (line 99) | def __init__(self, output_image_shape, **super_kwargs):
method clear_random_variables (line 122) | def clear_random_variables(self):
method build_random_variables (line 126) | def build_random_variables(self, height_leeway, width_leeway):
method image_function (line 134) | def image_function(self, image):
class RandomSizedCrop (line 168) | class RandomSizedCrop(Transform):
method __init__ (line 175) | def __init__(self, ratio_between=None, height_ratio_between=None, widt...
method build_random_variables (line 222) | def build_random_variables(self, image_shape):
method image_function (line 252) | def image_function(self, image):
class RandomGammaCorrection (line 282) | class RandomGammaCorrection(Transform):
method __init__ (line 291) | def __init__(self, gamma_between=(0.5, 2.), gain=1, **super_kwargs):
method build_random_variables (line 306) | def build_random_variables(self):
method image_function (line 312) | def image_function(self, image):
class ElasticTransform (line 319) | class ElasticTransform(Transform):
method __init__ (line 324) | def __init__(self, alpha, sigma, order=1, invert=False, **super_kwargs):
method build_random_variables (line 332) | def build_random_variables(self, **kwargs):
method cast (line 352) | def cast(self, image):
method uncast (line 358) | def uncast(self, image):
method image_function (line 364) | def image_function(self, image):
class AdditiveGaussianNoise (line 380) | class AdditiveGaussianNoise(Transform):
method __init__ (line 382) | def __init__(self, sigma, **super_kwargs):
method build_random_variables (line 386) | def build_random_variables(self, **kwargs):
method image_function (line 391) | def image_function(self, image):
class RandomRotate (line 396) | class RandomRotate(Transform):
method __init__ (line 398) | def __init__(self, **super_kwargs):
method build_random_variables (line 401) | def build_random_variables(self, **kwargs):
method image_function (line 405) | def image_function(self, image):
class RandomTranspose (line 409) | class RandomTranspose(Transform):
method __init__ (line 411) | def __init__(self, **super_kwargs):
method build_random_variables (line 414) | def build_random_variables(self, **kwargs):
method image_function (line 418) | def image_function(self, image):
class RandomFlip (line 424) | class RandomFlip(Transform):
method __init__ (line 426) | def __init__(self, allow_lr_flips=True, allow_ud_flips=True, **super_k...
method build_random_variables (line 431) | def build_random_variables(self, **kwargs):
method image_function (line 436) | def image_function(self, image):
class CenterCrop (line 444) | class CenterCrop(Transform):
method __init__ (line 446) | def __init__(self, size, **super_kwargs):
method image_function (line 451) | def image_function(self, image):
class BinaryMorphology (line 463) | class BinaryMorphology(Transform):
method __init__ (line 468) | def __init__(self, mode, num_iterations=1, morphology_kwargs=None, **s...
method image_function (line 492) | def image_function(self, image):
class BinaryDilation (line 505) | class BinaryDilation(BinaryMorphology):
method __init__ (line 507) | def __init__(self, num_iterations=1, morphology_kwargs=None, **super_k...
class BinaryErosion (line 513) | class BinaryErosion(BinaryMorphology):
method __init__ (line 515) | def __init__(self, num_iterations=1, morphology_kwargs=None, **super_k...
class FineRandomRotations (line 521) | class FineRandomRotations(Transform):
method __init__ (line 533) | def __init__(self, angle_range, axes=(1,2), mask_label=0, **super_kwar...
method build_random_variables (line 539) | def build_random_variables(self):
method batch_function (line 545) | def batch_function(self, image):
class RandomScaleSegmentation (line 551) | class RandomScaleSegmentation(Transform):
method __init__ (line 561) | def __init__(self, scale_range, resize=True, pad_const=0, **super_kwar...
method build_random_variables (line 567) | def build_random_variables(self):
method batch_function (line 573) | def batch_function(self, image):
FILE: inferno/io/transform/volume.py
class RandomFlip3D (line 9) | class RandomFlip3D(Transform):
method __init__ (line 10) | def __init__(self, **super_kwargs):
method build_random_variables (line 13) | def build_random_variables(self, **kwargs):
method volume_function (line 19) | def volume_function(self, volume):
class RandomRot3D (line 29) | class RandomRot3D(Transform):
method __init__ (line 30) | def __init__(self, rot_range, p=0.125, reshape=False, order=0, mode='n...
method build_random_variables (line 38) | def build_random_variables(self, **kwargs):
method volume_function (line 49) | def volume_function(self, volume):
class AdditiveRandomNoise3D (line 73) | class AdditiveRandomNoise3D(Transform):
method __init__ (line 83) | def __init__(self, shape, std, **super_kwargs):
method build_random_variables (line 88) | def build_random_variables(self, **kwargs):
method volume_function (line 93) | def volume_function(self, volume):
class AdditiveNoise (line 99) | class AdditiveNoise(Transform):
method __init__ (line 109) | def __init__(self, sigma, mode='gaussian', **super_kwargs):
method tensor_function (line 115) | def tensor_function(self, volume):
class CentralSlice (line 120) | class CentralSlice(Transform):
method volume_function (line 121) | def volume_function(self, volume):
class VolumeCenterCrop (line 126) | class VolumeCenterCrop(Transform):
method __init__ (line 128) | def __init__(self, size, **super_kwargs):
method volume_function (line 134) | def volume_function(self, volume):
class VolumeAsymmetricCrop (line 143) | class VolumeAsymmetricCrop(Transform):
method __init__ (line 145) | def __init__(self, crop_left, crop_right, **super_kwargs):
method volume_function (line 154) | def volume_function(self, volume):
class Slices2Channels (line 160) | class Slices2Channels(Transform):
method __init__ (line 164) | def __init__(self, num_channels, downsampling=1, **super_kwargs):
method batch_function (line 169) | def batch_function(self, batch):
class RandomScale3D (line 187) | class RandomScale3D(Transform):
method __init__ (line 189) | def __init__(self, zoom_factor_range, interpolation_order=0, p=0.5,
method build_random_variables (line 218) | def build_random_variables(self, **kwargs):
method volume_function (line 227) | def volume_function(self, volume):
class RandomBinaryMorphology3D (line 243) | class RandomBinaryMorphology3D(Transform):
method __init__ (line 248) | def __init__(self, p=0.5, num_iter_range=(1, 5), morphology_kwargs=Non...
method build_random_variables (line 272) | def build_random_variables(self, **kwargs):
method volume_function (line 278) | def volume_function(self, volume):
class CropPad2Divisible (line 295) | class CropPad2Divisible(Transform):
method __init__ (line 302) | def __init__(self, divisor=16, crop_pad_threshold=0.2,
method volume_function (line 330) | def volume_function(self, volume):
class CropPad2Size (line 349) | class CropPad2Size(Transform):
method __init__ (line 354) | def __init__(self, output_size, mode='constant',
method volume_function (line 375) | def volume_function(self, volume):
FILE: inferno/io/volumetric/lazy_volume_loader.py
function filter_base_sequence (line 26) | def filter_base_sequence(input_path, input_key,
class LazyVolumeLoaderBase (line 54) | class LazyVolumeLoaderBase(SyncableDataset):
method __init__ (line 55) | def __init__(self, dataset, window_size, stride, downsampling_ratio=No...
method load_base_sequence (line 98) | def load_base_sequence(base_sequence):
method normalize_slice (line 109) | def normalize_slice(self, data_slice):
method get_shape (line 120) | def get_shape(self):
method make_sliding_windows (line 131) | def make_sliding_windows(self):
method __getitem__ (line 139) | def __getitem__(self, index):
method clone (line 200) | def clone(self, dataset=None, transforms=None, name=None):
method __repr__ (line 216) | def __repr__(self):
class LazyVolumeLoader (line 221) | class LazyVolumeLoader(LazyVolumeLoaderBase):
method __init__ (line 222) | def __init__(self, file_impl, path,
method validate_data_slice (line 269) | def validate_data_slice(self, data_slice):
class LazyHDF5VolumeLoader (line 274) | class LazyHDF5VolumeLoader(LazyVolumeLoader):
method __init__ (line 275) | def __init__(self, path, path_in_h5_dataset=None, data_slice=None, tra...
method __del__ (line 284) | def __del__(self):
class LazyN5VolumeLoader (line 288) | class LazyN5VolumeLoader(LazyVolumeLoader):
method __init__ (line 289) | def __init__(self, path, path_in_file=None, data_slice=None, transform...
class LazyZarrVolumeLoader (line 301) | class LazyZarrVolumeLoader(LazyVolumeLoader):
method __init__ (line 302) | def __init__(self, path, path_in_file=None, data_slice=None, transform...
FILE: inferno/io/volumetric/volume.py
class VolumeLoader (line 13) | class VolumeLoader(SyncableDataset):
method __init__ (line 40) | def __init__(self, volume, window_size, stride, downsampling_ratio=Non...
method pad_volume (line 94) | def pad_volume(self, padding=None):
method make_sliding_windows (line 108) | def make_sliding_windows(self):
method __getitem__ (line 117) | def __getitem__(self, index):
method clone (line 133) | def clone(self, volume=None, transforms=None, name=None):
method __repr__ (line 149) | def __repr__(self):
class HDF5VolumeLoader (line 153) | class HDF5VolumeLoader(VolumeLoader):
method is_h5 (line 182) | def is_h5(file_path):
method __init__ (line 191) | def __init__(self, path, path_in_h5_dataset=None, data_slice=None, tra...
class TIFVolumeLoader (line 246) | class TIFVolumeLoader(VolumeLoader):
method __init__ (line 248) | def __init__(self, path, data_slice=None, transforms=None, name=None, ...
FILE: inferno/io/volumetric/volumetric_utils.py
function slidingwindowslices (line 5) | def slidingwindowslices(shape, window_size, strides,
function slidingwindowslices_depr (line 63) | def slidingwindowslices_depr(shape, nhoodsize, stride=1, ds=1, window=No...
function parse_data_slice (line 137) | def parse_data_slice(data_slice):
FILE: inferno/trainers/basic.py
class Trainer (line 39) | class Trainer(object):
method __init__ (line 56) | def __init__(self, model=None):
method mixed_precision (line 148) | def mixed_precision(self):
method mixed_precision (line 153) | def mixed_precision(self, mp):
method apex_opt_level (line 167) | def apex_opt_level(self):
method apex_opt_level (line 171) | def apex_opt_level(self, opt_level):
method console (line 177) | def console(self):
method set_console (line 181) | def set_console(self, console):
method quiet (line 186) | def quiet(self):
method callbacks (line 191) | def callbacks(self):
method register_callback (line 195) | def register_callback(self, callback, trigger='auto', **callback_kwargs):
method model (line 221) | def model(self):
method model (line 227) | def model(self, value):
method bind_model (line 230) | def bind_model(self, model):
method model_is_defined (line 254) | def model_is_defined(self):
method retain_graph (line 258) | def retain_graph(self):
method retain_graph (line 262) | def retain_graph(self, value):
method backprop_every (line 267) | def backprop_every(self):
method backprop_every (line 271) | def backprop_every(self, value):
method set_backprop_every (line 274) | def set_backprop_every(self, num_steps):
method optimizer (line 293) | def optimizer(self):
method optimizer (line 299) | def optimizer(self, value):
method optimizer_is_defined (line 308) | def optimizer_is_defined(self):
method build_optimizer (line 311) | def build_optimizer(self, method, param_groups=None, **kwargs):
method criterion (line 357) | def criterion(self):
method criterion (line 363) | def criterion(self, value):
method build_criterion (line 372) | def build_criterion(self, method, **kwargs):
method criterion_is_defined (line 426) | def criterion_is_defined(self):
method validation_criterion (line 430) | def validation_criterion(self):
method validation_criterion (line 437) | def validation_criterion(self, value):
method build_validation_criterion (line 446) | def build_validation_criterion(self, method, **kwargs):
method validation_criterion_is_train_criterion (line 499) | def validation_criterion_is_train_criterion(self, yes=True):
method validation_criterion_is_defined (line 506) | def validation_criterion_is_defined(self):
method metric (line 510) | def metric(self):
method metric (line 516) | def metric(self, value):
method evaluating_metric_every (line 523) | def evaluating_metric_every(self):
method evaluate_metric_every (line 526) | def evaluate_metric_every(self, frequency):
method evaluate_metric_now (line 547) | def evaluate_metric_now(self):
method evaluate_metric_now (line 568) | def evaluate_metric_now(self, value):
method build_metric (line 571) | def build_metric(self, method, **kwargs):
method metric_is_defined (line 608) | def metric_is_defined(self):
method eval_mode (line 612) | def eval_mode(self):
method train_mode (line 622) | def train_mode(self):
method train_loader (line 633) | def train_loader(self):
method train_loader (line 638) | def train_loader(self, value):
method validate_loader (line 643) | def validate_loader(self):
method validate_loader (line 648) | def validate_loader(self, value):
method logger (line 653) | def logger(self):
method logger (line 658) | def logger(self, value):
method log_directory (line 665) | def log_directory(self):
method log_directory (line 670) | def log_directory(self, value):
method pickle_module (line 676) | def pickle_module(self):
method pickle_module (line 684) | def pickle_module(self, value):
method saving_every (line 692) | def saving_every(self):
method save_at_best_validation_score (line 696) | def save_at_best_validation_score(self, yes=True):
method save_now (line 702) | def save_now(self):
method save_now (line 725) | def save_now(self, value):
method save_every (line 729) | def save_every(self, frequency, to_directory=None,
method save_directory (line 755) | def save_directory(self):
method save_to_directory (line 758) | def save_to_directory(self, to_directory=None, checkpoint_filename=None,
method validating_every (line 776) | def validating_every(self):
method validate_now (line 780) | def validate_now(self):
method validate_now (line 805) | def validate_now(self, value):
method validate_every (line 808) | def validate_every(self, frequency, for_num_iterations=None):
method iteration_count (line 833) | def iteration_count(self):
method epoch_count (line 837) | def epoch_count(self):
method target_batch_dim (line 841) | def target_batch_dim(self):
method target_batch_dim (line 845) | def target_batch_dim(self, value):
method set_target_batch_dim (line 851) | def set_target_batch_dim(self, value):
method build_logger (line 855) | def build_logger(self, logger=None, log_directory=None, **kwargs):
method set_log_directory (line 893) | def set_log_directory(self, log_directory):
method update_state (line 917) | def update_state(self, key, value):
method update_state_from_dictionary (line 923) | def update_state_from_dictionary(self, dictionary):
method update_state_from_model_state_hooks (line 929) | def update_state_from_model_state_hooks(self):
method get_state (line 935) | def get_state(self, key, default=None):
method current_learning_rate (line 942) | def current_learning_rate(self):
method get_current_learning_rate (line 945) | def get_current_learning_rate(self):
method to (line 960) | def to(self, device):
method cuda (line 976) | def cuda(self, devices=None, base_device=None):
method cpu (line 1019) | def cpu(self):
method is_cuda (line 1036) | def is_cuda(self):
method to_device (line 1040) | def to_device(self, objects):
method apply_model (line 1046) | def apply_model(self, *inputs):
method cast (line 1058) | def cast(self, objects):
method set_precision (line 1073) | def set_precision(self, dtype):
method dtype (line 1093) | def dtype(self):
method dtype (line 1097) | def dtype(self, value):
method bind_loader (line 1100) | def bind_loader(self, name, loader, num_inputs=None, num_targets=1):
method get_loader_specs (line 1151) | def get_loader_specs(self, name):
method fetch_next_batch (line 1157) | def fetch_next_batch(self, from_loader='train', restart_exhausted_gene...
method verify_batch (line 1183) | def verify_batch(self, batch, from_loader):
method split_batch (line 1201) | def split_batch(self, batch, from_loader):
method restart_generators (line 1219) | def restart_generators(self, of_loader=None):
method wrap_batch (line 1231) | def wrap_batch(self, batch, from_loader=None, requires_grad=False):
method next_iteration (line 1277) | def next_iteration(self):
method next_epoch (line 1280) | def next_epoch(self):
method stop_fitting (line 1294) | def stop_fitting(self, max_num_iterations=None, max_num_epochs=None):
method set_max_num_iterations (line 1313) | def set_max_num_iterations(self, max_num_iterations):
method set_max_num_epochs (line 1338) | def set_max_num_epochs(self, max_num_epochs):
method fit (line 1361) | def fit(self, max_num_iterations=None, max_num_epochs=None):
method apply_model_and_loss (line 1426) | def apply_model_and_loss(self, inputs, target, backward=True, mode=None):
method train_for (line 1457) | def train_for(self, num_iterations=None, break_callback=None):
method validate_for (line 1535) | def validate_for(self, num_iterations=None, loader_name='validate'):
method record_validation_results (line 1657) | def record_validation_results(self, validation_loss, validation_error):
method get_config (line 1671) | def get_config(self, exclude_loader=True):
method set_config (line 1683) | def set_config(self, config_dict):
method save (line 1692) | def save(self, exclude_loader=True, stash_best_checkpoint=True):
method save_model (line 1731) | def save_model(self, to_directory=None):
method load (line 1739) | def load(self, from_directory=None, best=False, filename=None, map_loc...
method load_model (line 1776) | def load_model(self, from_directory=None, filename=None):
method load_ (line 1786) | def load_(self, *args, **kwargs):
method print (line 1791) | def print(self, message):
method build (line 1795) | def build(cls, model=None, **trainer_config):
FILE: inferno/trainers/callbacks/base.py
class CallbackEngine (line 4) | class CallbackEngine(object):
method __init__ (line 46) | def __init__(self):
method register_new_trigger (line 52) | def register_new_trigger(self, trigger_name):
method bind_trainer (line 56) | def bind_trainer(self, trainer):
method unbind_trainer (line 60) | def unbind_trainer(self):
method trainer_is_bound (line 65) | def trainer_is_bound(self):
method register_callback (line 68) | def register_callback(self, callback, trigger='auto', bind_trainer=True):
method rebind_trainer_to_all_callbacks (line 92) | def rebind_trainer_to_all_callbacks(self):
method call (line 103) | def call(self, trigger, **kwargs):
method get_config (line 109) | def get_config(self):
method set_config (line 115) | def set_config(self, config_dict):
method __getstate__ (line 119) | def __getstate__(self):
method __setstate__ (line 122) | def __setstate__(self, state):
class Callback (line 126) | class Callback(object):
method __init__ (line 128) | def __init__(self):
method register_instance (line 134) | def register_instance(cls, instance):
method get_instances (line 141) | def get_instances(cls):
method trainer (line 148) | def trainer(self):
method bind_trainer (line 151) | def bind_trainer(self, trainer):
method unbind_trainer (line 155) | def unbind_trainer(self):
method __call__ (line 159) | def __call__(self, **kwargs):
method get_config (line 165) | def get_config(self):
method set_config (line 170) | def set_config(self, config_dict):
method __getstate__ (line 174) | def __getstate__(self):
method __setstate__ (line 177) | def __setstate__(self, state):
method toggle_debug (line 180) | def toggle_debug(self):
method debug_print (line 184) | def debug_print(self, message):
FILE: inferno/trainers/callbacks/console.py
class StdoutPrinter (line 4) | class StdoutPrinter(object):
method print (line 5) | def print(self, message):
class Console (line 9) | class Console(object):
method __init__ (line 15) | def __init__(self, printer=StdoutPrinter()):
method set_console (line 19) | def set_console(self, console):
method _print (line 22) | def _print(self, message, level):
method info (line 28) | def info(self, message):
method print (line 31) | def print(self, message):
method progress (line 34) | def progress(self, message):
method warning (line 37) | def warning(self, message):
method debug (line 40) | def debug(self, message):
method _toggle (line 43) | def _toggle(self, level, state):
method toggle_info (line 50) | def toggle_info(self, state):
method toggle_progress (line 53) | def toggle_progress(self, state):
method toggle_warning (line 56) | def toggle_warning(self, state):
class ShowMinimalConsoleInfo (line 61) | class ShowMinimalConsoleInfo(Callback):
method __init__ (line 67) | def __init__(self, *args, **kwargs):
method begin_of_fit (line 70) | def begin_of_fit(self,**_):
method end_of_epoch (line 73) | def end_of_epoch(self, **_):
FILE: inferno/trainers/callbacks/essentials.py
class NaNDetector (line 12) | class NaNDetector(Callback):
method end_of_training_iteration (line 13) | def end_of_training_iteration(self, **_):
class PersistentSave (line 22) | class PersistentSave(Callback):
method __init__ (line 23) | def __init__(self, template='checkpoint.pytorch.epoch{epoch_count}.ite...
method begin_of_save (line 27) | def begin_of_save(self, **kwargs):
method end_of_save (line 31) | def end_of_save(self, save_to_directory, **_):
class DumpHDF5Every (line 41) | class DumpHDF5Every(Callback):
method __init__ (line 43) | def __init__(self, frequency, to_directory,
method dump_every (line 64) | def dump_every(self):
method dump_every (line 68) | def dump_every(self, value):
method dump_now (line 75) | def dump_now(self):
method add_to_dump_cache (line 80) | def add_to_dump_cache(self, key, value):
method clear_dump_cache (line 87) | def clear_dump_cache(self):
method dump_state (line 90) | def dump_state(self, key, dump_while='training'):
method dump_states (line 113) | def dump_states(self, keys, dump_while='training'):
method get_file_path (line 118) | def get_file_path(self, mode):
method dump (line 131) | def dump(self, mode):
method end_of_training_iteration (line 147) | def end_of_training_iteration(self, **_):
method end_of_validation_run (line 160) | def end_of_validation_run(self, **_):
class SaveAtBestValidationScore (line 173) | class SaveAtBestValidationScore(Callback):
method __init__ (line 179) | def __init__(self, smoothness=0, verbose=False):
method end_of_validation_run (line 188) | def end_of_validation_run(self, **_):
class ParameterEMA (line 223) | class ParameterEMA(Callback):
method __init__ (line 225) | def __init__(self, momentum):
method maintain (line 239) | def maintain(self):
method apply (line 245) | def apply(self):
method end_of_training_iteration (line 252) | def end_of_training_iteration(self, **_):
class GradientClip (line 256) | class GradientClip(Callback):
method __init__ (line 257) | def __init__(self, clip_value=None, clip_norm=None):
method mode (line 270) | def mode(self):
method norm_or_value (line 274) | def norm_or_value(self):
method after_model_and_loss_is_applied (line 277) | def after_model_and_loss_is_applied(self, **_):
class GarbageCollection (line 281) | class GarbageCollection(Callback):
method end_of_training_iteration (line 287) | def end_of_training_iteration(self, **_):
FILE: inferno/trainers/callbacks/gradients.py
class LogOutputGradients (line 6) | class LogOutputGradients(Callback):
method __init__ (line 9) | def __init__(self, frequency):
method log_every (line 16) | def log_every(self):
method log_every (line 20) | def log_every(self, value):
method hook (line 26) | def hook(self, module, grad_input, grad_output):
method add_hook (line 38) | def add_hook(self):
method begin_of_fit (line 41) | def begin_of_fit(self, **kwargs):
method begin_of_save (line 46) | def begin_of_save(self, **_):
method end_of_save (line 52) | def end_of_save(self, **_):
FILE: inferno/trainers/callbacks/logging/__init__.py
function get_logger (line 10) | def get_logger(name):
FILE: inferno/trainers/callbacks/logging/base.py
class Logger (line 5) | class Logger(Callback):
method __init__ (line 13) | def __init__(self, log_directory=None):
method log_directory (line 20) | def log_directory(self):
method log_directory (line 29) | def log_directory(self, value):
method set_log_directory (line 32) | def set_log_directory(self, log_directory):
FILE: inferno/trainers/callbacks/logging/tensorboard.py
class TaggedImage (line 12) | class TaggedImage(object):
method __init__ (line 13) | def __init__(self, array, tag):
class TensorboardLogger (line 18) | class TensorboardLogger(Logger):
method __init__ (line 27) | def __init__(self, log_directory=None,
method writer (line 85) | def writer(self):
method log_scalars_every (line 91) | def log_scalars_every(self):
method log_scalars_every (line 97) | def log_scalars_every(self, value):
method log_scalars_now (line 101) | def log_scalars_now(self):
method log_images_every (line 109) | def log_images_every(self):
method log_images_every (line 115) | def log_images_every(self, value):
method log_images_now (line 119) | def log_images_now(self):
method log_histograms_every (line 127) | def log_histograms_every(self):
method log_histograms_every (line 133) | def log_histograms_every(self, value):
method log_histograms_now (line 137) | def log_histograms_now(self):
method observe_state (line 144) | def observe_state(self, key, observe_while='training'):
method unobserve_state (line 167) | def unobserve_state(self, key, observe_while='training'):
method unobserve_states (line 176) | def unobserve_states(self, keys, observe_while='training'):
method observe_training_and_validation_state (line 181) | def observe_training_and_validation_state(self, key):
method observe_states (line 185) | def observe_states(self, keys, observe_while='training'):
method observe_training_and_validation_states (line 190) | def observe_training_and_validation_states(self, keys):
method log_object (line 195) | def log_object(self, tag, object_,
method end_of_training_iteration (line 235) | def end_of_training_iteration(self, **_):
method end_of_validation_run (line 253) | def end_of_validation_run(self, **_):
method _tag_image (line 266) | def _tag_image(self, image, base_tag, prefix=None, instance_num=None, ...
method extract_images_from_batch (line 279) | def extract_images_from_batch(self, batch, base_tag=None, prefix=None):
method log_image_or_volume_batch (line 359) | def log_image_or_volume_batch(self, tag, batch, step=None):
method log_scalar (line 365) | def log_scalar(self, tag, value, step):
method log_images (line 377) | def log_images(self, tag, images, step, image_format='CHW'):
method _order_image_axes (line 401) | def _order_image_axes(image, image_format='CHW', target_format='CHW'):
method _normalize_image (line 436) | def _normalize_image(image):
method log_histogram (line 443) | def log_histogram(self, tag, values, step, bins=1000):
method get_config (line 448) | def get_config(self):
FILE: inferno/trainers/callbacks/scheduling.py
class _Scheduler (line 8) | class _Scheduler(Callback):
method __init__ (line 9) | def __init__(self, monitor='auto', monitor_momentum=0., monitor_while=...
method monitor (line 20) | def monitor(self):
method monitor (line 25) | def monitor(self, value):
method monitor_value (line 29) | def monitor_value(self):
method monitor_while (line 33) | def monitor_while(self):
method monitor_while (line 46) | def monitor_while(self, value):
method get_monitor_value (line 58) | def get_monitor_value(self):
method maintain_monitor_moving_average (line 88) | def maintain_monitor_moving_average(self):
class AutoLR (line 94) | class AutoLR(_Scheduler):
method __init__ (line 101) | def __init__(self, factor, patience, required_minimum_relative_improve...
method patience (line 163) | def patience(self):
method patience (line 168) | def patience(self, value):
method cooldown_duration (line 172) | def cooldown_duration(self):
method cooldown_duration (line 176) | def cooldown_duration(self, value):
method duration_since_last_decay (line 181) | def duration_since_last_decay(self):
method duration_since_last_improvment (line 201) | def duration_since_last_improvment(self):
method out_of_patience (line 221) | def out_of_patience(self):
method in_cooldown (line 225) | def in_cooldown(self):
method decay (line 231) | def decay(self):
method maintain_monitor_moving_average (line 244) | def maintain_monitor_moving_average(self):
method monitor_value_has_significantly_improved (line 250) | def monitor_value_has_significantly_improved(self):
method end_of_training_iteration (line 278) | def end_of_training_iteration(self, **_):
method end_of_validation_run (line 289) | def end_of_validation_run(self, **_):
method is_significantly_less_than (line 303) | def is_significantly_less_than(x, y, min_relative_delta):
class AutoLRDecay (line 311) | class AutoLRDecay(AutoLR):
class DecaySpec (line 321) | class DecaySpec(object):
method __init__ (line 323) | def __init__(self, duration, factor):
method match (line 330) | def match(self, iteration_count=None, epoch_count=None, when_equal_ret...
method new (line 342) | def new(self):
method build_from (line 346) | def build_from(cls, args):
class ManualLR (line 357) | class ManualLR(Callback):
method __init__ (line 358) | def __init__(self, decay_specs, exclude_param_groups=None):
method match (line 365) | def match(self):
method decay (line 379) | def decay(self, factor):
method end_of_training_iteration (line 390) | def end_of_training_iteration(self, **_):
class SaveModelRegularly (line 397) | class SaveModelRegularly(Callback):
method __init__ (line 400) | def __init__(self, frequency):
method save_now (line 405) | def save_now(self):
method end_of_training_iteration (line 410) | def end_of_training_iteration(self, **_):
FILE: inferno/trainers/callbacks/tqdm.py
class TQDMPrinter (line 7) | class TQDMPrinter(object):
method __init__ (line 8) | def __init__(self, progress):
method print (line 11) | def print(self, message):
class TQDMConsole (line 19) | class TQDMConsole(Console):
method __init__ (line 20) | def __init__(self):
class TQDMProgressBar (line 24) | class TQDMProgressBar(Callback):
method __init__ (line 25) | def __init__(self, *args, **kwargs):
method bind_trainer (line 32) | def bind_trainer(self, *args, **kwargs):
method _init_epoch_bar_train (line 37) | def _init_epoch_bar_train(self):
method print (line 43) | def print(self, message, **_):
method begin_of_fit (line 50) | def begin_of_fit(self, max_num_epochs, **_):
method end_of_fit (line 57) | def end_of_fit(self, **_):
method begin_of_epoch (line 62) | def begin_of_epoch(self, **_):
method end_of_epoch (line 66) | def end_of_epoch(self, **_):
method begin_of_training_iteration (line 70) | def begin_of_training_iteration(self, **_):
method begin_of_validation_iteration (line 78) | def begin_of_validation_iteration(self, **_):
method begin_of_training_run (line 82) | def begin_of_training_run(self, **_):
method end_of_training_run (line 85) | def end_of_training_run(self, **_):
method begin_of_validation_run (line 91) | def begin_of_validation_run(self, num_iterations, num_iterations_in_ge...
method end_of_validation_run (line 100) | def end_of_validation_run(self, **_):
FILE: inferno/trainers/callbacks/tqdmstub.py
class TQDMProgressBar (line 3) | class TQDMProgressBar(Callback):
method __init__ (line 4) | def __init__(self, *args, **kwargs):
method bind_trainer (line 7) | def bind_trainer(self, *args, **kwargs):
method begin_of_fit (line 11) | def begin_of_fit(self, **_):
FILE: inferno/utils/exceptions.py
function assert_ (line 4) | def assert_(condition, message='', exception_type=AssertionError):
class ShapeError (line 13) | class ShapeError(ValueError):
class FrequencyValueError (line 17) | class FrequencyValueError(ValueError):
class DeviceError (line 21) | class DeviceError(ValueError):
class NotSetError (line 25) | class NotSetError(ValueError):
class NotTorchModuleError (line 32) | class NotTorchModuleError(TypeError):
class FrequencyTypeError (line 36) | class FrequencyTypeError(TypeError):
class DTypeError (line 40) | class DTypeError(TypeError):
class ClassNotFoundError (line 47) | class ClassNotFoundError(LookupError):
class NotUnwrappableError (line 54) | class NotUnwrappableError(NotImplementedError):
FILE: inferno/utils/io_utils.py
function fromh5 (line 9) | def fromh5(path, datapath=None, dataslice=None, asnumpy=True, preptrain=...
function toh5 (line 29) | def toh5(data, path, datapath='data', compression=None, chunks=None):
function fromz5 (line 35) | def fromz5(path, datapath, dataslice=None, n_threads=8):
function yaml2dict (line 47) | def yaml2dict(path):
function print_tensor (line 56) | def print_tensor(tensor, prefix, directory):
FILE: inferno/utils/math_utils.py
function max_allowed_ds_steps (line 3) | def max_allowed_ds_steps(shape, factor):
FILE: inferno/utils/model_utils.py
function is_model_cuda (line 5) | def is_model_cuda(model):
class ModelTester (line 13) | class ModelTester(object):
method __init__ (line 14) | def __init__(self, input_shape, expected_output_shape):
method cuda (line 19) | def cuda(self):
method get_input (line 23) | def get_input(self):
method __call__ (line 30) | def __call__(self, model):
class MultiscaleModelTester (line 49) | class MultiscaleModelTester(ModelTester):
method __call__ (line 50) | def __call__(self, model):
FILE: inferno/utils/partial_cls.py
function partial_cls (line 13) | def partial_cls(base_cls, name, module, fix=None, default=None):
function register_partial_cls (line 120) | def register_partial_cls(base_cls, name, module, fix=None, default=None):
class Conv (line 130) | class Conv(object):
method __init__ (line 131) | def __init__(self, dim, activation, stride=1):
FILE: inferno/utils/python_utils.py
function ensure_dir (line 11) | def ensure_dir(directory):
function require_dict_kwargs (line 27) | def require_dict_kwargs(kwargs, msg=None):
function is_listlike (line 53) | def is_listlike(x):
function to_iterable (line 57) | def to_iterable(x):
function from_iterable (line 61) | def from_iterable(x):
function robust_len (line 65) | def robust_len(x):
function as_tuple_of_len (line 69) | def as_tuple_of_len(x, len_):
function has_callable_attr (line 79) | def has_callable_attr(object_, name):
function is_maybe_list_of (line 83) | def is_maybe_list_of(check_function):
class delayed_keyboard_interrupt (line 92) | class delayed_keyboard_interrupt(object):
method __enter__ (line 100) | def __enter__(self):
method handler (line 106) | def handler(self, sig, frame):
method __exit__ (line 109) | def __exit__(self, type, value, traceback):
function get_config_for_name (line 116) | def get_config_for_name(config, name):
function deprecated (line 129) | def deprecated(reason):
FILE: inferno/utils/test_utils.py
function generate_random_data (line 7) | def generate_random_data(num_samples, shape, num_classes, hardness=0.3, ...
function generate_random_dataset (line 18) | def generate_random_dataset(num_samples, shape, num_classes, hardness=0....
function generate_random_dataloader (line 29) | def generate_random_dataloader(num_samples, shape, num_classes, hardness...
FILE: inferno/utils/torch_utils.py
function unwrap (line 8) | def unwrap(input_, to_cpu=True, as_numpy=False, extract_item=False):
function is_tensor (line 37) | def is_tensor(object_):
function is_label_tensor (line 42) | def is_label_tensor(object_):
function is_image_tensor (line 46) | def is_image_tensor(object_):
function is_volume_tensor (line 50) | def is_volume_tensor(object_):
function is_image_or_volume_tensor (line 54) | def is_image_or_volume_tensor(object_):
function is_label_image_tensor (line 58) | def is_label_image_tensor(object_):
function is_label_volume_tensor (line 62) | def is_label_volume_tensor(object_):
function is_label_image_or_volume_tensor (line 66) | def is_label_image_or_volume_tensor(object_):
function is_matrix_tensor (line 70) | def is_matrix_tensor(object_):
function is_scalar_tensor (line 74) | def is_scalar_tensor(object_):
function is_vector_tensor (line 78) | def is_vector_tensor(object_):
function assert_same_size (line 82) | def assert_same_size(tensor_1, tensor_2):
function where (line 88) | def where(condition, if_true, if_false):
function flatten_samples (line 118) | def flatten_samples(input_):
function clip_gradients_ (line 143) | def clip_gradients_(parameters, mode, norm_or_value):
FILE: inferno/utils/train_utils.py
class AverageMeter (line 6) | class AverageMeter(object):
method __init__ (line 11) | def __init__(self):
method reset (line 17) | def reset(self):
method update (line 23) | def update(self, val, n=1):
class MovingAverage (line 30) | class MovingAverage(object):
method __init__ (line 32) | def __init__(self, momentum=0):
method reset (line 37) | def reset(self):
method update (line 40) | def update(self, val):
method relative_change (line 49) | def relative_change(self):
class CLUI (line 57) | class CLUI(object):
method __call__ (line 60) | def __call__(self, f):
class Frequency (line 94) | class Frequency(object):
method __init__ (line 96) | def __init__(self, value=None, units=None):
method value (line 106) | def value(self):
method value (line 110) | def value(self, value):
method units (line 124) | def units(self):
method units (line 128) | def units(self, value):
method assert_value_consistent (line 134) | def assert_value_consistent(self, value=None):
method assert_units_consistent (line 146) | def assert_units_consistent(self, units=None):
method is_consistent (line 154) | def is_consistent(self):
method epoch (line 162) | def epoch(self):
method iteration (line 166) | def iteration(self):
method by_epoch (line 171) | def by_epoch(self):
method by_iteration (line 175) | def by_iteration(self):
method every (line 178) | def every(self, value):
method match (line 182) | def match(self, iteration_count=None, epoch_count=None, persistent=Fal...
method __str__ (line 199) | def __str__(self):
method __repr__ (line 202) | def __repr__(self):
method from_string (line 206) | def from_string(cls, string):
method build_from (line 221) | def build_from(cls, args, priority='iterations'):
class Duration (line 234) | class Duration(Frequency):
method match (line 236) | def match(self, iteration_count=None, epoch_count=None, when_equal_ret...
method compare (line 245) | def compare(self, iteration_count=None, epoch_count=None):
method __sub__ (line 254) | def __sub__(self, other):
class NoLogger (line 265) | class NoLogger(object):
method __init__ (line 266) | def __init__(self, logdir=None):
method log_value (line 269) | def log_value(self, *kwargs):
function set_state (line 273) | def set_state(module, key, value):
function get_state (line 285) | def get_state(module, key, default=None):
FILE: tests/test_extensions/test_containers/test_graph.py
class TestGraph (line 6) | class TestGraph(unittest.TestCase):
method setUp (line 7) | def setUp(self):
method test_graph_dummy_basic (line 31) | def test_graph_dummy_basic(self):
method test_graph_dummy_inception (line 64) | def test_graph_dummy_inception(self):
method test_graph_basic (line 89) | def test_graph_basic(self):
method test_graph_device_transfers (line 102) | def test_graph_device_transfers(self):
method test_multi_gpu (line 119) | def test_multi_gpu(self):
FILE: tests/test_extensions/test_criteria/test_core.py
class TestCore (line 6) | class TestCore(unittest.TestCase):
method test_as_2d_criterion (line 7) | def test_as_2d_criterion(self):
FILE: tests/test_extensions/test_criteria/test_elementwise_measures.py
class TestElementwiseMeasures (line 6) | class TestElementwiseMeasures(unittest.TestCase):
method test_weighted_mse_loss (line 7) | def test_weighted_mse_loss(self):
FILE: tests/test_extensions/test_criteria/test_set_similarity_measures.py
class SetSimilarityTest (line 5) | class SetSimilarityTest(unittest.TestCase):
method get_dummy_variables (line 6) | def get_dummy_variables(self):
method get_dummy_variables_with_channels_and_classes (line 11) | def get_dummy_variables_with_channels_and_classes(self):
class TestSorensenDice (line 18) | class TestSorensenDice(SetSimilarityTest):
method test_channelwise (line 20) | def test_channelwise(self):
class TestGeneralizedSorensenDice (line 35) | class TestGeneralizedSorensenDice(SetSimilarityTest):
method test_channelwise (line 36) | def test_channelwise(self):
FILE: tests/test_extensions/test_layers/deprecated/building_blocks.py
class ResBlockTest (line 6) | class ResBlockTest(unittest.TestCase):
method test_2D_simple_ (line 8) | def test_2D_simple_(self):
method test_3D_simple_ (line 16) | def test_3D_simple_(self):
method test_2D_simple_2 (line 24) | def test_2D_simple_2(self):
method test_2D_simple_3 (line 32) | def test_2D_simple_3(self):
method test_2D_simple_4 (line 40) | def test_2D_simple_4(self):
method test_2D_simple_5 (line 49) | def test_2D_simple_5(self):
method test_2D_simple_6 (line 58) | def test_2D_simple_6(self):
method test_3D_simple_6 (line 67) | def test_3D_simple_6(self):
FILE: tests/test_extensions/test_layers/test_activations.py
class ActivationTest (line 6) | class ActivationTest(unittest.TestCase):
method test_selu (line 7) | def test_selu(self):
FILE: tests/test_extensions/test_layers/test_convolutional.py
class TestConvolutional (line 6) | class TestConvolutional(unittest.TestCase):
method test_bn_relu_depthwise_conv2d_pyinn (line 8) | def test_bn_relu_depthwise_conv2d_pyinn(self):
FILE: tests/test_extensions/test_layers/test_device.py
class TransferTest (line 6) | class TransferTest(unittest.TestCase):
method test_device_transfer (line 8) | def test_device_transfer(self):
method test_on_device (line 22) | def test_on_device(self):
FILE: tests/test_extensions/test_layers/test_reshape.py
class TestReshape (line 5) | class TestReshape(unittest.TestCase):
method _get_input_variable (line 6) | def _get_input_variable(self, *shape):
method test_as_matrix (line 9) | def test_as_matrix(self):
method test_flatten (line 17) | def test_flatten(self):
method test_as_2d (line 25) | def test_as_2d(self):
method test_as_3d (line 39) | def test_as_3d(self):
FILE: tests/test_extensions/test_metrics/categorical.py
class TestCategorical (line 6) | class TestCategorical(unittest.TestCase):
method test_iou_basic (line 7) | def test_iou_basic(self):
method test_iou_with_ignore_class (line 17) | def test_iou_with_ignore_class(self):
method test_multiclass_iou (line 26) | def test_multiclass_iou(self):
method test_multiclass_iou_with_ignore_class (line 37) | def test_multiclass_iou_with_ignore_class(self):
FILE: tests/test_extensions/test_models/test_res_unet.py
class ResUNetTest (line 7) | class ResUNetTest(unittest.TestCase):
method test_res_unet_2d (line 8) | def test_res_unet_2d(self):
method test_res_unet_3d (line 15) | def test_res_unet_3d(self):
method test_2d_side_out_bot_up (line 23) | def test_2d_side_out_bot_up(self):
method test_2d_side_out_up (line 42) | def test_2d_side_out_up(self):
method test_2d_side_out_down (line 60) | def test_2d_side_out_down(self):
FILE: tests/test_extensions/test_models/test_unet.py
class _MultiscaleUNet (line 6) | class _MultiscaleUNet(UNet):
method conv_op_factory (line 7) | def conv_op_factory(self, in_channels, out_channels, part, index):
method forward (line 10) | def forward(self, input):
class UNetTest (line 17) | class UNetTest(unittest.TestCase):
method test_unet_2d (line 18) | def test_unet_2d(self):
method test_unet_3d (line 24) | def test_unet_3d(self):
method test_monochannel_unet_3d (line 31) | def test_monochannel_unet_3d(self):
method test_inverse_pyramid_unet_2d (line 44) | def test_inverse_pyramid_unet_2d(self):
FILE: tests/test_inferno.py
class TestInferno (line 24) | class TestInferno(unittest.TestCase):
method read_environment_variables (line 32) | def read_environment_variables(self):
method setUp (line 40) | def setUp(self):
method setUpDatasets (line 44) | def setUpDatasets(self):
method setUpCallbacks (line 65) | def setUpCallbacks(self):
method generate_random_data (line 82) | def generate_random_data(self, num_samples, shape, num_classes,
method tearDown (line 92) | def tearDown(self):
method build_graph_model (line 97) | def build_graph_model(self):
method test_training_cpu (line 113) | def test_training_cpu(self):
FILE: tests/test_io/test_box/test_camvid.py
function _camvid_available (line 10) | def _camvid_available():
class TestCamvid (line 14) | class TestCamvid(unittest.TestCase):
method get_camvid_root (line 18) | def get_camvid_root(self):
method test_camvid_dataset_without_transforms (line 26) | def test_camvid_dataset_without_transforms(self):
method _test_camvid_dataset_with_transforms (line 37) | def _test_camvid_dataset_with_transforms(self):
method test_camvid_dataset_with_transforms (line 74) | def test_camvid_dataset_with_transforms(self):
method test_camvid_dataset_with_transforms_onehot (line 96) | def test_camvid_dataset_with_transforms_onehot(self):
FILE: tests/test_io/test_box/test_cityscapes.py
function _cityscapes_available (line 10) | def _cityscapes_available():
class TestCityscapes (line 14) | class TestCityscapes(unittest.TestCase):
method get_cityscapes_root (line 19) | def get_cityscapes_root(self):
method test_cityscapes_dataset_without_transforms (line 27) | def test_cityscapes_dataset_without_transforms(self):
method test_cityscapes_dataset_without_transforms_unzipped (line 38) | def test_cityscapes_dataset_without_transforms_unzipped(self):
method test_cityscapes_dataset_with_transforms (line 50) | def test_cityscapes_dataset_with_transforms(self):
method test_cityscapes_dataset_with_transforms_unzipped (line 82) | def test_cityscapes_dataset_with_transforms_unzipped(self):
FILE: tests/test_io/test_core/test_concatenate.py
class ConcatenateTest (line 4) | class ConcatenateTest(unittest.TestCase):
method test_concatenate (line 5) | def test_concatenate(self):
FILE: tests/test_io/test_core/test_zip.py
class ZipTest (line 4) | class ZipTest(unittest.TestCase):
method test_zip_minimal (line 5) | def test_zip_minimal(self):
method test_zip_sync (line 28) | def test_zip_sync(self):
method test_zip_reject (line 32) | def test_zip_reject(self):
FILE: tests/test_io/test_volumetric/test_lazy_volume_loader.py
class TestLazyVolumeLoader (line 19) | class TestLazyVolumeLoader(unittest.TestCase):
method tearDown (line 21) | def tearDown(self):
method test_h5_loader (line 28) | def test_h5_loader(self):
method test_h5_loader_data_slice (line 47) | def test_h5_loader_data_slice(self):
method test_h5_loader_pad (line 69) | def test_h5_loader_pad(self):
method test_h5_loader_data_slice_pad (line 91) | def test_h5_loader_data_slice_pad(self):
FILE: tests/test_io/test_volumetric/test_volume_loader.py
class TestVolumeLoader (line 9) | class TestVolumeLoader(unittest.TestCase):
method setUp (line 11) | def setUp(self):
method test_loader (line 14) | def test_loader(self):
class TestHDF5VolumeLoader (line 26) | class TestHDF5VolumeLoader(unittest.TestCase):
method setUp (line 28) | def setUp(self):
method tearDown (line 37) | def tearDown(self):
method test_hdf5_loader (line 43) | def test_hdf5_loader(self):
FILE: tests/test_training/test_basic.py
class TestTrainer (line 8) | class TestTrainer(TestCase):
method _make_test_model (line 16) | def _make_test_model():
method test_cifar (line 33) | def test_cifar(self):
method test_multi_io (line 64) | def test_multi_io(self):
method test_serialization (line 109) | def test_serialization(self):
method test_multi_gpu (line 132) | def test_multi_gpu(self):
method test_save (line 160) | def test_save(self):
method test_multi_gpu_setup (line 169) | def test_multi_gpu_setup(self):
FILE: tests/test_training/test_callbacks/test_base.py
class DummyCallback (line 10) | class DummyCallback(Callback):
method end_of_training_iteration (line 11) | def end_of_training_iteration(self, **_):
class WrongDummyCallback (line 15) | class WrongDummyCallback(Callback):
method end_of_iteration (line 16) | def end_of_iteration(self):
class CallbackMechTest (line 20) | class CallbackMechTest(unittest.TestCase):
method setUp (line 23) | def setUp(self):
method tearDown (line 26) | def tearDown(self):
method test_serialization (line 30) | def test_serialization(self):
method test_auto_registry (line 45) | def test_auto_registry(self):
method test_instance_registry (line 55) | def test_instance_registry(self):
FILE: tests/test_training/test_callbacks/test_essentials.py
class TestEssentials (line 13) | class TestEssentials(unittest.TestCase):
method setUp (line 16) | def setUp(self):
method test_dump_hdf5_every (line 41) | def test_dump_hdf5_every(self):
method tearDown (line 81) | def tearDown(self):
FILE: tests/test_training/test_callbacks/test_logging/test_base.py
class DummyLogger (line 7) | class DummyLogger(Logger):
method end_of_training_iteration (line 8) | def end_of_training_iteration(self, **_):
class TestLogger (line 12) | class TestLogger(unittest.TestCase):
method test_serialization (line 15) | def test_serialization(self):
FILE: tests/test_training/test_callbacks/test_logging/test_tensorboard.py
class TestTensorboard (line 15) | class TestTensorboard(unittest.TestCase):
method _make_test_model (line 22) | def _make_test_model(input_channels):
method tearDown (line 36) | def tearDown(self):
method get_random_dataloaders (line 43) | def get_random_dataloaders(self, input_channels=3):
method get_trainer (line 59) | def get_trainer(self, input_channels):
method test_tensorboard (line 81) | def test_tensorboard(self):
method test_tensorboard_grayscale (line 85) | def test_tensorboard_grayscale(self):
method test_serialization (line 89) | def test_serialization(self):
FILE: tests/test_training/test_callbacks/test_scheduling.py
class TestSchedulers (line 7) | class TestSchedulers(unittest.TestCase):
method test_manual_lr (line 9) | def test_manual_lr(self):
FILE: tests/test_utils/test_model_utils.py
class ModelUtilTester (line 8) | class ModelUtilTester(unittest.TestCase):
method test_model_tester (line 9) | def test_model_tester(self):
method test_model_tester_cuda (line 15) | def test_model_tester_cuda(self):
FILE: tests/test_utils/test_partial_cls.py
class TestCls (line 8) | class TestCls(object):
method __init__ (line 9) | def __init__(self, a, b, c=1, d=2):
class PartialClsTester (line 15) | class PartialClsTester(unittest.TestCase):
method test_partial_cls (line 17) | def test_partial_cls(self):
method test_update_existing_default_cls (line 40) | def test_update_existing_default_cls(self):
method test_fix_nothing (line 57) | def test_fix_nothing(self):
method test_fix_all (line 72) | def test_fix_all(self):
method test_default_all (line 98) | def test_default_all(self):
FILE: tests/test_utils/test_train_utils.py
class FrequencyTest (line 6) | class FrequencyTest(unittest.TestCase):
method test_from_string (line 7) | def test_from_string(self):
method test_from_tuple (line 19) | def test_from_tuple(self):
method test_is_consistent (line 24) | def test_is_consistent(self):
method test_init (line 29) | def test_init(self):
method test_duration (line 34) | def test_duration(self):
Condensed preview — 179 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (657K chars).
[
{
"path": ".editorconfig",
"chars": 292,
"preview": "# http://editorconfig.org\n\nroot = true\n\n[*]\nindent_style = space\nindent_size = 4\ntrim_trailing_whitespace = true\ninsert_"
},
{
"path": ".github/ISSUE_TEMPLATE.md",
"chars": 318,
"preview": "* inferno version:\n* Python version:\n* Operating System:\n\n### Description\n\nDescribe what you were trying to get done.\nTe"
},
{
"path": ".gitignore",
"chars": 777,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": ".travis.yml",
"chars": 1958,
"preview": "language: python\n\ndist: xenial\n\npython:\n - 3.7\n\nenv:\n - PYTORCH_CONDA=\"pytorch\" TORCHVISION_CONDA=\"torchvision\" TORCHV"
},
{
"path": "AUTHORS.rst",
"chars": 1917,
"preview": "=======\nCredits\n=======\n\nDevelopment Lead\n----------------\n\n* `Nasim Rahaman <https://github.com/nasimrahaman>`_ @ `Ima"
},
{
"path": "CONTRIBUTING.rst",
"chars": 3740,
"preview": ".. highlight:: shell\n\n============\nContributing\n============\n\nContributions are welcome, and they are greatly appreciate"
},
{
"path": "HISTORY.rst",
"chars": 690,
"preview": "=======\nHistory\n=======\n\n0.1.0 (2017-08-24)\n------------------\n\n* First early release on PyPI\n\n0.1.1 (2017-08-24)\n------"
},
{
"path": "LICENSE",
"chars": 591,
"preview": "\nApache Software License 2.0\n\nCopyright (c) 2017, Inferno Developers\n\nLicensed under the Apache License, Version 2.0 (th"
},
{
"path": "MANIFEST.in",
"chars": 262,
"preview": "include AUTHORS.rst\ninclude CONTRIBUTING.rst\ninclude HISTORY.rst\ninclude LICENSE\ninclude README.rst\n\nrecursive-include t"
},
{
"path": "Makefile",
"chars": 2287,
"preview": ".PHONY: clean clean-test clean-pyc clean-build docs help\n.DEFAULT_GOAL := help\ndefine BROWSER_PYSCRIPT\nimport os, webbro"
},
{
"path": "README.rst",
"chars": 5548,
"preview": "\n=======\nInferno\n=======\n\n.. image:: https://anaconda.org/conda-forge/inferno/badges/version.svg \n :target: htt"
},
{
"path": "add2path.sh",
"chars": 103,
"preview": "#!/usr/bin/env bash\n# Run this script from within the directory.\nexport PYTHONPATH=${PYTHONPATH}:${PWD}"
},
{
"path": "build_docs.sh",
"chars": 102,
"preview": "#!/bin/bash\ncd docs\nrm -r -f inferno-apidoc\nsphinx-apidoc -o inferno-apidoc ../inferno\nmake html\ncd .."
},
{
"path": "conda-recipe/build.sh",
"chars": 280,
"preview": "PY_VER=$(python -c \"import sys; print('{}.{}'.format(*sys.version_info[:2]))\")\n\n# Install python modules\nmkdir -p ${PREF"
},
{
"path": "conda-recipe/meta.yaml",
"chars": 899,
"preview": "package:\n name: inferno\n\n {% set tagged_version = GIT_DESCRIBE_TAG|replace(\"v\",\"\")|replace(\"-\", \".\") %}\n\n # If "
},
{
"path": "docs/.gitignore",
"chars": 41,
"preview": "/inferno.rst\n/inferno.*.rst\n/modules.rst\n"
},
{
"path": "docs/Makefile",
"chars": 6766,
"preview": "# Makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS =\nSPHINXBUILD "
},
{
"path": "docs/_templates/layout.html",
"chars": 139,
"preview": "{# layout.html #}\n{# Import the theme's layout. #}\n{% extends \"!layout.html\" %}\n\n{% set css_files = css_files + ['_stati"
},
{
"path": "docs/_templates/template_module.rst",
"chars": 910,
"preview": "{{ fullname }}\n{{ underline }}\n\n.. automodule:: {{ fullname }}\n \n {% block functions %}\n {% if functions %}\n\n F"
},
{
"path": "docs/authors.rst",
"chars": 28,
"preview": ".. include:: ../AUTHORS.rst\n"
},
{
"path": "docs/conf.py",
"chars": 10920,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n#\n# inferno documentation build configuration file, created by\n# sphinx-qu"
},
{
"path": "docs/contributing.rst",
"chars": 33,
"preview": ".. include:: ../CONTRIBUTING.rst\n"
},
{
"path": "docs/environment.yml",
"chars": 351,
"preview": "name: inferno_docs\n\nchannels:\n - soumith\n - anaconda\n\ndependencies:\n - python==3.5\n - pytorch>=0.1.12\n - torchvisio"
},
{
"path": "docs/examples.rst",
"chars": 146,
"preview": ".. _inferno_examples_gallery:\n\nInferno Examples Gallery\n============================\n\n\n.. toctree::\n :maxdepth: 5\n\n "
},
{
"path": "docs/history.rst",
"chars": 28,
"preview": ".. include:: ../HISTORY.rst\n"
},
{
"path": "docs/index.rst",
"chars": 362,
"preview": "Welcome to inferno's documentation!\n======================================\n\nContents:\n\n.. toctree::\n :maxdepth: 1\n\n "
},
{
"path": "docs/inferno-apidoc/inferno.extensions.containers.rst",
"chars": 646,
"preview": "inferno.extensions.containers package\n=====================================\n\nSubmodules\n----------\n\ninferno.extensions.c"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.criteria.rst",
"chars": 1115,
"preview": "inferno.extensions.criteria package\n===================================\n\nSubmodules\n----------\n\ninferno.extensions.crite"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.initializers.rst",
"chars": 652,
"preview": "inferno.extensions.initializers package\n=======================================\n\nSubmodules\n----------\n\ninferno.extensio"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.layers.rst",
"chars": 2204,
"preview": "inferno.extensions.layers package\n=================================\n\nSubmodules\n----------\n\ninferno.extensions.layers.ac"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.metrics.rst",
"chars": 1197,
"preview": "inferno.extensions.metrics package\n==================================\n\nSubmodules\n----------\n\ninferno.extensions.metrics"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.optimizers.rst",
"chars": 654,
"preview": "inferno.extensions.optimizers package\n=====================================\n\nSubmodules\n----------\n\ninferno.extensions.o"
},
{
"path": "docs/inferno-apidoc/inferno.extensions.rst",
"chars": 417,
"preview": "inferno.extensions package\n==========================\n\nSubpackages\n-----------\n\n.. toctree::\n\n inferno.extensions.con"
},
{
"path": "docs/inferno-apidoc/inferno.io.box.rst",
"chars": 841,
"preview": "inferno.io.box package\n======================\n\nSubmodules\n----------\n\ninferno.io.box.binary\\_blobs module\n--------------"
},
{
"path": "docs/inferno-apidoc/inferno.io.core.rst",
"chars": 841,
"preview": "inferno.io.core package\n=======================\n\nSubmodules\n----------\n\ninferno.io.core.base module\n--------------------"
},
{
"path": "docs/inferno-apidoc/inferno.io.rst",
"chars": 286,
"preview": "inferno.io package\n==================\n\nSubpackages\n-----------\n\n.. toctree::\n\n inferno.io.box\n inferno.io.core\n "
},
{
"path": "docs/inferno-apidoc/inferno.io.transform.rst",
"chars": 896,
"preview": "inferno.io.transform package\n============================\n\nSubmodules\n----------\n\ninferno.io.transform.base module\n-----"
},
{
"path": "docs/inferno-apidoc/inferno.io.volumetric.rst",
"chars": 813,
"preview": "inferno.io.volumetric package\n=============================\n\nSubmodules\n----------\n\ninferno.io.volumetric.lazy\\_volume\\_"
},
{
"path": "docs/inferno-apidoc/inferno.rst",
"chars": 425,
"preview": "inferno package\n===============\n\nSubpackages\n-----------\n\n.. toctree::\n\n inferno.extensions\n inferno.io\n infern"
},
{
"path": "docs/inferno-apidoc/inferno.trainers.callbacks.logging.rst",
"chars": 691,
"preview": "inferno.trainers.callbacks.logging package\n==========================================\n\nSubmodules\n----------\n\ninferno.tr"
},
{
"path": "docs/inferno-apidoc/inferno.trainers.callbacks.rst",
"chars": 1474,
"preview": "inferno.trainers.callbacks package\n==================================\n\nSubpackages\n-----------\n\n.. toctree::\n\n infern"
},
{
"path": "docs/inferno-apidoc/inferno.trainers.rst",
"chars": 427,
"preview": "inferno.trainers package\n========================\n\nSubpackages\n-----------\n\n.. toctree::\n\n inferno.trainers.callbacks"
},
{
"path": "docs/inferno-apidoc/inferno.utils.rst",
"chars": 1524,
"preview": "inferno.utils package\n=====================\n\nSubmodules\n----------\n\ninferno.utils.exceptions module\n--------------------"
},
{
"path": "docs/inferno-apidoc/modules.rst",
"chars": 58,
"preview": "inferno\n=======\n\n.. toctree::\n :maxdepth: 4\n\n inferno\n"
},
{
"path": "docs/installation.rst",
"chars": 2372,
"preview": ".. highlight:: shell\n\n==================================\nInstallation\n==================================\n\nInstall on Lin"
},
{
"path": "docs/make.bat",
"chars": 6461,
"preview": "@ECHO OFF\n\nREM Command file for Sphinx documentation\n\nif \"%SPHINXBUILD%\" == \"\" (\n\tset SPHINXBUILD=sphinx-build\n)\nset BUI"
},
{
"path": "docs/readme.rst",
"chars": 27,
"preview": ".. include:: ../README.rst\n"
},
{
"path": "docs/refs.bib",
"chars": 270,
"preview": "\n@inproceedings{alush_2013_simbad,\ntitle={Break and Conquer: Efficient Correlation Clustering for Image Segmentation},\na"
},
{
"path": "docs/usage.rst",
"chars": 10738,
"preview": "=====\nUsage\n=====\n\n\nInferno is a utility library built around [PyTorch](http://pytorch.org/), designed to help you train"
},
{
"path": "docs/zbibliography.rst",
"chars": 134,
"preview": ".. _inferno_bibliography:\n\nBibliography\n============================\n\nThe bibliography: \n\n.. bibliography:: refs.bib\n "
},
{
"path": "examples/README.txt",
"chars": 63,
"preview": "\n.. _examples-index:\n\nGallery of Examples\n===================\n\n"
},
{
"path": "examples/plot_cheap_unet.py",
"chars": 8435,
"preview": "\"\"\"\nUNet Tutorial\n================================\nA unet example which can be run without a gpu\n\"\"\"\n\n##################"
},
{
"path": "examples/plot_train_side_loss_unet.py",
"chars": 7112,
"preview": "\"\"\"\nTrain Side Loss UNet Example\n================================\n\nIn this example a UNet with side supervision\nand auxi"
},
{
"path": "examples/plot_unet_tutorial.py",
"chars": 9582,
"preview": "\"\"\"\nUNet Tutorial\n================================\nA tentative tutorial on the usage\nof the unet framework in inferno\n\"\""
},
{
"path": "examples/regularized_mnist.py",
"chars": 4545,
"preview": "\"\"\"\nRegularized MNIST Example\n================================\n\nThis example demonstrates adding and logging arbitrary r"
},
{
"path": "examples/trainer.py",
"chars": 2399,
"preview": "\"\"\"\nTrainer Example\n================================\n\nThis example should illustrate how to use the trainer class.\n\n\"\"\"\n"
},
{
"path": "inferno/__init__.py",
"chars": 319,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Top-level package for inferno.\"\"\"\n\nfrom . import extensions\nfrom . import io\nfrom . import t"
},
{
"path": "inferno/extensions/__init__.py",
"chars": 331,
"preview": "from . import containers\nfrom . import criteria\nfrom . import initializers\nfrom . import layers\nfrom . import metrics\nfr"
},
{
"path": "inferno/extensions/containers/__init__.py",
"chars": 47,
"preview": "from .graph import *\nfrom .sequential import *\n"
},
{
"path": "inferno/extensions/containers/graph.py",
"chars": 18135,
"preview": "from collections import OrderedDict\nimport sys\nimport threading\nimport multiprocessing as mp\nimport copy\nimport gc\n\nimpo"
},
{
"path": "inferno/extensions/containers/sequential.py",
"chars": 658,
"preview": "import torch.nn as nn\nfrom ...utils import python_utils as pyu\n\n\n__all__ = ['Sequential1', 'Sequential2']\n\n\nclass Sequen"
},
{
"path": "inferno/extensions/criteria/__init__.py",
"chars": 205,
"preview": "from .set_similarity_measures import *\nfrom .elementwise_measures import *\nfrom .core import *\nfrom .regularized import "
},
{
"path": "inferno/extensions/criteria/core.py",
"chars": 3139,
"preview": "import torch.nn as nn\nfrom functools import reduce\nfrom ...utils.exceptions import assert_, ShapeError, NotTorchModuleEr"
},
{
"path": "inferno/extensions/criteria/elementwise_measures.py",
"chars": 1434,
"preview": "import torch.nn as nn\nfrom ...utils.exceptions import assert_\n\n\nclass WeightedMSELoss(nn.Module):\n NEGATIVE_CLASS_WEI"
},
{
"path": "inferno/extensions/criteria/regularized.py",
"chars": 4589,
"preview": "import warnings\n\nimport torch\nfrom torch import nn\n\nfrom . import set_similarity_measures, core\n\n__all__ = [\n 'Regula"
},
{
"path": "inferno/extensions/criteria/set_similarity_measures.py",
"chars": 6051,
"preview": "import torch.nn as nn\nfrom ...utils.torch_utils import flatten_samples\n\n__all__ = ['SorensenDiceLoss', 'GeneralizedDiceL"
},
{
"path": "inferno/extensions/initializers/__init__.py",
"chars": 44,
"preview": "from .base import *\nfrom .presets import *\n\n"
},
{
"path": "inferno/extensions/initializers/base.py",
"chars": 4904,
"preview": "import torch.nn.init as init\n\n\n__all__ = ['Initializer',\n 'Initialization',\n 'WeightInitFunction',\n "
},
{
"path": "inferno/extensions/initializers/presets.py",
"chars": 2751,
"preview": "import numpy as np\nimport torch.nn.init as init\nfrom functools import partial\n\nfrom .base import Initialization, Initial"
},
{
"path": "inferno/extensions/layers/__init__.py",
"chars": 847,
"preview": "__all__ = []\nfrom .activations import *\nfrom .convolutional import *\nfrom .device import *\nfrom .reshape import *\nfrom ."
},
{
"path": "inferno/extensions/layers/activations.py",
"chars": 444,
"preview": "import torch.nn.functional as F\nimport torch.nn as nn\nfrom ...utils.torch_utils import where\n\n__all__ = ['SELU']\n_all = "
},
{
"path": "inferno/extensions/layers/convolutional.py",
"chars": 10967,
"preview": "import torch.nn as nn\nimport sys\nimport functools\nfrom ..initializers import (\n OrthogonalWeightsZeroBias,\n Kaimin"
},
{
"path": "inferno/extensions/layers/convolutional_blocks.py",
"chars": 2616,
"preview": "import torch.nn as nn\nfrom .convolutional import BNReLUConv2D, BNReLUDeconv2D, Conv2D, Deconv2D\nfrom ...utils import pyt"
},
{
"path": "inferno/extensions/layers/device.py",
"chars": 3792,
"preview": "import torch.nn as nn\nfrom ...utils.python_utils import from_iterable, to_iterable\nfrom ...utils.exceptions import asser"
},
{
"path": "inferno/extensions/layers/identity.py",
"chars": 198,
"preview": "import torch.nn as nn\n__all__ = ['identity']\n_all = __all__\n\nclass Identity(nn.Module): \n def __init__(self):\n "
},
{
"path": "inferno/extensions/layers/normalization.py",
"chars": 504,
"preview": "import torch.nn as nn\n\n\nclass BatchNormND(nn.Module):\n def __init__(self, dim, num_features, \n eps=1e"
},
{
"path": "inferno/extensions/layers/reshape.py",
"chars": 8186,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom ...utils.exceptions import assert_, ShapeError\nf"
},
{
"path": "inferno/extensions/layers/sampling.py",
"chars": 3269,
"preview": "import torch.nn as nn\n\n__all__ = ['AnisotropicUpsample', 'AnisotropicPool', 'Upsample', 'AnisotropicUpsample2D', 'Anisot"
},
{
"path": "inferno/extensions/metrics/__init__.py",
"chars": 48,
"preview": "from .categorical import *\nfrom .arand import *\n"
},
{
"path": "inferno/extensions/metrics/arand.py",
"chars": 5733,
"preview": "from .base import Metric\nimport numpy as np\nimport scipy.sparse as sparse\nimport logging\n\n\nclass ArandScore(Metric):\n "
},
{
"path": "inferno/extensions/metrics/base.py",
"chars": 966,
"preview": "\n\nclass Metric(object):\n\n def forward(self, *args, **kwargs):\n raise NotImplementedError\n\n def __call__(sel"
},
{
"path": "inferno/extensions/metrics/categorical.py",
"chars": 6159,
"preview": "import torch\nfrom .base import Metric\nfrom ...utils.torch_utils import flatten_samples, is_label_tensor\nfrom ...utils.ex"
},
{
"path": "inferno/extensions/metrics/cremi_score.py",
"chars": 359,
"preview": "import numpy as np\nfrom .voi import voi\nfrom .arand import adapted_rand\n\n\n# TODO build metrics object\n\n\ndef cremi_metric"
},
{
"path": "inferno/extensions/metrics/voi.py",
"chars": 10010,
"preview": "from .base import Metric\n\nimport numpy as np\nimport scipy.sparse as sparse\n\n\nclass VoiScore(Metric):\n \"\"\"\n Compute"
},
{
"path": "inferno/extensions/models/__init__.py",
"chars": 68,
"preview": "from .unet import UNet, UNetBase\nfrom .res_unet import ResBlockUNet\n"
},
{
"path": "inferno/extensions/models/res_unet.py",
"chars": 8404,
"preview": "import torch\nimport torch.nn as nn\nfrom ..layers.convolutional import ConvActivation\nfrom .unet import UNetBase\nfrom ..."
},
{
"path": "inferno/extensions/models/unet.py",
"chars": 14315,
"preview": "import torch\nimport torch.nn as nn\nfrom ..layers.identity import Identity\nfrom ..layers.convolutional import ConvELU2D, "
},
{
"path": "inferno/extensions/optimizers/__init__.py",
"chars": 110,
"preview": "from .adam import Adam\nfrom .annealed_adam import AnnealedAdam\nfrom .ranger import Ranger, RangerQH, RangerVA\n"
},
{
"path": "inferno/extensions/optimizers/adam.py",
"chars": 3106,
"preview": "import math\nfrom torch.optim import Optimizer\n\n\nclass Adam(Optimizer):\n \"\"\"Implements Adam algorithm with the option "
},
{
"path": "inferno/extensions/optimizers/annealed_adam.py",
"chars": 1833,
"preview": "from .adam import Adam\n\n\nclass AnnealedAdam(Adam):\n \"\"\"Implements Adam algorithm with learning rate annealing and opt"
},
{
"path": "inferno/extensions/optimizers/ranger.py",
"chars": 248,
"preview": "# easy support for additional ranger optimizers from\n# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer\ntry:\n"
},
{
"path": "inferno/inferno.py",
"chars": 44,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Main module.\"\"\"\n"
},
{
"path": "inferno/io/__init__.py",
"chars": 85,
"preview": "from . import box\nfrom . import core\nfrom . import transform\nfrom . import volumetric"
},
{
"path": "inferno/io/box/__init__.py",
"chars": 349,
"preview": "\"\"\"Things that work out of the box. ;)\"\"\"\n\nfrom .camvid import CamVid, get_camvid_loaders\nfrom .cityscapes import Citysc"
},
{
"path": "inferno/io/box/binary_blobs.py",
"chars": 5580,
"preview": "import torch.utils.data as data\nimport skimage.data\nimport numpy\nfrom operator import mul\nfrom functools import reduce\n\n"
},
{
"path": "inferno/io/box/camvid.py",
"chars": 8776,
"preview": "# Adapted from felixgwu's PR here:\n# https://github.com/felixgwu/vision/blob/cf491d301f62ae9c77ff7250fb7def5cd55ec963/to"
},
{
"path": "inferno/io/box/cifar.py",
"chars": 6337,
"preview": "import os\nimport torch\nimport torchvision\nimport torchvision.transforms as transforms\nfrom torch.utils.data.sampler impo"
},
{
"path": "inferno/io/box/cityscapes.py",
"chars": 12804,
"preview": "import zipfile\nimport io\nimport os\nimport torch.utils.data as data\nfrom PIL import Image\nfrom os.path import join, relpa"
},
{
"path": "inferno/io/core/__init__.py",
"chars": 103,
"preview": "from .base import SyncableDataset\nfrom .zip import Zip, ZipReject\nfrom .concatenate import Concatenate\n"
},
{
"path": "inferno/io/core/base.py",
"chars": 1170,
"preview": "from torch.utils.data.dataset import Dataset\n\n\nclass SyncableDataset(Dataset):\n def __init__(self, base_sequence=None"
},
{
"path": "inferno/io/core/concatenate.py",
"chars": 2559,
"preview": "import numpy as np\nfrom torch.utils.data.dataset import Dataset\nfrom ...utils import python_utils as pyu\n\n\nclass Concate"
},
{
"path": "inferno/io/core/data_utils.py",
"chars": 248,
"preview": "\ndef implements_sync_primitives(dataset):\n return hasattr(dataset, 'sync_with') and callable(getattr(dataset, 'sync_w"
},
{
"path": "inferno/io/core/zip.py",
"chars": 10048,
"preview": "from torch.utils.data.dataset import Dataset\nimport torch.multiprocessing as mp\nimport numpy as np\nfrom . import data_ut"
},
{
"path": "inferno/io/transform/__init__.py",
"chars": 100,
"preview": "from .base import Transform, Compose\nfrom . import generic\nfrom . import image\nfrom . import volume\n"
},
{
"path": "inferno/io/transform/base.py",
"chars": 8578,
"preview": "from ...utils import python_utils as pyu\nimport numpy as np\n\n\nclass Transform(object):\n \"\"\"\n Base class for a Tran"
},
{
"path": "inferno/io/transform/generic.py",
"chars": 8277,
"preview": "import numpy as np\nimport torch\nfrom .base import Transform, DTypeMapping\nfrom ...utils.exceptions import assert_, DType"
},
{
"path": "inferno/io/transform/image.py",
"chars": 26648,
"preview": "import numpy as np\nfrom scipy.ndimage import zoom\nfrom scipy.ndimage.filters import gaussian_filter\nfrom scipy.ndimage.i"
},
{
"path": "inferno/io/transform/volume.py",
"chars": 16485,
"preview": "import numpy as np\nimport scipy\nfrom scipy.ndimage import zoom\nfrom scipy.ndimage.morphology import binary_dilation, bin"
},
{
"path": "inferno/io/volumetric/__init__.py",
"chars": 163,
"preview": "from .volume import VolumeLoader, HDF5VolumeLoader, TIFVolumeLoader\rfrom .lazy_volume_loader import LazyHDF5VolumeLoader"
},
{
"path": "inferno/io/volumetric/lazy_volume_loader.py",
"chars": 13337,
"preview": "import numpy as np\nimport os\nimport pickle\nfrom concurrent import futures\n\n# try to load io libraries (h5py and z5py)\ntr"
},
{
"path": "inferno/io/volumetric/volume.py",
"chars": 11762,
"preview": "import numpy as np\nimport os\nimport skimage.io\n\nfrom ..core.base import SyncableDataset\nfrom ..core.base import IndexSpe"
},
{
"path": "inferno/io/volumetric/volumetric_utils.py",
"chars": 6557,
"preview": "import random\nimport itertools as it\n\n\ndef slidingwindowslices(shape, window_size, strides,\n ds=1"
},
{
"path": "inferno/trainers/__init__.py",
"chars": 113,
"preview": "from . import basic\nfrom . import callbacks\nfrom . basic import Trainer\n__all__ = ['basic','callbacks','Trainer']"
},
{
"path": "inferno/trainers/basic.py",
"chars": 72802,
"preview": "from datetime import datetime\nfrom inspect import signature\nimport os\nimport shutil\n\n# These are fetched from globals, t"
},
{
"path": "inferno/trainers/callbacks/__init__.py",
"chars": 384,
"preview": "__all__ = ['CallbackEngine', 'Callback', 'Console', 'essentials', 'scheduling', 'gradients']\n\nfrom .base import Callback"
},
{
"path": "inferno/trainers/callbacks/base.py",
"chars": 6659,
"preview": "from ...utils import python_utils as pyu\n\n\nclass CallbackEngine(object):\n \"\"\"\n Gathers and manages callbacks.\n\n "
},
{
"path": "inferno/trainers/callbacks/console.py",
"chars": 2617,
"preview": "from datetime import datetime\nfrom .base import Callback\n\nclass StdoutPrinter(object):\n def print(self, message):\n "
},
{
"path": "inferno/trainers/callbacks/essentials.py",
"chars": 12316,
"preview": "import numpy as np\nimport os\nimport h5py as h5\nfrom ...utils import torch_utils as tu\nfrom ...utils.train_utils import F"
},
{
"path": "inferno/trainers/callbacks/gradients.py",
"chars": 1841,
"preview": "from ...utils.train_utils import Frequency\nfrom ...utils.exceptions import assert_, FrequencyValueError\nfrom .base impor"
},
{
"path": "inferno/trainers/callbacks/logging/__init__.py",
"chars": 374,
"preview": "__all__ = ['get_logger']\ntry:\n INFERNO_WITH_TENSORBOARD_LOGGER = True\n from .tensorboard import TensorboardLogger\n"
},
{
"path": "inferno/trainers/callbacks/logging/base.py",
"chars": 1249,
"preview": "import os\nfrom ..base import Callback\n\n\nclass Logger(Callback):\n \"\"\"\n A special callback for logging.\n\n Loggers"
},
{
"path": "inferno/trainers/callbacks/logging/tensorboard.py",
"chars": 20981,
"preview": "import warnings\nimport numpy as np\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom .base import Logger\nfrom ...."
},
{
"path": "inferno/trainers/callbacks/scheduling.py",
"chars": 18376,
"preview": "from ...utils.train_utils import Frequency, Duration, MovingAverage\nfrom ...utils import python_utils as pyu\nfrom ...uti"
},
{
"path": "inferno/trainers/callbacks/tqdm.py",
"chars": 3542,
"preview": "from .base import Callback\nfrom tqdm import tqdm\nfrom datetime import datetime\nfrom .console import Console\n\n\nclass TQDM"
},
{
"path": "inferno/trainers/callbacks/tqdmstub.py",
"chars": 429,
"preview": "from .base import Callback\n\nclass TQDMProgressBar(Callback):\n def __init__(self, *args, **kwargs):\n super(TQDM"
},
{
"path": "inferno/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "inferno/utils/exceptions.py",
"chars": 778,
"preview": "\"\"\"Exceptions and Error Handling\"\"\"\n\n\ndef assert_(condition, message='', exception_type=AssertionError):\n \"\"\"Like ass"
},
{
"path": "inferno/utils/io_utils.py",
"chars": 2704,
"preview": "import os\nimport h5py as h5\nimport numpy as np\nimport yaml\nfrom skimage.io import imsave\n\n\n# Function to load in a datas"
},
{
"path": "inferno/utils/math_utils.py",
"chars": 873,
"preview": "\n\ndef max_allowed_ds_steps(shape, factor):\n \"\"\"How often can a shape be down-sampled by a given factor\n such t"
},
{
"path": "inferno/utils/model_utils.py",
"chars": 2688,
"preview": "import torch\nfrom .exceptions import assert_, NotTorchModuleError, ShapeError\n\n\ndef is_model_cuda(model):\n try:\n "
},
{
"path": "inferno/utils/partial_cls.py",
"chars": 4415,
"preview": "import functools\nimport sys\nimport types\nimport inspect\n\n\n__all__ = [\n 'partial_cls',\n 'register_partial_cls'\n]\n\n"
},
{
"path": "inferno/utils/python_utils.py",
"chars": 5870,
"preview": "\"\"\"Utility functions with no external dependencies.\"\"\"\nimport signal\nimport warnings\nimport functools\nimport inspect\nimp"
},
{
"path": "inferno/utils/test_utils.py",
"chars": 1901,
"preview": "import torch\nfrom torch.utils.data.dataset import TensorDataset\nfrom torch.utils.data.dataloader import DataLoader\nimpor"
},
{
"path": "inferno/utils/torch_utils.py",
"chars": 4749,
"preview": "import numpy as np\nimport torch\n\nfrom .python_utils import delayed_keyboard_interrupt\nfrom .exceptions import assert_, S"
},
{
"path": "inferno/utils/train_utils.py",
"chars": 9484,
"preview": "\"\"\"Utilities for training.\"\"\"\nimport numpy as np\nfrom .exceptions import assert_, FrequencyTypeError, FrequencyValueErro"
},
{
"path": "inferno/version.py",
"chars": 22,
"preview": "__version__ = '0.4.0'\n"
},
{
"path": "readthedocs.yml",
"chars": 81,
"preview": "conda:\n file: docs/environment.yml\npython:\n version: 3.5\n pip_install: false"
},
{
"path": "requirements.txt",
"chars": 54,
"preview": "dill\npyyaml\nscipy>=0.13.0\nh5py\nnumpy>=1.8\nscikit-image"
},
{
"path": "requirements_dev.txt",
"chars": 277,
"preview": "pip==8.1.2\nbumpversion==0.5.3\nwheel==0.29.0\nwatchdog==0.8.3\nflake8==2.6.0\ntox==2.3.1\ncoverage==4.1\nSphinx==1.4.8\ncryptog"
},
{
"path": "setup.py",
"chars": 2255,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"The setup script.\"\"\"\n\nfrom setuptools import setup, find_packages\nimpo"
},
{
"path": "tests/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_extensions/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_extensions/test_containers/test_graph.py",
"chars": 5288,
"preview": "import unittest\nfrom functools import reduce\nimport torch\n\n\nclass TestGraph(unittest.TestCase):\n def setUp(self):\n "
},
{
"path": "tests/test_extensions/test_criteria/test_core.py",
"chars": 507,
"preview": "import unittest\nimport torch\nimport torch.nn as nn\n\n\nclass TestCore(unittest.TestCase):\n def test_as_2d_criterion(sel"
},
{
"path": "tests/test_extensions/test_criteria/test_elementwise_measures.py",
"chars": 644,
"preview": "import unittest\nimport inferno.extensions.criteria.elementwise_measures as em\nimport torch\n\n\nclass TestElementwiseMeasur"
},
{
"path": "tests/test_extensions/test_criteria/test_set_similarity_measures.py",
"chars": 1999,
"preview": "import unittest\nimport torch\n\n\nclass SetSimilarityTest(unittest.TestCase):\n def get_dummy_variables(self):\n x "
},
{
"path": "tests/test_extensions/test_layers/deprecated/building_blocks.py",
"chars": 2372,
"preview": "import unittest\nimport torch\nimport inferno.extensions.layers.building_blocks as bb\n\n\nclass ResBlockTest(unittest.TestCa"
},
{
"path": "tests/test_extensions/test_layers/test_activations.py",
"chars": 325,
"preview": "import unittest\nimport torch\nimport inferno.extensions.layers.activations as activations\n\n\nclass ActivationTest(unittest"
},
{
"path": "tests/test_extensions/test_layers/test_convolutional.py",
"chars": 615,
"preview": "import unittest\nimport torch\nfrom inferno.utils.model_utils import ModelTester\n\n\nclass TestConvolutional(unittest.TestCa"
},
{
"path": "tests/test_extensions/test_layers/test_device.py",
"chars": 1215,
"preview": "import unittest\nfrom inferno.extensions.layers.device import DeviceTransfer, OnDevice\nimport torch\n\n\nclass TransferTest("
},
{
"path": "tests/test_extensions/test_layers/test_reshape.py",
"chars": 2420,
"preview": "import unittest\nimport torch\n\n\nclass TestReshape(unittest.TestCase):\n def _get_input_variable(self, *shape):\n "
},
{
"path": "tests/test_extensions/test_metrics/categorical.py",
"chars": 2048,
"preview": "import unittest\nimport torch\nfrom inferno.extensions.metrics import IOU\n\n\nclass TestCategorical(unittest.TestCase):\n "
},
{
"path": "tests/test_extensions/test_models/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_extensions/test_models/test_res_unet.py",
"chars": 3023,
"preview": "import unittest\nimport torch\nimport torch.cuda as cuda\nfrom inferno.utils.model_utils import ModelTester\n\n\nclass ResUNet"
},
{
"path": "tests/test_extensions/test_models/test_unet.py",
"chars": 2199,
"preview": "import unittest\nimport torch.cuda as cuda\nfrom inferno.utils.model_utils import ModelTester, MultiscaleModelTester\nfrom "
},
{
"path": "tests/test_inferno.py",
"chars": 6088,
"preview": "#!/usr/bin/env python\n# -*- coding: utf-8 -*-\n\n\"\"\"Tests for `inferno` package.\"\"\"\n\n\nimport unittest\nimport numpy as np\ni"
},
{
"path": "tests/test_io/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_io/test_box/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_io/test_box/test_camvid.py",
"chars": 5616,
"preview": "import os\nfrom os.path import join, dirname, exists, isdir\nimport unittest\nimport numpy as np\n\n\n_CAMVID_ROOT = None\n\n\nde"
},
{
"path": "tests/test_io/test_box/test_cityscapes.py",
"chars": 5462,
"preview": "import os\nfrom os.path import join, dirname, exists, isdir\nimport unittest\nimport numpy as np\nimport time\n\n_CITYSCAPES_R"
},
{
"path": "tests/test_io/test_core/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_io/test_core/test_concatenate.py",
"chars": 945,
"preview": "import unittest\n\n\nclass ConcatenateTest(unittest.TestCase):\n def test_concatenate(self):\n from inferno.io.core"
},
{
"path": "tests/test_io/test_core/test_zip.py",
"chars": 2058,
"preview": "import unittest\n\n\nclass ZipTest(unittest.TestCase):\n def test_zip_minimal(self):\n \"\"\"Minimal test with python "
},
{
"path": "tests/test_io/test_volumetric/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_io/test_volumetric/test_lazy_volume_loader.py",
"chars": 4451,
"preview": "import unittest\nimport os\nimport numpy as np\n\n# try to load io libraries (h5py and z5py)\ntry:\n import h5py\n WITH_H"
},
{
"path": "tests/test_io/test_volumetric/test_volume_loader.py",
"chars": 1743,
"preview": "import unittest\nimport os\nfrom shutil import rmtree\n\nimport numpy as np\nimport h5py\n\n\nclass TestVolumeLoader(unittest.Te"
},
{
"path": "tests/test_training/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_training/test_basic.py",
"chars": 7264,
"preview": "from unittest import TestCase, skipUnless\nimport torch\nfrom unittest import main\nimport time\nfrom os.path import join, d"
},
{
"path": "tests/test_training/test_callbacks/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_training/test_callbacks/test_base.py",
"chars": 2391,
"preview": "import unittest\nimport torch\nfrom inferno.trainers.callbacks.base import Callback, CallbackEngine\nfrom inferno.trainers."
},
{
"path": "tests/test_training/test_callbacks/test_essentials.py",
"chars": 3964,
"preview": "import unittest\nimport shutil\nimport h5py as h5\nfrom os.path import dirname, join\nfrom os import listdir\nfrom inferno.tr"
},
{
"path": "tests/test_training/test_callbacks/test_logging/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_training/test_callbacks/test_logging/test_base.py",
"chars": 1007,
"preview": "import unittest\nfrom inferno.trainers.callbacks.logging.base import Logger\nfrom inferno.trainers.basic import Trainer\nfr"
},
{
"path": "tests/test_training/test_callbacks/test_logging/test_tensorboard.py",
"chars": 3992,
"preview": "import unittest\n\nimport os\nfrom shutil import rmtree\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom inferno."
},
{
"path": "tests/test_training/test_callbacks/test_scheduling.py",
"chars": 1256,
"preview": "import unittest\nfrom inferno.trainers.callbacks.scheduling import ManualLR\nfrom torch import nn\nfrom torch.optim import "
},
{
"path": "tests/test_utils/__init__.py",
"chars": 62,
"preview": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for inferno.\"\"\"\n"
},
{
"path": "tests/test_utils/test_model_utils.py",
"chars": 832,
"preview": "import unittest\nimport inferno.utils.model_utils as mu\nfrom inferno.utils.exceptions import ShapeError\nimport torch\nimpo"
},
{
"path": "tests/test_utils/test_partial_cls.py",
"chars": 3316,
"preview": "import unittest\nimport inferno.utils.model_utils as mu\nfrom inferno.utils.partial_cls import register_partial_cls\nimport"
},
{
"path": "tests/test_utils/test_train_utils.py",
"chars": 1900,
"preview": "import unittest\nimport inferno.utils.train_utils as tu\nimport numpy as np\n\n\nclass FrequencyTest(unittest.TestCase):\n "
}
]
About this extraction
This page contains the full source code of the inferno-pytorch/inferno GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 179 files (610.4 KB), approximately 144.7k tokens, and a symbol index with 1067 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.