[
  {
    "path": ".codecov.yml",
    "content": "comment: off\n\ncoverage:\n  status:\n    project:\n      default:\n        target: auto\n        threshold: 0.50\n        base: auto\n    patch: off\nignore:\n  - \"tests/\"\n  - \"notebooks/\"\n  - \"*/__init.py\"\n  \n"
  },
  {
    "path": ".gitattributes",
    "content": "delira/_version.py export-subst\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/bug_report.md",
    "content": "---\nname: Bug report\nabout: Report a bug and give us a minimal example to reproduce it\ntitle: \"[Bug]\"\nlabels: bug\nassignees: ''\n\n---\n\n**Description**\nWhat happens? What should happen?\n\n**Environment**\n* OS:\n* Python version:\n* `delira` version\n* How did you install `delira`? [ pip | source | conda | docker ]\n\n**Reproduction**\nGive us a minimal example to reproduce the error\n\n**Additional context**\nAdd any other context about the problem here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/feature_request.md",
    "content": "---\nname: Feature request\nabout: Request a feature\ntitle: \"[FeatureRequest]\"\nlabels: new feature\nassignees: ''\n\n---\n\n**Description**\nWhat should be added/changed?\n\n**Feature History**\nWhat have you tried so far?\n\n**Proposal**\nHow could the feature be implemented? \n*Are you able/willing to implement the feature yourself (with some guidance from us)?\n\n**Additional context**\nAdd any other context about the feature request here.\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE/question.md",
    "content": "---\nname: Question\nabout: Ask a question/for support\ntitle: \"[Question]\"\nlabels: question\nassignees: ''\n\n---\n\n**Description**\nWhat happens? What should happen?\n\n**Environment**\n* OS:\n* Python version:\n* `delira` version\n* How did you install `delira`? [ pip | source | conda | docker ]\n* Machine Specs:\n* Minimal working Example:\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# pycharm\n.idea/\n.DS_Store\n\n.idea\n.vscode\n.pytest_cache\n\n# delira config\n*/.delira\n\n# backend extensions\n*.pkl\nevents.*\n*.pt\n*.pth\n*.ptj\n*.chain\n*.meta\n\n# Test results\n*/UnnamedExperiment/*\n"
  },
  {
    "path": ".readthedocs.yml",
    "content": "# .readthedocs.yml\nversion: 2\n\nformats: \n    - epub\n    - pdf\n    - htmlzip\n\n# python:\n#     version: 3.7\n#     install:\n#         - requirements: docs/requirements.txt\n#         - method: setuptools\n#     system_packages: false\n\nbuild:\n  image: latest\n \nconda:\n    environment: docs/conda.yml\n"
  },
  {
    "path": ".travis.yml",
    "content": "language: python\n\nmatrix:\n    include:\n        # basic tests withut a backend\n        - name: \"Unittests Python 3.5 No Backend\"\n          python: 3.5\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"None\"\n        - name: \"Unittests Python 3.6 No Backend\"\n          python: 3.6\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"None\"\n        - name: \"Unittests Python 3.7 No Backend\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"None\"\n    # SCIKIT-LEARN BACKEND TESTS\n        - name: \"Unittests Python 3.5 Sklearn Backend\"\n          python: 3.5\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Sklearn\"\n        - name: \"Unittests Python 3.6 Sklearn Backend\"\n          python: 3.6\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Sklearn\"\n        - name: \"Unittests Python 3.7 Sklearn Backend\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Sklearn\"\n              \n    # TENSORFLOW EAGER BACKEND TESTS\n        - name: \"Unittests Python 3.5 TF Eager Backend\"\n          python: 3.5\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"TFEager\"\n        - name: \"Unittests Python 3.6 TF Eager Backend\"\n          python: 3.6\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"TFEager\"\n        - name: \"Unittests Python 3.7 TF Eager Backend\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"TFEager\"\n\n        # TENSORFLOW GRAPH BACKEND TESTS\n        -   name: \"Unittests Python 3.5 TF Graph Backend\"\n            python: 3.5\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TFGraph\"\n        -   name: \"Unittests Python 3.6 TF Graph Backend\"\n            python: 3.6\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TFGraph\"\n        -   name: \"Unittests Python 3.7 TF Graph Backend\"\n            python: 3.7\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TFGraph\"\n\n      # PYTORCH BACKEND TESTS\n        - name: \"Unittests Python 3.5 Torch Backend\"\n          python: 3.5\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Torch\"\n        - name: \"Unittests Python 3.6 Torch Backend\"\n          python: 3.6\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Torch\"\n        - name: \"Unittests Python 3.7 Torch Backend\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Torch\"\n\n        # TORCHSCRIPT BACKEND TESTS\n        -   name: \"Unittests Python 3.5 TorchScript Backend\"\n            python: 3.5\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TorchScript\"\n        -   name: \"Unittests Python 3.6 TorchScript Backend\"\n            python: 3.6\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TorchScript\"\n        -   name: \"Unittests Python 3.7 TorchScript Backend\"\n            python: 3.7\n            dist: xenial\n            env:\n                - TEST_TYPE=\"unittests\"\n                - BACKEND=\"TorchScript\"\n\n      # CHAINER BACKEND TESTS\n        - name: \"Unittests Python 3.5 Chainer Backend\"\n          python: 3.5\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Chainer\"\n        - name: \"Unittests Python 3.6 Chainer Backend\"\n          python: 3.6\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Chainer\"\n        - name: \"Unittests Python 3.7 Chainer Backend\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"unittests\"\n              - BACKEND=\"Chainer\"\n              \n      # STATIC CHECKS\n        - name: \"Static Style Checks\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"style-check\"\n        - name: \"Documentation\"\n          python: 3.7\n          dist: xenial\n          env:\n              - TEST_TYPE=\"docs\"\n\n# command to install dependencies\nbefore_install:\n    - if [[ \"$TEST_TYPE\" == \"unittests\" ]]; then\n          bash scripts/ci/install_before_tests.sh;\n      elif [[ \"$TEST_TYPE\" == \"docs\" ]]; then\n          bash scripts/ci/install_before_docs.sh;\n      else\n          bash scripts/ci/install_before_style_check.sh;\n          pip install -r docs/requirements.txt;\n      fi\n\ninstall:\n    - pip install --no-deps .\n  \n# command to run tests\nscript:\n    # run tests or stylechecks\n    - if [[ \"$TEST_TYPE\" == \"unittests\" ]]; then\n          bash scripts/ci/run_tests.sh;\n      elif [[ \"$TEST_TYPE\" == \"docs\" ]]; then\n          bash scripts/ci/build_docs.sh;\n      else\n          bash scripts/ci/run_style_checks.sh;\n      fi\n\nafter_script:\n  - if [[ \"$TEST_TYPE\" == \"unittests\" ]]; then\n      codecov;\n\nbefore_deploy:\n    - cd $TRAVIS_BUILD_DIR\n\ndeploy:\n        - provider: pages\n          skip_cleanup: true\n          github_token: $GITHUB_TOKEN  # Set in travis-ci.org dashboard, marked secure\n          keep-history: true\n          on:\n              branch: master\n              condition: $TEST_TYPE = Docs\n              local_dir: docs/_build/html\n        - provider: pypi\n          user: $PYPI_USERNAME\n          password: $PYPI_PASSWORD\n          on:\n              tags: true\n              distributions: \"sdist bdist_wheel\"\n              skip_existing: true\n              condition: $TEST_TYPE = style-check\n"
  },
  {
    "path": "AUTHORS.rst",
    "content": "Authors\n==========\n\n\n**Core Development Team:**\n\n- Justus Schock: `GitHub <https://github.com/justusschock>`_ | `LinkedIn <https://www.linkedin.com/in/justus-schock/>`_ | `Google Scholar <https://scholar.google.de/citations?hl=de&user=KYf-ZHoAAAAJ>`_ | `E-Mail <mailto:justus.schock@rwth-aachen.de>`_\n- Michael Baumgartner: `GitHub <https://github.com/mibaumgartner>`_ | `LinkedIn <https://www.linkedin.com/in/michael-baumgartner-/>`_\n- Oliver Rippel: `GitHub <https://github.com/ORippler>`_ | `LinkedIn <https://www.linkedin.com/in/oliver-rippel-70361113a/>`_ | `Google Scholar <https://scholar.google.de/citations?user=DaTF8RsAAAAJ&hl=de>`_\n- Christoph Haarburger: `GitHub <https://github.com/haarburger>`_ | `LinkedIn <https://www.linkedin.com/in/chaarburger/>`_ | `Google Scholar <https://scholar.google.de/citations?user=Lb8DcccAAAAJ&hl=de>`_ \n\n**Contributions:**\n\n- Nicolas Horst\n\n- Alexander Moriz\n"
  },
  {
    "path": "CODEOWNERS",
    "content": "# Use this CODEOWNERS file for automatically request reviews from owners at PRs. \n# For Details see https://help.github.com/en/articles/about-code-owners\n# The order of the codeowners is simply alphabetically. \n\n# General Namespace (versioning backend resolution etc.)\n/delira/* @justusschock\n\n# DataLoading\n/delira/data_loading/ @justusschock @mibaumgartner\n\n# IO\n/delira/io/ @justusschock\n/delira/io/tf.py @ORippler\n\n# Logging\n/delira/logging/ @justusschock @ORippler\n\n# Models\n/delira/models/* @justusschock\n/delira/models/backends/* @justusschock\n/delira/models/backends/chainer/ @justusschock\n/delira/models/backends/sklearn/ @justusschock\n/delira/models/backends/tf_eager/ @justusschock @ORippler\n/delira/models/backends/tf_graph/ @ORippler\n/delira/models/backends/torch/ @justusschock @mibaumgartner\n/delira/models/backends/torchscript/ @justusschock\n\n# Training\n/delira/training/__init__.py @justusschock\n/delira/training/base_experiment.py @justusschock @mibaumgartner @ORippler\n/delira/training/base_trainer.py @justusschock @mibaumgartner @ORippler\n/delira/training/losses.py @mibaumgartner\n/delira/training/metrics.py @justusschock @mibaumgartner\n/delira/training/parameters.py @justusschock @mibaumgartner\n/delira/training/predictor.py @justusschock @mibaumgartner @ORippler\n/delira/training/utils.py @justusschock\n/delira/training/backends/* @justusschock\n/delira/training/backends/chainer/ @justusschock\n/delira/training/backends/sklearn/ @justusschock\n/delira/training/backends/tf_eager/ @justusschock @ORippler\n/delira/training/backends/tf_graph/ @ORippler\n/delira/taining/backends/torch/ @justusschock @mibaumgartner\n/delira/training/backends/torchscript/ @justusschock\n/delira/training/callbacks/ @justusschock\n\n# Utils\n/delira/utils/ @justusschock @mibaumgartner\n\n\n# Global repo stuff\n/* @justusschock\n/docker/ @haarburger\n/docs/ @justusschock\n/notebooks/* @mibaumgartner\n/paper/ @haarburger\n/requirements/ @haarburger @justusschock @mibaumgartner @ORippler\n/scripts/ci/ @justusschock\n\n# Tests\n/tests/* @justusschock\n/tests/data_loading @justusschock @mibaumgartner\n/tests/io/ @justusschock @ORippler\n/logging/ @justusschock @ORippler\n/tests/models/ @justusschock\n/tests/training/* @mibaumgarnter\n/tests/training/backends/ @justusschock\n"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "# Contributing to `delira`\n\nIf you are interested in contributing to `delira`, you will either\n\n* implement a new feature\n\nor \n\n* fix a bug.\n\nFor both types of contribution, the process is roughly the same:\n\n1. File an issue at [this repo] and discuss \nthe issue with us! Maybe we can give you some hints towards \nimplementation/fixing.\n\n2. Create your own fork of `delira`\n\n3. In your own fork, start a new branch for the implementation of your issue. \nMake sure to include basic unittests (We know, that the current code is not \nthat well tested so far, but we want to change this in future).\n\n> **Note:** To improve readability and maintainability, [PEP8 Style](https://www.python.org/dev/peps/pep-0008/) should always be followed (no exceptions).\n\n> **Note:** To ensure our CI/CD running correctly, you should *never* use relative imports but absolute ones.\n\n> **Note:** If you added a feature, you should also add it to the documentation\n\n4. After finishing the coding part, send a pull request to \n[this repo]\n\n5. Afterwards, have a look at your pull request since we might suggest some \nchanges.\n\n\nIf you are not familiar with creating a Pull Request, here are some guides:\n- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request\n- https://help.github.com/articles/creating-a-pull-request/\n\n\n## Development Install\n\nTo develop `delira` on your machine, here are some tips:\n\n1. Uninstall all existing installs of `delira`:\n```\nconda uninstall delira\npip uninstall delira\npip uninstall delira # run this command twice\n```\n\n2. Clone a copy of `delira` from source:\n\n```\ngit clone https://github.com/justusschock/delira.git\ncd delira\n```\n\n3. Install `delira` in `build develop` mode:\n\nInstall it via \n\n```\npython setup.py build develop\n```\n\nor \n\n```\npip install -e .\n```\n\nThis mode will symlink the python files from the current local source tree into the\npython install.\n\nHence, if you modify a python file, you do not need to reinstall `delira` \nagain and again\n\nIn case you want to reinstall, make sure that you uninstall `delira` first by running `pip uninstall delira`\nand `python setup.py clean`. Then you can install in `build develop` mode again.\n\n\n## Unit testing\n\nUnittests are located under `test/`. Run the entire test suite with\n\n```\npython test/run_test.py\n```\n\nor run individual test files, like `python test/test_dummy.py`, for individual test suites.\n\n### Better local unit tests with unittest\nTesting is done with a `unittest` suite\n\n## Writing documentation\n\n`delira` uses [numpy style](http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_numpy.html)\nfor formatting docstrings. Length of line inside docstrings block must be limited to 80 characters to\nfit into Jupyter documentation popups.\n\n[this repo]: https://github.com/delira-dev/delira\n"
  },
  {
    "path": "LICENSE",
    "content": "                    GNU AFFERO GENERAL PUBLIC LICENSE\n                       Version 3, 19 November 2007\n\n Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>\n Everyone is permitted to copy and distribute verbatim copies\n of this license document, but changing it is not allowed.\n\n                            Preamble\n\n  The GNU Affero General Public License is a free, copyleft license for\nsoftware and other kinds of works, specifically designed to ensure\ncooperation with the community in the case of network server software.\n\n  The licenses for most software and other practical works are designed\nto take away your freedom to share and change the works.  By contrast,\nour General Public Licenses are intended to guarantee your freedom to\nshare and change all versions of a program--to make sure it remains free\nsoftware for all its users.\n\n  When we speak of free software, we are referring to freedom, not\nprice.  Our General Public Licenses are designed to make sure that you\nhave the freedom to distribute copies of free software (and charge for\nthem if you wish), that you receive source code or can get it if you\nwant it, that you can change the software or use pieces of it in new\nfree programs, and that you know you can do these things.\n\n  Developers that use our General Public Licenses protect your rights\nwith two steps: (1) assert copyright on the software, and (2) offer\nyou this License which gives you legal permission to copy, distribute\nand/or modify the software.\n\n  A secondary benefit of defending all users' freedom is that\nimprovements made in alternate versions of the program, if they\nreceive widespread use, become available for other developers to\nincorporate.  Many developers of free software are heartened and\nencouraged by the resulting cooperation.  However, in the case of\nsoftware used on network servers, this result may fail to come about.\nThe GNU General Public License permits making a modified version and\nletting the public access it on a server without ever releasing its\nsource code to the public.\n\n  The GNU Affero General Public License is designed specifically to\nensure that, in such cases, the modified source code becomes available\nto the community.  It requires the operator of a network server to\nprovide the source code of the modified version running there to the\nusers of that server.  Therefore, public use of a modified version, on\na publicly accessible server, gives the public access to the source\ncode of the modified version.\n\n  An older license, called the Affero General Public License and\npublished by Affero, was designed to accomplish similar goals.  This is\na different license, not a version of the Affero GPL, but Affero has\nreleased a new version of the Affero GPL which permits relicensing under\nthis license.\n\n  The precise terms and conditions for copying, distribution and\nmodification follow.\n\n                       TERMS AND CONDITIONS\n\n  0. Definitions.\n\n  \"This License\" refers to version 3 of the GNU Affero General Public License.\n\n  \"Copyright\" also means copyright-like laws that apply to other kinds of\nworks, such as semiconductor masks.\n\n  \"The Program\" refers to any copyrightable work licensed under this\nLicense.  Each licensee is addressed as \"you\".  \"Licensees\" and\n\"recipients\" may be individuals or organizations.\n\n  To \"modify\" a work means to copy from or adapt all or part of the work\nin a fashion requiring copyright permission, other than the making of an\nexact copy.  The resulting work is called a \"modified version\" of the\nearlier work or a work \"based on\" the earlier work.\n\n  A \"covered work\" means either the unmodified Program or a work based\non the Program.\n\n  To \"propagate\" a work means to do anything with it that, without\npermission, would make you directly or secondarily liable for\ninfringement under applicable copyright law, except executing it on a\ncomputer or modifying a private copy.  Propagation includes copying,\ndistribution (with or without modification), making available to the\npublic, and in some countries other activities as well.\n\n  To \"convey\" a work means any kind of propagation that enables other\nparties to make or receive copies.  Mere interaction with a user through\na computer network, with no transfer of a copy, is not conveying.\n\n  An interactive user interface displays \"Appropriate Legal Notices\"\nto the extent that it includes a convenient and prominently visible\nfeature that (1) displays an appropriate copyright notice, and (2)\ntells the user that there is no warranty for the work (except to the\nextent that warranties are provided), that licensees may convey the\nwork under this License, and how to view a copy of this License.  If\nthe interface presents a list of user commands or options, such as a\nmenu, a prominent item in the list meets this criterion.\n\n  1. Source Code.\n\n  The \"source code\" for a work means the preferred form of the work\nfor making modifications to it.  \"Object code\" means any non-source\nform of a work.\n\n  A \"Standard Interface\" means an interface that either is an official\nstandard defined by a recognized standards body, or, in the case of\ninterfaces specified for a particular programming language, one that\nis widely used among developers working in that language.\n\n  The \"System Libraries\" of an executable work include anything, other\nthan the work as a whole, that (a) is included in the normal form of\npackaging a Major Component, but which is not part of that Major\nComponent, and (b) serves only to enable use of the work with that\nMajor Component, or to implement a Standard Interface for which an\nimplementation is available to the public in source code form.  A\n\"Major Component\", in this context, means a major essential component\n(kernel, window system, and so on) of the specific operating system\n(if any) on which the executable work runs, or a compiler used to\nproduce the work, or an object code interpreter used to run it.\n\n  The \"Corresponding Source\" for a work in object code form means all\nthe source code needed to generate, install, and (for an executable\nwork) run the object code and to modify the work, including scripts to\ncontrol those activities.  However, it does not include the work's\nSystem Libraries, or general-purpose tools or generally available free\nprograms which are used unmodified in performing those activities but\nwhich are not part of the work.  For example, Corresponding Source\nincludes interface definition files associated with source files for\nthe work, and the source code for shared libraries and dynamically\nlinked subprograms that the work is specifically designed to require,\nsuch as by intimate data communication or control flow between those\nsubprograms and other parts of the work.\n\n  The Corresponding Source need not include anything that users\ncan regenerate automatically from other parts of the Corresponding\nSource.\n\n  The Corresponding Source for a work in source code form is that\nsame work.\n\n  2. Basic Permissions.\n\n  All rights granted under this License are granted for the term of\ncopyright on the Program, and are irrevocable provided the stated\nconditions are met.  This License explicitly affirms your unlimited\npermission to run the unmodified Program.  The output from running a\ncovered work is covered by this License only if the output, given its\ncontent, constitutes a covered work.  This License acknowledges your\nrights of fair use or other equivalent, as provided by copyright law.\n\n  You may make, run and propagate covered works that you do not\nconvey, without conditions so long as your license otherwise remains\nin force.  You may convey covered works to others for the sole purpose\nof having them make modifications exclusively for you, or provide you\nwith facilities for running those works, provided that you comply with\nthe terms of this License in conveying all material for which you do\nnot control copyright.  Those thus making or running the covered works\nfor you must do so exclusively on your behalf, under your direction\nand control, on terms that prohibit them from making any copies of\nyour copyrighted material outside their relationship with you.\n\n  Conveying under any other circumstances is permitted solely under\nthe conditions stated below.  Sublicensing is not allowed; section 10\nmakes it unnecessary.\n\n  3. Protecting Users' Legal Rights From Anti-Circumvention Law.\n\n  No covered work shall be deemed part of an effective technological\nmeasure under any applicable law fulfilling obligations under article\n11 of the WIPO copyright treaty adopted on 20 December 1996, or\nsimilar laws prohibiting or restricting circumvention of such\nmeasures.\n\n  When you convey a covered work, you waive any legal power to forbid\ncircumvention of technological measures to the extent such circumvention\nis effected by exercising rights under this License with respect to\nthe covered work, and you disclaim any intention to limit operation or\nmodification of the work as a means of enforcing, against the work's\nusers, your or third parties' legal rights to forbid circumvention of\ntechnological measures.\n\n  4. Conveying Verbatim Copies.\n\n  You may convey verbatim copies of the Program's source code as you\nreceive it, in any medium, provided that you conspicuously and\nappropriately publish on each copy an appropriate copyright notice;\nkeep intact all notices stating that this License and any\nnon-permissive terms added in accord with section 7 apply to the code;\nkeep intact all notices of the absence of any warranty; and give all\nrecipients a copy of this License along with the Program.\n\n  You may charge any price or no price for each copy that you convey,\nand you may offer support or warranty protection for a fee.\n\n  5. Conveying Modified Source Versions.\n\n  You may convey a work based on the Program, or the modifications to\nproduce it from the Program, in the form of source code under the\nterms of section 4, provided that you also meet all of these conditions:\n\n    a) The work must carry prominent notices stating that you modified\n    it, and giving a relevant date.\n\n    b) The work must carry prominent notices stating that it is\n    released under this License and any conditions added under section\n    7.  This requirement modifies the requirement in section 4 to\n    \"keep intact all notices\".\n\n    c) You must license the entire work, as a whole, under this\n    License to anyone who comes into possession of a copy.  This\n    License will therefore apply, along with any applicable section 7\n    additional terms, to the whole of the work, and all its parts,\n    regardless of how they are packaged.  This License gives no\n    permission to license the work in any other way, but it does not\n    invalidate such permission if you have separately received it.\n\n    d) If the work has interactive user interfaces, each must display\n    Appropriate Legal Notices; however, if the Program has interactive\n    interfaces that do not display Appropriate Legal Notices, your\n    work need not make them do so.\n\n  A compilation of a covered work with other separate and independent\nworks, which are not by their nature extensions of the covered work,\nand which are not combined with it such as to form a larger program,\nin or on a volume of a storage or distribution medium, is called an\n\"aggregate\" if the compilation and its resulting copyright are not\nused to limit the access or legal rights of the compilation's users\nbeyond what the individual works permit.  Inclusion of a covered work\nin an aggregate does not cause this License to apply to the other\nparts of the aggregate.\n\n  6. Conveying Non-Source Forms.\n\n  You may convey a covered work in object code form under the terms\nof sections 4 and 5, provided that you also convey the\nmachine-readable Corresponding Source under the terms of this License,\nin one of these ways:\n\n    a) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by the\n    Corresponding Source fixed on a durable physical medium\n    customarily used for software interchange.\n\n    b) Convey the object code in, or embodied in, a physical product\n    (including a physical distribution medium), accompanied by a\n    written offer, valid for at least three years and valid for as\n    long as you offer spare parts or customer support for that product\n    model, to give anyone who possesses the object code either (1) a\n    copy of the Corresponding Source for all the software in the\n    product that is covered by this License, on a durable physical\n    medium customarily used for software interchange, for a price no\n    more than your reasonable cost of physically performing this\n    conveying of source, or (2) access to copy the\n    Corresponding Source from a network server at no charge.\n\n    c) Convey individual copies of the object code with a copy of the\n    written offer to provide the Corresponding Source.  This\n    alternative is allowed only occasionally and noncommercially, and\n    only if you received the object code with such an offer, in accord\n    with subsection 6b.\n\n    d) Convey the object code by offering access from a designated\n    place (gratis or for a charge), and offer equivalent access to the\n    Corresponding Source in the same way through the same place at no\n    further charge.  You need not require recipients to copy the\n    Corresponding Source along with the object code.  If the place to\n    copy the object code is a network server, the Corresponding Source\n    may be on a different server (operated by you or a third party)\n    that supports equivalent copying facilities, provided you maintain\n    clear directions next to the object code saying where to find the\n    Corresponding Source.  Regardless of what server hosts the\n    Corresponding Source, you remain obligated to ensure that it is\n    available for as long as needed to satisfy these requirements.\n\n    e) Convey the object code using peer-to-peer transmission, provided\n    you inform other peers where the object code and Corresponding\n    Source of the work are being offered to the general public at no\n    charge under subsection 6d.\n\n  A separable portion of the object code, whose source code is excluded\nfrom the Corresponding Source as a System Library, need not be\nincluded in conveying the object code work.\n\n  A \"User Product\" is either (1) a \"consumer product\", which means any\ntangible personal property which is normally used for personal, family,\nor household purposes, or (2) anything designed or sold for incorporation\ninto a dwelling.  In determining whether a product is a consumer product,\ndoubtful cases shall be resolved in favor of coverage.  For a particular\nproduct received by a particular user, \"normally used\" refers to a\ntypical or common use of that class of product, regardless of the status\nof the particular user or of the way in which the particular user\nactually uses, or expects or is expected to use, the product.  A product\nis a consumer product regardless of whether the product has substantial\ncommercial, industrial or non-consumer uses, unless such uses represent\nthe only significant mode of use of the product.\n\n  \"Installation Information\" for a User Product means any methods,\nprocedures, authorization keys, or other information required to install\nand execute modified versions of a covered work in that User Product from\na modified version of its Corresponding Source.  The information must\nsuffice to ensure that the continued functioning of the modified object\ncode is in no case prevented or interfered with solely because\nmodification has been made.\n\n  If you convey an object code work under this section in, or with, or\nspecifically for use in, a User Product, and the conveying occurs as\npart of a transaction in which the right of possession and use of the\nUser Product is transferred to the recipient in perpetuity or for a\nfixed term (regardless of how the transaction is characterized), the\nCorresponding Source conveyed under this section must be accompanied\nby the Installation Information.  But this requirement does not apply\nif neither you nor any third party retains the ability to install\nmodified object code on the User Product (for example, the work has\nbeen installed in ROM).\n\n  The requirement to provide Installation Information does not include a\nrequirement to continue to provide support service, warranty, or updates\nfor a work that has been modified or installed by the recipient, or for\nthe User Product in which it has been modified or installed.  Access to a\nnetwork may be denied when the modification itself materially and\nadversely affects the operation of the network or violates the rules and\nprotocols for communication across the network.\n\n  Corresponding Source conveyed, and Installation Information provided,\nin accord with this section must be in a format that is publicly\ndocumented (and with an implementation available to the public in\nsource code form), and must require no special password or key for\nunpacking, reading or copying.\n\n  7. Additional Terms.\n\n  \"Additional permissions\" are terms that supplement the terms of this\nLicense by making exceptions from one or more of its conditions.\nAdditional permissions that are applicable to the entire Program shall\nbe treated as though they were included in this License, to the extent\nthat they are valid under applicable law.  If additional permissions\napply only to part of the Program, that part may be used separately\nunder those permissions, but the entire Program remains governed by\nthis License without regard to the additional permissions.\n\n  When you convey a copy of a covered work, you may at your option\nremove any additional permissions from that copy, or from any part of\nit.  (Additional permissions may be written to require their own\nremoval in certain cases when you modify the work.)  You may place\nadditional permissions on material, added by you to a covered work,\nfor which you have or can give appropriate copyright permission.\n\n  Notwithstanding any other provision of this License, for material you\nadd to a covered work, you may (if authorized by the copyright holders of\nthat material) supplement the terms of this License with terms:\n\n    a) Disclaiming warranty or limiting liability differently from the\n    terms of sections 15 and 16 of this License; or\n\n    b) Requiring preservation of specified reasonable legal notices or\n    author attributions in that material or in the Appropriate Legal\n    Notices displayed by works containing it; or\n\n    c) Prohibiting misrepresentation of the origin of that material, or\n    requiring that modified versions of such material be marked in\n    reasonable ways as different from the original version; or\n\n    d) Limiting the use for publicity purposes of names of licensors or\n    authors of the material; or\n\n    e) Declining to grant rights under trademark law for use of some\n    trade names, trademarks, or service marks; or\n\n    f) Requiring indemnification of licensors and authors of that\n    material by anyone who conveys the material (or modified versions of\n    it) with contractual assumptions of liability to the recipient, for\n    any liability that these contractual assumptions directly impose on\n    those licensors and authors.\n\n  All other non-permissive additional terms are considered \"further\nrestrictions\" within the meaning of section 10.  If the Program as you\nreceived it, or any part of it, contains a notice stating that it is\ngoverned by this License along with a term that is a further\nrestriction, you may remove that term.  If a license document contains\na further restriction but permits relicensing or conveying under this\nLicense, you may add to a covered work material governed by the terms\nof that license document, provided that the further restriction does\nnot survive such relicensing or conveying.\n\n  If you add terms to a covered work in accord with this section, you\nmust place, in the relevant source files, a statement of the\nadditional terms that apply to those files, or a notice indicating\nwhere to find the applicable terms.\n\n  Additional terms, permissive or non-permissive, may be stated in the\nform of a separately written license, or stated as exceptions;\nthe above requirements apply either way.\n\n  8. Termination.\n\n  You may not propagate or modify a covered work except as expressly\nprovided under this License.  Any attempt otherwise to propagate or\nmodify it is void, and will automatically terminate your rights under\nthis License (including any patent licenses granted under the third\nparagraph of section 11).\n\n  However, if you cease all violation of this License, then your\nlicense from a particular copyright holder is reinstated (a)\nprovisionally, unless and until the copyright holder explicitly and\nfinally terminates your license, and (b) permanently, if the copyright\nholder fails to notify you of the violation by some reasonable means\nprior to 60 days after the cessation.\n\n  Moreover, your license from a particular copyright holder is\nreinstated permanently if the copyright holder notifies you of the\nviolation by some reasonable means, this is the first time you have\nreceived notice of violation of this License (for any work) from that\ncopyright holder, and you cure the violation prior to 30 days after\nyour receipt of the notice.\n\n  Termination of your rights under this section does not terminate the\nlicenses of parties who have received copies or rights from you under\nthis License.  If your rights have been terminated and not permanently\nreinstated, you do not qualify to receive new licenses for the same\nmaterial under section 10.\n\n  9. Acceptance Not Required for Having Copies.\n\n  You are not required to accept this License in order to receive or\nrun a copy of the Program.  Ancillary propagation of a covered work\noccurring solely as a consequence of using peer-to-peer transmission\nto receive a copy likewise does not require acceptance.  However,\nnothing other than this License grants you permission to propagate or\nmodify any covered work.  These actions infringe copyright if you do\nnot accept this License.  Therefore, by modifying or propagating a\ncovered work, you indicate your acceptance of this License to do so.\n\n  10. Automatic Licensing of Downstream Recipients.\n\n  Each time you convey a covered work, the recipient automatically\nreceives a license from the original licensors, to run, modify and\npropagate that work, subject to this License.  You are not responsible\nfor enforcing compliance by third parties with this License.\n\n  An \"entity transaction\" is a transaction transferring control of an\norganization, or substantially all assets of one, or subdividing an\norganization, or merging organizations.  If propagation of a covered\nwork results from an entity transaction, each party to that\ntransaction who receives a copy of the work also receives whatever\nlicenses to the work the party's predecessor in interest had or could\ngive under the previous paragraph, plus a right to possession of the\nCorresponding Source of the work from the predecessor in interest, if\nthe predecessor has it or can get it with reasonable efforts.\n\n  You may not impose any further restrictions on the exercise of the\nrights granted or affirmed under this License.  For example, you may\nnot impose a license fee, royalty, or other charge for exercise of\nrights granted under this License, and you may not initiate litigation\n(including a cross-claim or counterclaim in a lawsuit) alleging that\nany patent claim is infringed by making, using, selling, offering for\nsale, or importing the Program or any portion of it.\n\n  11. Patents.\n\n  A \"contributor\" is a copyright holder who authorizes use under this\nLicense of the Program or a work on which the Program is based.  The\nwork thus licensed is called the contributor's \"contributor version\".\n\n  A contributor's \"essential patent claims\" are all patent claims\nowned or controlled by the contributor, whether already acquired or\nhereafter acquired, that would be infringed by some manner, permitted\nby this License, of making, using, or selling its contributor version,\nbut do not include claims that would be infringed only as a\nconsequence of further modification of the contributor version.  For\npurposes of this definition, \"control\" includes the right to grant\npatent sublicenses in a manner consistent with the requirements of\nthis License.\n\n  Each contributor grants you a non-exclusive, worldwide, royalty-free\npatent license under the contributor's essential patent claims, to\nmake, use, sell, offer for sale, import and otherwise run, modify and\npropagate the contents of its contributor version.\n\n  In the following three paragraphs, a \"patent license\" is any express\nagreement or commitment, however denominated, not to enforce a patent\n(such as an express permission to practice a patent or covenant not to\nsue for patent infringement).  To \"grant\" such a patent license to a\nparty means to make such an agreement or commitment not to enforce a\npatent against the party.\n\n  If you convey a covered work, knowingly relying on a patent license,\nand the Corresponding Source of the work is not available for anyone\nto copy, free of charge and under the terms of this License, through a\npublicly available network server or other readily accessible means,\nthen you must either (1) cause the Corresponding Source to be so\navailable, or (2) arrange to deprive yourself of the benefit of the\npatent license for this particular work, or (3) arrange, in a manner\nconsistent with the requirements of this License, to extend the patent\nlicense to downstream recipients.  \"Knowingly relying\" means you have\nactual knowledge that, but for the patent license, your conveying the\ncovered work in a country, or your recipient's use of the covered work\nin a country, would infringe one or more identifiable patents in that\ncountry that you have reason to believe are valid.\n\n  If, pursuant to or in connection with a single transaction or\narrangement, you convey, or propagate by procuring conveyance of, a\ncovered work, and grant a patent license to some of the parties\nreceiving the covered work authorizing them to use, propagate, modify\nor convey a specific copy of the covered work, then the patent license\nyou grant is automatically extended to all recipients of the covered\nwork and works based on it.\n\n  A patent license is \"discriminatory\" if it does not include within\nthe scope of its coverage, prohibits the exercise of, or is\nconditioned on the non-exercise of one or more of the rights that are\nspecifically granted under this License.  You may not convey a covered\nwork if you are a party to an arrangement with a third party that is\nin the business of distributing software, under which you make payment\nto the third party based on the extent of your activity of conveying\nthe work, and under which the third party grants, to any of the\nparties who would receive the covered work from you, a discriminatory\npatent license (a) in connection with copies of the covered work\nconveyed by you (or copies made from those copies), or (b) primarily\nfor and in connection with specific products or compilations that\ncontain the covered work, unless you entered into that arrangement,\nor that patent license was granted, prior to 28 March 2007.\n\n  Nothing in this License shall be construed as excluding or limiting\nany implied license or other defenses to infringement that may\notherwise be available to you under applicable patent law.\n\n  12. No Surrender of Others' Freedom.\n\n  If conditions are imposed on you (whether by court order, agreement or\notherwise) that contradict the conditions of this License, they do not\nexcuse you from the conditions of this License.  If you cannot convey a\ncovered work so as to satisfy simultaneously your obligations under this\nLicense and any other pertinent obligations, then as a consequence you may\nnot convey it at all.  For example, if you agree to terms that obligate you\nto collect a royalty for further conveying from those to whom you convey\nthe Program, the only way you could satisfy both those terms and this\nLicense would be to refrain entirely from conveying the Program.\n\n  13. Remote Network Interaction; Use with the GNU General Public License.\n\n  Notwithstanding any other provision of this License, if you modify the\nProgram, your modified version must prominently offer all users\ninteracting with it remotely through a computer network (if your version\nsupports such interaction) an opportunity to receive the Corresponding\nSource of your version by providing access to the Corresponding Source\nfrom a network server at no charge, through some standard or customary\nmeans of facilitating copying of software.  This Corresponding Source\nshall include the Corresponding Source for any work covered by version 3\nof the GNU General Public License that is incorporated pursuant to the\nfollowing paragraph.\n\n  Notwithstanding any other provision of this License, you have\npermission to link or combine any covered work with a work licensed\nunder version 3 of the GNU General Public License into a single\ncombined work, and to convey the resulting work.  The terms of this\nLicense will continue to apply to the part which is the covered work,\nbut the work with which it is combined will remain governed by version\n3 of the GNU General Public License.\n\n  14. Revised Versions of this License.\n\n  The Free Software Foundation may publish revised and/or new versions of\nthe GNU Affero General Public License from time to time.  Such new versions\nwill be similar in spirit to the present version, but may differ in detail to\naddress new problems or concerns.\n\n  Each version is given a distinguishing version number.  If the\nProgram specifies that a certain numbered version of the GNU Affero General\nPublic License \"or any later version\" applies to it, you have the\noption of following the terms and conditions either of that numbered\nversion or of any later version published by the Free Software\nFoundation.  If the Program does not specify a version number of the\nGNU Affero General Public License, you may choose any version ever published\nby the Free Software Foundation.\n\n  If the Program specifies that a proxy can decide which future\nversions of the GNU Affero General Public License can be used, that proxy's\npublic statement of acceptance of a version permanently authorizes you\nto choose that version for the Program.\n\n  Later license versions may give you additional or different\npermissions.  However, no additional obligations are imposed on any\nauthor or copyright holder as a result of your choosing to follow a\nlater version.\n\n  15. Disclaimer of Warranty.\n\n  THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY\nAPPLICABLE LAW.  EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT\nHOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM \"AS IS\" WITHOUT WARRANTY\nOF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,\nTHE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR\nPURPOSE.  THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM\nIS WITH YOU.  SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF\nALL NECESSARY SERVICING, REPAIR OR CORRECTION.\n\n  16. Limitation of Liability.\n\n  IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING\nWILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS\nTHE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY\nGENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE\nUSE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF\nDATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD\nPARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),\nEVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF\nSUCH DAMAGES.\n\n  17. Interpretation of Sections 15 and 16.\n\n  If the disclaimer of warranty and limitation of liability provided\nabove cannot be given local legal effect according to their terms,\nreviewing courts shall apply local law that most closely approximates\nan absolute waiver of all civil liability in connection with the\nProgram, unless a warranty or assumption of liability accompanies a\ncopy of the Program in return for a fee.\n\n                     END OF TERMS AND CONDITIONS\n\n            How to Apply These Terms to Your New Programs\n\n  If you develop a new program, and you want it to be of the greatest\npossible use to the public, the best way to achieve this is to make it\nfree software which everyone can redistribute and change under these terms.\n\n  To do so, attach the following notices to the program.  It is safest\nto attach them to the start of each source file to most effectively\nstate the exclusion of warranty; and each file should have at least\nthe \"copyright\" line and a pointer to where the full notice is found.\n\n    <one line to give the program's name and a brief idea of what it does.>\n    Copyright (C) <year>  <name of author>\n\n    This program is free software: you can redistribute it and/or modify\n    it under the terms of the GNU Affero General Public License as published\n    by the Free Software Foundation, either version 3 of the License, or\n    (at your option) any later version.\n\n    This program is distributed in the hope that it will be useful,\n    but WITHOUT ANY WARRANTY; without even the implied warranty of\n    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n    GNU Affero General Public License for more details.\n\n    You should have received a copy of the GNU Affero General Public License\n    along with this program.  If not, see <https://www.gnu.org/licenses/>.\n\nAlso add information on how to contact you by electronic and paper mail.\n\n  If your software can interact with users remotely through a computer\nnetwork, you should also make sure that it provides a way for users to\nget its source.  For example, if your program is a web application, its\ninterface could display a \"Source\" link that leads users to an archive\nof the code.  There are many ways you could offer source, and different\nsolutions will be better for different programs; see section 13 for the\nspecific requirements.\n\n  You should also get your employer (if you work as a programmer) or school,\nif any, to sign a \"copyright disclaimer\" for the program, if necessary.\nFor more information on this, and how to apply and follow the GNU AGPL, see\n<https://www.gnu.org/licenses/>.\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include requirements/*.txt\ninclude *.md\ninclude LICENSE\ninclude notebooks/*.ipynb\ninclude setup.cfg\ninclude versioneer.py\ninclude delira/_version.py\n"
  },
  {
    "path": "README.md",
    "content": "[<img src=\"https://img.shields.io/badge/chat-slack%20channel-75BBC4.svg\">](https://join.slack.com/t/deliradev/shared_invite/enQtNjI1MjA4MjQzMzQ2LTUzNTQ0MjQyNjJjNzgyODczY2Y1YjYxNjA3ZmQ0MGFhODhkYzQ4M2RjMGM1YWM3YWU5MDM0ZjdiNTQ4MmQ0ZDk)\n[![PyPI version](https://badge.fury.io/py/delira.svg)](https://badge.fury.io/py/delira) [![Build Status](https://travis-ci.com/delira-dev/delira.svg?branch=master)](https://travis-ci.com/delira-dev/delira) [![Documentation Status](https://readthedocs.org/projects/delira/badge/?version=master)](https://delira.readthedocs.io/en/master/?badge=master) [![codecov](https://codecov.io/gh/justusschock/delira/branch/master/graph/badge.svg)](https://codecov.io/gh/delira-dev/delira)\n[![DOI](http://joss.theoj.org/papers/10.21105/joss.01488/status.svg)](https://doi.org/10.21105/joss.01488)\n\n![logo](docs/_static/logo/delira.svg \"delira - A Backend Agnostic High Level Deep Learning Library\")\n\n# delira - A Backend Agnostic High Level Deep Learning Library\nAuthors: [Justus Schock, Michael Baumgartner, Oliver Rippel, Christoph Haarburger](AUTHORS.rst)\n\nCopyright (C) 2020 by RWTH Aachen University                      \nhttp://www.rwth-aachen.de                                             \n                                                                         \nLicense:                                                                                                                                       \nThis software is dual-licensed under:                                 \n• Commercial license (please contact: lfb@lfb.rwth-aachen.de)         \n• AGPL (GNU Affero General Public License) open source license        \n\n## Introduction\n`delira` is designed to work as a backend agnostic high level deep learning library. You can choose among several computation [backends](#choose-backend).\nIt allows you to compare different models written for different backends without rewriting them.\n\nFor this case, `delira` couples the entire training and prediction logic in backend-agnostic modules to achieve identical behavior for training in all backends.\n\n`delira` is designed in a very modular way so that almost everything is easily exchangeable or customizable.\n\nA (non-comprehensive) list of the features included in `delira`:\n* Dataset loading\n* Dataset sampling\n* Augmentation (multi-threaded) including 3D images with any number of channels (based on [`batchgenerators`](https://github.com/MIC-DKFZ/batchgenerators))\n* A generic trainer class that implements the training process for all [backends](#choose-backend)\n* Training monitoring using [Visdom](https://github.com/facebookresearch/visdom) or [Tensorboard](https://www.tensorflow.org/guide/summaries_and_tensorboard)\n* Model save and load functions\n* Already impelemented Datasets\n* Many operations and utilities for medical imaging\n\n## What about the name?\n`delira` started as a library to enable deep learning research and fast prototyping in medical imaging (especially in radiology). \nThat's also where the name comes from: `delira` was an acronym for **DE**ep **L**earning **I**n **RA**diology*. \nTo adapt many other use cases we changed the framework's focus quite a bit, although we are still having many medical-related utilities \nand are working on constantly factoring them out.\n\n\n## Installation\n\n### Choose Backend\n\nYou may choose a backend from the list below. If your desired backend is not listed and you want to add it, please open an issue (it should not be hard at all) and we will guide you during the process of doing so.\n\n\n| Backend                                                   | Binary Installation               | Source Installation                                                                               | Notes                                                                                                                                                 |\n|-----------------------------------------------------------|-----------------------------------|---------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|\n| None                                                      | `pip install delira`              | `pip install git+https://github.com/delira-dev/delira.git`                                      | Training not possible if backend is not installed separately                                                                                          |\n| [`torch`](https://pytorch.org)                            | `pip install delira[torch]`       | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[torch]`       | `delira` with `torch` backend supports mixed-precision training via [NVIDIA/apex](https://github.com/NVIDIA/apex.git) (must be installed separately). |\n| [`torchscript`](https://pytorch.org/docs/stable/jit.html) | `pip install delira[torchscript]` | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[torchscript]` | The `torchscript` backend currently supports only single-GPU-training                                                                                 |\n| [`tensorflow eager`](https://www.tensorflow.org/)         | `pip install delira[tensorflow]`  | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[tensorflow]`  | the `tensorflow` backend is still very experimental and lacks some [features](https://github.com/delira-dev/delira/issues/47)                       |\n| [`tensorflow graph`](https://www.tensorflow.org/)         | `pip install delira[tensorflow]`  | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[tensorflow]`  | the `tensorflow` backend is still very experimental and lacks some [features](https://github.com/delira-dev/delira/issues/47)                       |\n| [`scikit-learn`](https://scikit-learn.org/stable/)        | `pip install delira`              | `pip install git+https://github.com/delira-dev/delira.git`                                      | /                                                                                                                                                     |\n| [`chainer`](https://chainer.org/)                         | `pip install delira[chainer]`     | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[chainer]`     | /\n| Full                                                      | `pip install delira[full]`        | `git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[full]`        | All backends will be installed.                                                                                                                       |\n\n### Docker\nThe easiest way to use `delira` is via docker (with the [nvidia-runtime](https://github.com/NVIDIA/nvidia-docker) for GPU-support) and using the [Dockerfile](docker/Dockerfile) or the [prebuild-images](https://cloud.docker.com/u/justusschock/repository/docker/justusschock/delira).\n\n### Chat\nWe have a [community chat on slack](https://deliradev.slack.com). If you need an invitation, just follow [this link](https://join.slack.com/t/deliradev/shared_invite/enQtNjI1MjA4MjQzMzQ2LTUzNTQ0MjQyNjJjNzgyODczY2Y1YjYxNjA3ZmQ0MGFhODhkYzQ4M2RjMGM1YWM3YWU5MDM0ZjdiNTQ4MmQ0ZDk).\n\n## Getting Started\nThe best way to learn how to use is to have a look at the [tutorial notebook](notebooks/tutorial_delira.ipynb).\nExample implementations for classification problems, segmentation approaches and GANs are also provided in the [notebooks](notebooks) folder.\n\n## Documentation\nThe docs are hosted on [ReadTheDocs/Delira](https://delira.rtfd.io).\nThe documentation of the latest master branch can always be found at the project's [github page](https://delira-dev.github.io/delira/).\n\n## Contributing\nIf you find a bug or have an idea for an improvement, please have a look at our [contribution guideline](CONTRIBUTING.md).\n"
  },
  {
    "path": "delira/__init__.py",
    "content": "from delira._debug_mode import get_current_debug_mode, switch_debug_mode, \\\n    set_debug_mode\nfrom delira._backends import get_backends, seed_all\n\nfrom ._version import get_versions as _get_versions\n\nimport warnings\nwarnings.simplefilter('default', DeprecationWarning)\nwarnings.simplefilter('ignore', ImportWarning)\n\n\n__version__ = _get_versions()['version']\ndel _get_versions\n"
  },
  {
    "path": "delira/_backends.py",
    "content": "import os\nimport json\nfrom delira._version import get_versions as _get_versions\n\n# to register new possible backends, they have to be added to this list.\n# each backend should consist of a tuple of length 2 with the first entry\n# being the package import name and the second being the backend abbreviation.\n# E.g. TensorFlow's package is named 'tensorflow' but if the package is found,\n# it will be considered as 'tf' later on\n__POSSIBLE_BACKENDS = ((\"torch\", \"torch\"),\n                       (\"tensorflow\", \"tf\"),\n                       (\"chainer\", \"chainer\"),\n                       (\"sklearn\", \"sklearn\"))\n__BACKENDS = ()\n\n\ndef _determine_backends():\n    \"\"\"\n    Internal Helper Function to determine the currently valid backends by\n    trying to import them. The valid backends are not returned, but appended\n    to the global ``__BACKENDS`` variable\n\n    \"\"\"\n\n    _config_file = __file__.replace(\"_backends.py\", \".delira\")\n    # look for config file to determine backend\n    # if file exists: load config into environment variables\n\n    if not os.path.isfile(_config_file):\n        _backends = {}\n        # try to import all possible backends to determine valid backends\n\n        import importlib\n        for curr_backend in __POSSIBLE_BACKENDS:\n            try:\n                assert len(curr_backend) == 2\n                assert all([isinstance(_tmp, str) for _tmp in curr_backend]), \\\n                    \"All entries in current backend must be strings\"\n\n                # check if backend can be imported\n                bcknd = importlib.util.find_spec(curr_backend[0])\n\n                if bcknd is not None:\n                    _backends[curr_backend[1]] = True\n                else:\n                    _backends[curr_backend[1]] = False\n                del bcknd\n\n            except ValueError:\n                _backends[curr_backend[1]] = False\n\n        with open(_config_file, \"w\") as f:\n            json.dump({\"version\": _get_versions()['version'],\n                       \"backend\": _backends},\n                      f, sort_keys=True, indent=4)\n\n        del _backends\n\n    # set values from config file to variable and empty Backend-List before\n    global __BACKENDS\n    __BACKENDS = []\n    with open(_config_file) as f:\n        _config_dict = json.load(f)\n    for key, val in _config_dict.pop(\"backend\").items():\n        if val:\n            __BACKENDS.append(key.upper())\n    del _config_dict\n\n    del _config_file\n\n    # make __BACKENDS non mutable\n    __BACKENDS = tuple(__BACKENDS)\n\n\ndef get_backends():\n    \"\"\"\n    Return List of currently available backends\n\n    Returns\n    -------\n    list\n        list of strings containing the currently installed backends\n    \"\"\"\n    global __BACKENDS\n\n    if not __BACKENDS:\n        _determine_backends()\n    return __BACKENDS\n\n\ndef seed_all(seed):\n    \"\"\"\n    Helper Function to seed all available backends\n\n    Parameters\n    ----------\n    seed : int\n        the new random seed\n\n    \"\"\"\n    import sys\n\n    import numpy as np\n    np.random.seed(seed)\n\n    import random\n    random.seed = seed\n\n    if \"torch\" in sys.modules and \"TORCH\" in get_backends():\n        import torch\n        torch.random.manual_seed(seed)\n\n    elif \"tensorflow\" in sys.modules and \"TF\" in get_backends():\n        import tensorflow as tf\n        tf.random.set_random_seed(seed)\n\n    elif \"chainer\" in sys.modules and \"CHAINER\" in get_backends():\n        try:\n            import cupy\n            cupy.random.seed(seed)\n        except ImportError:\n            pass\n"
  },
  {
    "path": "delira/_debug_mode.py",
    "content": "__DEBUG_MODE = False\n\n# Functions to get and set the internal __DEBUG_MODE variable. This variable\n# currently only defines whether to use multiprocessing or not. At the moment\n# this is only used inside the DataManager, which either returns a\n# MultiThreadedAugmenter or a SingleThreadedAugmenter depending on the current\n# debug mode.\n# All other functions using multiprocessing should be aware of this and\n# implement a functionality without multiprocessing\n# (even if this slows down things a lot!).\n\n\ndef get_current_debug_mode():\n    \"\"\"\n    Getter function for the current debug mode\n    Returns\n    -------\n    bool\n        current debug mode\n    \"\"\"\n    return __DEBUG_MODE\n\n\ndef switch_debug_mode():\n    \"\"\"\n    Alternates the current debug mode\n    \"\"\"\n    set_debug_mode(not get_current_debug_mode())\n\n\ndef set_debug_mode(mode: bool):\n    \"\"\"\n    Sets a new debug mode\n    Parameters\n    ----------\n    mode : bool\n        the new debug mode\n    \"\"\"\n    global __DEBUG_MODE\n    __DEBUG_MODE = mode\n"
  },
  {
    "path": "delira/_version.py",
    "content": "# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain. Generated by\n# versioneer-0.18 (https://github.com/warner/python-versioneer)\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\n\n\ndef get_keywords():\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"$Format:%d$\"\n    git_full = \"$Format:%H$\"\n    git_date = \"$Format:%ci$\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_config():\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"pep440\"\n    cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = \"\"\n    cfg.versionfile_source = \"delira/_version.py\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY = {}\nHANDLERS = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Decorator to mark a method as the handler for a particular VCS.\"\"\"\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,\n                env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    p = None\n    for c in commands:\n        try:\n            dispcmd = str([c] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            p = subprocess.Popen([c] + args, cwd=cwd, env=env,\n                                 stdout=subprocess.PIPE,\n                                 stderr=(subprocess.PIPE if hide_stderr\n                                         else None))\n            break\n        except EnvironmentError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = p.communicate()[0].strip()\n    if sys.version_info[0] >= 3:\n        stdout = stdout.decode()\n    if p.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, p.returncode\n    return stdout, p.returncode\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for i in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        else:\n            rootdirs.append(root)\n            root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" %\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        f = open(versionfile_abs, \"r\")\n        for line in f.readlines():\n            if line.strip().startswith(\"git_refnames =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"refnames\"] = mo.group(1)\n            if line.strip().startswith(\"git_full =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"full\"] = mo.group(1)\n            if line.strip().startswith(\"git_date =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"date\"] = mo.group(1)\n        f.close()\n    except EnvironmentError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if not keywords:\n        raise NotThisMethod(\"no keywords at all, weird\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = set([r.strip() for r in refnames.strip(\"()\").split(\",\")])\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = set([r for r in refs if re.search(r'\\d', r)])\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    out, rc = run_command(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                          hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = run_command(GITS, [\"describe\", \"--tags\", \"--dirty\",\n                                          \"--always\", \"--long\",\n                                          \"--match\", \"%s*\" % tag_prefix],\n                                   cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = run_command(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparseable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%s'\"\n                               % describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%s' doesn't start with prefix '%s'\"\n                               % (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        count_out, rc = run_command(GITS, [\"rev-list\", \"HEAD\", \"--count\"],\n                                    cwd=root)\n        pieces[\"distance\"] = int(count_out)  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = run_command(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"],\n                       cwd=root)[0].strip()\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.post.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \".post.dev%d\" % pieces[\"distance\"]\n    else:\n        # exception #1\n        rendered = \"0.post.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Eexceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions():\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for i in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n"
  },
  {
    "path": "delira/data_loading/__init__.py",
    "content": "# basic imports\nfrom delira.data_loading.data_loader import DataLoader\nfrom delira.data_loading.dataset import AbstractDataset, IterableDataset, \\\n    DictDataset, BaseCacheDataset, BaseExtendCacheDataset, BaseLazyDataset, \\\n    ConcatDataset\nfrom delira.data_loading.augmenter import Augmenter\nfrom delira.data_loading.data_manager import DataManager\nfrom delira.data_loading.load_utils import LoadSample, LoadSampleLabel\n\nfrom delira.data_loading.sampler import *\nfrom delira import get_backends as _get_backends\n\n# if numba is installed: Import Numba Transforms\ntry:\n    from delira.data_loading.numba_transform import NumbaTransform, \\\n        NumbaTransformWrapper, NumbaCompose\nexcept ImportError:\n    pass\n"
  },
  {
    "path": "delira/data_loading/augmenter.py",
    "content": "import multiprocessing\nfrom multiprocessing import connection as mpconnection\nfrom collections import Callable\nimport abc\nimport os\nimport sys\nimport numpy as np\nimport random\n\nfrom delira.data_loading.sampler import AbstractSampler, BatchSampler\nfrom delira.data_loading.data_loader import DataLoader\nfrom delira import get_current_debug_mode\n\n\nclass AbstractAugmenter(object):\n    \"\"\"\n    Basic Augmenter Class providing a general Augmenter API\n    \"\"\"\n\n    def __init__(\n            self,\n            data_loader,\n            batchsize,\n            sampler,\n            transforms=None,\n            seed=1,\n            drop_last=False):\n        \"\"\"\n        Parameters\n        ----------\n        data_loader : :class:`DataLoader`\n            the dataloader, loading samples for given indices\n        batchsize : int\n            the batchsize to use for sampling\n        sampler : :class:`AbstractSampler`\n            the sampler_old (may be batch sampler_old or usual sampler_old),\n            defining the actual sampling strategy; Is an iterable yielding\n            indices\n        transforms : :class:`collections.Callable`\n            the transforms to apply; defaults to None\n        seed : int\n            the basic seed; default: 1\n        drop_last : bool\n            whether to drop the last (possibly smaller) batch or not\n        \"\"\"\n\n        self._data_loader = data_loader\n\n        if not isinstance(sampler, BatchSampler):\n            if isinstance(sampler, AbstractSampler):\n                sampler = BatchSampler(sampler, batchsize,\n                                       drop_last=drop_last)\n            else:\n                raise TypeError(\"Invalid Sampler given: %s\" % str(sampler))\n\n        self._sampler = sampler\n\n        self._drop_last = drop_last\n\n        self._transforms = transforms\n        self._seed = seed\n\n        # seed numpy.random and random as these are the random number\n        # generators, which might be used for sampling\n        np.random.seed(seed)\n        random.seed(seed)\n\n    @abc.abstractmethod\n    def __iter__(self):\n        raise NotImplementedError\n\n\nclass _ParallelAugmenter(AbstractAugmenter):\n    \"\"\"\n    An Augmenter that loads and augments multiple batches in parallel\n    \"\"\"\n\n    def __init__(self, data_loader, batchsize, sampler, num_processes=None,\n                 transforms=None, seed=1, drop_last=False):\n        \"\"\"\n        Parameters\n        ----------\n        data_loader : :class:`DataLoader`\n            the dataloader, loading samples for given indices\n         batchsize : int\n            the batchsize to use for sampling\n        sampler : :class:`AbstractSampler`\n            the sampler_old (may be batch sampler_old or usual sampler_old),\n            defining the actual sampling strategy; Is an iterable yielding\n            indices\n        num_processes : int\n            the number of processes to use for dataloading + augmentation;\n            if None: the number of available CPUs will be used as number of\n            processes\n        transforms : :class:`collections.Callable`\n            the transforms to apply; defaults to None\n        seed : int\n            the basic seed; default: 1\n        drop_last : bool\n            whether to drop the last (possibly smaller) batch or not\n        \"\"\"\n\n        super().__init__(data_loader, batchsize, sampler, transforms, seed,\n                         drop_last)\n\n        if num_processes is None:\n            num_processes = os.cpu_count()\n\n        self._num_processes = num_processes\n\n        self._processes = []\n\n        self._index_pipes = []\n        self._data_pipes = []\n\n        self._index_pipe_counter = 0\n        self._data_pipe_counter = 0\n        self._abort_event = None\n        self._data_queued = []\n        self._processes_running = False\n\n    @property\n    def abort_event(self):\n        \"\"\"\n        Property to access the abortion Event\n\n        Returns\n        -------\n        :class:`multiprocessing.Event`\n            the abortion event\n        \"\"\"\n        return self._abort_event\n\n    @abort_event.setter\n    def abort_event(self, new_event):\n        \"\"\"\n        Setter for the abortion Event;\n\n        Parameters\n        ----------\n        new_event : class:`multiprocessing.Event`\n            the new event\n        \"\"\"\n\n        self._abort_event = new_event\n\n    def _start_processes(self):\n        \"\"\"\n        Starts new processes and pipes for interprocess communication\n        \"\"\"\n\n        # reset abortion event\n        self.abort_event = multiprocessing.Event()\n\n        # for each process do:\n        for i in range(self._num_processes):\n            # start two oneway pipes (one for passing index to workers\n            # and one for passing back data to main process)\n            recv_conn_out, send_conn_out = multiprocessing.Pipe(duplex=False)\n            recv_conn_in, send_conn_in = multiprocessing.Pipe(duplex=False)\n\n            # create the actual process\n            process = _WorkerProcess(dataloader=self._data_loader,\n                                     output_pipe=send_conn_out,\n                                     index_pipe=recv_conn_in,\n                                     transforms=self._transforms,\n                                     abort_event=self._abort_event,\n                                     process_id=i)\n            process.daemon = True\n            process.start()\n            # wait until process was created and started\n            while not process.is_alive():\n                pass\n\n            # append process and pipes to list\n            self._processes.append(process)\n            self._index_pipes.append(send_conn_in),\n            self._data_pipes.append(recv_conn_out)\n            self._data_queued.append(0)\n            self._processes_running = True\n\n    def _shutdown_processes(self):\n        \"\"\"\n        Shuts down the processes and resets all related flags and counters\n        \"\"\"\n\n        # create copy to avoid modifying the list we iterate over\n        worker = list(\n            zip(self._data_pipes, self._index_pipes, self._processes))\n\n        for _data_conn, _index_conn, _process in worker:\n\n            _index_conn.send(None)\n\n            _process.join()\n            if sys.version_info >= (3, 7):\n                _process.close()\n            else:\n                _process.terminate()\n\n            _index_conn.close()\n            _data_conn.close()\n\n            self._data_pipes.pop()\n            self._data_queued.pop()\n            self._index_pipes.pop()\n            self._processes.pop()\n\n        # reset running process flag and counters\n        self._processes_running = False\n        self._data_pipe_counter = 0\n        self._index_pipe_counter = 0\n\n    @property\n    def _next_index_pipe(self):\n        \"\"\"\n        Property implementing switch to next index pipe\n        \"\"\"\n        ctr = self._index_pipe_counter\n        new_ctr = (self._index_pipe_counter + 1) % self._num_processes\n        self._index_pipe_counter = new_ctr\n\n        return ctr\n\n    @property\n    def _next_data_pipe(self):\n        \"\"\"\n        Property implementing switch to next data pipe\n        \"\"\"\n        ctr = self._data_pipe_counter\n        new_ctr = (self._data_pipe_counter + 1) % self._num_processes\n        self._data_pipe_counter = new_ctr\n\n        return ctr\n\n    def _enqueue_indices(self, sample_idxs):\n        \"\"\"\n        Enqueues a set of indices to workers while iterating over workers in\n        cyclic way\n        Parameters\n        ----------\n        sample_idxs : list\n            the indices to enqueue to the workers\n        \"\"\"\n\n        # iterating over all batch indices\n        for idxs in sample_idxs:\n            # switch to next counter\n            index_pipe_ctr = self._next_index_pipe\n            # increase number of queued batches for current worker\n            self._data_queued[index_pipe_ctr] += 1\n            # enqueue indices to worker\n            self._index_pipes[index_pipe_ctr].send(idxs)\n\n    def _receive_data(self):\n        \"\"\"\n        Receives data from worker\n        \"\"\"\n        # switching to next worker\n        _data_pipe = self._next_data_pipe\n\n        # receive data from worker\n        data = self._data_pipes[_data_pipe].recv()\n        # decrease number of enqueued batches for current worker\n        self._data_queued[_data_pipe] -= 1\n\n        return data\n\n    def __iter__(self):\n        self._start_processes()\n\n        sampler_iter = iter(self._sampler)\n        all_sampled = False\n\n        try:\n            # start by enqueuing two items per process as buffer\n            _indices = []\n            try:\n                for i in range(self._num_processes * 2):\n                    idxs = next(sampler_iter)\n                    _indices.append(idxs)\n            except StopIteration:\n                all_sampled = True\n\n            self._enqueue_indices(_indices)\n\n            # iterate while not all data has been sampled and any data is\n            # enqueued\n            while True:\n\n                if self.abort_event.is_set():\n                    raise RuntimeError(\"Abort Event was set in one of the \"\n                                       \"workers\")\n\n                # enqueue additional indices if sampler_old was not already\n                # exhausted\n                try:\n                    if not all_sampled:\n                        idxs = next(sampler_iter)\n                        self._enqueue_indices([idxs])\n                except StopIteration:\n                    all_sampled = True\n\n                # receive data from workers\n                if any(self._data_queued):\n                    yield self._receive_data()\n                else:\n                    break\n\n        except Exception as e:\n            # set abort event to shutdown workers\n            self._abort_event.set()\n            raise e\n\n        finally:\n            if self._processes_running:\n                self._shutdown_processes()\n\n\nclass _WorkerProcess(multiprocessing.Process):\n    \"\"\"\n    A Process running an infinite loop of loading data for given indices\n    \"\"\"\n\n    def __init__(self, dataloader: DataLoader,\n                 output_pipe: mpconnection.Connection,\n                 index_pipe: mpconnection.Connection,\n                 abort_event: multiprocessing.Event,\n                 transforms: Callable,\n                 process_id):\n        \"\"\"\n        Parameters\n        ----------\n        dataloader : :class:`DataLoader`\n            the data loader which loads the data corresponding to the given\n            indices\n        output_pipe : :class:`multiprocessing.connection.Connection`\n            the pipe, the loaded data shoud be sent to\n        index_pipe : :class:`multiprocessing.connection.Connection`\n            the pipe to accept the indices\n        abort_event : class:`multiprocessing.Event`\n            the abortion event; will be set for every Exception;\n            If set: Worker terminates\n        transforms : :class:`collections.Callable`\n            the transforms to transform the data\n        process_id : int\n            the process id\n        \"\"\"\n        super().__init__()\n\n        self._data_loader = dataloader\n        self._output_pipe = output_pipe\n        self._input_pipe = index_pipe\n        self._abort_event = abort_event\n        self._process_id = process_id\n        self._transforms = transforms\n\n    def run(self) -> None:\n        # set the process id\n        self._data_loader.process_id = self._process_id\n\n        try:\n            while True:\n                # check if worker should terminate\n                if self._abort_event.is_set():\n                    raise RuntimeError(\"Abort Event has been set externally\")\n\n                # get indices if available (with timeout to frequently check\n                # for abortions\n                if self._input_pipe.poll(timeout=0.2):\n                    idxs = self._input_pipe.recv()\n\n                    # final indices -> shutdown workers\n                    if idxs is None:\n                        break\n\n                    # load data\n                    data = self._data_loader(idxs)\n\n                    #\n                    if self._transforms is not None:\n                        data = self._transforms(**data)\n\n                    self._output_pipe.send(data)\n\n        except Exception as e:\n            self._abort_event.set()\n            raise e\n\n\nclass _SequentialAugmenter(AbstractAugmenter):\n    \"\"\"\n    An Augmenter that loads and augments batches sequentially without any\n    parallelism\n    \"\"\"\n\n    def __init__(\n            self,\n            data_loader,\n            batchsize,\n            sampler,\n            transforms=None,\n            seed=1,\n            drop_last=False):\n        \"\"\"\n        Parameters\n        ----------\n        data_loader : :class:`DataLoader`\n            the dataloader, loading samples for given indices\n        sampler : :class:`AbstractSampler`\n            the sampler_old (may be batch sampler_old or usual sampler_old),\n            defining the actual sampling strategy; Is an iterable yielding\n            indices\n        transforms : :class:`collections.Callable`\n            the transforms to apply; defaults to None\n        seed : int\n            the basic seed; default: 1\n        drop_last : bool\n            whether to drop the last (possibly smaller) batch or not\n        \"\"\"\n        super().__init__(data_loader=data_loader, batchsize=batchsize,\n                         sampler=sampler, transforms=transforms, seed=seed,\n                         drop_last=drop_last)\n\n    def __iter__(self):\n        # create sampler_old iterator\n        sampler_iter = iter(self._sampler)\n\n        # for every index load and augment the data\n        for idxs in sampler_iter:\n\n            # load data\n            data = self._data_loader(idxs)\n\n            # transform data if transforms given\n            if self._transforms is not None:\n                data = self._transforms(**data)\n\n            yield data\n\n\nclass Augmenter(object):\n    \"\"\"\n    The actual Augmenter wrapping the :class:`_SequentialAugmenter` and the\n    :class:`_ParallelAugmenter` and switches between them by arguments and\n    debug mode\n    \"\"\"\n\n    def __init__(self, data_loader, batchsize, sampler, num_processes=None,\n                 transforms=None, seed=1, drop_last=False):\n        \"\"\"\n        Parameters\n        ----------\n        data_loader : :class:`DataLoader`\n            the dataloader, loading samples for given indices\n        sampler : :class:`AbstractSampler`\n            the sampler_old (may be batch sampler_old or usual sampler_old),\n            defining the actual sampling strategy; Is an iterable yielding\n            indices\n        num_processes : int\n            the number of processes to use for dataloading + augmentation;\n            if None: the number of available CPUs will be used as number of\n            processes\n        transforms : :class:`collections.Callable`\n            the transforms to apply; defaults to None\n        seed : int\n            the basic seed; default: 1\n        drop_last : bool\n            whether to drop the last (possibly smaller) batch or not\n        \"\"\"\n\n        self._augmenter = self._resolve_augmenter_cls(num_processes,\n                                                      data_loader=data_loader,\n                                                      batchsize=batchsize,\n                                                      sampler=sampler,\n                                                      transforms=transforms,\n                                                      seed=seed,\n                                                      drop_last=drop_last)\n\n    @staticmethod\n    def _resolve_augmenter_cls(num_processes, **kwargs):\n        \"\"\"\n        Resolves the augmenter class by the number of specified processes and\n        the debug mode and creates an instance of the chosen class\n        Parameters\n        ----------\n        num_processes : int\n            the number of processes to use for dataloading + augmentation;\n            if None: the number of available CPUs will be used as number of\n            processes\n        **kwargs :\n            additional keyword arguments, used for instantiation of the chosen\n            class\n        Returns\n        -------\n        :class:`AbstractAugmenter`\n            an instance of the chosen augmenter class\n        \"\"\"\n        if get_current_debug_mode() or num_processes == 0:\n            return _SequentialAugmenter(**kwargs)\n        return _ParallelAugmenter(num_processes=num_processes, **kwargs)\n\n    def __iter__(self):\n        \"\"\"\n        Makes the Augmenter iterable by generators\n        Returns\n        -------\n        Generator\n            a generator function yielding the arguments\n        \"\"\"\n        yield from self._augmenter\n"
  },
  {
    "path": "delira/data_loading/data_loader.py",
    "content": "import numpy as np\nfrom delira.data_loading.dataset import AbstractDataset, DictDataset, \\\n    IterableDataset\nfrom collections import Iterable, defaultdict\n\n\nclass DataLoader:\n    \"\"\"\n    Basic Dataloader class, that returns data for a given set of indices and\n    combines it as batches\n    \"\"\"\n\n    def __init__(self, data):\n        \"\"\"\n        Parameters\n        ----------\n        data : Any\n            the data to use; Ideally this either is a dataset, an iterable or\n            a dict, but in general, this must only be indexable, have a length\n            and return a dict of arrays if indexed\n        \"\"\"\n        self._process_id = None\n        if isinstance(data, AbstractDataset):\n            dataset = data\n\n        else:\n            # wrap it into dataset depending on datatype\n            if isinstance(data, dict):\n                dataset = DictDataset(data)\n            elif isinstance(data, Iterable):\n                dataset = IterableDataset(data)\n            else:\n                raise TypeError(\"Invalid dataset type: %s\"\n                                % type(data).__name__)\n\n        self.dataset = dataset\n\n    def __call__(self, indices):\n        \"\"\"\n        Loads data for given indices and combines them to batches\n        Parameters\n        ----------\n        indices : list\n            a list of integers specifying the data indices\n        Returns\n        -------\n        dict\n            a dict of numpy arrays (specifying the batches)\n        \"\"\"\n\n        # get data for all indices\n        data = [self.dataset[idx] for idx in indices]\n\n        data_dict = defaultdict(list)\n\n        # concatenate dict entities by keys\n        for _result_dict in data:\n            for key, val in _result_dict.items():\n                data_dict[key].append(val)\n\n        # convert list to numpy arrays\n        for key, val_list in data_dict.items():\n            data_dict[key] = np.asarray(val_list)\n\n        return data_dict\n\n    @property\n    def process_id(self):\n        \"\"\"\n        A Property to access the process id\n        Returns\n        -------\n        int\n            the process id\n        \"\"\"\n        if self._process_id is None:\n            return 0\n        return self._process_id\n\n    @process_id.setter\n    def process_id(self, new_id):\n        \"\"\"\n        Setter for the :attr:`process_id`; Makes sure, that the process id is\n        only set once\n        Parameters\n        ----------\n        new_id : int\n        Raises\n        ------\n        AttributeError\n            if the process id has already been set once\n        \"\"\"\n        if self._process_id is not None:\n            raise AttributeError(\"Attribute 'process_id' can be set only once\")\n\n        self._process_id = new_id\n"
  },
  {
    "path": "delira/data_loading/data_manager.py",
    "content": "import logging\n\nfrom batchgenerators.transforms import AbstractTransform\n\nfrom delira import get_current_debug_mode\nfrom delira.data_loading.data_loader import DataLoader\nfrom delira.data_loading.sampler import SequentialSampler, AbstractSampler\nfrom delira.data_loading.augmenter import Augmenter\nfrom delira.data_loading.dataset import DictDataset, IterableDataset, \\\n    AbstractDataset\nfrom collections import Iterable\nimport inspect\n\nlogger = logging.getLogger(__name__)\n\n\nclass DataManager(object):\n    \"\"\"\n    Class to Handle Data\n    Creates Dataset (if necessary), Dataloader and Augmenter\n\n    \"\"\"\n\n    def __init__(self, data, batch_size, n_process_augmentation,\n                 transforms, sampler_cls=SequentialSampler,\n                 drop_last=False, data_loader_cls=None,\n                 **sampler_kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data : str or Dataset\n            if str: Path to data samples\n            if dataset: Dataset\n        batch_size : int\n            Number of samples per batch\n        n_process_augmentation : int\n            Number of processes for augmentations\n        transforms :\n            Data transformations for augmentation\n        sampler_cls : AbstractSampler\n            class defining the sampling strategy\n        drop_last : bool\n            whether to drop the last (possibly smaller) batch\n        data_loader_cls : subclass of SlimDataLoaderBase\n            DataLoader class\n        **sampler_kwargs :\n            other keyword arguments (passed to sampler_cls)\n\n        Raises\n        ------\n        AssertionError\n            ``data_loader_cls`` is not :obj:`None` and not a subclass of\n            `DataLoader`\n        TypeError\n            ``data`` is not a Dataset object and not of type dict or iterable\n\n        See Also\n        --------\n        :class:`AbstractDataset`\n\n        \"\"\"\n\n        # Instantiate Hidden variables for property access\n        if sampler_kwargs is None:\n            sampler_kwargs = {}\n        self._batch_size = None\n        self._n_process_augmentation = None\n        self._transforms = None\n        self._data_loader_cls = None\n        self._sampler = None\n        self.drop_last = drop_last\n\n        # set actual values to properties\n        self.batch_size = batch_size\n\n        self.n_process_augmentation = n_process_augmentation\n        self.transforms = transforms\n\n        if data_loader_cls is None:\n            logger.info(\"No dataloader Class specified. Using DataLoader\")\n            data_loader_cls = DataLoader\n        else:\n            if not inspect.isclass(data_loader_cls):\n                raise TypeError(\n                    \"data_loader_cls must be class not instance of class\")\n\n            if not issubclass(data_loader_cls, DataLoader):\n                raise TypeError(\n                    \"data_loader_cls must be subclass of DataLoader\")\n\n        self.data_loader_cls = data_loader_cls\n\n        self.data = data\n\n        if not (inspect.isclass(sampler_cls) and issubclass(sampler_cls,\n                                                            AbstractSampler)):\n            raise TypeError\n\n        self.sampler_cls = sampler_cls\n        self.sampler_kwargs = sampler_kwargs\n\n    def get_batchgen(self, seed=1):\n        \"\"\"\n        Create DataLoader and Batchgenerator\n\n        Parameters\n        ----------\n        seed : int\n            seed for Random Number Generator\n\n        Returns\n        -------\n        Augmenter\n           The actual iterable batchgenerator\n\n        Raises\n        ------\n        AssertionError\n            :attr:`DataManager.n_batches` is smaller than or equal to zero\n\n        \"\"\"\n        assert self.n_batches > 0\n\n        data_loader = self.data_loader_cls(\n            self.data\n        )\n\n        sampler = self.sampler_cls.from_dataset(data_loader.dataset,\n                                                **self.sampler_kwargs)\n\n        return Augmenter(data_loader=data_loader,\n                         batchsize=self.batch_size,\n                         sampler=sampler,\n                         num_processes=self.n_process_augmentation,\n                         transforms=self.transforms,\n                         seed=seed,\n                         drop_last=self.drop_last\n                         )\n\n    def get_subset(self, indices):\n        \"\"\"\n        Returns a Subset of the current datamanager based on given indices\n\n        Parameters\n        ----------\n        indices : iterable\n            valid indices to extract subset from current dataset\n\n        Returns\n        -------\n        :class:`DataManager`\n            manager containing the subset\n\n        \"\"\"\n\n        subset_kwargs = {\n            \"batch_size\": self.batch_size,\n            \"n_process_augmentation\": self.n_process_augmentation,\n            \"transforms\": self.transforms,\n            \"sampler_cls\": self.sampler_cls,\n            \"data_loader_cls\": self.data_loader_cls,\n            \"drop_last\": self.drop_last,\n            **self.sampler_kwargs\n        }\n\n        return self.__class__(\n            self.data.get_subset(indices),\n            **subset_kwargs)\n\n    def update_state_from_dict(self, new_state: dict):\n        \"\"\"\n        Updates internal state and therefore the behavior from dict.\n        If a key is not specified, the old attribute value will be used\n\n        Parameters\n        ----------\n        new_state : dict\n            The dict to update the state from.\n            Valid keys are:\n\n                * ``batch_size``\n                * ``n_process_augmentation``\n                * ``data_loader_cls``\n                * ``sampler_cls``\n                * ``sampler_kwargs``\n                * ``transforms``\n\n            If a key is not specified, the old value of the corresponding\n            attribute will be used\n\n        Raises\n        ------\n        KeyError\n            Invalid keys are specified\n\n        \"\"\"\n\n        # update batch_size if specified\n        self.batch_size = new_state.pop(\"batch_size\", self.batch_size)\n        # update n_process_augmentation if specified\n        self.n_process_augmentation = new_state.pop(\n            \"n_process_augmentation\", self.n_process_augmentation)\n        # update data_loader_cls if specified\n        self.data_loader_cls = new_state.pop(\"data_loader_cls\",\n                                             self.data_loader_cls)\n        # update sampler\n        self.sampler_cls = new_state.pop(\"sampler_cls\", self.sampler_cls)\n        self.sampler_kwargs = new_state.pop(\"sampler_kwargs\",\n                                            self.sampler_kwargs)\n\n        self.transforms = new_state.pop(\"transforms\", self.transforms)\n\n        if new_state:\n            raise KeyError(\"Invalid Keys in new_state given: %s\"\n                           % (','.join(map(str, new_state.keys()))))\n\n    @property\n    def batch_size(self):\n        \"\"\"\n        Property to access the batchsize\n\n        Returns\n        -------\n        int\n            the batchsize\n        \"\"\"\n\n        return self._batch_size\n\n    @batch_size.setter\n    def batch_size(self, new_batch_size):\n        \"\"\"\n        Setter for current batchsize, casts to int before setting the attribute\n\n        Parameters\n        ----------\n        new_batch_size : int, Any\n            the new batchsize; should be int but can be of any type that can be\n            casted to an int\n\n        \"\"\"\n\n        self._batch_size = int(new_batch_size)\n\n    @property\n    def n_process_augmentation(self):\n        \"\"\"\n        Property to access the number of augmentation processes\n\n        Returns\n        -------\n        int\n            number of augmentation processes\n        \"\"\"\n\n        if get_current_debug_mode():\n            return 0\n        return self._n_process_augmentation\n\n    @n_process_augmentation.setter\n    def n_process_augmentation(self, new_process_number):\n        \"\"\"\n        Setter for number of augmentation processes, casts to int before\n        setting the attribute\n\n\n        Parameters\n        ----------\n        new_process_number : int, Any\n            new number of augmentation processes; should be int but can be of\n            any type that can be casted to an int\n\n        \"\"\"\n\n        self._n_process_augmentation = int(new_process_number)\n\n    @property\n    def transforms(self):\n        \"\"\"\n        Property to access the current data transforms\n\n        Returns\n        -------\n        None, ``AbstractTransform``\n            The transformation, can either be None or an instance of\n            ``AbstractTransform``\n        \"\"\"\n\n        return self._transforms\n\n    @transforms.setter\n    def transforms(self, new_transforms):\n        \"\"\"\n        Setter for data transforms, assert if transforms are of valid type\n        (either None or instance of ``AbstractTransform``)\n\n        Parameters\n        ----------\n        new_transforms : None, ``AbstractTransform``\n            the new transforms\n\n        \"\"\"\n\n        if new_transforms is not None and not isinstance(\n                new_transforms, AbstractTransform):\n            raise TypeError\n\n        self._transforms = new_transforms\n\n    @property\n    def data_loader_cls(self):\n        \"\"\"\n        Property to access the current data loader class\n\n        Returns\n        -------\n        type\n            Subclass of ``DataLoader``\n        \"\"\"\n\n        return self._data_loader_cls\n\n    @data_loader_cls.setter\n    def data_loader_cls(self, new_loader_cls):\n        \"\"\"\n        Setter for current data loader class, asserts if class is of valid\n        type\n        (must be a class and a subclass of ``DataLoader``)\n\n        Parameters\n        ----------\n        new_loader_cls : type\n            the new data loader class\n\n        \"\"\"\n\n        if not inspect.isclass(new_loader_cls) and issubclass(\n                new_loader_cls, DataLoader):\n            raise TypeError\n\n        self._data_loader_cls = new_loader_cls\n\n    @property\n    def n_samples(self):\n        \"\"\"\n        Number of Samples\n\n        Returns\n        -------\n        int\n            Number of Samples\n\n        \"\"\"\n        return len(self.dataset)\n\n    @property\n    def n_batches(self):\n        \"\"\"\n        Returns Number of Batches based on batchsize and number of samples\n\n        Returns\n        -------\n        int\n            Number of Batches\n\n        Raises\n        ------\n        AssertionError\n            :attr:`DataManager.n_samples` is smaller than or equal to zero\n\n        \"\"\"\n        assert self.n_samples > 0\n\n        n_batches = self.n_samples // self.batch_size\n\n        truncated_batch = self.n_samples % self.batch_size\n\n        n_batches += int(bool(truncated_batch) and not self.drop_last)\n\n        return n_batches\n\n    @property\n    def dataset(self):\n        return self.data\n\n    @dataset.setter\n    def dataset(self, new_dset):\n        if not isinstance(new_dset, AbstractDataset):\n            raise TypeError\n\n        self.data = new_dset\n\n    def __iter__(self):\n        \"\"\"\n        Build-In function to create an iterator. First creates an\n        :class:`Augmenter` and afterwards an iterable for the created\n        augmenter, which is then returned\n\n        Returns\n        -------\n        Generator object\n            generator object to iterate over the augmented batches\n\n        \"\"\"\n        return iter(self.get_batchgen())\n"
  },
  {
    "path": "delira/data_loading/dataset.py",
    "content": "import abc\nimport os\nimport typing\n\nimport numpy as np\nfrom skimage.transform import resize\nfrom sklearn.model_selection import train_test_split\nfrom collections import Iterable\nfrom tqdm import tqdm\n\nfrom delira.utils import subdirs\n\n\nclass AbstractDataset:\n    \"\"\"\n    Base Class for Dataset\n\n    \"\"\"\n\n    def __init__(self, data_path: str, load_fn: typing.Callable):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path : str\n            path to data samples\n        load_fn : function\n            function to load single sample\n        \"\"\"\n        self.data_path = data_path\n        self._load_fn = load_fn\n        self.data = []\n\n    @abc.abstractmethod\n    def _make_dataset(self, path: str):\n        \"\"\"\n        Create dataset\n\n        Parameters\n        ----------\n        path : str\n            path to data samples\n\n        Returns\n        -------\n        list\n            data: List of sample paths if lazy; List of samples if not\n\n        \"\"\"\n        pass\n\n    @abc.abstractmethod\n    def __getitem__(self, index):\n        \"\"\"\n        return data with given index (and loads it before if lazy)\n\n        Parameters\n        ----------\n        index : int\n            index of data\n\n        Returns\n        -------\n        dict\n            data\n\n        \"\"\"\n        pass\n\n    def __len__(self):\n        \"\"\"\n        Return number of samples\n\n        Returns\n        -------\n        int\n            number of samples\n        \"\"\"\n        return len(self.data)\n\n    def __iter__(self):\n        \"\"\"\n        Return an iterator for the dataset\n\n        Returns\n        -------\n        object\n            a single sample\n        \"\"\"\n        return _DatasetIter(self)\n\n    def get_sample_from_index(self, index):\n        \"\"\"\n        Returns the data sample for a given index\n        (without any loading if it would be necessary)\n        This implements the base case and can be subclassed\n        for index mappings.\n        The actual loading behaviour (lazy or cached) should be\n        implemented in ``__getitem__``\n\n        See Also\n        --------\n        :method:ConcatDataset.get_sample_from_index\n        :method:BaseLazyDataset.__getitem__\n        :method:BaseCacheDataset.__getitem__\n\n        Parameters\n        ----------\n        index : int\n            index corresponding to targeted sample\n\n        Returns\n        -------\n        Any\n            sample corresponding to given index\n        \"\"\"\n\n        return self.data[index]\n\n    def get_subset(self, indices):\n        \"\"\"\n        Returns a Subset of the current dataset based on given indices\n\n        Parameters\n        ----------\n        indices : iterable\n            valid indices to extract subset from current dataset\n\n        Returns\n        -------\n        :class:`BlankDataset`\n            the subset\n\n        \"\"\"\n\n        # extract other important attributes from current dataset\n        kwargs = {}\n\n        for key, val in vars(self).items():\n            if not (key.startswith(\"__\") and key.endswith(\"__\")):\n\n                if key == \"data\":\n                    continue\n                kwargs[key] = val\n\n        kwargs[\"old_getitem\"] = self.__class__.__getitem__\n        subset_data = [self.get_sample_from_index(idx) for idx in indices]\n\n        return BlankDataset(subset_data, **kwargs)\n\n\nclass _DatasetIter(object):\n    \"\"\"\n    Iterator for dataset\n    \"\"\"\n\n    def __init__(self, dset):\n        \"\"\"\n\n        Parameters\n        ----------\n        dset: :class: `AbstractDataset`\n            the dataset which should be iterated\n        \"\"\"\n        self._dset = dset\n        self._curr_index = 0\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        if self._curr_index >= len(self._dset):\n            raise StopIteration\n\n        sample = self._dset[self._curr_index]\n        self._curr_index += 1\n        return sample\n\n\nclass DictDataset(AbstractDataset):\n    \"\"\"\n    Dataset to wrap a dict of keys and iterables.\n    \"\"\"\n\n    def __init__(self, data: dict):\n        \"\"\"\n\n        Parameters\n        ----------\n        data : dict\n            dictionary consisting of keys and iterables.\n            The iterables should contain an item for each index\n        \"\"\"\n        super().__init__(None, None)\n        self._data = data\n\n    def __getitem__(self, index: int):\n        \"\"\"\n        Function to make the dataset indexable. Returns the sample\n        corresponding to the given index\n\n        Parameters\n        ----------\n        index : int\n            the index specifying the sample to return\n\n        Returns\n        -------\n        dict\n            the sample corresponding to :param:`index`\n\n        \"\"\"\n        return {k: v[index] for k, v in self._data.items()}\n\n    def get_sample_from_index(self, index):\n        \"\"\"\n        Mapping from index to sample\n\n        Parameters\n        ----------\n        index : int\n            the index specifying the sample to return\n\n        Returns\n        -------\n        dict\n            the sample corresponding to :param:`index`\n\n        \"\"\"\n        return self[index]\n\n    def _make_dataset(self, path: str):\n        \"\"\"\n        Function to create the dataset\n        (not necessary here, since the data is already in memory)\n\n        Parameters\n        ----------\n        path : str\n            the path to load the data from\n\n        \"\"\"\n        pass\n\n    def __len__(self):\n        \"\"\"\n        Function to determine the dataset's length\n\n        Returns\n        -------\n        int\n            the number of samples\n        \"\"\"\n        return min([len(v) for v in self._data.values()])\n\n\nclass IterableDataset(AbstractDataset):\n    \"\"\"\n    Dataset to wrap a list of dicts.\n    \"\"\"\n\n    def __init__(self, data: Iterable):\n        \"\"\"\n\n        Parameters\n        ----------\n        data : Iterable\n            an iterable of dicts each representing a single sample\n        \"\"\"\n        super().__init__(None, None)\n        self._data = data\n\n    def __getitem__(self, index):\n        \"\"\"\n        Function to make the dataset indexable. Returns the sample\n        corresponding to the given index\n\n        Parameters\n        ----------\n        index : int\n           the index specifying the sample to return\n\n        Returns\n        -------\n        dict\n           the sample corresponding to :param:`index`\n\n       \"\"\"\n        return self._data[index]\n\n    def get_sample_from_index(self, index):\n        \"\"\"\n        Mapping from index to sample\n\n        Parameters\n        ----------\n        index : int\n            the index specifying the sample to return\n\n        Returns\n        -------\n        dict\n            the sample corresponding to :param:`index`\n\n        \"\"\"\n        return self[index]\n\n    def _make_dataset(self, path: str):\n        \"\"\"\n        Function to create the dataset\n        (not necessary here, since the data is already in memory)\n\n        Parameters\n        ----------\n        path : str\n            the path to load the data from\n\n        \"\"\"\n        pass\n\n    def __len__(self):\n        \"\"\"\n        Function to determine the dataset's length\n\n        Returns\n        -------\n        int\n            the number of samples\n        \"\"\"\n        return len(self._data)\n\n\nclass BlankDataset(AbstractDataset):\n    \"\"\"\n    Blank Dataset loading the data, which has been passed\n    in it's ``__init__`` by it's ``_sample_fn``\n\n    \"\"\"\n\n    def __init__(self, data, old_getitem, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data : iterable\n            data to load\n        old_getitem : function\n            get item method of previous dataset\n        **kwargs :\n            additional keyword arguments (are set as class attribute)\n\n        \"\"\"\n        super().__init__(None, None)\n\n        self.data = data\n        self._old_getitem = old_getitem\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def __getitem__(self, index):\n        \"\"\"\n        returns single sample corresponding to ``index`` via the ``_sample_fn``\n\n        Parameters\n        ----------\n        index : int\n            index specifying the data to load\n\n        Returns\n        -------\n        dict\n            dictionary containing a single sample\n\n        \"\"\"\n        return self._old_getitem(self, index)\n\n    def __len__(self):\n        \"\"\"\n        returns the length of the dataset\n\n        Returns\n        -------\n        int\n            number of samples\n\n        \"\"\"\n        return len(self.data)\n\n\nclass BaseCacheDataset(AbstractDataset):\n    \"\"\"\n    Dataset to preload and cache data\n\n    Notes\n    -----\n    data needs to fit completely into RAM!\n\n    \"\"\"\n\n    def __init__(self, data_path: typing.Union[str, list],\n                 load_fn: typing.Callable, **load_kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path : str or list\n            if data_path is a string, _sample_fn is called for all items inside\n            the specified directory\n            if data_path is a list, _sample_fn is called for elements in the\n            list\n        load_fn : function\n            function to load a single data sample\n        **load_kwargs :\n            additional loading keyword arguments (image shape,\n            channel number, ...); passed to _sample_fn\n\n        \"\"\"\n        super().__init__(data_path, load_fn)\n        self._load_kwargs = load_kwargs\n        self.data = self._make_dataset(data_path)\n\n    def _make_dataset(self, path: typing.Union[str, list]):\n        \"\"\"\n        Helper Function to make a dataset containing all samples in a certain\n        directory\n\n        Parameters\n        ----------\n        path: str or list\n            if data_path is a string, _sample_fn is called for all items inside\n            the specified directory\n            if data_path is a list, _sample_fn is called for elements in the\n            list\n\n        Returns\n        -------\n        list\n            list of items which where returned from _sample_fn (typically dict)\n\n        Raises\n        ------\n        AssertionError\n            if `path` is not a list and is not a valid directory\n\n        \"\"\"\n        data = []\n        if isinstance(path, list):\n            # iterate over all elements\n            for p in tqdm(path, unit='samples', desc=\"Loading samples\"):\n                data.append(self._load_fn(p, **self._load_kwargs))\n        else:\n            # call _sample_fn for all elements inside directory\n            assert os.path.isdir(path), '%s is not a valid directory' % path\n            for p in tqdm(os.listdir(path), unit='samples',\n                          desc=\"Loading samples\"):\n                data.append(self._load_fn(os.path.join(path, p),\n                                          **self._load_kwargs))\n        return data\n\n    def __getitem__(self, index):\n        \"\"\"\n        return data sample specified by index\n\n        Parameters\n        ----------\n        index : int\n            index to specifiy which data sample to return\n\n        Returns\n        -------\n        dict\n            data sample\n\n        \"\"\"\n        data_dict = self.get_sample_from_index(index)\n        return data_dict\n\n\nclass BaseLazyDataset(AbstractDataset):\n    \"\"\"\n    Dataset to load data in a lazy way\n\n    \"\"\"\n\n    def __init__(self, data_path: typing.Union[str, list],\n                 load_fn: typing.Callable, **load_kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path : str or list\n            if data_path is a string, _sample_fn is called for all items inside\n            the specified directory\n            if data_path is a list, _sample_fn is called for elements in the\n            list\n        load_fn : function\n            function to load single data sample\n        **load_kwargs :\n            additional loading keyword arguments (image shape,\n            channel number, ...); passed to _sample_fn\n\n        \"\"\"\n        super().__init__(data_path, load_fn)\n        self._load_kwargs = load_kwargs\n        self.data = self._make_dataset(self.data_path)\n\n    def _make_dataset(self, path: typing.Union[str, list]):\n        \"\"\"\n        Helper Function to make a dataset containing paths to all images in a\n        certain directory\n\n        Parameters\n        ----------\n        path : str or list\n            path to data samples\n\n        Returns\n        -------\n        list\n            list of sample paths\n\n        Raises\n        ------\n        AssertionError\n            if `path` is not a valid directory\n\n        \"\"\"\n        if isinstance(path, list):\n            # generate list from iterable\n            data = list(path)\n        else:\n            # generate list from all items\n            assert os.path.isdir(path), '%s is not a valid directory' % path\n            data = [os.path.join(path, p) for p in os.listdir(path)]\n        return data\n\n    def __getitem__(self, index):\n        \"\"\"\n        load data sample specified by index\n\n        Parameters\n        ----------\n        index : int\n            index to specifiy which data sample to load\n\n        Returns\n        -------\n        dict\n            loaded data sample\n        \"\"\"\n        data_dict = self._load_fn(self.get_sample_from_index(index),\n                                  **self._load_kwargs)\n        return data_dict\n\n\nclass BaseExtendCacheDataset(BaseCacheDataset):\n    \"\"\"\n    Dataset to preload and cache data. Function to load sample is expected\n    to return an iterable which can contain multiple samples\n\n    Notes\n    -----\n    data needs to fit completely into RAM!\n\n    \"\"\"\n\n    def __init__(self, data_path: typing.Union[str, list],\n                 load_fn: typing.Callable, **load_kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_path : str or list\n            if data_path is a string, _sample_fn is called for all items inside\n            the specified directory\n            if data_path is a list, _sample_fn is called for elements in the\n            list\n        load_fn : function\n            function to load a multiple data samples at once. Needs to return\n            an iterable which extends the internal list.\n        **load_kwargs :\n            additional loading keyword arguments (image shape,\n            channel number, ...); passed to _sample_fn\n\n        See Also\n        --------\n        :class: `BaseCacheDataset`\n\n        \"\"\"\n        super().__init__(data_path, load_fn, **load_kwargs)\n\n    def _make_dataset(self, path: typing.Union[str, list]):\n        \"\"\"\n        Helper Function to make a dataset containing all samples in a certain\n        directory\n\n        Parameters\n        ----------\n        path: str or iterable\n            if data_path is a string, _sample_fn is called for all items inside\n            the specified directory\n            if data_path is a list, _sample_fn is called for elements in the\n            list\n\n        Returns\n        -------\n        list\n            list of items which where returned from _sample_fn (typically dict)\n\n        Raises\n        ------\n        AssertionError\n            if `path` is not a list and is not a valid directory\n\n        \"\"\"\n        data = []\n        if isinstance(path, list):\n            # iterate over all elements\n            for p in tqdm(path, unit='samples', desc=\"Loading samples\"):\n                data.extend(self._load_fn(p, **self._load_kwargs))\n        else:\n            # call _sample_fn for all elements inside directory\n            assert os.path.isdir(path), '%s is not a valid directory' % dir\n            for p in tqdm(os.listdir(path), unit='samples',\n                          desc=\"Loading samples\"):\n                data.extend(self._load_fn(os.path.join(path, p),\n                                          **self._load_kwargs))\n        return data\n\n\nclass ConcatDataset(AbstractDataset):\n    def __init__(self, *datasets):\n        \"\"\"\n        Concatenate multiple datasets to one\n\n        Parameters\n        ----------\n        datasets:\n            variable number of datasets\n        \"\"\"\n        super().__init__(None, None)\n\n        # TODO: Why should datasets[0] be a list not a AbstractDataset?\n\n        # check if first item in datasets is list and datasets is of length 1\n        if (len(datasets) == 1) and isinstance(datasets[0], list):\n            datasets = datasets[0]\n\n        self.data = datasets\n\n    def get_sample_from_index(self, index):\n        \"\"\"\n        Returns the data sample for a given index\n        (without any loading if it would be necessary)\n        This method implements the index mapping of a global index to\n        the subindices for each dataset.\n        The actual loading behaviour (lazy or cached) should be\n        implemented in ``__getitem__``\n\n        See Also\n        --------\n        :method:AbstractDataset.get_sample_from_index\n        :method:BaseLazyDataset.__getitem__\n        :method:BaseCacheDataset.__getitem__\n\n        Parameters\n        ----------\n        index : int\n            index corresponding to targeted sample\n\n        Returns\n        -------\n        Any\n            sample corresponding to given index\n        \"\"\"\n\n        curr_max_index = 0\n        for dset in self.data:\n            prev_max_index = curr_max_index\n            curr_max_index += len(dset)\n\n            if prev_max_index <= index < curr_max_index:\n                return dset[index - prev_max_index]\n\n            else:\n                continue\n\n        raise IndexError(\"Index %d is out of range for %d items in datasets\" %\n                         (index, len(self)))\n\n    def __getitem__(self, index):\n        return self.get_sample_from_index(index)\n\n    def __len__(self):\n        return sum([len(dset) for dset in self.data])\n"
  },
  {
    "path": "delira/data_loading/load_utils.py",
    "content": "import collections\nimport os\n\nimport numpy as np\nfrom skimage.io import imread\nfrom skimage.transform import resize\n\n\ndef norm_range(mode):\n    \"\"\"\n    Closure function for range normalization\n    Parameters\n    ----------\n    mode : str\n        '-1,1' normalizes data to range [-1, 1], while '0,1'\n        normalizes data to range [0, 1]\n    Returns\n    -------\n    callable\n        normalization function\n    \"\"\"\n    def norm_fn(data):\n        \"\"\"\n        Returns the input data normalized to the range\n        Parameters\n        ----------\n        data : np.ndarray\n            data which should be normalized\n        Returns\n        -------\n        np.ndarary\n            normalized data\n        \"\"\"\n        norm = data - data.min()\n        norm = norm / norm.max()\n        if mode == '-1,1':\n            norm = norm - 0.5\n            norm = norm * 2\n        elif mode == '0,1':\n            pass\n        else:\n            raise ValueError('{mode} not supported.')\n        return norm\n    return norm_fn\n\n\ndef norm_zero_mean_unit_std(data):\n    \"\"\"\n    Return normalized data with mean 0, standard deviation 1\n    Parameters\n    ----------\n    data : np.nadarray\n    Returns\n    -------\n    np.ndarray\n        normalized data\n    \"\"\"\n    return (data - np.mean(data)) / np.std(data)\n\n\nclass LoadSample:\n    \"\"\"\n    Provides a callable to load a single sample from multiple files in a folder\n    \"\"\"\n\n    def __init__(self,\n                 sample_ext: dict,\n                 sample_fn: collections.abc.Callable,\n                 dtype: dict = None, normalize: tuple = (),\n                 norm_fn=norm_range('-1,1'),\n                 **kwargs):\n        \"\"\"\n        Parameters\n        ----------\n        sample_ext : dict of iterable\n            Defines the data _sample_ext. The dict key defines the position of\n            the sample inside the returned data dict, while the list defines\n            the the files which should be loaded inside the data dict.\n        sample_fn : function\n            function to load a single sample\n        dtype : dict\n            defines the data type which should be used for the respective key\n        normalize : iterable of hashable\n            list of hashable which should be normalized. Can contain\n            entire keys of extension (normalizes each element individually)\n            or provide the file name which should be normalized\n        norm_fn : function\n            function to normalize input. Default: normalize range to [-1, 1]\n        kwargs :\n            variable number of keyword arguments passed to load function\n        Examples\n        --------\n        Simple loading function which returns a dict with `data`\n        >>> from delira.data_loading.nii import load_nii\n        >>> load_fn = LoadSample({'data:': ['data.nii']}, load_nii)\n        Loading function for data (casted to float32 and normalized) and\n        segmentation (casted to unit8)\n        >>> from delira.data_loading.nii import load_nii\n        >>> load_fn = LoadSample({'data:': ['data.nii'], 'seg': ['seg.nii']},\n        >>>                      load_nii, dtype={'data': 'float32',\n        >>>                                       'seg': 'uint8'},\n        >>>                      normalize=('data',))\n        \"\"\"\n        if dtype is None:\n            dtype = {}\n        self._sample_ext = sample_ext\n        self._sample_fn = sample_fn\n        self._dtype = dtype\n        self._normalize = normalize\n        self._norm_fn = norm_fn\n        self._kwargs = kwargs\n\n    def __call__(self, path) -> dict:\n        \"\"\"\n        Load sample from multiple files\n        Parameters\n        ----------\n        path : str\n            defines patch to folder which contain the _sample_ext\n        Returns\n        -------\n        dict\n            dict with data defines by _sample_ext\n        \"\"\"\n        sample_dict = {}\n        for key, item in self._sample_ext.items():\n            data_list = []\n            for f in item:\n                data = self._sample_fn(os.path.join(path, f), **self._kwargs)\n\n                # _normalize data if necessary\n                if (key in self._normalize) or (f in self._normalize):\n                    data = self._norm_fn(data)\n\n                # cast data to type\n                if key in self._dtype:\n                    data = data.astype(self._dtype[key])\n\n                # append data\n                data_list.append(data)\n            if len(data_list) == 1:\n                sample_dict[key] = data_list[0][np.newaxis]\n            else:\n                sample_dict[key] = np.stack(data_list)\n        return sample_dict\n\n\nclass LoadSampleLabel(LoadSample):\n    def __init__(self,\n                 sample_ext: dict,\n                 sample_fn: collections.abc.Callable,\n                 label_ext: str,\n                 label_fn: collections.abc.Callable,\n                 dtype: dict = None, normalize: tuple = (),\n                 norm_fn=norm_range('-1,1'),\n                 sample_kwargs=None, **kwargs):\n        \"\"\"\n        Load sample and label from folder\n        Parameters\n        ----------\n        sample_ext : dict of list\n            Defines the data _sample_ext. The dict key defines the position of\n            the sample inside the returned data dict, while the list defines\n            the the files which should be loaded inside the data dict.\n            Passed to LoadSample.\n        sample_fn : function\n            function to load a single sample\n            Passed to LoadSample.\n        label_ext : str\n            extension for label\n        label_fn: function\n            functions which returns the label inside a dict\n        dtype : dict\n            defines the data type which should be used for the respective key\n        normalize : iterable of hashable\n            list of hashable which should be normalized. Can contain\n            entire keys of extension (normalizes each element individually)\n            or provide the file name which should be normalized\n        norm_fn : function\n            function to normalize input. Default: normalize range to [-1, 1]\n        sample_kwargs :\n            additional keyword arguments passed to LoadSample\n        kwargs :\n            variable number of keyword arguments passed to _label_fn\n        See Also\n        --------\n        :class: `LoadSample`\n        \"\"\"\n        if sample_kwargs is None:\n            sample_kwargs = {}\n\n        super().__init__(sample_ext=sample_ext, sample_fn=sample_fn,\n                         dtype=dtype, normalize=normalize, norm_fn=norm_fn,\n                         **sample_kwargs)\n        self._label_ext = label_ext\n        self._label_fn = label_fn\n        self._label_kwargs = kwargs\n\n    def __call__(self, path) -> dict:\n        \"\"\"\n        Loads a sample and a label\n        Parameters\n        ----------\n        path : str\n        Returns\n        -------\n        dict\n            dict with data and label\n        \"\"\"\n        sample_dict = super().__call__(path)\n        label_dict = self._label_fn(os.path.join(path, self._label_ext),\n                                    **self._label_kwargs)\n        sample_dict.update(label_dict)\n        return sample_dict\n"
  },
  {
    "path": "delira/data_loading/numba_transform.py",
    "content": "from batchgenerators.transforms import AbstractTransform, Compose\n\nimport logging\nfrom delira import get_current_debug_mode\nimport numba\n\nlogger = logging.getLogger(__name__)\n\n\nclass NumbaTransformWrapper(AbstractTransform):\n    def __init__(self, transform: AbstractTransform, nopython=True,\n                 target=\"cpu\", parallel=False, **options):\n\n        if get_current_debug_mode():\n            # set options for debug mode\n            logging.debug(\"Debug mode detected. Overwriting numba options \"\n                          \"nopython to False and target to cpu\")\n            nopython = False\n            target = \"cpu\"\n\n        transform.__call__ = numba.jit(transform.__call__, nopython=nopython,\n                                       target=target,\n                                       parallel=parallel, **options)\n        self._transform = transform\n\n    def __call__(self, **kwargs):\n        return self._transform(**kwargs)\n\n\nclass NumbaTransform(NumbaTransformWrapper):\n    def __init__(self, transform_cls, nopython=True, target=\"cpu\",\n                 parallel=False, **kwargs):\n        trafo = transform_cls(**kwargs)\n\n        super().__init__(trafo, nopython=nopython, target=target,\n                         parallel=parallel)\n\n\nclass NumbaCompose(Compose):\n    def __init__(self, transforms):\n        super().__init__(transforms=[NumbaTransformWrapper(trafo)\n                                     for trafo in transforms])\n"
  },
  {
    "path": "delira/data_loading/sampler/__init__.py",
    "content": "from delira.data_loading.sampler.abstract import AbstractSampler\nfrom delira.data_loading.sampler.batch import BatchSampler\nfrom delira.data_loading.sampler.random import RandomSampler, \\\n    RandomSamplerNoReplacement, RandomSamplerWithReplacement\nfrom delira.data_loading.sampler.sequential import SequentialSampler\nfrom delira.data_loading.sampler.weighted import WeightedRandomSampler, \\\n    PrevalenceRandomSampler\n"
  },
  {
    "path": "delira/data_loading/sampler/abstract.py",
    "content": "from delira.data_loading.dataset import AbstractDataset\n\n\nclass AbstractSampler(object):\n    \"\"\"\n    Abstract Class defining a sampler interface\n    \"\"\"\n\n    def __init__(self, indices):\n        \"\"\"\n\n        Parameters\n        ----------\n        indices : list\n            the indices containing the classes to sample from\n        \"\"\"\n        self._indices = indices\n\n    def __iter__(self):\n        \"\"\"\n        Returns an iterator, must be overwritten in subclasses\n\n        Raises\n        ------\n        NotImplementedError\n            if not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    def __len__(self):\n        \"\"\"\n        Defines the class length\n\n        Returns\n        -------\n        int\n            the number of samples\n\n        \"\"\"\n        return len(self._indices)\n\n    @classmethod\n    def from_dataset(cls, dset: AbstractDataset, **kwargs):\n        \"\"\"\n        Class Method to create a sampler from a given dataset\n\n        Parameters\n        ----------\n        dset : :class:`AbstractDataset`\n            the dataset to create the sampler from\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        if hasattr(dset, \"__len__\"):\n            length = len(dset)\n        else:\n            length = len([tmp for tmp in dset])\n        return cls(list(range(length)), **kwargs)\n"
  },
  {
    "path": "delira/data_loading/sampler/batch.py",
    "content": "from delira.data_loading.sampler.abstract import AbstractSampler\n\n\nclass BatchSampler(object):\n    \"\"\"\n    A Sampler-Wrapper combining the single indices sampled by a sampler to\n    batches of a given size\n    \"\"\"\n\n    def __init__(self, sampler: AbstractSampler, batch_size, drop_last=False):\n        \"\"\"\n\n        Parameters\n        ----------\n        sampler : :class:`AbstractSampler`\n            the actual sampler producing single-sized samples\n        batch_size : int\n            the size of each batch\n        drop_last : bool\n            whether or not to discard the last (possibly smaller) batch\n        \"\"\"\n        self._sampler = sampler\n        self._batchsize = batch_size\n        self._drop_last = drop_last\n\n    def __iter__(self):\n        \"\"\"\n        Iterator holding lists of sample-indices. Each list contains indices\n        for a single batch\n\n        Yields\n        ------\n        list\n            a list containing the sample indices of the current batch\n\n        \"\"\"\n        batch_idxs = []\n\n        for idx in self._sampler:\n            batch_idxs.append(idx)\n\n            if len(batch_idxs) == self._batchsize:\n                yield batch_idxs\n\n                batch_idxs = []\n\n        if not self._drop_last and batch_idxs:\n            yield batch_idxs\n\n    def __len__(self):\n        \"\"\"\n        Defines the class length\n\n        Returns\n        -------\n        int\n            number of samples\n\n        \"\"\"\n        num_batches = len(self._sampler) // self._batchsize\n\n        if not self._drop_last:\n            num_batches += int(bool(len(self._sampler) % self._batchsize))\n\n        return num_batches\n"
  },
  {
    "path": "delira/data_loading/sampler/random.py",
    "content": "from delira.data_loading.sampler.abstract import AbstractSampler\nimport numpy as np\n\n\nclass RandomSampler(AbstractSampler):\n    \"\"\"\n    A Generic Random Sampler\n    \"\"\"\n\n    def __init__(self, indices, replacement=False, num_samples=None):\n        \"\"\"\n\n        Parameters\n        ----------\n        indices : list\n            the indices containing the classes to sample from\n        replacement : bool\n            whether to sample with or without replacement\n        num_samples : int\n            the number of samples to provide. Must only be specified\n            if :param:`replacement` is True; If not specified, it defaults to\n            the number of samples present in :param:`indices`\n        \"\"\"\n        super().__init__(indices)\n\n        if replacement and num_samples is None:\n            num_samples = len(self._indices)\n\n        self._replacement = replacement\n        self._num_samples = num_samples\n\n    def __iter__(self):\n        \"\"\"\n        Returns an iterator returning random samples\n\n        Returns\n        -------\n        Iterator\n            an iterator returning random samples\n\n        \"\"\"\n        n = len(self._indices)\n\n        if self._replacement:\n            return iter(np.random.randint(n, size=self._num_samples).tolist())\n\n        possible_samples = np.arange(n)\n        np.random.shuffle(possible_samples)\n\n        return iter(possible_samples)\n\n    def __len__(self):\n        \"\"\"\n        Defines the length of the sampler\n\n        Returns\n        -------\n        int\n            the number of samples\n        \"\"\"\n        if self._replacement:\n            return self._num_samples\n        else:\n            return super().__len__()\n\n\nclass RandomSamplerNoReplacement(RandomSampler):\n    \"\"\"\n    A Random Sampler without replacement\n    \"\"\"\n\n    def __init__(self, indices):\n        \"\"\"\n\n        Parameters\n        ----------\n        indices : list\n            the indices containing the classes to sample from\n\n        \"\"\"\n        super().__init__(indices, False, None)\n\n\nclass RandomSamplerWithReplacement(RandomSampler):\n    \"\"\"\n    A Random Sampler With Replacement\n    \"\"\"\n\n    def __init__(self, indices, num_samples=None):\n        \"\"\"\n\n        Parameters\n        ----------\n        indices : list\n            the indices containing the classes to sample from\n        num_samples : int\n            number of samples to provide, if not specified: defaults to the\n            amount values given in :param:`indices`\n\n        \"\"\"\n        super().__init__(indices, True, num_samples)\n"
  },
  {
    "path": "delira/data_loading/sampler/sequential.py",
    "content": "from delira.data_loading.sampler.abstract import AbstractSampler\n\n\nclass SequentialSampler(AbstractSampler):\n    \"\"\"\n    Class to implement sequential sampling\n    \"\"\"\n\n    def __iter__(self):\n        \"\"\"\n        Creates an iterator returning sequential samples\n\n        Returns\n        -------\n        Iterator\n            iterator returning samples in a sequential manner\n        \"\"\"\n        return iter(range(len(self._indices)))\n"
  },
  {
    "path": "delira/data_loading/sampler/weighted.py",
    "content": "from delira.data_loading.sampler.abstract import AbstractSampler\nfrom delira.data_loading.dataset import AbstractDataset\nimport numpy as np\n\n\nclass WeightedRandomSampler(AbstractSampler):\n    \"\"\"\n    Class implementing Weighted Random Sampling\n    \"\"\"\n\n    def __init__(self, weights, num_samples=None):\n        \"\"\"\n\n        Parameters\n        ----------\n        weights : list\n            per-sample weights\n        num_samples : int\n            number of samples to provide. If not specified this defaults to\n            the amount of values given in :param:`num_samples´\n        \"\"\"\n        if num_samples is None:\n            num_samples = len(weights)\n\n        self._num_samples = num_samples\n        super().__init__(np.arange(num_samples))\n        self._weights = weights\n\n    def __iter__(self):\n        \"\"\"\n        Defines the actual weighted random sampling\n\n        Returns\n        -------\n        Iterator\n            iterator producing random samples\n        \"\"\"\n        return iter(np.random.choice(self._indices, size=self._num_samples,\n                                     p=self._weights))\n\n    def __len__(self):\n        \"\"\"\n        Defines the length of the sampler\n\n        Returns\n        -------\n        int\n            the number of samples\n        \"\"\"\n        return self._num_samples\n\n\nclass PrevalenceRandomSampler(WeightedRandomSampler):\n    \"\"\"\n    Class implementing prevalence weighted sampling\n    \"\"\"\n\n    def __init__(self, indices):\n        \"\"\"\n\n        Parameters\n        ----------\n        indices : list\n            list of class indices to calculate a weighting from\n        \"\"\"\n\n        weights = np.array(indices).astype(np.float)\n        classes, classes_count = np.unique(indices, return_counts=True)\n\n        # compute probabilities\n        target_prob = 1 / classes.shape[0]\n\n        # generate weight matrix\n        for i, c in enumerate(classes):\n            weights[weights == c] = (target_prob / classes_count[i])\n\n        super().__init__(weights, num_samples=len(indices))\n\n    @classmethod\n    def from_dataset(cls, dset: AbstractDataset, key=\"label\", **kwargs):\n        \"\"\"\n        CLass function to create an instance of this sampler by giving it a\n        dataset\n\n        Parameters\n        ----------\n        dset : :class:`AbstractDataset`\n            the dataset to create weightings from\n        key : str\n            the key holding the class index for each sample\n        **kwargs :\n            Additional keyword arguments\n\n        \"\"\"\n        return cls([_sample[key] for _sample in dset], **kwargs)\n"
  },
  {
    "path": "delira/io/__init__.py",
    "content": "from delira import get_backends\n\nif \"TORCH\" in get_backends():\n    from delira.io.torch import save_checkpoint_torch as torch_save_checkpoint\n    from delira.io.torch import load_checkpoint_torch as torch_load_checkpoint\n\n    from delira.io.torch import save_checkpoint_torchscript \\\n        as torchscript_save_checkpoint\n    from delira.io.torch import load_checkpoint_torchscript \\\n        as torchscript_load_checkpoint\n\nif \"TF\" in get_backends():\n    from delira.io.tf import save_checkpoint as tf_save_checkpoint\n    from delira.io.tf import load_checkpoint as tf_load_checkpoint\n\n    from delira.io.tf import save_checkpoint_eager as tf_eager_save_checkpoint\n    from delira.io.tf import load_checkpoint_eager as tf_eager_load_checkpoint\n\nif \"CHAINER\" in get_backends():\n    from delira.io.chainer import save_checkpoint as chainer_save_checkpoint\n    from delira.io.chainer import load_checkpoint as chainer_load_checkpoint\n\nif \"SKLEARN\" in get_backends():\n    from delira.io.sklearn import load_checkpoint as sklearn_load_checkpoint\n    from delira.io.sklearn import save_checkpoint as sklearn_save_checkpoint\n"
  },
  {
    "path": "delira/io/chainer.py",
    "content": "import chainer\nimport zipfile\nimport os\nimport json\n\n\ndef save_checkpoint(file, model=None, optimizers=None, epoch=None):\n    \"\"\"\n    Saves the given checkpoint\n\n    Parameters\n    ----------\n    file : str\n        string containing the path, the state should be saved to\n    model : :class:`AbstractChainerNetwork`\n    optimizers : dict\n        dictionary containing all optimizers\n    epoch : int\n        the current epoch\n\n    \"\"\"\n    # config file for path mapping insde the archive\n    save_config = {}\n    # files to write to archive and delete afterwards\n    del_files = []\n\n    # save model to hdf5\n    if model is not None:\n        # temporary filename\n        _curr_file = file.replace(\"chain\", \"model\")\n        # serialize to temporary file\n        chainer.serializers.save_hdf5(_curr_file, model)\n        # add to config (without path to navigate inside archive)\n        save_config[\"model\"] = os.path.basename(_curr_file)\n        # append to files to process\n        del_files.append(_curr_file)\n\n    # save all optimizers to hdf5\n    if optimizers is not None:\n        # dict for mapping optimizer names to files\n        optim_config = {}\n        for k, v in optimizers.items():\n            # temporary file\n            _curr_file = file.replace(\"chain\", \"optim.%s\" % str(k))\n            # serialize to temporary file\n            chainer.serializers.save_hdf5(_curr_file, v)\n            # add to optimizer config (without path to navigate inside archive)\n            optim_config[k] = os.path.basename(_curr_file)\n            # append to files to process\n            del_files.append(_curr_file)\n\n        # add optimizer path mapping to config\n        save_config[\"optimizers\"] = optim_config\n\n    # add epoch to config\n    if epoch is not None:\n        save_config[\"epoch\"] = epoch\n    # temporary config file\n    _curr_file = file.replace(\"chain\", \"config\")\n    # serialize config dict to temporary json config file\n    with open(_curr_file, \"w\") as f:\n        json.dump(save_config, f)\n    # append to files to process\n    del_files.append(_curr_file)\n\n    # create the actual archive\n    with zipfile.ZipFile(file, mode=\"w\") as f:\n        for _file in del_files:\n            # write temporary file to archive and remove it afterwards\n            f.write(_file, os.path.basename(_file))\n            os.remove(_file)\n\n\ndef _deserialize_and_load(archive: zipfile.ZipFile, file: str, obj,\n                          temp_dir: str):\n    \"\"\"\n    Helper Function to temporarily extract a file from a given archive,\n    deserialize the object in this file and remove the temporary file\n\n    Parameters\n    ----------\n    archive : :class:`zipfile.Zipfile`\n        the archive containing the file to deserialize\n    file : str\n        identifier specifying the file inside the archive to extract and\n        deserialize\n    obj : Any\n        the object to load the deserialized state to. Must provide a\n        `serialize` function\n    temp_dir : str\n        the directory the file will be temporarily extracted to\n\n    Returns\n    -------\n    Any\n        the object with the loaded and deserialized state\n\n    \"\"\"\n    # temporary extract file\n    archive.extract(file, temp_dir)\n    # deserialize object\n    chainer.serializers.load_hdf5(os.path.join(temp_dir, file), obj)\n    # remove temporary file\n    os.remove(os.path.join(temp_dir, file))\n    return obj\n\n\ndef load_checkpoint(file, old_state: dict = None,\n                    model: chainer.link.Link = None, optimizers: dict = None):\n    \"\"\"\n    Loads a state from a given file\n\n    Parameters\n    ----------\n    file : str\n        string containing the path to the file containing the saved state\n    old_state : dict\n        dictionary containing the modules to load the states to\n    model : :class:`chainer.link.Link`\n        the model the state should be loaded to;\n        overwrites the ``model`` key in ``old_state`` if not None\n    optimizers : dict\n        dictionary containing all optimizers.\n        overwrites the ``optimizers`` key in ``old_state`` if not None\n\n    Returns\n    -------\n    dict\n        the loaded state\n\n    \"\"\"\n    if old_state is None:\n        old_state = {}\n\n    if model is not None:\n        old_state[\"model\"] = model\n    if optimizers is not None:\n        old_state[\"optimizers\"] = optimizers\n\n    loaded_state = {}\n\n    # open zip archive\n    with zipfile.ZipFile(file) as f:\n\n        # load config\n        _curr_file = file.replace(\"chain\", \"config\")\n        # temporarily extract json file to dir\n        f.extract(os.path.basename(_curr_file),\n                  os.path.dirname(file))\n        # load config dict\n        with open(_curr_file) as _file:\n            config = json.load(_file)\n        # remove temporary json file\n        os.remove(_curr_file)\n\n        # load model if path is inside config\n        if \"model\" in config:\n            # open file in archive by temporary extracting it\n            loaded_state[\"model\"] = _deserialize_and_load(\n                f, config[\"model\"], old_state[\"model\"], os.path.dirname(file))\n\n        # load optimizers if path mapping is inside config\n        if \"optimizers\" in config:\n            loaded_state[\"optimizers\"] = {}\n            optimizer_config = config[\"optimizers\"]\n\n            for k, v in optimizer_config.items():\n                # open file in archive by temporary extracting it\n                loaded_state[\"optimizers\"][k] = _deserialize_and_load(\n                    f, v, old_state[\"optimizers\"][k], os.path.dirname(file))\n\n        # load epoch from config if possible\n        if \"epoch\" in config:\n            loaded_state[\"epoch\"] = config[\"epoch\"]\n\n    return loaded_state\n"
  },
  {
    "path": "delira/io/sklearn.py",
    "content": "import logging\nimport joblib\nlogger = logging.getLogger(__name__)\n\n\ndef save_checkpoint(file: str, model=None, epoch=None, **kwargs):\n    \"\"\"\n    Save model's parameters\n\n    Parameters\n    ----------\n    file : str\n        filepath the model should be saved to\n    model : AbstractNetwork or None\n        the model which should be saved\n        if None: empty dict will be saved as state dict\n    epoch : int\n        current epoch (will also be pickled)\n\n    \"\"\"\n\n    return_val = joblib.dump({\"model\": model, \"epoch\": epoch}, file, **kwargs)\n    return return_val\n\n\ndef load_checkpoint(file, **kwargs):\n    \"\"\"\n    Loads a saved model\n\n    Parameters\n    ----------\n    file : str\n        filepath to a file containing a saved model\n    **kwargs:\n        Additional keyword arguments (passed to torch.load)\n        Especially \"map_location\" is important to change the device the\n        state_dict should be loaded to\n\n    Returns\n    -------\n    OrderedDict\n        checkpoint state_dict\n\n    \"\"\"\n    return joblib.load(file, **kwargs)\n"
  },
  {
    "path": "delira/io/tf.py",
    "content": "from delira.models.backends.tf_eager import AbstractTfEagerNetwork\nimport typing\nimport logging\n\nimport tensorflow as tf\n\nlogger = logging.getLogger(__name__)\n\n\ndef save_checkpoint(file: str, model=None):\n    \"\"\"\n    Save model's parameters contained in it's graph\n\n    Parameters\n    ----------\n    file : str\n        filepath the model should be saved to\n    model : TfNetwork\n        the model which should be saved\n    \"\"\"\n    tf.train.Saver().save(model._sess, file)\n\n\ndef load_checkpoint(file: str, model=None):\n    \"\"\"\n    Loads a saved model\n\n    Parameters\n    ----------\n    file : str\n        filepath to a file containing a saved model\n    model : TfNetwork\n        the model which should be loaded\n    \"\"\"\n\n    # following operation adds AssignVariableOps to the graph, keep an eye on\n    # this for memory leak\n    tf.train.Saver().restore(model._sess, file)\n    return {}\n\n\ndef _create_varlist(model: AbstractTfEagerNetwork = None,\n                    optimizer: typing.Dict[str, tf.train.Optimizer] = None):\n    variable_list = []\n\n    if model is not None:\n        variable_list += model.variables\n\n    if optimizer is not None:\n        for k, v in optimizer.items():\n            variable_list += v.variables()\n\n    return variable_list\n\n\ndef save_checkpoint_eager(file,\n                          model: AbstractTfEagerNetwork = None,\n                          optimizer: typing.Dict[str,\n                                                 tf.train.Optimizer] = None,\n                          epoch=None):\n    variable_list = _create_varlist(model, optimizer)\n\n    # can only save if variables exist, this is not the case if there was no\n    # input forwarded through the network (yet)\n    if variable_list:\n        saver = tf.contrib.eager.Saver(variable_list)\n        saver.save(file, global_step=epoch)\n        return\n    logging.warning(\"Could not save any variables because they don't exist \"\n                    \"(yet). If you haven't forwarded any input through your \"\n                    \"network yet, this is not an error, but expected behavior\")\n\n\ndef load_checkpoint_eager(file,\n                          model: AbstractTfEagerNetwork = None,\n                          optimizer: typing.Dict[str,\n                                                 tf.train.Optimizer] = None):\n\n    variable_list = _create_varlist(model, optimizer)\n\n    if variable_list:\n        saver = tf.contrib.eager.Saver(variable_list)\n        saver.restore(file)\n\n        return {\"model\": model, \"optimizer\": optimizer}\n\n    raise RuntimeError(\n        \"No Variables found to restore, probably no variables \"\n        \"exist, because they aren't yet created. Make sure, you \"\n        \"have at least once forwarded an input through your \"\n        \"model!\")\n"
  },
  {
    "path": "delira/io/torch.py",
    "content": "from delira.models.backends.torchscript import AbstractTorchScriptNetwork\nfrom delira.models.backends.torch import AbstractPyTorchNetwork\nimport torch\nimport logging\nimport os\nfrom collections import OrderedDict\n\nlogger = logging.getLogger(__name__)\n\n\ndef save_checkpoint_torch(file: str, model=None, optimizers=None,\n                          epoch=None, **kwargs):\n    \"\"\"\n    Save checkpoint\n\n    Parameters\n    ----------\n    file : str\n        filepath the model should be saved to\n    model : AbstractNetwork or None\n        the model which should be saved\n        if None: empty dict will be saved as state dict\n    optimizers : dict\n        dictionary containing all optimizers\n    epoch : int\n        current epoch (will also be pickled)\n\n    \"\"\"\n    if optimizers is None:\n        optimizers = {}\n    if isinstance(model, torch.nn.DataParallel):\n        _model = model.module\n    else:\n        _model = model\n\n    if isinstance(_model, (AbstractPyTorchNetwork,\n                           AbstractTorchScriptNetwork)):\n        model_state = _model.state_dict()\n    else:\n        model_state = {}\n        logger.debug(\"Saving checkpoint without Model\")\n\n    optim_state = OrderedDict()\n    for key, val in optimizers.items():\n        if isinstance(val, torch.optim.Optimizer):\n            optim_state[key] = val.state_dict()\n\n    if not optim_state:\n        logger.debug(\"Saving checkpoint without Optimizer\")\n\n    if epoch is None:\n        epoch = 0\n\n    state = {\"optimizer\": optim_state,\n             \"model\": model_state,\n             \"epoch\": epoch}\n\n    torch.save(state, file, **kwargs)\n\n\ndef load_checkpoint_torch(file, **kwargs):\n    \"\"\"\n    Loads a saved model\n\n    Parameters\n    ----------\n    file : str\n        filepath to a file containing a saved model\n    **kwargs:\n        Additional keyword arguments (passed to torch.load)\n        Especially \"map_location\" is important to change the device the\n        state_dict should be loaded to\n\n    Returns\n    -------\n    OrderedDict\n        checkpoint state_dict\n\n    \"\"\"\n    checkpoint = torch.load(file, **kwargs)\n\n    if not all([_key in checkpoint\n                for _key in [\"model\", \"optimizer\", \"epoch\"]]):\n        return checkpoint['state_dict']\n    return checkpoint\n\n\ndef save_checkpoint_torchscript(file: str, model=None, optimizers=None,\n                                epoch=None, **kwargs):\n    \"\"\"\n    Save current checkpoint to two different files:\n        1.) ``file + \"_model.ptj\"``: Will include the state of the model\n            (including the graph; this is the opposite to\n            :func:`save_checkpoint`)\n        2.) ``file + \"_trainer_state.pt\"``: Will include the states of all\n            optimizers and the current epoch (if given)\n\n    Parameters\n    ----------\n    file : str\n        filepath the model should be saved to\n    model : AbstractPyTorchJITNetwork or None\n        the model which should be saved\n        if None: empty dict will be saved as state dict\n    optimizers : dict\n        dictionary containing all optimizers\n    epoch : int\n        current epoch (will also be pickled)\n\n    \"\"\"\n\n    # remove file extension if given\n    if optimizers is None:\n        optimizers = {}\n    if any([file.endswith(ext) for ext in [\".pth\", \".pt\", \".ptj\"]]):\n\n        file, old_ext = file.rsplit(\".\", 1)\n\n        if old_ext != \"ptj\":\n            logger.info(\"File extension was changed from %s to ptj to \"\n                        \"indicate that the current module is a \"\n                        \"torchscript module (including the graph)\")\n\n    if isinstance(model, AbstractTorchScriptNetwork):\n        torch.jit.save(model, file + \".model.ptj\")\n\n    if optimizers or epoch is not None:\n        save_checkpoint_torch(file + \".trainer_state.pt\", None,\n                              optimizers=optimizers, epoch=epoch, **kwargs)\n\n\ndef load_checkpoint_torchscript(file: str, **kwargs):\n    \"\"\"\n    Loads a saved checkpoint consisting of 2 files\n    (see :func:`save_checkpoint_jit` for details)\n\n    Parameters\n    ----------\n    file : str\n        filepath to a file containing a saved model\n    **kwargs:\n        Additional keyword arguments (passed to torch.load)\n        Especially \"map_location\" is important to change the device the\n        state_dict should be loaded to\n\n    Returns\n    -------\n    OrderedDict\n        checkpoint state_dict\n\n    \"\"\"\n\n    # load model\n    if os.path.isfile(file):\n        model_file = file\n    elif os.path.isfile(file.replace(\".ptj\", \".model.ptj\")):\n        model_file = file.replace(\".ptj\", \".model.ptj\")\n    else:\n        raise ValueError(\"No Model File found for %s\" % file)\n\n    # load trainer state (if possible)\n    trainer_file = model_file.replace(\".model.ptj\", \".trainer_state.pt\")\n    if os.path.isfile(trainer_file):\n        trainer_state = load_checkpoint_torch(trainer_file, **kwargs)\n\n    else:\n        trainer_state = {\"optimizer\": {},\n                         \"epoch\": None}\n\n    trainer_state.update({\"model\": torch.jit.load(model_file)})\n\n    return trainer_state\n"
  },
  {
    "path": "delira/logging/__init__.py",
    "content": "from delira.logging.tensorboard_backend import TensorboardBackend\nfrom delira.logging.visdom_backend import VisdomBackend\nfrom delira.logging.base_backend import BaseBackend\nfrom delira.logging.writer_backend import WriterLoggingBackend\nfrom delira.logging.base_logger import Logger, SingleThreadedLogger, \\\n    make_logger\nfrom delira.logging.registry import unregister_logger, register_logger, \\\n    get_logger, logger_exists, log as _log, get_available_loggers\nfrom delira.logging.logging_context import LoggingContext\n\nlog = _log\n"
  },
  {
    "path": "delira/logging/base_backend.py",
    "content": "\nfrom queue import Empty\nfrom abc import abstractmethod, ABCMeta\nfrom threading import Event\nfrom queue import Queue\nimport warnings\n\n_FUNCTIONS_WITHOUT_STEP = (\"graph_pytorch\", \"graph_tf\", \"graph_onnx\",\n                           \"embedding\")\n\n# Deprecated Keys with their future alternative\n_DEPRECATED_KEYS = {\"img\": \"image\", \"picture\": \"image\", \"imgs\": \"images\",\n                    \"pictures\": \"images\", \"bounding_boxes\": \"image_with_boxes\",\n                    \"bboxes\": \"image_with_boxes\", \"value\": \"scalar\",\n                    \"values\": \"scalar\", \"hist\": \"histogram\", \"fig\": \"figure\",\n                    \"sound\": \"audio\", \"pr\": \"pr_curve\", \"curve\": \"line\",\n                    \"hm\": \"heatmap\"}\n\n\nclass BaseBackend(object, metaclass=ABCMeta):\n    \"\"\"\n    The basic Logging Backend, Provides an abstract interface to log\n    different value types and some keyword mappings\n    \"\"\"\n\n    class FigureManager:\n        \"\"\"\n        A Figure Manager, which creates a figure during entrance and pushes\n        the figure to logging writer during exit\n        \"\"\"\n\n        def __init__(self, push_fn, figure_kwargs: dict, push_kwargs: dict):\n            \"\"\"\n\n            Parameters\n            ----------\n            push_fn : function\n                A function accepting a figure and some keyword arguments\n                to push it to the logging writer\n            figure_kwargs : dict\n                dictionary containing all keyword arguments to create the\n                figure\n            push_kwargs : dict\n                dictionary containing all keyword arguments to push the figure\n                to the loggging writer\n            \"\"\"\n            self._push_fn = push_fn\n            self._figure_kwargs = figure_kwargs\n            self._push_kwargs = push_kwargs\n            self._fig = None\n\n        def __enter__(self):\n            \"\"\"\n            Function to be executed during context-manager entrance;\n            Will create a figure with the figure kwargs\n\n            \"\"\"\n            from matplotlib.pyplot import figure\n            self._fig = figure(**self._figure_kwargs)\n\n        def __exit__(self, *args):\n            \"\"\"\n            Function to be executed during context-manager exit;\n            Will push the figure to the logging writer and destroy it\n            afterwards\n\n            Parameters\n            ----------\n            *args :\n                arbitrary positional arguments; Necessary to be compatible\n                with other context managers, but not used in this one\n\n            \"\"\"\n            from matplotlib.pyplot import close\n            self._push_fn(figure=self._fig, **self._push_kwargs)\n\n            close(self._fig)\n            self._fig = None\n\n    def __init__(self, abort_event: Event = None, queue: Queue = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        abort_event : :class:`threading.Event`\n            the event to signalize, when the logger must be destroyed\n        queue : :class:`queue.Queue`\n            the queue to enqueue all tuples of mapped functions and the\n            corresponding arguments before their execution\n\n        \"\"\"\n        super().__init__()\n        self.KEYWORD_FN_MAPPING = {}\n\n        self.daemon = True\n\n        self._queue = queue\n        self._abort_event = abort_event\n        self._global_steps = {}\n        # create Keyword mapping\n        self.KEYWORD_FN_MAPPING.update(**{\n            \"image\": self._image,\n            \"img\": self._image,\n            \"picture\": self._image,\n            \"images\": self._images,\n            \"imgs\": self._images,\n            \"pictures\": self._images,\n            \"image_with_boxes\": self._image_with_boxes,\n            \"bounding_boxes\": self._image_with_boxes,\n            \"bboxes\": self._image_with_boxes,\n            \"scalar\": self._scalar,\n            \"value\": self._scalar,\n            \"scalars\": self._scalars,\n            \"values\": self._scalars,\n            \"histogram\": self._histogram,\n            \"hist\": self._histogram,\n            \"figure\": self._figure,\n            \"fig\": self._figure,\n            \"audio\": self._audio,\n            \"sound\": self._audio,\n            \"video\": self._video,\n            \"text\": self._text,\n            \"graph_pytorch\": self._graph_pytorch,\n            \"graph_tf\": self._graph_tf,\n            \"graph_onnx\": self._graph_onnx,\n            \"embedding\": self._embedding,\n            \"pr_curve\": self._pr_curve,\n            \"pr\": self._pr_curve,\n            \"scatter\": self._scatter,\n            \"line\": self._line,\n            \"curve\": self._line,\n            \"stem\": self._stem,\n            \"heatmap\": self._heatmap,\n            \"hm\": self._heatmap,\n            \"bar\": self._bar,\n            \"boxplot\": self._boxplot,\n            \"surface\": self._surface,\n            \"contour\": self._contour,\n            \"quiver\": self._quiver,\n            # \"mesh\": self._mesh\n        })\n\n    def _log_item(self):\n        \"\"\"\n        Internal helper function to log an item of the queue\n\n        Raises\n        ------\n        ValueError\n            if the item to log is not a dict\n\n        \"\"\"\n        # get item from dict\n        process_item = self._queue.get(timeout=0.001)\n        # log item if item is dict\n        if isinstance(process_item, dict):\n\n            for key, val in process_item.items():\n                # raise DeprecationWarning for deprecated keys\n                if key in _DEPRECATED_KEYS:\n                    warnings.warn(\"The Key %s is deprecated and will\"\n                                  \" be removed in the next release. \"\n                                  \"Please use %s instead!\"\n                                  % (key, _DEPRECATED_KEYS[key]),\n                                  DeprecationWarning)\n\n                # performs the actual mapping\n                execute_fn = self.KEYWORD_FN_MAPPING[str(key).lower()]\n\n                # resolve the global step\n                val = self._resolve_global_step(str(key).lower(), **val)\n\n                # execute the logging function\n                self._call_exec_fn(execute_fn, val)\n\n        # item is no dict -> raise Error\n        else:\n            raise ValueError(\"Invalid Value passed for logging: %s\"\n                             % str(process_item))\n\n    def _resolve_global_step(self, key, **val):\n        \"\"\"\n        Helper function to resolve the global step from given Arguments\n\n        Parameters\n        ----------\n        key : str\n            the function key to resolve the step for\n        **val :\n            kwargs which may contain the step information\n\n        Returns\n        -------\n        int\n            the global step\n\n        Raises\n        ------\n        ValueError\n            If no valid tag was found although a tag should exist\n\n        \"\"\"\n        # check if function should be processed statically\n        # (no time update possible)\n        if str(key).lower() not in _FUNCTIONS_WITHOUT_STEP:\n\n            # check for different step names\n            if \"tag\" in val:\n                tag = \"tag\"\n            elif \"main_tag\" in val:\n                tag = \"main_tag\"\n            else:\n                raise ValueError(\"No valid tag found to extract global step\")\n\n            # check if global step is given\n            if \"global_step\" not in val or val[\"global_step\"] is None:\n\n                # check if tag is already part of internal global steps\n                if val[tag] in self._global_steps:\n                    # if already existent: increment step for given tag\n                    self._global_steps[val[tag]] += 1\n                    step = self._global_steps[val[tag]]\n\n                else:\n                    # if not existent_ set step for given tag to zero\n                    step = 0\n                    self._global_steps[val[tag]] = step\n\n                val.update({\"global_step\": step})\n\n            elif \"global_step\" in val:\n                self._global_steps[tag] = val[\"global_step\"]\n\n        return val\n\n    def run(self):\n        \"\"\"\n        Main function which executes the logging, catches exceptions and sets\n        the abortion event if necessary\n\n        \"\"\"\n        try:\n            self._log_item()\n\n        except Empty:\n            pass\n\n        except Exception as e:\n            self._abort_event.set()\n            raise e\n\n    def set_queue(self, queue: Queue):\n        \"\"\"\n        Setter Function for the Queue\n\n        Parameters\n        ----------\n        queue : :class:`queue.Queue`\n            the new queue\n\n        \"\"\"\n        self._queue = queue\n\n    def set_event(self, event: Event):\n        \"\"\"\n        Setter Function for the abortion event\n\n        Parameters\n        ----------\n        event : :class:`threading.Event`\n            the new abortion event\n\n        \"\"\"\n        self._abort_event = event\n\n    def _call_exec_fn(self, exec_fn, args):\n        \"\"\"\n        Helper Function calling the actual  mapped function\n\n        Parameters\n        ----------\n        exec_fn : function\n            the function which will execute the actual logging\n        args : iterable (listlike) or mapping (dictlike)\n            the arguments passed to the ``exec_fn``\n\n        Returns\n        -------\n        Any\n            the return value obtained by the ``exec_fn``\n\n        Raises\n        ------\n        TypeError\n            if the given ``args`` are neither of type dict or tuple/list\n\n        \"\"\"\n\n        if isinstance(args, dict):\n            ret_val = exec_fn(**args)\n        elif isinstance(args, (tuple, list)):\n            ret_val = exec_fn(*args)\n\n        else:\n            raise TypeError(\"Invalid type for args. Must be either dict, \"\n                            \"tuple or list, but got %s.\"\n                            % args.__class__.__name__)\n\n        return ret_val\n\n    @abstractmethod\n    def _image(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single image\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _images(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log multiple images\n\n        Parameters\n        ----------\n        *args\n        **kwargs\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _image_with_boxes(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single image with bounding boxes\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _scalar(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single scalar value\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _scalars(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log multiple scalar values\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _histogram(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to create and log a histogram out of given\n        values\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _figure(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single ``matplotlib`` figure\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _audio(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single audio signal\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _video(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single video\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _text(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a single string as text\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _graph_pytorch(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a ``PyTorch`` Graph\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n\n        raise NotImplementedError\n\n    @abstractmethod\n    def _graph_tf(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a TF Graph\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _graph_onnx(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to log a ONNX Graph\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _embedding(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to create and log an embedding\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    @abstractmethod\n    def _pr_curve(self, *args, **kwargs):\n        \"\"\"\n        Abstract Interface Function to calculate and log a PR curve out of\n        given values\n\n        Parameters\n        ----------\n        *args\n            arbitrary positional arguments\n        **kwargs\n            arbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten in subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    def _scatter(self, plot_kwargs: dict, figure_kwargs: dict = None,\n                 **kwargs):\n        \"\"\"\n        Function to create a scatter plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import scatter\n\n            scatter(self, **plot_kwargs)\n\n    def _line(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a line plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import plot\n            plot(**plot_kwargs)\n\n    def _stem(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a stem plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import stem\n            stem(**plot_kwargs)\n\n    def _heatmap(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a heatmap plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from seaborn import heatmap\n            heatmap(**plot_kwargs)\n\n    def _bar(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a bar plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import bar\n            bar(**plot_kwargs)\n\n    def _boxplot(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a boxplot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import boxplot\n            boxplot(**plot_kwargs)\n\n    def _surface(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a surface plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from seaborn import kdeplot\n\n            kdeplot(**plot_kwargs)\n\n    def _contour(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a contour plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import contour\n\n            contour(**plot_kwargs)\n\n    def _quiver(self, plot_kwargs=None, figure_kwargs=None, **kwargs):\n        \"\"\"\n        Function to create a quiver plot and push it\n\n        Parameters\n        ----------\n        plot_kwargs : dict\n            the arguments for plotting\n        figure_kwargs : dict\n            the arguments to actually create the figure\n        **kwargs :\n            additional keyword arguments for pushing the created figure to the\n            logging writer\n\n        \"\"\"\n        if plot_kwargs is None:\n            plot_kwargs = {}\n        if figure_kwargs is None:\n            figure_kwargs = {}\n        with self.FigureManager(self._figure, figure_kwargs, kwargs):\n            from matplotlib.pyplot import quiver\n            quiver(**plot_kwargs)\n\n    @property\n    def name(self):\n        return \"BaseBackend\"\n"
  },
  {
    "path": "delira/logging/base_logger.py",
    "content": "from multiprocessing.queues import Queue as MpQueue\nfrom threading import Event\nfrom queue import Queue, Full\nfrom delira.logging.base_backend import BaseBackend\nfrom delira.utils.dict_reductions import get_reduction, possible_reductions, \\\n    reduce_dict\nimport logging\nfrom types import FunctionType\n\n\nclass Logger(object):\n    \"\"\"\n    The actual Logger Frontend, passing logging messages to the assigned\n    logging backend if appropriate or to python's logging module if not\n    \"\"\"\n\n    def __init__(self, backend: BaseBackend, max_queue_size: int = None,\n                 logging_frequencies=None, reduce_types=None,\n                 level=logging.NOTSET):\n        \"\"\"\n\n        Parameters\n        ----------\n        backend : :class:`delira.logging.base_backend.BaseBackend`\n            the logging backend to use\n        max_queue_size : int\n            the maximum size for the queue; if queue is full, all additional\n            logging tasks will be dropped until some tasks inside the queue\n            were executed; Per default no maximum size is applied\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n            will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        reduce_types : str of FunctionType or dict\n            Values are logged in each iteration. This argument specifies,\n            how to reduce them to a single value if a logging_frequency\n            besides 1 is passed\n\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied for\n                all keys.\n            if dict: should contain pairs of valid logging keys and either str\n                or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'max' | 'min'.\n        level : int\n            the logging value to use if passing the logging message to\n            python's logging module because it is not appropriate for logging\n            with the assigned logging backendDict[str, Callable]\n\n        Warnings\n        --------\n        Since the intermediate values between to logging steps  are stored in\n        memory to enable reduction, this might cause OOM errors easily\n        (especially if the logged items are still on GPU).\n        If this occurs you may want to choose a lower logging frequency.\n\n        \"\"\"\n\n        # 0 means unlimited size, but None is more readable\n        if max_queue_size is None:\n            max_queue_size = 0\n\n        # convert to empty dict if None\n        if logging_frequencies is None:\n            logging_frequencies = {}\n\n        # if int: assign int to all possible keys\n        if isinstance(logging_frequencies, int):\n            logging_frequencies = {\n                k: logging_frequencies\n                for k in backend.KEYWORD_FN_MAPPING.keys()}\n        # if dict: update missing keys with 1 and make sure other values\n        # are ints\n        elif isinstance(logging_frequencies, dict):\n            for k in backend.KEYWORD_FN_MAPPING.keys():\n                if k not in logging_frequencies:\n                    logging_frequencies[k] = 1\n                else:\n                    logging_frequencies[k] = int(logging_frequencies[k])\n        else:\n            raise TypeError(\"Invalid Type for logging frequencies: %s\"\n                            % type(logging_frequencies).__name__)\n\n        # assign frequencies and create empty queues\n        self._logging_frequencies = logging_frequencies\n        self._logging_queues = {}\n\n        default_reduce_type = \"last\"\n        if reduce_types is None:\n            reduce_types = default_reduce_type\n\n        # map string and function to all valid keys\n        if isinstance(reduce_types, (str, FunctionType)):\n            reduce_types = {\n                k: reduce_types\n                for k in backend.KEYWORD_FN_MAPPING.keys()}\n\n        # should be dict by now!\n        if isinstance(reduce_types, dict):\n            # check all valid keys for occurences\n            for k in backend.KEYWORD_FN_MAPPING.keys():\n                # use default reduce type if necessary\n                if k not in reduce_types:\n                    reduce_types[k] = default_reduce_type\n                # check it is either valid string or already function type\n                else:\n                    if not isinstance(reduce_types, FunctionType):\n                        assert reduce_types[k] in possible_reductions()\n                        reduce_types[k] = str(reduce_types[k])\n                # map all strings to actual functions\n                if isinstance(reduce_types[k], str):\n                    reduce_types[k] = get_reduction(reduce_types[k])\n\n        else:\n            raise TypeError(\"Invalid Type for logging reductions: %s\"\n                            % type(reduce_types).__name__)\n\n        self._reduce_types = reduce_types\n\n        self._abort_event = Event()\n        self._flush_queue = Queue(max_queue_size)\n        self._backend = backend\n        self._backend.set_queue(self._flush_queue)\n        self._backend.set_event(self._abort_event)\n        self._level = level\n\n    def log(self, log_message: dict):\n        \"\"\"\n        Main Logging Function, Decides whether to log with the assigned\n        backend or python's internal module\n\n        Parameters\n        ----------\n        log_message : dict\n            the message to log; Should be a dict, where the keys indicate the\n            logging function to execute, and the corresponding value holds\n            the arguments necessary to execute this function\n\n        Raises\n        ------\n        RuntimeError\n            If the abort event was set externally\n\n        \"\"\"\n\n        try:\n            if self._abort_event.is_set():\n                self.close()\n                raise RuntimeError(\"Abort-Event in logging process was set: %s\"\n                                   % self._backend.name)\n\n            # convert tuple to dict if necessary\n            if isinstance(log_message, (tuple, list)):\n                if len(log_message) == 2:\n                    log_message = (log_message,)\n                log_message = dict(log_message)\n\n            # try logging and drop item if queue is full\n            try:\n                # logging appropriate message with backend\n                if isinstance(log_message, dict):\n                    # multiple logging instances at once possible with\n                    # different keys\n                    for k, v in log_message.items():\n                        # append tag if tag is given, because otherwise we\n                        # would enqueue same types but different tags in same\n                        # queue\n                        if \"tag\" in v:\n                            queue_key = k + \".\" + v[\"tag\"]\n                        else:\n                            queue_key = k\n\n                        # create queue if necessary\n                        if queue_key not in self._logging_queues:\n                            self._logging_queues[queue_key] = []\n\n                        # append current message to queue\n                        self._logging_queues[queue_key].append({k: v})\n                        # check if logging should be executed\n                        if (len(self._logging_queues[queue_key])\n                                % self._logging_frequencies[k] == 0):\n                            # reduce elements inside queue\n                            reduce_message = reduce_dict(\n                                self._logging_queues[queue_key],\n                                self._reduce_types[k])\n                            # flush reduced elements\n                            self._flush_queue.put_nowait(reduce_message)\n                            # empty queue\n                            self._logging_queues[queue_key] = []\n                else:\n                    # logging inappropriate message with python's logging\n                    logging.log(self._level, log_message)\n            except Full:\n                pass\n\n        # if an exception was raised anywhere, the abort event will be set\n        except Exception as e:\n            self._abort_event.set()\n            raise e\n\n    def __call__(self, log_message: dict):\n        \"\"\"\n        Makes the class callable and forwards the call to\n        :meth:`delira.logging.base_logger.Logger.log`\n\n        Parameters\n        ----------\n        log_message : dict\n            the logging message to log\n\n        Returns\n        -------\n        Any\n            the return value obtained by\n            :meth:`delira.logging.base_logger.Logger.log`\n\n        \"\"\"\n        return self.log(log_message)\n\n    def close(self):\n        \"\"\"\n        Function to close the actual logger; Waits for queue closing and sets\n        the abortion event\n\n        \"\"\"\n        if hasattr(self, \"_flush_queue\"):\n            if isinstance(self._flush_queue, MpQueue):\n                self._flush_queue.close()\n                self._flush_queue.join_thread()\n\n        if hasattr(self, \"abort_event\"):\n            self._abort_event.set()\n\n    def __del__(self):\n        \"\"\"\n        Function to be executed, when class instance will be deleted;\n        Calls :meth:`delira.logging.base_logger.Logger.close`\n\n        \"\"\"\n\n        self.close()\n\n\nclass SingleThreadedLogger(Logger):\n    \"\"\"\n    A single threaded Logger which executes the backend after logging\n    a single element\n    \"\"\"\n\n    def log(self, log_message: dict):\n        \"\"\"\n        Function to log an actual logging message; Calls the backend to\n        execute the logging right after pushing it to the queue\n\n        Parameters\n        ----------\n        log_message : dict\n            the message to log; Should be a dict, where the keys indicate the\n            logging function to execute, and the corresponding value holds\n            the arguments necessary to execute this function\n\n        \"\"\"\n        super().log(log_message)\n        self._backend.run()\n\n\ndef make_logger(backend: BaseBackend, max_queue_size: int = None,\n                logging_frequencies=None, reduce_types=None,\n                level=logging.NOTSET):\n    \"\"\"\n    Function to create a logger\n\n    Parameters\n    ----------\n    backend : :class:`delira.logging.base_backend.BaseBackend`\n        the logging backend\n    max_queue_size : int\n        the maximum queue size\n    logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n            will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n    reduce_types : str of FunctionType or dict\n        if str:\n            specifies the reduction type to use. Valid types are\n            'last' | 'first' | 'mean' | 'max' | 'min'.\n            The given type will be mapped to all valid keys.\n        if FunctionType:\n            specifies the actual reduction function. Will be applied for\n            all keys.\n        if dict: should contain pairs of valid logging keys and either str\n            or FunctionType. Specifies the logging value per key.\n            Missing keys will be filles with a default value of 'last'.\n            Valid types for strings are\n            'last' | 'first' | 'mean' | 'max' | 'min'.\n    level : int\n        the logging level for python's internal logging module\n\n    Notes\n    -----\n    This function shall be used to create\n    Loggers (if possible), since it may be extended with new functionalities\n    in the future\n\n    Returns\n    -------\n    :class:`Logger`\n        the instance of aa newly created logger\n\n    \"\"\"\n\n    return SingleThreadedLogger(backend=backend, max_queue_size=max_queue_size,\n                                logging_frequencies=logging_frequencies,\n                                reduce_types=reduce_types, level=level)\n"
  },
  {
    "path": "delira/logging/logging_context.py",
    "content": "from delira.logging.registry import logger_exists, register_logger, \\\n    unregister_logger, log as _log\nfrom delira.logging.base_logger import make_logger\n\nlog = _log\n\n\nclass LoggingContext(object):\n    \"\"\"\n    Contextmanager to set a new logging context\n    \"\"\"\n\n    def __init__(\n            self,\n            name,\n            initialize_if_missing=False,\n            destroy_on_exit=None,\n            **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        name : str\n            the name of the logger to use\n        initialize_if_missing : bool\n            whether to create a logger if it does not yet exist\n        destroy_on_exit : bool\n            whether to destroy the logger on exit; If None, the logger will\n            only be destroyed, if it was created here\n        **kwargs:\n            additional keyword arguments to create a logger if necessary\n\n        Raises\n        ------\n        ValueError\n            if the logger does not exist already and shall not be created\n        \"\"\"\n\n        # Logger does exist already\n        if logger_exists(name):\n            self._name = name\n            if destroy_on_exit is None:\n                destroy_on_exit = False\n\n        # logger will be created\n        elif initialize_if_missing:\n            register_logger(make_logger(**kwargs), name)\n            if destroy_on_exit is None:\n                destroy_on_exit = True\n            self._name = name\n\n        # logger does not exist and shall not be created\n        else:\n            raise ValueError(\"No valid logger for name %s and \"\n                             \"'initialize_if_missing' is False\" % name)\n\n        self._destroy_on_exit = destroy_on_exit\n\n    def __enter__(self):\n        \"\"\"\n        Function to be executed during entrance;\n        Resets the logging context\n\n        Returns\n        -------\n        :class:`LoggingContext`\n            self\n        \"\"\"\n        global log\n        log = self.log\n        return self\n\n    def __exit__(self, *args):\n        \"\"\"\n        Function to be called during exiting the context manager;\n        Destroys the logger if necessary and resets the old logging context\n\n        Parameters\n        ----------\n        *args\n            Postional arguments to be compatible with other context managers\n\n        Returns\n        -------\n\n        \"\"\"\n        if self._destroy_on_exit:\n            _logger = unregister_logger(self._name)\n            del _logger\n\n        global log\n        log = _log\n\n    def log(self, msg: dict):\n        \"\"\"\n        Main Logging Function, Decides whether to log with the assigned\n        backend or python's internal module\n\n        Parameters\n        ----------\n        msg : dict\n            the message to log; Should be a dict, where the keys indicate the\n            logging function to execute, and the corresponding value holds\n            the arguments necessary to execute this function\n        \"\"\"\n\n        _log(msg, self._name)\n\n    def __call__(self, log_message: dict):\n        \"\"\"\n        Makes the class callable and forwards the call to\n        :meth:`delira.logging.base_logger.Logger.log`\n\n        Parameters\n        ----------\n        log_message : dict\n            the logging message to log\n\n        Returns\n        -------\n        Any\n            the return value obtained by\n            :meth:`LoggingContext.log`\n\n        \"\"\"\n        return self.log(log_message)\n"
  },
  {
    "path": "delira/logging/registry.py",
    "content": "from delira.logging.base_logger import Logger\nfrom collections import OrderedDict\n\n# Registry dict containing all registered available Loggers\n# Use Ordered Dict here to use first logger for logging if no name was given\n_AVAILABLE_LOGGERS = OrderedDict()\n\n\ndef log(msg: dict, name=None):\n    \"\"\"\n    Global logging function\n\n    Parameters\n    ----------\n    msg : dict\n        the message to log; Should be a dict, where the keys indicate the\n        logging function to execute, and the corresponding value holds\n        the arguments necessary to execute this function\n    name : str\n        the name of the logger to use;\n        if None: the last logger will be used\n\n    Raises\n    ------\n    AssertionError\n        if the logger with the specified name does not exist\n    AssertionError\n        if the returned object is not a logger\n\n    Returns\n    -------\n    Any\n        the value obtained by the loggers ``log`` function\n\n    \"\"\"\n\n    # use last name if no name is present\n    if name is None:\n        name = get_available_loggers()[-1]\n\n    assert logger_exists(name)\n    _logger = get_logger(name)\n\n    assert isinstance(_logger, Logger)\n\n    return _logger.log(msg)\n\n\ndef logger_exists(name: str):\n    \"\"\"\n    Check if logger exists\n\n    Parameters\n    ----------\n    name : str\n        the name to check the existence for\n\n    Returns\n    -------\n    bool\n        whether a logger with the given name exists\n\n    \"\"\"\n    return name in _AVAILABLE_LOGGERS\n\n\ndef register_logger(logger: Logger, name: str, overwrite=False):\n    \"\"\"\n    Register a new logger to the Registry\n\n    Parameters\n    ----------\n    logger : :class:`delira.logging.base_logger.Logger`\n        the logger to register\n    name : str\n        the corresponding name, to register the logger at\n    overwrite : bool\n        whether or not to overwrite existing loggers if necessary\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the registered logger object\n\n    \"\"\"\n\n    if not logger_exists(name) or overwrite:\n        _AVAILABLE_LOGGERS[name] = logger\n\n    return get_logger(name)\n\n\ndef unregister_logger(name: str):\n    \"\"\"\n    Unregisters a logger from the registry\n\n    Parameters\n    ----------\n    name : str\n        the name of the logger to unregister\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the registered logger object\n    \"\"\"\n    return _AVAILABLE_LOGGERS.pop(name)\n\n\ndef get_logger(name):\n    \"\"\"\n    Returns a logger from the registry\n\n    Parameters\n    ----------\n    name : str\n        the name indicating the logger to return\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the specified logger object\n\n    \"\"\"\n    return _AVAILABLE_LOGGERS[name]\n\n\ndef get_available_loggers():\n    \"\"\"\n    Gets names for all registered loggers\n\n    Returns\n    -------\n    tuple\n        a tuple of strings specifying the names of all registered loggers\n\n    \"\"\"\n    return tuple(_AVAILABLE_LOGGERS.keys())\n"
  },
  {
    "path": "delira/logging/tensorboard_backend.py",
    "content": "from threading import Event\nfrom queue import Queue\n\nfrom delira.logging.writer_backend import WriterLoggingBackend\n\n# use torch SummaryWriter if possible, since this one has latest pytorch\n# capabilities\ntry:\n    from torch.utils.tensorboard import SummaryWriter\n    LOGDIR_KWARG = \"log_dir\"\nexcept ImportError:\n    from tensorboardX import SummaryWriter\n    LOGDIR_KWARG = \"logdir\"\n\n\nclass TensorboardBackend(WriterLoggingBackend):\n    \"\"\"\n    A Tensorboard logging backend\n    \"\"\"\n\n    def __init__(self, writer_kwargs=None,\n                 abort_event: Event = None, queue: Queue = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        writer_kwargs : dict\n            arguments to initialize a writer\n        abort_event : :class:`threading.Event`\n            the abortion event\n        queue : :class:`queue.Queue`\n            the queue holding all logging tasks\n        \"\"\"\n\n        if writer_kwargs is None:\n            writer_kwargs = {}\n\n        if \"logdir\" in writer_kwargs:\n            writer_kwargs[LOGDIR_KWARG] = writer_kwargs.pop(\"logdir\")\n        elif \"log_dir\" in writer_kwargs:\n            writer_kwargs[LOGDIR_KWARG] = writer_kwargs.pop(\"log_dir\")\n\n        super().__init__(SummaryWriter, writer_kwargs,\n                         abort_event, queue)\n\n    def _call_exec_fn(self, exec_fn, args):\n        \"\"\"\n        Helper Function calling the actual mapped function and flushing\n        results to the writer afterwards\n\n        Parameters\n        ----------\n        exec_fn : function\n            the function which will execute the actual logging\n        args : iterable (listlike) or mapping (dictlike)\n            the arguments passed to the ``exec_fn``\n\n        Returns\n        -------\n        Any\n            the return value obtained by the ``exec_fn``\n\n        \"\"\"\n        ret_val = super()._call_exec_fn(exec_fn, args)\n\n        self._writer.file_writer.flush()\n\n        return ret_val\n\n    def __del__(self):\n        \"\"\"\n        Function to be executed at deletion;\n        Flushes all unsaved changes\n\n        \"\"\"\n        self._writer.file_writer.flush()\n\n    def _graph_pytorch(self, model, input_to_model=None, verbose=False,\n                       **kwargs):\n        \"\"\"\n        Function to log a PyTorch graph\n\n        Parameters\n        ----------\n        model : :class:`AbstractPyTorchNetwork`\n            the model, whose graph shall be logged\n        input_to_model : :class:`torch.Tensor`\n            the input to the model; necessary for graph traversal\n        verbose : bool\n            verbosity option\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            model=model, input_to_model=input_to_model,\n            verbose=verbose, **kwargs)\n\n        self._writer.add_graph(*converted_args, **converted_kwargs)\n\n    def _graph_tf(self, graph, run_metadata=None):\n        \"\"\"\n        Function to log a TensorFlow Graph\n\n        Parameters\n        ----------\n        graph : :class:`tensorflow.Graph` or :class:`tensorflow.GraphDef`\n        run_metadata :\n            the run metadata\n\n        Raises\n        ------\n        TypeError\n            if given graph cannot be converted to graphdef\n\n        \"\"\"\n        import tensorflow as tf\n        from tensorboardX.proto.event_pb2 import Event, TaggedRunMetadata\n\n        # convert to graphdef\n        if isinstance(graph, tf.Graph):\n            graphdef = graph.as_graph_def()\n        elif isinstance(graph, tf.GraphDef):\n            graphdef = graph\n        elif hasattr(graph, \"SerializeToString\"):\n            graphdef = graph\n        else:\n            raise TypeError(\"Invalid type given for graph: %s\" %\n                            graph.__class__.__name__)\n\n        if run_metadata:\n            run_metadata = TaggedRunMetadata(\n                tag='step1', run_metadata=run_metadata.SerializeToString())\n\n        self._writer._get_file_writer().add_event(\n            Event(\n                graph_def=graphdef.SerializeToString(),\n                tagged_run_metadata=run_metadata))\n\n    def _graph_onnx(self, prototxt):\n        \"\"\"\n        Function to log a ONNX graph to file\n\n        Parameters\n        ----------\n        prototxt : str\n            filepath to a given prototxt file containing an ONNX graph\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            prototxt=prototxt)\n        self._writer.add_onnx_graph(*converted_args, **converted_kwargs)\n\n    def _embedding(self, mat, metadata=None, label_img=None, global_step=None,\n                   tag='default', metadata_header=None):\n        \"\"\"\n        Function to create an embedding of given data\n\n        Parameters\n        ----------\n        mat : array-like\n            an arraylike object, which can be converted to a numpy array;\n            holds the actual embedding value\n        metadata :\n            the embeddings metadata\n        label_img : array-like\n            an arraylike object, which can be converted to a numpy array;\n            holds the label image\n        global_step : int\n            the global step\n        tag : str\n            the tag to store the embedding at\n        metadata_header :\n            the metadata header\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            mat=mat, metadata=metadata, label_img=label_img,\n            global_step=global_step\n        )\n        self._writer.add_embedding(*converted_args, **converted_kwargs)\n\n    def _scalars(self, main_tag: str, tag_scalar_dict: dict, global_step=None,\n                 walltime=None, sep=\"/\"):\n        \"\"\"\n        Function to log multiple scalars at once. Opposing to the base\n        function, this is done sequentially rather then parallel to avoid\n        creating new event files\n\n        Parameters\n        ----------\n        main_tag : str\n            the main tag, will be combined with the subtags inside the\n            ``tag_scalar_dict``\n        tag_scalar_dict : dict\n            dictionary of (key, scalar) pairs\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        sep : str\n            the character separating maintag and subtag in the final tag\n\n        \"\"\"\n\n        # log scalars sequentially\n        for key, val in tag_scalar_dict.items():\n            # combine tags\n            new_tag = main_tag + sep + key\n            self._scalar(new_tag, val, global_step=global_step,\n                         walltime=walltime)\n\n    @property\n    def name(self):\n        return \"TensorFlow Backend\"\n"
  },
  {
    "path": "delira/logging/visdom_backend.py",
    "content": "import tensorboardX\nfrom threading import Event\nfrom queue import Queue\n\nfrom delira.logging.writer_backend import WriterLoggingBackend\n\n\nclass VisdomBackend(WriterLoggingBackend):\n    \"\"\"\n    A Visdom Logging backend\n    \"\"\"\n\n    def __init__(self, writer_kwargs: dict = None,\n                 abort_event: Event = None, queue: Queue = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        writer_kwargs : dict\n            arguments to initialize a writer\n        abort_event : :class:`threading.Event`\n            the abortion event\n        queue : :class:`queue.Queue`\n            the queue holding all logging tasks\n        \"\"\"\n\n        if writer_kwargs is None:\n            writer_kwargs = {}\n\n        super().__init__(\n            tensorboardX.visdom_writer.VisdomWriter,\n            writer_kwargs,\n            abort_event,\n            queue)\n\n    @property\n    def name(self):\n        return \"VisdomBackend\"\n"
  },
  {
    "path": "delira/logging/writer_backend.py",
    "content": "\nfrom delira.logging.base_backend import BaseBackend\nfrom queue import Queue\nfrom threading import Event\n\n\nclass WriterLoggingBackend(BaseBackend):\n    \"\"\"\n    A Basic Writer Backend for a unspecified writer class\n    \"\"\"\n\n    def __init__(self, writer_cls, writer_kwargs: dict,\n                 abort_event: Event = None, queue: Queue = None):\n        super().__init__(abort_event, queue)\n\n        self._writer = writer_cls(**writer_kwargs)\n\n    @staticmethod\n    def convert_to_npy(*args, **kwargs):\n        \"\"\"\n        Function to convert all positional args and keyword args to numpy\n        (returns identity per default, but can be overwritten in subclass to\n        log more complex types)\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n        Returns\n        -------\n        tuple\n            converted positional arguments\n        dict\n            converted keyword arguments\n        \"\"\"\n        return args, kwargs\n\n    def _image(self, tag, img_tensor, global_step=None, walltime=None,\n               dataformats='CHW'):\n        \"\"\"\n        Function to log a single image\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, global_step=global_step,\n            walltime=walltime, dataformats=dataformats)\n\n        self._writer.add_image(*converted_args, **converted_kwargs)\n\n    def _images(self, tag, img_tensor, global_step=None, walltime=None,\n                dataformats='NCHW'):\n        \"\"\"\n        Function to log multiple values\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, global_step=global_step,\n            walltime=walltime, dataformats=dataformats)\n\n        self._writer.add_images(*converted_args, **converted_kwargs)\n\n    def _image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,\n                          walltime=None, dataformats='CHW', **kwargs):\n        \"\"\"\n        Function to log a single image with bounding boxes\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        box_tensor : array-like\n            an array-like object containing the actual bounding boxes in xyxy\n            format; must be convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, box_tensor=box_tensor,\n            global_step=global_step, walltime=walltime,\n            dataformats=dataformats, **kwargs)\n\n        self._writer.add_image_with_boxes(*converted_args, **converted_kwargs)\n\n    def _scalar(self, tag, scalar_value, global_step=None, walltime=None):\n        \"\"\"\n        Function to log a single scalar value\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        scalar_value : int or float\n            the scalar value to log\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, scalar_value=scalar_value, global_step=global_step,\n            walltime=walltime)\n        self._writer.add_scalar(*converted_args, **converted_kwargs)\n\n    def _scalars(self, main_tag, tag_scalar_dict, global_step=None,\n                 walltime=None):\n        \"\"\"\n        Function to log multiple scalars\n\n        Parameters\n        ----------\n        main_tag : str\n            the main tag to store the scalars at\n        tag_scalar_dict : dict\n            a dictionary containing tags as keys and the corresponding scalar\n            values\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            main_tag=main_tag, tag_scalar_dict=tag_scalar_dict,\n            global_step=global_step, walltime=walltime)\n\n        self._writer.add_scalars(*converted_args, **converted_kwargs)\n\n    def _histogram(self, tag, values, global_step=None, bins='tensorflow',\n                   walltime=None):\n        \"\"\"\n        Function to create and log a histogram out of given values\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the histogram at\n        values : arraylike\n            an arraylike object containing the raw data to create a histogram\n            from; Must be convertible to numpy\n        global_step : int\n            global step\n        bins : str\n            string indicating the bins format\n        walltime :\n            the overall time\n\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, values=values, global_step=global_step, bins=bins)\n        self._writer.add_histogram(*converted_args, **converted_kwargs)\n\n    def _figure(self, tag, figure, global_step=None, close=True,\n                walltime=None):\n        \"\"\"\n        Function to log a ``matplotlib.pyplot`` figure\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the figure at\n        figure : :class:`matplotlib.pyplot.Figure``\n            the figure to log\n        global_step : int\n            the global step\n        close : bool\n            whether to close the figure after pushing it\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, figure=figure, global_step=global_step, close=close,\n            walltime=walltime)\n        self._writer.add_figure(*converted_args, **converted_kwargs)\n\n    def _audio(self, tag, snd_tensor, global_step=None, sample_rate=44100,\n               walltime=None):\n        \"\"\"\n        Function to log a single audio signal\n        Parameters\n        ----------\n        tag : str\n            the tag to store the sound signal at\n        snd_tensor : arraylike\n            arraylike object containing the sound signal;\n            must be convertible to numpy\n        global_step : int\n            the global step\n        sample_rate : int\n            the sampling rate for the sound signal\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, snd_tensor=snd_tensor, global_step=global_step,\n            sample_rate=sample_rate, walltime=walltime\n        )\n        self._writer.add_audio(*converted_args, **converted_kwargs)\n\n    def _text(self, tag, text_string, global_step=None, walltime=None):\n        \"\"\"\n        Function to log a single string as text\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the text at\n        text_string : str\n            the text string to log\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, text_string=text_string, global_step=global_step,\n            walltime=walltime)\n        self._writer.add_text(*converted_args, **converted_kwargs)\n\n    def _pr_curve(self, tag, labels, predictions, global_step=None,\n                  num_thresholds=127, weights=None, walltime=None):\n        \"\"\"\n        Function to create and log a PR curve out of given predictions and +\n        labels\n\n        Parameters\n        ----------\n        tag : str\n            function to store the curve at\n        labels : arraylike\n            arraylike object containing the groundtruth data; must be\n            convertible to numpy\n        predictions : arraylike\n            arraylike object containing the predictions; must be convertible\n            to numpy\n        global_step : int\n            the global step\n        num_thresholds : int\n            number of thresholds to apply for PR calculation\n        weights : arraylike\n            arraylike object containing sample weights, must be covertible to\n            numpy\n        walltime :\n            overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, labels=labels, predictions=predictions,\n            global_step=global_step, num_thresholds=num_thresholds,\n            weights=weights, walltime=walltime)\n        self._writer.add_pr_curve(*converted_args, **converted_kwargs)\n\n    def _video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):\n        \"\"\"\n        Function to log a single video\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        vid_tensor : arraylike\n            arraylike object containing the video frames; must be convertible\n            to numpy\n        global_step : int\n            the global step\n        fps : int\n            frames per second to display\n        walltime : int\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, vid_tensor=vid_tensor, global_step=global_step, fps=fps,\n            walltime=walltime)\n        self._writer.add_video(*converted_args, **converted_kwargs)\n\n    @property\n    def name(self):\n        return \"WriterBackend\"\n"
  },
  {
    "path": "delira/models/__init__.py",
    "content": "from delira.models.abstract_network import AbstractNetwork\nfrom delira.models.backends import *\n"
  },
  {
    "path": "delira/models/abstract_network.py",
    "content": "import abc\nimport logging\n\nfile_logger = logging.getLogger(__name__)\n\n\nclass AbstractNetwork(object):\n    \"\"\"\n    Abstract class all networks should be derived from\n\n    \"\"\"\n\n    _init_kwargs = {}\n\n    @abc.abstractmethod\n    def __init__(self, **kwargs):\n        \"\"\"\n        Init function to register init kwargs (should be called from all\n        subclasses)\n\n        Parameters\n        ----------\n        **kwargs\n            keyword arguments (will be registered to `self.init_kwargs`)\n\n        \"\"\"\n        super().__init__()\n        for key, val in kwargs.items():\n            self._init_kwargs[key] = val\n\n    @abc.abstractmethod\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        AbstractMethod to specify that each model should be able to be called\n        for predictions\n\n        Parameters\n        ----------\n        *args :\n            Positional arguments\n        **kwargs :\n            Keyword Arguments\n\n        Raises\n        ------\n        NotImplementedError\n            if not overwritten by subclass\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @staticmethod\n    @abc.abstractmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        Function which handles prediction from batch, logging, loss calculation\n        and optimizer step\n\n        Parameters\n        ----------\n        model : :class:`AbstractNetwork`\n            model to forward data through\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary containing all optimizers to perform parameter update\n        losses : dict\n            Functions or classes to calculate losses\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        kwargs : dict\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses)\n        dict\n            Arbitrary number of predictions\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten by subclass\n\n        \"\"\"\n        raise NotImplementedError()\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Converts a numpy batch of data and labels to suitable datatype and\n        pushes them to correct devices\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing the batch (must have keys 'data' and 'label'\n        input_device :\n            device for network inputs\n        output_device :\n            device for network outputs\n\n        Returns\n        -------\n        dict\n            dictionary containing all necessary data in right format and type\n            and on the correct device\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten by subclass\n\n        \"\"\"\n\n        raise NotImplementedError()\n\n    @property\n    def init_kwargs(self):\n        \"\"\"\n        Returns all arguments registered as init kwargs\n\n        Returns\n        -------\n        dict\n            init kwargs\n\n        \"\"\"\n        return self._init_kwargs\n"
  },
  {
    "path": "delira/models/backends/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"CHAINER\" in _get_backends():\n    from delira.models.backends.chainer import *\n\nif \"SKLEARN\" in _get_backends():\n    from delira.models.backends.sklearn import *\n\nif \"TF\" in _get_backends():\n    from delira.models.backends.tf_eager import *\n    from delira.models.backends.tf_graph import *\n\nif \"TORCH\" in _get_backends():\n    from delira.models.backends.torch import *\n    from delira.models.backends.torchscript import *\n"
  },
  {
    "path": "delira/models/backends/chainer/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"CHAINER\" in _get_backends():\n    from delira.models.backends.chainer.abstract_network import \\\n        AbstractChainerNetwork\n    from delira.models.backends.chainer.data_parallel import \\\n        DataParallelChainerNetwork\n    from delira.models.backends.chainer.data_parallel import \\\n        DataParallelChainerOptimizer\n    from delira.models.backends.chainer.data_parallel import \\\n        ParallelOptimizerUpdateModelParameters\n    from delira.models.backends.chainer.data_parallel import \\\n        ParallelOptimizerCumulateGradientsHook\n"
  },
  {
    "path": "delira/models/backends/chainer/abstract_network.py",
    "content": "import abc\nimport chainer\nimport numpy as np\n\nfrom delira.models.abstract_network import AbstractNetwork\n\n\n# Use this Mixin Class to set __call__ to None, because there is an\n# internal check inside chainer.Link.__call__ for other __call__ methods\n# of parent classes to be not None. If this would be the case,\n# this function would be executed instead of our forward\nclass ChainerMixin(AbstractNetwork):\n    __call__ = None\n\n\nclass AbstractChainerNetwork(chainer.Chain, ChainerMixin):\n    \"\"\"\n    Abstract Class for Chainer Networks\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        **kwargs :\n            keyword arguments of arbitrary number and type\n            (will be registered as ``init_kwargs``)\n\n        \"\"\"\n        chainer.Chain.__init__(self)\n        AbstractNetwork.__init__(self, **kwargs)\n\n    @abc.abstractmethod\n    def forward(self, *args, **kwargs) -> dict:\n        \"\"\"\n        Feeds Arguments through the network\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n\n        Returns\n        -------\n        dict\n            dictionary containing all computation results\n\n        \"\"\"\n        raise NotImplementedError\n\n    def __call__(self, *args, **kwargs) -> dict:\n        \"\"\"\n        Makes instances of this class callable.\n        Calls the ``forward`` method.\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n\n        Returns\n        -------\n        dict\n            dictionary containing all computation results\n\n        \"\"\"\n\n        return chainer.Chain.__call__(self, *args, **kwargs)\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them\n        to correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : chainer.backend.Device or string\n            device for network inputs\n        output_device : torch.device\n            device for network outputs\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on\n            correct device\n\n        \"\"\"\n        new_batch = {k: chainer.as_variable(v.astype(np.float32))\n                     for k, v in batch.items()}\n\n        for k, v in new_batch.items():\n            if k == \"data\":\n                device = input_device\n            else:\n                device = output_device\n\n            # makes modification inplace!\n            v.to_device(device)\n\n        return new_batch\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num, fold=0, **kwargs):\n        \"\"\"\n        default closure method to do a single training step;\n        Could be overwritten for more advanced models\n\n        Parameters\n        ----------\n        model : :class:`AbstractChainerNetwork`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters;\n            ignored here, just passed for compatibility reasons\n        losses : dict\n            dict holding the losses to calculate errors;\n            ignored here, just passed for compatibility reasons\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses; will always\n            be empty here)\n        dict\n            dictionary containing all predictions\n\n        \"\"\"\n        assert (optimizers and losses) or not optimizers, \\\n            \"Criterion dict cannot be emtpy, if optimizers are passed\"\n\n        loss_vals = {}\n        total_loss = 0\n\n        inputs = data_dict[\"data\"]\n        preds = model(inputs)\n\n        for key, crit_fn in losses.items():\n            _loss_val = crit_fn(preds[\"pred\"], data_dict[\"label\"])\n            loss_vals[key] = _loss_val.item()\n            total_loss += _loss_val\n\n        model.cleargrads()\n        total_loss.backward()\n        optimizers['default'].update()\n        for k, v in preds.items():\n            v.unchain()\n        return loss_vals, preds\n"
  },
  {
    "path": "delira/models/backends/chainer/data_parallel.py",
    "content": "from delira.models.backends.chainer.abstract_network import \\\n    AbstractChainerNetwork\nimport chainer\n\n\ndef _apply_scatter(inputs: chainer.Variable, target_devices: list,\n                   dim: int = 0):\n    \"\"\"\n    Scatters inputs to target devices; Slicing will be done against a given\n    dimension\n\n    Parameters\n    ----------\n    inputs : :class:`chainer.Variable`\n        the input variable to scatter\n    target_devices : list\n        the target devices to scatter to\n    dim : int\n        the dimension to use for slicing\n\n    Returns\n    -------\n    list\n        list of variable slices on correct devices\n\n    \"\"\"\n\n    def _slice_inputs(input_var, dim, num_dims, start, end, target_device):\n        \"\"\"\n        Slices the input variable along a given dimension from start to end\n        and pushes it to correct device\n\n        Parameters\n        ----------\n        input_var : :class:`chainer.Variable`\n            the variable to slice\n        dim : int\n            the dimension to slice along\n        num_dims : int\n            the dimensionality of ``input_var``\n        start : int\n            the start value for slicing (included)\n        end : int\n            the end value for slicing (excluded)\n        target_device: str or :class:`chainer.backend.Device`\n            the device to push to\n\n        Returns\n        -------\n        :class:`chainer.Variable`\n            the slice of the variable\n\n        \"\"\"\n        slc = [slice(None)] * num_dims\n        slc[dim] = slice(start, end)\n        sliced_var = input_var[slc]\n        sliced_var.to_device(target_device)\n        output_shape = list(input_var.shape)\n        output_shape[dim] = -1\n        return sliced_var.reshape(output_shape)\n\n    # create empty sliced input list\n    scattered_inputs = []\n\n    # calculate constant only once\n    num_devices = len(target_devices)\n    samples_per_device = inputs.shape[dim] // num_devices\n    num_dims = len(inputs.shape)\n\n    # iterate over number of devices and slice accordingly\n    # (exclude last device)\n    # iterating until the minimum of num_devices and inputs.shape[dim] -1\n    # ensures that if the batchsize is too small to be scattered across all\n    # devices, we will only scatter across as many devices as possible\n    for i in range(min(num_devices, inputs.shape[dim]) - 1):\n        start, end = i * samples_per_device, i + 1 * samples_per_device\n        scattered_inputs.append(_slice_inputs(inputs, dim,\n                                              num_dims, start, end,\n                                              target_devices[i]))\n\n    # all remaining samples (not yet sliced) are appended now\n    # (all samples used; will be pushed to last device later)\n    scattered_inputs.append(_slice_inputs(\n        inputs, dim, len(inputs.shape,),\n        (num_devices - 1) * samples_per_device,\n        inputs.shape[dim], target_devices[-1]))\n\n    return scattered_inputs\n\n\ndef _apply_gather(target_device, dim, *outputs):\n    for _output in outputs:\n        _output.to_device(target_device)\n\n    return chainer.functions.concat(outputs, dim)\n\n\ndef _scatter(inputs, target_devices: list, dim):\n    \"\"\"\n    Scatters all inputs across given target_devices\n\n    Parameters\n    ----------\n    inputs : Any\n    target_devices : list\n        list of devices to scatter to\n    dim : int\n        dimension to use for slicing\n\n    Returns\n    -------\n    list\n        list of scattered inputs\n\n    \"\"\"\n\n    def _scatter_map(inputs):\n        \"\"\"\n        Scatters all inputs across given target_devices\n\n        Parameters\n        ----------\n        inputs : Any\n\n        Returns\n        -------\n        list\n            list of scattered inputs\n\n        \"\"\"\n\n        # directly apply the scattering on variable\n        if isinstance(inputs, chainer.Variable):\n            return _apply_scatter(inputs, target_devices, dim)\n\n        # map _scatter_map recursively to all samples in tuple\n        if isinstance(inputs, tuple) and inputs:\n            return list(zip(*map(_scatter_map, inputs)))\n\n        # map _scatter_map recursively to all samples in list\n        if isinstance(inputs, list) and inputs:\n            return list(map(list, zip(*map(_scatter_map,\n                                           inputs))))\n\n        # map _scatter_map recursively to all samples in dict\n        if isinstance(inputs, dict) and inputs:\n            return list(map(type(inputs), zip(*map(_scatter_map,\n                                                   inputs.items()))))\n\n        # try to convert inputs to chainer variable first and afterwards\n        # apply _scatter_map again\n\n        try:\n            return _scatter_map(chainer.as_variable(inputs))\n        except TypeError:\n            return [inputs for targets in target_devices]\n\n    # After scatter_map is called, a scatter_map cell will exist. This cell\n    # has a reference to the actual function scatter_map, which has\n    # references to a closure that has a reference to the scatter_map cell\n    # (because the fn is recursive). To avoid this reference cycle, we set\n    # the function to None, clearing the cell\n\n    try:\n        return _scatter_map(inputs)\n    finally:\n        _scatter_map = None\n\n\ndef _gather(outputs, target_device, dim=0):\n    r\"\"\"\n    Gathers tensors from different GPUs on a specified device\n      (-1 means the CPU).\n    \"\"\"\n\n    def gather_map(outputs):\n        out = outputs[0]\n        if isinstance(out, chainer.Variable):\n            return _apply_gather(target_device, dim, *outputs)\n        if out is None:\n            return None\n        if isinstance(out, dict):\n            if not all((len(out) == len(d) for d in outputs)):\n                raise ValueError(\n                    'All dicts must have the same number of keys')\n\n            return type(out)(((k, gather_map([d[k] for d in outputs]))\n                              for k in out))\n        return type(out)(map(gather_map, zip(*outputs)))\n\n    # Recursive function calls like this create reference cycles.\n    # Setting the function to None clears the refcycle.\n    try:\n        return gather_map(outputs)\n    finally:\n        gather_map = None\n\n\nclass DataParallelChainerNetwork(AbstractChainerNetwork):\n    \"\"\"\n    A Wrapper around a :class:`AbstractChainerNetwork` instance to implement\n    parallel training by splitting the batches\n    \"\"\"\n\n    def __init__(self, module: AbstractChainerNetwork, devices: list,\n                 output_device=None,\n                 batch_dim=0):\n        \"\"\"\n\n        Parameters\n        ----------\n        module : :class:`AbstractChainerNetwork`\n            the module to wrap (will be replicated on all devices)\n        devices : list\n            a list containing the devices to use (either as strings or as\n            :class:`chainer.backend.Device`).\n        output_device : str or :class:`chainer.backend.Device`\n            The output device\n            Make sure, your labels are also on this device\n            for loss calculation!\n            If not specified, the second device of ``devices`` will be used\n            for output gathering.\n        batch_dim : int\n            the index of the batchdimension (usually 0, but can become\n            e.g. 1 in NLP tasks)\n\n        \"\"\"\n        super().__init__()\n\n        modules = [module.copy() for _ in devices]\n\n        for _module, _device in zip(modules, devices):\n            _module.to_device(_device)\n\n        with self.init_scope():\n            self.modules = chainer.ChainList(*modules)\n\n        self.devices = devices\n\n        if output_device is None:\n            output_device = devices[1]\n\n        self._output_device = output_device\n        assert self._output_device in self.devices\n        self._output_device_idx = self.devices.index(self._output_device)\n        self.dim = batch_dim\n\n    def forward(self, *args, **kwargs):\n        \"\"\"\n        Scatters the inputs (both positional and keyword arguments) across\n        all devices, feeds them through model replicas and re-builds\n        batches on output device\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n\n        Returns\n        -------\n        Any\n            combined output from all scattered models\n\n        \"\"\"\n        scattered_args, scattered_kwargs = self._scatter(args, kwargs,\n                                                         self.devices,\n                                                         self.dim)\n        predictions = []\n\n        for _args, _kwargs, _module in zip(scattered_args,\n                                           scattered_kwargs,\n                                           self.modules):\n\n            predictions.append(_module(*_args, **_kwargs))\n\n        predictions = self._gather(predictions, self.dim,\n                                   self._output_device)\n\n        return predictions\n\n    def params(self, include_uninit=True):\n        \"\"\"\n        Only the parameters of the module on the first device will actually\n        be updated, all the other parameters will be replicated by the\n        optimizer after an update\n\n        Parameters\n        ----------\n        include_uninit : bool\n\n        Returns\n        -------\n        a generator holding the root-modules parameters\n        \"\"\"\n        return self.modules[0].params(include_uninit)\n\n    @staticmethod\n    def _scatter(inputs, kwargs, target_devices: list, dim=0):\n        \"\"\"\n        Scatters all inputs (args and kwargs) to target devices and splits\n        along given dimension\n\n        Parameters\n        ----------\n        inputs : list or tuple\n            positional arguments\n        kwargs : dict\n            keyword arguments\n        target_devices : list\n            list of target device (either string or chainer.backend.Device)\n        dim : int\n            the dimension, which should be used for splitting the batch\n\n        Returns\n        -------\n        tuple\n            scattered positional arguments\n        tuple\n            scattered keyword arguments\n\n        \"\"\"\n\n        # scatter inputs if given\n        inputs = _scatter(inputs, target_devices, dim) if inputs else []\n        # scatter kwargs if given\n        kwargs = _scatter(kwargs, target_devices, dim) if kwargs else []\n\n        # extend lengths by empty tuples if necessary\n        if len(inputs) < len(kwargs):\n            inputs.extend([() for _ in range(len(kwargs) - len(inputs))])\n        elif len(kwargs) < len(inputs):\n            kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])\n\n        inputs = tuple(inputs)\n        kwargs = tuple(kwargs)\n\n        return inputs, kwargs\n\n    @staticmethod\n    def _gather(predictions, dim, target_device):\n        \"\"\"\n        Re-Builds batches on the target device\n\n        Parameters\n        ----------\n        predictions : list\n            list containing the predictions from all replicated models\n        dim : int\n            dimension to use for concatenating single predictions\n        target_device : str or chainer.backend.Device\n            the device, the re-built batch should lie on\n\n        Returns\n        -------\n        Any\n            the rebuild batch (lying on ``target_device``)\n\n        \"\"\"\n        return _gather(predictions, target_device, dim)\n\n    def cleargrads(self):\n        for module in self.modules:\n            module.cleargrads()\n\n    def zerograds(self):\n        for module in self.modules:\n            module.zerograds()\n\n    @property\n    def closure(self):\n        return self.modules[0].closure\n\n    @property\n    def prepare_batch(self):\n        return self.modules[0].prepare_batch\n\n\nclass ParallelOptimizerCumulateGradientsHook(object):\n    \"\"\"\n    A hook which sums up all replication's gradients in a\n    DataParallel-Scenario\n    \"\"\"\n\n    name = \"DataParallelCumulateGradients\"\n    call_for_each_param = False\n    timing = 'pre'\n\n    def __call__(self, optimizer: chainer.Optimizer):\n        \"\"\"\n        Summing up all parameters if the target is an instance of\n        ``DataParallel``\n\n        Parameters\n        ----------\n        optimizer : chainer.Optimizer\n            the optimizer holding the target, whoose gradients should be\n            summed across the replications\n\n        \"\"\"\n        if isinstance(optimizer.target, DataParallelChainerNetwork):\n            for module in optimizer.target.modules[1:]:\n                optimizer.target.modules[0].addgrads(module)\n\n\nclass ParallelOptimizerUpdateModelParameters(object):\n    \"\"\"\n    A hook to replicate all parameters from the root model, to all\n    model-replicas after the optimizer step\n    \"\"\"\n\n    name = \"DataParallelUpdateModelParams\"\n    call_for_each_param = False\n    timing = \"post\"\n\n    def __call__(self, optimizer: chainer.Optimizer):\n        if isinstance(optimizer.target, DataParallelChainerNetwork):\n            for module in optimizer.target.modules[1:]:\n                module.copyparams(optimizer.target.modules[0])\n\n\nclass DataParallelChainerOptimizer(chainer.Optimizer):\n    \"\"\"\n    An Optimizer-Wrapper to enable DataParallel. Basically this forwards\n    all functions to the interal optimizer, but registers the additional\n    hooks needed for DataParallel (namely\n    :class:`ParallelOptimizerUpdateModelParameters` as a post-update hook\n    and :class:`ParallelOptimizerCumulateGradientsHook` as a pre-update hook)\n\n    \"\"\"\n\n    def __init__(self, optimizer):\n        \"\"\"\n\n        Parameters\n        ----------\n        optimizer : :class:`chainer.Optimizer`\n            the optimizer to wrap\n\n        \"\"\"\n        if isinstance(optimizer, chainer.Optimizer):\n            self._optimizer = optimizer\n\n        else:\n            raise RuntimeError(\"Invalid optimizer class given: Expected \"\n                               \"instance of chainer.Optimizer, but got %s\"\n                               % optimizer.__class__.__name__)\n\n    @classmethod\n    def from_optimizer_class(cls, optim_cls, *args, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        optim_cls : subclass of :class:`chainer.Optimizer`\n            the optimizer to use internally\n        *args :\n            arbitrary positional arguments (will be used for\n            initialization of internally used optimizer)\n        **kwargs :\n            arbitrary keyword arguments (will be used for initialization\n            of internally used optimizer)\n\n        \"\"\"\n        if optim_cls is not None and issubclass(optim_cls,\n                                                chainer.Optimizer):\n            _optim = optim_cls(*args, **kwargs)\n        else:\n            raise RuntimeError(\"Invalid optimizer class given: Expected \"\n                               \"Subclass of chainer.Optimizer, but got %s\"\n                               % optim_cls.__name__)\n        return cls(_optim)\n\n    def setup(self, link):\n        \"\"\"\n        Calls the setup method of the internal optimizer and registers the\n        necessary grads for data-parallel behavior\n\n        Parameters\n        ----------\n        link : :class:`DataParallel`\n            the target, whoose parameters should be updated\n\n        \"\"\"\n        self._optimizer.setup(link)\n\n        self._optimizer.add_hook(ParallelOptimizerCumulateGradientsHook())\n        self._optimizer.add_hook(ParallelOptimizerUpdateModelParameters())\n\n    @property\n    def target(self):\n        return self._optimizer.target\n\n    @property\n    def epoch(self):\n        return self._optimizer.epoch\n\n    @property\n    def _pre_update_hooks(self):\n        return self._optimizer._pre_update_hooks\n\n    @property\n    def _loss_scale(self):\n        return self._optimizer._loss_scale\n\n    @property\n    def _loss_scale_max(self):\n        return self._optimizer._loss_scale_max\n\n    @property\n    def _loss_scaling_is_dynamic(self):\n        return self._optimizer._loss_scaling_is_dynamic\n\n    @property\n    def use_auto_new_epoch(self):\n        return self._optimizer.use_auto_new_epoch\n\n    @property\n    def update(self):\n        return self._optimizer.update\n\n    @property\n    def new_epoch(self):\n        return self._optimizer.new_epoch\n\n    @property\n    def add_hook(self):\n        return self._optimizer.add_hook\n\n    @property\n    def remove_hook(self):\n        return self._optimizer.remove_hook\n\n    @property\n    def call_hooks(self):\n        return self._optimizer.call_hooks\n\n    @property\n    def serialize(self):\n        return self._optimizer.serialize\n\n    @property\n    def loss_scaling(self):\n        return self._optimizer.loss_scaling\n\n    @property\n    def set_loss_scale(self):\n        return self._optimizer.set_loss_scale\n\n    @property\n    def check_nan_in_grads(self):\n        return self._optimizer.check_nan_in_grads\n\n    @property\n    def is_safe_to_update(self):\n        return self._optimizer.is_safe_to_update\n\n    @property\n    def update_loss_scale(self):\n        return self._optimizer.update_loss_scale\n"
  },
  {
    "path": "delira/models/backends/sklearn/__init__.py",
    "content": "from delira import get_backends as _get_backends\nif \"SKLEARN\" in _get_backends():\n    from delira.models.backends.sklearn.abstract_network import \\\n        SklearnEstimator\n"
  },
  {
    "path": "delira/models/backends/sklearn/abstract_network.py",
    "content": "from inspect import signature as get_signature\nfrom sklearn.base import BaseEstimator\n\nfrom delira.models.abstract_network import AbstractNetwork\n\n\nclass SklearnEstimator(AbstractNetwork):\n    \"\"\"\n    Wrapper Class to wrap all ``sklearn`` estimators and provide delira\n    compatibility\n    \"\"\"\n\n    def __init__(self, module: BaseEstimator):\n        \"\"\"\n\n        Parameters\n        ----------\n        module : :class:`sklearn.base.BaseEstimator`\n            the module to wrap\n        \"\"\"\n\n        super().__init__()\n\n        self.module = module\n\n        # forwards methods to self.module if necessary\n\n        for key in [\"fit\", \"partial_fit\", \"predict\"]:\n            if hasattr(self.module, key):\n                setattr(self, key, getattr(self.module, key))\n\n        # if estimator is build dynamically based on input, classes have to\n        # be passed at least at first time (we pass it every time), because\n        # not every class is present in  every batch\n        # variable is initialized here, but feeded during the training\n        if (self.iterative_training and \"classes\" in get_signature(\n                self.partial_fit).parameters):\n            self.classes = None\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Calls ``self.predict`` with args and kwargs\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n\n        Returns\n        -------\n        dict\n            dictionary containing the predictions under key 'pred'\n\n        \"\"\"\n        return {\"pred\": self.predict(*args, **kwargs)}\n\n    @property\n    def iterative_training(self):\n        \"\"\"\n        Property indicating, whether a the current module can be\n        trained iteratively (batchwise)\n\n        Returns\n        -------\n        bool\n            True: if current module can be trained iteratively\n            False: else\n\n        \"\"\"\n        return hasattr(self, \"partial_fit\")\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them to\n        correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : Any\n            device for module inputs (will be ignored here; just given for\n            compatibility)\n        output_device : Any\n            device for module outputs (will be ignored here; just given for\n            compatibility)\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on correct\n            device\n\n        \"\"\"\n\n        new_batch = {\"X\": batch[\"data\"].reshape(batch[\"data\"].shape[0], -1)}\n        if \"label\" in batch:\n            new_batch[\"y\"] = batch[\"label\"].ravel()\n\n        return new_batch\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        default closure method to do a single training step;\n        Could be overwritten for more advanced models\n\n        Parameters\n        ----------\n        model : :class:`SkLearnEstimator`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters;\n            ignored here, just passed for compatibility reasons\n        losses : dict\n            dict holding the losses to calculate errors;\n            ignored here, just passed for compatibility reasons\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses; will always\n            be empty here)\n        dict\n            dictionary containing all predictions\n\n        \"\"\"\n\n        if model.iterative_training:\n            fit_fn = model.partial_fit\n\n        else:\n            fit_fn = model.fit\n\n        if hasattr(model, \"classes\"):\n            # classes must be specified here, because not all classes\n            # must be present in each batch and some estimators are build\n            # dynamically\n            fit_fn(**data_dict, classes=model.classes)\n        else:\n            fit_fn(**data_dict)\n\n        preds = model(data_dict[\"X\"])\n\n        return {}, preds\n"
  },
  {
    "path": "delira/models/backends/tf_eager/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TF\" in _get_backends():\n    from delira.models.backends.tf_eager.abstract_network import \\\n        AbstractTfEagerNetwork\n    from delira.models.backends.tf_eager.data_parallel import \\\n        DataParallelTfEagerNetwork\n"
  },
  {
    "path": "delira/models/backends/tf_eager/abstract_network.py",
    "content": "import abc\nimport typing\nimport tensorflow as tf\nimport numpy as np\nfrom delira.models.abstract_network import AbstractNetwork\n\n\nclass AbstractTfEagerNetwork(AbstractNetwork, tf.keras.layers.Layer):\n    \"\"\"\n    Abstract Network for TF eager execution backend.\n    All models to use with this backend should be derived from this class\n\n    \"\"\"\n\n    def __init__(self, data_format=\"channels_first\", trainable=True,\n                 name=None, dtype=None, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        data_format : str\n            the accepted data format (default: 'channels_first')\n        trainable : wheter or not the model is trainable (default: True)\n        name : str\n            the network's name\n        dtype :\n            the dtype to use for the model's parameters\n        **kwargs :\n            additional keyword arguments (will be registered as\n            ``init_kwargs``)\n\n        \"\"\"\n        tf.keras.layers.Layer.__init__(self, trainable=trainable,\n                                       name=name, dtype=dtype)\n\n        AbstractNetwork.__init__(self, **kwargs)\n\n        self.data_format = data_format\n        self.device = \"/cpu:0\"\n\n    @abc.abstractmethod\n    def call(self, *args, **kwargs):\n        \"\"\"\n        Defines the model's forward pass\n\n        Parameters\n        ----------\n        *args :\n            arbitrary positional arguments\n        **kwargs :\n            arbbitrary keyword arguments\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten by subclass\n\n        \"\"\"\n        raise NotImplementedError\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Executes the modules forward pass\n\n        Parameters\n        ----------\n        *args :\n            arbitrary positional arguments\n        **kwargs :\n            arbitrary keyword arguments\n\n        \"\"\"\n\n        return self.call(*args, **kwargs)\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them to\n        correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : str\n            device for module inputs\n        output_device : str\n            device for module outputs\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on correct\n            device\n\n        \"\"\"\n        new_batch = {}\n        with tf.device(output_device):\n            new_batch[\"label\"] = tf.convert_to_tensor(\n                batch[\"label\"].astype(np.float32))\n\n        with tf.device(input_device):\n            for k, v in batch.items():\n                if k == \"label\":\n                    continue\n                new_batch[k] = tf.convert_to_tensor(v.astype(np.float32))\n\n        return new_batch\n\n    @staticmethod\n    def closure(model, data_dict: dict,\n                optimizers: typing.Dict[str, tf.train.Optimizer], losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        default closure method to do a single training step;\n        Could be overwritten for more advanced models\n\n        Parameters\n        ----------\n        model : :class:`SkLearnEstimator`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters;\n            ignored here, just passed for compatibility reasons\n        losses : dict\n            dict holding the losses to calculate errors;\n            ignored here, just passed for compatibility reasons\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses; will always\n            be empty here)\n        dict\n            dictionary containing all predictions\n\n        \"\"\"\n\n        loss_vals = {}\n\n        # calculate loss with graph created by gradient taping\n        with tf.GradientTape() as tape:\n            preds = model(data_dict[\"data\"])\n            total_loss = None\n            for k, loss_fn in losses.items():\n                _loss_val = loss_fn(preds[\"pred\"],\n                                    data_dict[\"label\"])\n                loss_vals[k] = _loss_val.numpy()\n                if total_loss is None:\n                    total_loss = _loss_val\n                else:\n                    total_loss += _loss_val\n\n        # calculate gradients\n        grads = tape.gradient(total_loss,\n                              model.trainable_variables)\n\n        # perform optimization step\n        optimizers[\"default\"].apply_gradients(\n            zip(grads, model.trainable_variables))\n\n        return loss_vals, preds\n"
  },
  {
    "path": "delira/models/backends/tf_eager/data_parallel.py",
    "content": "import tensorflow as tf\nfrom delira.models.backends.tf_eager.abstract_network import \\\n    AbstractTfEagerNetwork\n\n\nclass DataParallelTfEagerNetwork(AbstractTfEagerNetwork):\n    \"\"\"\n    DataParallel Module for the TF eager execution backend\n\n    Warnings\n    --------\n    This Module is highly experimental and not guaranteed to work properly!\n    \"\"\"\n\n    def __init__(self, module, devices):\n        \"\"\"\n\n        Parameters\n        ----------\n        module : :class:`AbstractTfEagerNetwork`\n            the module to scatter across different devices\n        devices : list\n            list of ints specifying the GPU indices\n        \"\"\"\n        super().__init__()\n\n        self._closure = module.closure\n        self._prepare_batch = module.pepare_batch\n\n        self.module = tf.keras.utils.multi_gpu_model(module, devices, True)\n\n    def call(self, *args, **kwargs):\n        \"\"\"\n        Defines the forward pass of the module\n\n        Parameters\n        ----------\n        *args :\n            arbitrary positional arguments\n        **kwargs :\n            arbitrary keyword arguments\n\n        \"\"\"\n        return self.module.call(*args, **kwargs)\n\n    @property\n    def closure(self):\n        return self._closure\n\n    @property\n    def prepare_batch(self):\n        return self._prepare_batch\n"
  },
  {
    "path": "delira/models/backends/tf_graph/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TF\" in _get_backends():\n    from delira.models.backends.tf_graph.abstract_network import \\\n        AbstractTfGraphNetwork\n"
  },
  {
    "path": "delira/models/backends/tf_graph/abstract_network.py",
    "content": "import abc\nimport logging\nimport tensorflow as tf\nimport numpy as np\n\nfrom delira.models.abstract_network import AbstractNetwork\n\n\nclass AbstractTfGraphNetwork(AbstractNetwork, metaclass=abc.ABCMeta):\n    \"\"\"\n    Abstract Class for Tf Networks\n\n    See Also\n    --------\n    :class:`AbstractNetwork`\n\n    \"\"\"\n\n    @abc.abstractmethod\n    def __init__(self, sess=tf.Session, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        **kwargs :\n            keyword arguments (are passed to :class:`AbstractNetwork`'s `\n            __init__ to register them as init kwargs\n\n        \"\"\"\n        AbstractNetwork.__init__(self, **kwargs)\n        self._sess = sess()\n        self.inputs = {}\n        self.outputs_train = {}\n        self.outputs_eval = {}\n        self._losses = None\n        self._optims = None\n        self.training = True\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Wrapper for calling self.run in eval setting\n\n        Parameters\n        ----------\n        *args :\n            positional arguments (passed to `self.run`)\n        **kwargs:\n            keyword arguments (passed to `self.run`)\n\n        Returns\n        -------\n        Any\n            result: module results of arbitrary type and number\n\n        \"\"\"\n        self.training = False\n        return self.run(*args, **kwargs)\n\n    def run(self, *args, **kwargs):\n        \"\"\"\n        Evaluates `self.outputs_train` or `self.outputs_eval` based on\n        `self.training`\n\n        Parameters\n        ----------\n        *args :\n            currently unused, exist for compatibility reasons\n        **kwargs :\n            kwargs used to feed as ``self.inputs``. Same keys as for\n            ``self.inputs`` must be used\n\n        Returns\n        -------\n        dict\n            sames keys as outputs_train or outputs_eval,\n            containing evaluated expressions as values\n\n        \"\"\"\n        _feed_dict = {}\n\n        for feed_key, feed_value in kwargs.items():\n            assert feed_key in self.inputs.keys(), \\\n                \"{} not found in self.inputs\".format(feed_key)\n            _feed_dict[self.inputs[feed_key]] = feed_value\n\n        if self.training:\n            return self._sess.run(self.outputs_train, feed_dict=_feed_dict)\n\n        return self._sess.run(self.outputs_eval, feed_dict=_feed_dict)\n\n    def _add_losses(self, losses: dict):\n        \"\"\"\n        Adds losses to model that are to be used by optimizers or\n        during evaluation. Can be overwritten for more advanced loss behavior\n\n        Parameters\n        ----------\n        losses : dict\n            dictionary containing all losses. Individual losses are averaged\n\n        \"\"\"\n        if self._losses is not None and losses:\n            logging.warning('Change of losses is not yet supported')\n            raise NotImplementedError()\n\n        elif self._losses is not None and not losses:\n            pass\n\n        else:\n            self._losses = {}\n            for name, _loss in losses.items():\n                self._losses[name] = _loss(self.inputs[\"label\"],\n                                           self.outputs_train[\"pred\"])\n\n            total_loss = tf.reduce_mean(list(self._losses.values()), axis=0)\n\n            self._losses['total'] = total_loss\n            self.outputs_train[\"losses\"] = self._losses\n            self.outputs_eval[\"losses\"] = self._losses\n\n    def _add_optims(self, optims: dict):\n        \"\"\"\n        Adds optims to model that are to be used by optimizers or during\n        training. Can be overwritten for more advanced optimizers\n\n        Parameters\n        ----------\n        optim: dict\n            dictionary containing all optimizers, optimizers should be of\n            Type[tf.train.Optimizer]\n\n        \"\"\"\n        if self._optims is not None and optims:\n            logging.warning('Change of optims is not yet supported')\n        elif self._optims is not None and not optims:\n            pass\n        else:\n            self._optims = optims['default']\n            grads = self._optims.compute_gradients(self._losses['total'])\n            step = self._optims.apply_gradients(grads)\n            self.outputs_train[\"default_step\"] = step\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them to\n        correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : Any\n            device for module inputs (will be ignored here; just given for\n            compatibility)\n        output_device : Any\n            device for module outputs (will be ignored here; just given for\n            compatibility)\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on correct\n            device\n\n        \"\"\"\n        return {k: v.astype(np.float32) for k, v in batch.items()}\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        default closure method to do a single training step;\n        Could be overwritten for more advanced models\n\n        Parameters\n        ----------\n        model : :class:`SkLearnEstimator`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters;\n            ignored here, just passed for compatibility reasons\n        losses : dict\n            dict holding the losses to calculate errors;\n            ignored here, just passed for compatibility reasons\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses; will always\n            be empty here)\n        dict\n            dictionary containing all predictions\n\n        \"\"\"\n\n        inputs = data_dict['data']\n\n        outputs = model.run(data=inputs, label=data_dict['label'])\n        loss_vals = outputs['losses']\n\n        return loss_vals, outputs\n"
  },
  {
    "path": "delira/models/backends/torch/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TORCH\" in _get_backends():\n    from delira.models.backends.torch.abstract_network import \\\n        AbstractPyTorchNetwork\n    from delira.models.backends.torch.data_parallel import \\\n        DataParallelPyTorchNetwork\n    from delira.models.backends.torch.utils import scale_loss\n"
  },
  {
    "path": "delira/models/backends/torch/abstract_network.py",
    "content": "import abc\nimport torch\nfrom delira.models.abstract_network import AbstractNetwork\n\nfrom delira.models.backends.torch.utils import scale_loss\n\n\nclass AbstractPyTorchNetwork(AbstractNetwork, torch.nn.Module):\n    \"\"\"\n    Abstract Class for PyTorch Networks\n\n    See Also\n    --------\n    `torch.nn.Module`\n    :class:`AbstractNetwork`\n\n    \"\"\"\n    @abc.abstractmethod\n    def __init__(self, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        **kwargs :\n            keyword arguments (are passed to :class:`AbstractNetwork`'s `\n            __init__ to register them as init kwargs\n\n        \"\"\"\n        torch.nn.Module.__init__(self)\n        AbstractNetwork.__init__(self, **kwargs)\n\n    @abc.abstractmethod\n    def forward(self, *inputs):\n        \"\"\"\n        Forward inputs through module (defines module behavior)\n        Parameters\n        ----------\n        inputs : list\n            inputs of arbitrary type and number\n\n        Returns\n        -------\n        Any\n            result: module results of arbitrary type and number\n\n        \"\"\"\n        raise NotImplementedError()\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Calls Forward method\n\n        Parameters\n        ----------\n        *args :\n            positional arguments (passed to `forward`)\n        **kwargs :\n            keyword arguments (passed to `forward`)\n\n        Returns\n        -------\n        Any\n            result: module results of arbitrary type and number\n\n        \"\"\"\n        return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them\n        to correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : torch.device\n            device for network inputs\n        output_device : torch.device\n            device for network outputs\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on\n            correct device\n\n        \"\"\"\n        return_dict = {\"data\": torch.from_numpy(batch[\"data\"]).to(\n            input_device).to(torch.float)}\n\n        for key, vals in batch.items():\n            if key == \"data\":\n                continue\n            return_dict[key] = torch.from_numpy(vals).to(output_device).to(\n                torch.float)\n\n        return return_dict\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        closure method to do a single backpropagation step\n\n        Parameters\n        ----------\n        model : :class:`AbstractPyTorchNetwork`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters\n        losses : dict\n            dict holding the losses to calculate errors\n            (gradients from different losses will be accumulated)\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses)\n        dict\n            Arbitrary number of predictions as numpy array\n\n        \"\"\"\n\n        loss_vals = {}\n        total_loss = 0\n\n        with torch.enable_grad():\n\n            # predict\n            inputs = data_dict[\"data\"]\n            preds = model(inputs)\n\n            # calculate losses\n            for key, crit_fn in losses.items():\n                _loss_val = crit_fn(preds[\"pred\"], data_dict[\"label\"])\n                loss_vals[key] = _loss_val.item()\n                total_loss += _loss_val\n\n            optimizers['default'].zero_grad()\n            # perform loss scaling via apex if half precision is enabled\n            with scale_loss(total_loss, optimizers[\"default\"]) as scaled_loss:\n                scaled_loss.backward()\n            optimizers['default'].step()\n\n        return loss_vals, {k: v.detach()\n                           for k, v in preds.items()}\n"
  },
  {
    "path": "delira/models/backends/torch/data_parallel.py",
    "content": "import torch\n\nfrom delira.models.backends.torch.abstract_network import \\\n    AbstractPyTorchNetwork\n\n\nclass DataParallelPyTorchNetwork(AbstractPyTorchNetwork,\n                                 torch.nn.DataParallel):\n    \"\"\"\n    A Wrapper around a :class:`AbstractPyTorchNetwork` instance to\n    implement parallel training by splitting the batches\n    \"\"\"\n\n    def __init__(self, module: AbstractPyTorchNetwork,\n                 device_ids=None, output_device=None, dim=0):\n        \"\"\"\n\n        Parameters\n        ----------\n        module : :class:`AbstractPyTorchNetwork`\n            the module to wrap (will be replicated on all devices)\n        device_ids : list\n            a list containing the devices to use (either as strings or as\n            :class:`chainer.backend.Device`).\n        output_device : str or :class:`chainer.backend.Device`\n            The output device\n            Make sure, your labels are also on this device\n            for loss calculation!\n            If not specified, the second device of ``devices`` will be used\n            for output gathering.\n        dim : int\n            the index of the batchdimension (usually 0, but can become\n            e.g. 1 in NLP tasks)\n\n        \"\"\"\n\n        AbstractPyTorchNetwork.__init__(self)\n        torch.nn.DataParallel.__init__(self, module, device_ids, output_device,\n                                       dim)\n\n    def forward(self, *args, **kwargs):\n        \"\"\"\n        Scatters the inputs (both positional and keyword arguments) across\n        all devices, feeds them through model replicas and re-builds\n        batches on output device\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n\n        Returns\n        -------\n        Any\n            combined output from all scattered models\n\n        \"\"\"\n        return torch.nn.DataParallel.forward(*args, **kwargs)\n\n    @property\n    def closure(self):\n        return self.module.closure\n\n    @property\n    def prepare_batch(self):\n        return self.module.prepare_batch\n"
  },
  {
    "path": "delira/models/backends/torch/utils.py",
    "content": "import contextlib\n\ntry:\n    # use apex loss scaling if possible\n    # (and enabled, this is done internally by apex)\n    from apex import amp\nexcept ImportError:\n    # use no loss scaling with same API if apex is unavailable\n    amp = None\n\n\n@contextlib.contextmanager\ndef scale_loss(loss,\n               optimizers,\n               loss_id=0,\n               model=None,\n               delay_unscale=False,\n               **kwargs):\n    \"\"\"\n    Context Manager which automatically switches between loss scaling via\n    apex.amp (if apex is available) and no loss scaling\n\n    Parameters\n    ----------\n    loss : :class:`torch.Tensor`\n        a pytorch tensor containing the loss value\n    optimizers : list\n        a list of :class:`torch.optim.Optimizer` containing all optimizers,\n        which are holding paraneters affected by the backpropagation of the\n        current loss value\n    loss_id : int\n        When used in conjunction with the ``num_losses`` argument\n        to ``amp.initialize``, enables Amp to use a different loss scale per\n        loss.  ``loss_id`` must be an integer between 0 and ``num_losses`` that\n        tells Amp which loss is being used for the current backward pass.\n        If ``loss_id`` is left unspecified, Amp will use the default global\n        loss scaler for this backward pass.\n    model : :class:`AbstractPyTorchNetwork` or None\n        Currently unused, reserved to enable future optimizations.\n    delay_unscale : bool\n        ``delay_unscale`` is never necessary, and the default value of\n        ``False`` is strongly recommended. If ``True``, Amp will not unscale\n        the gradients or perform model->master gradient copies on\n        context manager exit. ``delay_unscale=True`` is a minor ninja\n        performance optimization and can result\n        in weird gotchas (especially with multiple models/optimizers/losses),\n        so only use it if you know what you're doing.\n    **kwargs :\n        additional keyword arguments; currently unused, but provided for the\n        case amp decides to extend the functionality here\n\n    Yields\n    ------\n    :class:`torch.Tensor`\n        the new loss value (scaled if apex.amp is available and was configured\n        to do so, unscaled in all other cases)\n\n    \"\"\"\n\n    if amp is None:\n        yield loss\n\n    else:\n        with amp.scale_loss(loss=loss, optimizers=optimizers,\n                            loss_id=loss_id, model=model,\n                            delay_unscale=delay_unscale,\n                            **kwargs) as _loss:\n\n            yield _loss\n"
  },
  {
    "path": "delira/models/backends/torchscript/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TORCH\" in _get_backends():\n    from .abstract_network import AbstractTorchScriptNetwork\n"
  },
  {
    "path": "delira/models/backends/torchscript/abstract_network.py",
    "content": "import abc\nimport torch\nfrom delira.models.abstract_network import AbstractNetwork\n\n\nclass AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):\n    \"\"\"\n    Abstract Interface Class for TorchScript Networks. For more information\n    have a look at https://pytorch.org/docs/stable/jit.html#torchscript\n\n    Warnings\n    --------\n    In addition to the here defined API, a forward function must be\n    implemented and decorated with ``@torch.jit.script_method``\n\n    \"\"\"\n\n    @abc.abstractmethod\n    def __init__(self, optimize=True, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        optimize : bool\n            whether to optimize the network graph or not; default: True\n        **kwargs :\n            additional keyword arguments\n            (passed to :class:`AbstractNetwork`)\n        \"\"\"\n        torch.jit.ScriptModule.__init__(self, optimize=optimize)\n        AbstractNetwork.__init__(self, **kwargs)\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Calls Forward method\n\n        Parameters\n        ----------\n        *args :\n            positional arguments (passed to `forward`)\n        **kwargs :\n            keyword arguments (passed to `forward`)\n\n        Returns\n        -------\n        Any\n            result: module results of arbitrary type and number\n\n        \"\"\"\n        return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        \"\"\"\n        Helper Function to prepare Network Inputs and Labels (convert them\n        to correct type and shape and push them to correct devices)\n\n        Parameters\n        ----------\n        batch : dict\n            dictionary containing all the data\n        input_device : torch.device\n            device for network inputs\n        output_device : torch.device\n            device for network outputs\n\n        Returns\n        -------\n        dict\n            dictionary containing data in correct type and shape and on\n            correct device\n\n        \"\"\"\n        return_dict = {\"data\": torch.from_numpy(batch[\"data\"]).to(\n            input_device).to(torch.float)}\n\n        for key, vals in batch.items():\n            if key == \"data\":\n                continue\n            return_dict[key] = torch.from_numpy(vals).to(output_device).to(\n                torch.float)\n\n        return return_dict\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\n                iter_num: int, fold=0, **kwargs):\n        \"\"\"\n        closure method to do a single backpropagation step\n\n        Parameters\n        ----------\n        model : :class:`AbstractTorchScriptNetwork`\n            trainable model\n        data_dict : dict\n            dictionary containing the data\n        optimizers : dict\n            dictionary of optimizers to optimize model's parameters\n        losses : dict\n            dict holding the losses to calculate errors\n            (gradients from different losses will be accumulated)\n        iter_num: int\n            the number of of the current iteration in the current epoch;\n            Will be restarted at zero at the beginning of every epoch\n        fold : int\n            Current Fold in Crossvalidation (default: 0)\n        **kwargs:\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            Loss values (with same keys as input dict losses)\n        dict\n            Arbitrary number of predictions as numpy array\n\n        \"\"\"\n\n        loss_vals = {}\n        total_loss = 0\n\n        with torch.enable_grad():\n\n            # predict\n            inputs = data_dict[\"data\"]\n            preds = model(inputs)\n\n            # calculate losses\n            for key, crit_fn in losses.items():\n                _loss_val = crit_fn(preds[\"pred\"], data_dict[\"label\"])\n                loss_vals[key] = _loss_val.item()\n                total_loss += _loss_val\n\n            optimizers['default'].zero_grad()\n            # apex does not yet support torchscript\n            total_loss.backward()\n            optimizers['default'].step()\n\n        return loss_vals, {k: v.detach()\n                           for k, v in preds.items()}\n"
  },
  {
    "path": "delira/training/__init__.py",
    "content": "\nfrom delira.training.base_experiment import BaseExperiment\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.training.predictor import Predictor\n\nfrom delira.training.backends import *\n"
  },
  {
    "path": "delira/training/backends/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\n\nif \"CHAINER\" in _get_backends():\n    from delira.training.backends.chainer import *\n\nif \"SKLEARN\" in _get_backends():\n    from delira.training.backends.sklearn import *\n\nif \"TF\" in _get_backends():\n    from delira.training.backends.tf_graph import *\n    from delira.training.backends.tf_eager import *\n\nif \"TORCH\" in _get_backends():\n    from delira.training.backends.torch import *\n    from delira.training.backends.torchscript import *\n"
  },
  {
    "path": "delira/training/backends/chainer/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"CHAINER\" in _get_backends():\n    from delira.training.backends.chainer.trainer import ChainerNetworkTrainer\n    from delira.training.backends.chainer.experiment import ChainerExperiment\n    from delira.training.backends.chainer.utils import convert_to_numpy \\\n        as convert_chainer_to_numpy\n    from delira.training.backends.chainer.utils import create_optims_default \\\n        as create_chainer_optims_default\n"
  },
  {
    "path": "delira/training/backends/chainer/experiment.py",
    "content": "import typing\nfrom functools import partial\n\nfrom delira.models.backends.chainer import AbstractChainerNetwork\nfrom delira.data_loading import DataManager\nfrom delira.training.base_experiment import BaseExperiment\nfrom delira.utils import DeliraConfig\n\nfrom delira.training.backends.chainer.utils import create_optims_default\nfrom delira.training.backends.chainer.utils import convert_to_numpy\nfrom delira.training.backends.chainer.trainer import ChainerNetworkTrainer\n\n\nclass ChainerExperiment(BaseExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractChainerNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=create_optims_default,\n                 checkpoint_freq=1,\n                 trainer_cls=ChainerNetworkTrainer,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractChainerNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"x\": \"data\"} will be used here\n        val_score_key : str or None\n            key defining which metric to use for validation (determining\n            best model and scheduling lr); if None: No validation-based\n            operations will be done (model might still get validated,\n            but validation metrics can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers.\n            defaults to :func:`create_optims_default_chainer`\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`ChainerNetworkTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`ChainerNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if key_mapping is None:\n            key_mapping = {\"x\": \"data\"}\n        super().__init__(config=config, model_cls=model_cls,\n                         n_epochs=n_epochs, name=name, save_path=save_path,\n                         key_mapping=key_mapping,\n                         val_score_key=val_score_key,\n                         optim_builder=optim_builder,\n                         checkpoint_freq=checkpoint_freq,\n                         trainer_cls=trainer_cls,\n                         **kwargs)\n\n    def test(self, network: AbstractChainerNetwork,\n             test_data: DataManager,\n             metrics: dict, metric_keys=None,\n             verbose=False, prepare_batch=None,\n             convert_fn=convert_to_numpy, **kwargs):\n        \"\"\"\n        Setup and run testing on a given network\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the (trained) network to test\n        test_data : :class:`DataManager`\n            the data to use for testing\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label``\n             will be used for metric calculation\n        verbose : bool\n            verbosity of the test process\n        prepare_batch : function\n            function to convert a batch-dict to a format accepted by the\n            model. This conversion typically includes dtype-conversion,\n            reshaping, wrapping to backend-specific tensors and\n            pushing to correct devices. If not further specified uses the\n            ``network``'s ``prepare_batch`` with CPU devices\n        convert_fn : function\n            function to convert a batch of tensors to numpy\n            if not specified defaults to\n            :func:`convert_chainer_tensor_to_npy`\n\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions obtained by feeding the ``test_data`` through\n            the ``network``\n        dict\n            all metrics calculated upon the ``test_data`` and the obtained\n            predictions\n\n        \"\"\"\n\n        # use backend-specific and model-specific prepare_batch fn\n        # (runs on same device as passed network per default)\n\n        device = network.device\n        if prepare_batch is None:\n            prepare_batch = partial(network.prepare_batch,\n                                    input_device=device,\n                                    output_device=device)\n\n        return super().test(network=network, test_data=test_data,\n                            metrics=metrics, metric_keys=metric_keys,\n                            verbose=verbose, prepare_batch=prepare_batch,\n                            convert_fn=convert_fn, **kwargs)\n"
  },
  {
    "path": "delira/training/backends/chainer/trainer.py",
    "content": "from delira.training.backends.chainer.utils import convert_to_numpy\nfrom delira.training.backends.chainer.utils import create_optims_default\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\nfrom delira.io.chainer import load_checkpoint, save_checkpoint\nfrom delira.models.backends.chainer import AbstractChainerNetwork, \\\n    DataParallelChainerNetwork, \\\n    DataParallelChainerOptimizer\nfrom delira.training.base_trainer import BaseNetworkTrainer\nimport chainer\nfrom batchgenerators.dataloading import MultiThreadedAugmenter\nimport os\nimport logging\nfrom functools import partial\nlogger = logging.getLogger(__name__)\n\n\nclass ChainerNetworkTrainer(BaseNetworkTrainer):\n    \"\"\"\n    Train and Validate a Network\n\n    See Also\n    --------\n    :class:`AbstractNetwork`\n\n    \"\"\"\n\n    def __init__(self,\n                 network: AbstractChainerNetwork,\n                 save_path: str,\n                 key_mapping,\n                 losses=None,\n                 optimizer_cls=None,\n                 optimizer_params=None,\n                 metrics=None,\n                 lr_scheduler_cls=None,\n                 lr_scheduler_params=None,\n                 gpu_ids=None,\n                 save_freq=1,\n                 optim_fn=create_optims_default,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 mixed_precision=False,\n                 val_freq=1,\n                 ** kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractChainerNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of chainer.Optimizer\n            optimizer class implementing the optimization algorithm of\n            choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric\n            as target; default: None, which will result in key \"label\" for\n            all metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this\n            is a function, which detaches the tensor, moves it to cpu and\n            then calls ``.array`` on it\n        mixed_precision : bool\n            whether to use mixed precision or not (False per default)\n        val_freq : int\n            validation frequency specifying how often to validate the\n            trained model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        # prevent mutable defaults\n        if callbacks is None:\n            callbacks = []\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if gpu_ids is None:\n            gpu_ids = []\n        if lr_scheduler_params is None:\n            lr_scheduler_params = {}\n        if metrics is None:\n            metrics = {}\n        if optimizer_params is None:\n            optimizer_params = {}\n\n        super().__init__(network=network,\n                         save_path=save_path,\n                         losses=losses,\n                         optimizer_cls=optimizer_cls,\n                         optimizer_params=optimizer_params,\n                         metrics=metrics,\n                         lr_scheduler_cls=lr_scheduler_cls,\n                         lr_scheduler_params=lr_scheduler_params,\n                         gpu_ids=gpu_ids,\n                         save_freq=save_freq,\n                         optim_fn=optim_fn,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         **kwargs\n                         )\n\n        self._setup(network, optim_fn, optimizer_cls, optimizer_params,\n                    lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n                    key_mapping, convert_batch_to_npy_fn,\n                    mixed_precision, callbacks)\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,\n               lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n               key_mapping, convert_batch_to_npy_fn, mixed_precision,\n               callbacks):\n        \"\"\"\n        Defines the Trainers Setup\n\n        Parameters\n        ----------\n        network : :class:`AbstractChainerNetwork`\n            the network to train\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        optimizer_cls : subclass of torch.optim.Optimizer\n            optimizer class implementing the optimization algorithm of\n            choice\n        optimizer_params : dict\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        convert_batch_to_npy_fn : type\n            function converting a batch-tensor to numpy\n        mixed_precision : bool\n            whether to use mixed precision or not (False per default)\n        callbacks : list\n            initial callbacks to register\n\n        \"\"\"\n\n        self.optimizers = optim_fn(network, optimizer_cls,\n                                   **optimizer_params)\n\n        super()._setup(network, None, lr_scheduler_params,\n                       gpu_ids, key_mapping, convert_batch_to_npy_fn,\n                       network.prepare_batch, callbacks)\n\n        if mixed_precision:\n            # enable chainer mixed precision globally\n            chainer.global_config.dtype = chainer.mixed16\n\n        # Load latest epoch file if available\n        if os.path.isdir(self.save_path):\n            latest_state_path, latest_epoch = self._search_for_prev_state(\n                self.save_path)\n\n            if latest_state_path is not None:\n\n                # if pth file does not exist, load pt file instead\n                if not os.path.isfile(latest_state_path):\n                    latest_state_path = latest_state_path[:-1]\n\n                logger.info(\"Attempting to load state from previous \\\n                            training from %s\" % latest_state_path)\n\n                self.update_state(latest_state_path)\n                self.start_epoch = latest_epoch\n\n        if chainer.chainerx.is_available():\n            gpu_device_prefix = \"cuda:\"\n            cpu_device_prefix = \"native\"\n        else:\n            gpu_device_prefix = \"@cupy:\"\n            cpu_device_prefix = \"@numpy\"\n\n        if gpu_ids:\n            try:\n                if chainer.cuda.check_cuda_available():\n                    self.use_gpu = True\n                    if len(gpu_ids) > 1:\n                        # use GPU 0 as default input GPU\n\n                        self.input_device = chainer.get_device(\n                            gpu_device_prefix + str(gpu_ids[0]))\n\n                        # Train on multiple GPUs and use GPU 0 as output\n                        # device\n                        self.module = DataParallelChainerNetwork(\n                            self.module.to_device(\"@numpy\"),\n                            devices=[chainer.get_device(\n                                gpu_device_prefix + str(_id))\n                                for _id in gpu_ids])\n\n                        # ToDo: Creating Multiple DataParallelOptimizers is\n                        #  kinda tricky right now, since we need to add the\n                        #  class itself to the parameters and use\n                        #  DataParallelOptimizer as optimizer class.\n                        #  Should look for other possibility,\n                        #  but currently I don't know any\n                        self.optimizers = optim_fn(\n                            DataParallelChainerOptimizer,\n                            {**optimizer_params,\n                             \"optim_cls\": optimizer_cls})\n\n                        self.output_device = chainer.get_device(\n                            gpu_device_prefix + str(gpu_ids[0]))\n                    else:\n                        # use the only available GPU as input device\n                        self.input_device = chainer.get_device(\n                            cpu_device_prefix)\n                        self.module = self.module.to_device(\n                            self.input_device)\n\n                        # use GPU 0 as output device as output device\n                        self.output_device = chainer.get_device(\n                            cpu_device_prefix)\n                else:\n                    # cuda unavailable -> no GPU support\n                    self.use_gpu = False\n                    self.input_device = chainer.get_device(\n                        cpu_device_prefix)\n                    self.output_device = chainer.get_device(\n                        cpu_device_prefix)\n                    self.module = self.module.to_device(self.input_device)\n\n            # thrown if Cupy is unavailable -> no GPU support\n            except RuntimeError as e:\n                logging.exception(e)\n                self.use_gpu = False\n                self.input_device = chainer.get_device(cpu_device_prefix)\n                self.output_device = chainer.get_device(cpu_device_prefix)\n                self.module = self.module.to_device(self.input_device)\n\n        # no gpu indices given\n        else:\n            self.use_gpu = False\n            self.input_device = chainer.get_device(cpu_device_prefix)\n            self.output_device = chainer.get_device(cpu_device_prefix)\n            self.module = self.module.to_device(self.input_device)\n\n        self._prepare_batch = partial(\n            self._prepare_batch, input_device=self.input_device,\n            output_device=self.output_device)\n\n    def _at_training_begin(self, *args, **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of training\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        for cbck in self._callbacks:\n            self._update_state(cbck.at_training_begin(self, *args, **kwargs))\n\n        self.save_state(os.path.join(\n            self.save_path, \"checkpoint_epoch_%d\" % self.start_epoch),\n            self.start_epoch)\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines Behaviour at end of training: Loads best model if\n        available\n\n        Returns\n        -------\n        :class:`AbstractPyTorchNetwork`\n            best network\n\n        \"\"\"\n        if os.path.isfile(os.path.join(self.save_path,\n                                       'checkpoint_best.chain')):\n\n            # load best model and return it\n            self.update_state(os.path.join(self.save_path,\n                                           'checkpoint_best.chain'))\n\n        return super()._at_training_end(*args, **kwargs)\n\n    def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,\n                      **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of each epoch: Executes all\n        callbacks's `at_epoch_end` method and saves current state if\n        necessary\n\n        Parameters\n        ----------\n        metrics_val : dict\n            validation metrics\n        val_score_key : str\n            validation score key\n        epoch : int\n            current epoch\n        num_epochs : int\n            total number of epochs\n        is_best : bool\n            whether current model is best one so far\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n\n        for cb in self._callbacks:\n\n            self._update_state(cb.at_epoch_end(self,\n                                               val_metrics=metrics_val,\n                                               val_score_key=val_score_key,\n                                               curr_epoch=epoch))\n\n        if epoch % self.save_freq == 0:\n            self.save_state(\n                os.path.join(\n                    self.save_path,\n                    \"checkpoint_epoch_%d.chain\" %\n                    epoch),\n                epoch)\n\n        if is_best:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_best.chain\"),\n                            epoch)\n\n    def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch,\n                            verbose=False):\n        \"\"\"\n        Trains the network a single epoch\n\n        Parameters\n        ----------\n        batchgen : MultiThreadedAugmenter\n            Generator yielding the training batches\n        epoch : int\n            current epoch\n\n        \"\"\"\n\n        chainer.global_config.train = True\n\n        return super()._train_single_epoch(batchgen, epoch,\n                                           verbose=verbose)\n\n    def predict_data_mgr(self, datamgr, batchsize=None, metrics={},\n                         metric_keys={}, verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            predictions\n        dict\n            calculated metrics\n\n        \"\"\"\n        chainer.global_config.train = False\n\n        return super().predict_data_mgr(datamgr, batchsize, metrics,\n                                        metric_keys, verbose, **kwargs)\n\n    def save_state(self, file_name, epoch, **kwargs):\n        \"\"\"\n        saves the current state via\n        :func:`delira.io.chainer.save_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        epoch : int\n            current epoch (will be saved for mapping back)\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        if not file_name.endswith(\".chain\"):\n            file_name = file_name + \".chain\"\n        save_checkpoint(file_name, self.module, self.optimizers,\n                        **kwargs)\n\n    @staticmethod\n    def load_state(file_name, **kwargs):\n        \"\"\"\n        Loads the new state from file via\n        :func:`delira.io.chainer.load_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        **kwargs : keyword arguments\n\n        Returns\n        -------\n        dict\n            new state\n\n        \"\"\"\n\n        if not file_name.endswith(\".chain\"):\n            file_name = file_name + \".chain\"\n\n        return load_checkpoint(file_name, **kwargs)\n\n    def update_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        Update internal state from a loaded state\n\n        Parameters\n        ----------\n        file_name : str\n            file containing the new state to load\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n        self._update_state(self.load_state(file_name,\n                                           old_state={\n                                               \"model\": self.module,\n                                               \"optimizers\": self.optimizers},\n                                           **kwargs))\n\n    def _update_state(self, new_state):\n        \"\"\"\n        Update the state from a given new state\n\n        Parameters\n        ----------\n        new_state : dict\n            new state to update internal state from\n\n        Returns\n        -------\n        :class:`ChainerNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n\n        if \"model\" in new_state:\n            self.module = new_state.pop(\"model\")\n\n        if \"optimizers\" in new_state and new_state[\"optimizers\"]:\n            self.optimizers = new_state.pop(\"optimizers\")\n\n        if \"epoch\" in new_state:\n            self.start_epoch = new_state.pop(\"epoch\")\n\n        return super()._update_state(new_state)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latest checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".chain\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n"
  },
  {
    "path": "delira/training/backends/chainer/utils.py",
    "content": "import chainer\nfrom delira.models.backends.chainer import DataParallelChainerOptimizer\nfrom delira.training.utils import convert_to_numpy_identity, \\\n    recursively_convert_elements\n\n\ndef _single_element_tensor_conversion(element):\n    element.to_cpu()\n    return element.array\n\n\ndef convert_to_numpy(*args, **kwargs):\n    \"\"\"\n    Converts all chainer variables in args and kwargs to numpy array\n\n    Parameters\n    ----------\n    *args :\n        positional arguments of arbitrary number and type\n    **kwargs :\n        keyword arguments of arbitrary number and type\n\n    Returns\n    -------\n    list\n        converted positional arguments\n    dict\n        converted keyboard arguments\n    \"\"\"\n    args = recursively_convert_elements(args, chainer.Variable,\n                                        _single_element_tensor_conversion)\n\n    kwargs = recursively_convert_elements(kwargs, chainer.Variable,\n                                          _single_element_tensor_conversion)\n\n    return convert_to_numpy_identity(*args, **kwargs)\n\n\ndef create_optims_default(model, optim_cls, **optimizer_params):\n    \"\"\"\n    Default function to create a single optimizer for chainer\n    (also supports Data-Parallel)\n\n    Parameters\n    ----------\n    model : :class:`chainer.Link`\n        the model, which should be updated by the optimizer\n    optim_cls : type\n        the optimizer class implementing the actual parameter update\n    optimizer_params : dict\n        the params used for initializing an instance of ``optim_cls``\n\n    Returns\n    -------\n    dict\n        dictionary containing the created optimizer (key: \"default\")\n\n    \"\"\"\n    if issubclass(optim_cls, DataParallelChainerOptimizer):\n        optim = optim_cls.from_optimizer_class(**optimizer_params)\n\n    else:\n        optim = optim_cls(**optimizer_params)\n\n    optim = optim.setup(model)\n\n    return {\"default\": optim}\n"
  },
  {
    "path": "delira/training/backends/sklearn/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"SKLEARN\" in _get_backends():\n    from delira.training.backends.sklearn.trainer import \\\n        SklearnEstimatorTrainer\n    from delira.training.backends.sklearn.experiment import SklearnExperiment\n    from delira.training.backends.sklearn.utils import create_optims_default \\\n        as create_sklearn_optims_default\n"
  },
  {
    "path": "delira/training/backends/sklearn/experiment.py",
    "content": "from functools import partial\nimport typing\nimport os\n\nfrom sklearn.base import BaseEstimator\n\nfrom delira.models.backends.sklearn import SklearnEstimator\n\nfrom delira.training.base_experiment import BaseExperiment\nfrom delira.utils import DeliraConfig\n\nfrom delira.training.backends.sklearn.trainer import SklearnEstimatorTrainer\n\n\nclass SklearnExperiment(BaseExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: BaseEstimator,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 checkpoint_freq=1,\n                 trainer_cls=SklearnEstimatorTrainer,\n                 model_wrapper_cls=SklearnEstimator,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`sklearn.base.BaseEstimator`\n            the class implementing the model to train (will be wrapped by\n            :class:`SkLearnEstimator`)\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"X\": \"X\"} will be used here\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`SkLearnEstimatorTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`PyTorchNetworkTrainer`\n        model_wrapper_cls : subclass of :class:`SkLearnEstimator`\n            class wrapping the actual sklearn model to provide delira\n            compatibility\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if key_mapping is None:\n            key_mapping = {\"X\": \"X\"}\n\n        super().__init__(config=config,\n                         model_cls=model_cls,\n                         n_epochs=n_epochs,\n                         name=name,\n                         save_path=save_path,\n                         key_mapping=key_mapping,\n                         val_score_key=val_score_key,\n                         checkpoint_freq=checkpoint_freq,\n                         trainer_cls=trainer_cls,\n                         **kwargs)\n        self._model_wrapper_cls = model_wrapper_cls\n\n    def _setup_training(self, config, **kwargs):\n        \"\"\"\n            Handles the setup for training case\n\n            Parameters\n            ----------\n            config : :class:`DeliraConfig`\n                the config containing the model and training kwargs\n            **kwargs :\n                additional keyword arguments\n\n            Returns\n            -------\n            :class:`BaseNetworkTrainer`\n                the created trainer\n        \"\"\"\n        model_kwargs = config.model_params\n        model_kwargs = {**model_kwargs[\"variable\"], **model_kwargs[\"fixed\"]}\n\n        _model = self.model_cls(**model_kwargs)\n        model = self._model_wrapper_cls(_model)\n\n        training_params = config.training_params\n        metrics = training_params.nested_get(\"metrics\")\n\n        # necessary for resuming training from a given path\n        save_path = kwargs.pop(\"save_path\", os.path.join(\n            self.save_path,\n            \"checkpoints\",\n            \"run_%02d\" % self._run))\n\n        return self.trainer_cls(\n            estimator=model,\n            save_path=save_path,\n            key_mapping=self.key_mapping,\n            metrics=metrics,\n            save_freq=self.checkpoint_freq,\n            **kwargs\n        )\n\n    def _setup_test(self, config, model, convert_batch_to_npy_fn,\n                    prepare_batch_fn, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig`\n            the config containing the model and training kwargs\n            (ignored here, just passed for subclassing and unified API)\n        model : :class:`sklearn.base.BaseEstimator`\n            the model to test\n        convert_batch_to_npy_fn : function\n            function to convert a batch of tensors to numpy\n        prepare_batch_fn : function\n            function to convert a batch-dict to a format accepted by the\n            model. This conversion typically includes dtype-conversion,\n            reshaping, wrapping to backend-specific tensors and pushing to\n            correct devices\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`Predictor`\n                        the created predictor\n\n                    \"\"\"\n        if not isinstance(model, SklearnEstimator):\n            model = SklearnEstimator(model)\n\n        if prepare_batch_fn is None:\n            prepare_batch_fn = partial(model.prepare_batch,\n                                       input_device=\"cpu\",\n                                       output_device=\"cpu\")\n\n        return super()._setup_test(config, model, convert_batch_to_npy_fn,\n                                   prepare_batch_fn, **kwargs)\n"
  },
  {
    "path": "delira/training/backends/sklearn/trainer.py",
    "content": "from delira.training.backends.sklearn.utils import create_optims_default\nfrom delira.training.utils import convert_to_numpy_identity as \\\n    convert_to_numpy\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.io.sklearn import save_checkpoint, load_checkpoint\nfrom delira.models.backends.sklearn import SklearnEstimator\nfrom delira.data_loading import DataManager\nfrom delira.data_loading.sampler import RandomSamplerWithReplacement, \\\n    RandomSamplerNoReplacement\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\nimport os\nimport logging\nimport numpy as np\nfrom tqdm.auto import tqdm\nfrom functools import partial\n\nlogger = logging.getLogger(__name__)\n\n\nclass SklearnEstimatorTrainer(BaseNetworkTrainer):\n    \"\"\"\n    Train and Validate a ``sklearn`` estimator\n\n    See Also\n    --------\n    :class:`SkLearnEstimator`\n\n    \"\"\"\n\n    def __init__(self,\n                 estimator: SklearnEstimator,\n                 save_path: str,\n                 key_mapping,\n                 metrics=None,\n                 save_freq=1,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 val_freq=1,\n                 ** kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        estimator : :class:`SklearnEstimator`\n            the estimator to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric as\n            target; default: None, which will result in key \"label\" for all\n            metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            a function, returning the inputs without changing anything\n        val_freq : int\n            validation frequency specifying how often to validate the trained\n            model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        # prevent mutable defaults\n        if callbacks is None:\n            callbacks = []\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if metrics is None:\n            metrics = {}\n\n        super().__init__(network=estimator,\n                         save_path=save_path,\n                         losses={},\n                         optimizer_cls=None,\n                         optimizer_params={},\n                         metrics=metrics,\n                         lr_scheduler_cls=None,\n                         lr_scheduler_params={},\n                         gpu_ids=[],\n                         save_freq=save_freq,\n                         optim_fn=create_optims_default,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         **kwargs\n                         )\n\n        self._setup(estimator,\n                    key_mapping, convert_batch_to_npy_fn, callbacks)\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def _setup(self, estimator, key_mapping, convert_batch_to_npy_fn,\n               callbacks):\n        \"\"\"\n        Defines the Trainers Setup\n\n        Parameters\n        ----------\n        estimator : :class:`SkLearnEstimator`\n            the network to train\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        convert_batch_to_npy_fn : type\n            function converting a batch-tensor to numpy\n        callbacks : list\n            initial callbacks to register\n\n        \"\"\"\n\n        self.optimizers = create_optims_default()\n\n        super()._setup(estimator, None, {},\n                       [], key_mapping, convert_batch_to_npy_fn,\n                       estimator.prepare_batch, callbacks)\n\n        # Load latest epoch file if available\n        if os.path.isdir(self.save_path):\n            # check all files in directory starting with \"checkpoint\" and\n            # not ending with \"_best.pth\"\n            latest_state_path, latest_epoch = self._search_for_prev_state(\n                self.save_path)\n\n            # if list is not empty: load previous state\n            if latest_state_path is not None:\n                self.update_state(latest_state_path)\n\n                self.start_epoch = latest_epoch\n\n        self.use_gpu = False\n        self.input_device = \"cpu\"\n        self.output_device = \"cpu\"\n\n        self._prepare_batch = partial(\n            self._prepare_batch, input_device=self.input_device,\n            output_device=self.output_device)\n\n    def _at_training_begin(self, *args, **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of training\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        for cbck in self._callbacks:\n            self._update_state(cbck.at_training_begin(self, *args, **kwargs))\n\n        self.save_state(os.path.join(\n            self.save_path, \"checkpoint_epoch_%d\" % self.start_epoch),\n            self.start_epoch)\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines Behaviour at end of training: Loads best model if\n        available\n\n        Returns\n        -------\n        :class:`SkLearnEstimator`\n            best network\n\n        \"\"\"\n        if os.path.isfile(os.path.join(self.save_path,\n                                       'checkpoint_best.pkl')):\n\n            # load best model and return it\n            self.update_state(os.path.join(self.save_path,\n                                           'checkpoint_best.pkl'))\n\n        return super()._at_training_end(*args, **kwargs)\n\n    def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,\n                      **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of each epoch: Executes all callbacks's\n        `at_epoch_end` method and saves current state if necessary\n\n        Parameters\n        ----------\n        metrics_val : dict\n            validation metrics\n        val_score_key : str\n            validation score key\n        epoch : int\n            current epoch\n        num_epochs : int\n            total number of epochs\n        is_best : bool\n            whether current model is best one so far\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n\n        for cb in self._callbacks:\n            self._update_state(cb.at_epoch_end(self,\n                                               val_metrics=metrics_val,\n                                               val_score_key=val_score_key,\n                                               curr_epoch=epoch))\n\n        if epoch % self.save_freq == 0:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_epoch_%d.pkl\"\n                                         % epoch),\n                            epoch)\n\n        if is_best:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_best.pkl\"),\n                            epoch)\n\n    def _get_classes_if_necessary(self, dmgr: DataManager, verbose,\n                                  label_key=None):\n        \"\"\"\n        Checks if available classes have to be collected before starting\n        the training to dynamically build the estimator (not all batches\n        contain all classes) and collects them if necessary\n\n        Parameters\n        ----------\n        dmgr : :class:`DataManager`\n            the datamanager to collect the classes from\n        verbose : bool\n            verbosity\n        label_key : str or None\n            the key corresponding to the target value inside the data dict\n\n        \"\"\"\n\n        if label_key is None or not hasattr(self.module, \"classes\"):\n            return\n        dset = dmgr.dataset\n\n        if verbose:\n            iterable = tqdm(enumerate(dset), unit=' sample', total=len(\n                dset), desc=\"Creating unique targets to estimate \" \"classes\")\n\n        else:\n            iterable = enumerate(dset)\n\n        unique_targets = []\n\n        # iterate over dataset\n        for sample_idx, sample in iterable:\n            item = sample[label_key]\n            if item not in unique_targets:\n\n                # convert item if necessary\n                if np.isscalar(item):\n                    item = np.array([item])\n                unique_targets.append(item)\n\n        # sort and concatenate items and feed variable inside the module\n        unique_targets = np.concatenate(list(sorted(unique_targets)))\n        self.module.classes = unique_targets\n\n    def train(self, num_epochs, datamgr_train, datamgr_valid=None,\n              val_score_key=None, val_score_mode='highest',\n              reduce_mode='mean', verbose=True, label_key=\"label\"):\n        \"\"\"\n        Defines a routine to train a specified number of epochs\n\n        Parameters\n        ----------\n        num_epochs : int\n            number of epochs to train\n        datamgr_train : DataManager\n            the datamanager holding the train data\n        datamgr_valid : DataManager\n            the datamanager holding the validation data (default: None)\n        val_score_key : str\n            the key specifying which metric to use for validation\n            (default: None)\n        val_score_mode : str\n            key specifying what kind of validation score is best\n        reduce_mode : str\n            'mean','sum','first_only'\n        verbose : bool\n            whether to show progress bars or not\n        label_key : str or None\n            key specifiying the value inside the batch dict to use for\n            class collection if necessary\n\n        Raises\n        ------\n        NotImplementedError\n            If not overwritten by subclass\n\n        \"\"\"\n        if self.module.iterative_training:\n\n            # estimate classes from validation data\n            if datamgr_valid is not None:\n                self._get_classes_if_necessary(datamgr_valid, verbose,\n                                               label_key)\n            else:\n                self._get_classes_if_necessary(datamgr_train, verbose,\n                                               label_key)\n        else:\n            # Setting batchsize to length of dataset and replacing random\n            # sampler_old with replacement by random sampler_old without\n            # replacement ensures, that each sample is present in each\n            # batch and only one batch is sampled per epoch\n            datamgr_train.batchsize = len(datamgr_train.dataset)\n            if issubclass(datamgr_train.sampler_cls,\n                          RandomSamplerWithReplacement):\n                datamgr_train.sampler_cls = RandomSamplerNoReplacement\n\n            # additionally setting the number of epochs to train ensures,\n            # that only one epoch consisting of one batch (which holds the\n            # whole dataset) is used for training\n            if num_epochs > 1:\n\n                logging.info(\n                    \"An epoch number greater than 1 is given, \"\n                    \"but the current module does not support \"\n                    \"iterative training. Falling back to usual \"\n                    \"dataset fitting. For huge datasets, this \"\n                    \"might easily result in out of memory errors!\")\n\n                num_epochs = 1\n\n        return super().train(num_epochs, datamgr_train, datamgr_valid,\n                             val_score_key, val_score_mode, reduce_mode,\n                             verbose)\n\n    def save_state(self, file_name, epoch, **kwargs):\n        \"\"\"\n        saves the current state via\n        :func:`delira.io.sklearn.save_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        epoch : int\n            current epoch (will be saved for mapping back)\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        if not file_name.endswith(\".pkl\"):\n            file_name = file_name + \".pkl\"\n        save_checkpoint(file_name, self.module, epoch, **kwargs)\n\n    @staticmethod\n    def load_state(file_name, *args, **kwargs):\n        \"\"\"\n        Loads the new state from file via\n        :func:`delira.io.sklearn.load_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        **kwargs : keyword arguments\n\n        Returns\n        -------\n        dict\n            new state\n\n        \"\"\"\n\n        if not file_name.endswith(\".pkl\"):\n            file_name = file_name + \".pkl\"\n\n        return load_checkpoint(file_name, **kwargs)\n\n    def _update_state(self, new_state):\n        \"\"\"\n        Update the state from a given new state\n\n        Parameters\n        ----------\n        new_state : dict\n            new state to update internal state from\n\n        Returns\n        -------\n        :class:`SkLearnEstimatorTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n\n        if \"model\" in new_state:\n            self.module = new_state.pop(\"model\")\n\n        if \"epoch\" in new_state:\n            self.start_epoch = new_state.pop(\"epoch\")\n\n        return super()._update_state(new_state)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latst checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".pkl\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n\n    @staticmethod\n    def calc_metrics(batch, metrics: dict = None, metric_keys=None):\n        if metrics is None:\n            metrics = {}\n\n        if metric_keys is None:\n            metric_keys = {k: (\"pred\", \"y\") for k in metrics.keys()}\n\n        return BaseNetworkTrainer.calc_metrics(batch, metrics, metric_keys)\n"
  },
  {
    "path": "delira/training/backends/sklearn/utils.py",
    "content": "def create_optims_default(*args, **kwargs):\n    \"\"\"\n    Function returning an empty optimizer dict\n\n    Parameters\n    ----------\n    *args :\n        arbitrary positional arguments (ignored; only provided for api\n        conformity)\n    **kwargs :\n        arbitrary keyword arguments (ignored; only provided for api conformity)\n\n    Returns\n    -------\n    dict\n        empty dictionary\n\n    \"\"\"\n    return {}\n"
  },
  {
    "path": "delira/training/backends/tf_eager/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TF\" in _get_backends():\n    from delira.training.backends.tf_eager.experiment import TfEagerExperiment\n    from delira.training.backends.tf_eager.trainer import TfEagerNetworkTrainer\n    from delira.training.backends.tf_eager.utils import convert_to_numpy \\\n        as convert_tfeager_to_numpy\n    from delira.training.backends.tf_eager.utils import create_optims_default \\\n        as create_tfeager_optims_default\n"
  },
  {
    "path": "delira/training/backends/tf_eager/experiment.py",
    "content": "import typing\nfrom functools import partial\n\nimport tensorflow as tf\n\nfrom delira.data_loading import DataManager\nfrom delira.models.backends.tf_eager import AbstractTfEagerNetwork\n\nfrom delira.training.base_experiment import BaseExperiment\nfrom delira.utils import DeliraConfig\n\nfrom delira.training.backends.tf_eager.trainer import TfEagerNetworkTrainer\nfrom delira.training.backends.tf_eager.utils import create_optims_default\nfrom delira.training.backends.tf_eager.utils import convert_to_numpy\n\n\nclass TfEagerExperiment(BaseExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractTfEagerNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=create_optims_default,\n                 checkpoint_freq=1,\n                 trainer_cls=TfEagerNetworkTrainer,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractTfEagerNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"x\": \"data\"} will be used\n            here\n        val_score_key : str or None\n            key defining which metric to use for validation (determining\n            best model and scheduling lr); if None: No validation-based\n            operations will be done (model might still get validated,\n            but validation metrics can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers.\n            defaults to :func:`create_optims_default`\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`TfEagerNetworkTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`TfNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if key_mapping is None:\n            key_mapping = {\"x\": \"data\"}\n        super().__init__(config=config, model_cls=model_cls,\n                         n_epochs=n_epochs, name=name, save_path=save_path,\n                         key_mapping=key_mapping,\n                         val_score_key=val_score_key,\n                         optim_builder=optim_builder,\n                         checkpoint_freq=checkpoint_freq,\n                         trainer_cls=trainer_cls,\n                         **kwargs)\n\n    def kfold(self, data: DataManager, metrics: dict, num_epochs=None,\n              num_splits=None, shuffle=False, random_seed=None,\n              split_type=\"random\", val_split=0.2, label_key=\"label\",\n              train_kwargs: dict = None, test_kwargs: dict = None,\n              metric_keys: dict = None, config=None, verbose=False,\n              **kwargs):\n        \"\"\"\n        Performs a k-Fold cross-validation\n\n        Parameters\n        ----------\n        data : :class:`DataManager`\n            the data to use for training(, validation) and testing. Will be\n            split based on ``split_type`` and ``val_split``\n        metrics : dict\n            dictionary containing the metrics to evaluate during k-fold\n        num_epochs : int or None\n            number of epochs to train (if not given, will either be\n            extracted from ``config``, ``self.config`` or ``self.n_epochs``)\n        num_splits : int or None\n            the number of splits to extract from ``data``.\n            If None: uses a default of 10\n        shuffle : bool\n            whether to shuffle the data before splitting or not\n            (implemented by index-shuffling rather than actual\n            data-shuffling to retain potentially lazy-behavior of datasets)\n        random_seed : None\n            seed to seed numpy, the splitting functions and the used\n            backend-framework\n        split_type : str\n            must be one of ['random', 'stratified']\n            if 'random': uses random data splitting\n            if 'stratified': uses stratified data splitting. Stratification\n            will be based on ``label_key``\n        val_split : float or None\n            the fraction of the train data to use as validation set.\n            If None: No validation will be done during training; only\n            testing for each fold after the training is complete\n        label_key : str\n            the label to use for stratification. Will be ignored unless\n            ``split_type`` is 'stratified'. Default: 'label'\n        train_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the train data. If None: empty dict will be passed\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label`` will be used for metric calculation\n        test_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the test and validation data.\n            If None: empty dict will be passed\n        config : :class:`DeliraConfig` or None\n            the training and model parameters\n            (will be merged with ``self.config``)\n        verbose : bool\n            verbosity\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions from all folds\n        dict\n            all metric values from all folds\n\n        Raises\n        ------\n        ValueError\n            if ``split_type`` is neither 'random', nor 'stratified'\n\n        See Also\n        --------\n\n        * :class:`sklearn.model_selection.KFold`\n        and :class:`sklearn.model_selection.ShuffleSplit`\n        for random data-splitting\n\n        * :class:`sklearn.model_selection.StratifiedKFold`\n        and :class:`sklearn.model_selection.StratifiedShuffleSplit`\n        for stratified data-splitting\n\n        * :meth:`DataManager.update_from_state_dict` for updating the\n        data managers by kwargs\n\n        * :meth:`BaseExperiment.run` for the training\n\n        * :meth:`BaseExperiment.test` for the testing\n\n        Notes\n        -----\n        using stratified splits may be slow during split-calculation, since\n        each item must be loaded once to obtain the labels necessary for\n        stratification.\n\n        \"\"\"\n\n        # seed tf backend\n        if random_seed is not None:\n            tf.set_random_seed(random_seed)\n\n        return super().kfold(\n            data=data,\n            metrics=metrics,\n            num_epochs=num_epochs,\n            num_splits=num_splits,\n            shuffle=shuffle,\n            random_seed=random_seed,\n            split_type=split_type,\n            val_split=val_split,\n            label_key=label_key,\n            train_kwargs=train_kwargs,\n            test_kwargs=test_kwargs,\n            metric_keys=metric_keys,\n            config=config,\n            verbose=verbose,\n            **kwargs)\n\n    def test(self, network, test_data: DataManager,\n             metrics: dict, metric_keys=None,\n             verbose=False, prepare_batch=lambda x: x,\n             convert_fn=None, **kwargs):\n        \"\"\"\n        Setup and run testing on a given network\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the (trained) network to test\n        test_data : :class:`DataManager`\n            the data to use for testing\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label``\n             will be used for metric calculation\n        verbose : bool\n            verbosity of the test process\n        prepare_batch : function\n            function to convert a batch-dict to a format accepted by the\n            model. This conversion typically includes dtype-conversion,\n            reshaping, wrapping to backend-specific tensors and\n            pushing to correct devices. If not further specified uses the\n            ``network``'s ``prepare_batch`` with CPU devices\n        convert_fn : function\n            function to convert a batch of tensors to numpy\n            if not specified defaults to\n            :func:`convert_torch_tensor_to_npy`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions obtained by feeding the ``test_data`` through\n            the ``network``\n        dict\n            all metrics calculated upon the ``test_data`` and the obtained\n            predictions\n\n        \"\"\"\n        # specify convert_fn to correct backend function\n        if convert_fn is None:\n            convert_fn = convert_to_numpy\n\n        if prepare_batch is None:\n            prepare_batch = partial(\n                network.prepare_batch,\n                input_device=\"/cpu:0\",\n                output_device=\"/cpu:0\")\n\n        return super().test(network=network, test_data=test_data,\n                            metrics=metrics, metric_keys=metric_keys,\n                            verbose=verbose, prepare_batch=prepare_batch,\n                            convert_fn=convert_fn, **kwargs)\n\n    def setup(self, config, training=True, **kwargs):\n        \"\"\"\n        Defines the setup behavior (model, trainer etc.) for training and\n        testing case\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig`\n            the parameters to use for setup\n        training : bool\n            whether to setup for training case or for testing case\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the created trainer (if ``training=True``)\n        :class:`Predictor`\n            the created predictor (if ``training=False``)\n\n        See Also\n        --------\n\n        * :meth:`BaseExperiment._setup_training` for training setup\n\n        * :meth:`BaseExperiment._setup_test` for test setup\n\n        \"\"\"\n        tf.reset_default_graph()\n        return super().setup(config=config, training=training,\n                             **kwargs)\n"
  },
  {
    "path": "delira/training/backends/tf_eager/trainer.py",
    "content": "from delira.training.backends.tf_eager.utils import create_optims_default\nfrom delira.training.backends.tf_eager.utils import convert_to_numpy\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.io.tf import save_checkpoint_eager, load_checkpoint_eager\nfrom delira.models.backends.tf_eager import AbstractTfEagerNetwork, \\\n    DataParallelTfEagerNetwork\n\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\nimport logging\nimport os\nfrom functools import partial\n\nimport tensorflow as tf\n\nlogger = logging.getLogger(__name__)\n\n\nclass TfEagerNetworkTrainer(BaseNetworkTrainer):\n    def __init__(self,\n                 network: AbstractTfEagerNetwork,\n                 save_path: str,\n                 key_mapping: dict,\n                 losses: dict,\n                 optimizer_cls,\n                 optimizer_params=None,\n                 metrics=None,\n                 lr_scheduler_cls=None,\n                 lr_scheduler_params=None,\n                 gpu_ids=None,\n                 save_freq=1,\n                 optim_fn=create_optims_default,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 val_freq=1,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractTfEagerNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric as\n            target; default: None, which will result in key \"label\" for all\n            metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            the identity function\n        val_freq : int\n            validation frequency specifying how often to validate the trained\n            model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            Additional keyword arguments\n\n        \"\"\"\n\n        # prevent mutable default arguments\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if callbacks is None:\n            callbacks = []\n        if gpu_ids is None:\n            gpu_ids = []\n        if lr_scheduler_params is None:\n            lr_scheduler_params = {}\n        if metrics is None:\n            metrics = {}\n        if optimizer_params is None:\n            optimizer_params = {}\n\n        # check if eager execution is enabled\n        assert tf.executing_eagerly()\n\n        super().__init__(network=network,\n                         save_path=save_path,\n                         losses=losses,\n                         optimizer_cls=optimizer_cls,\n                         optimizer_params=optimizer_params,\n                         metrics=metrics,\n                         lr_scheduler_cls=lr_scheduler_cls,\n                         lr_scheduler_params=lr_scheduler_params,\n                         gpu_ids=gpu_ids,\n                         save_freq=save_freq,\n                         optim_fn=optim_fn,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         **kwargs\n                         )\n\n        self._setup(network, optim_fn, optimizer_cls, optimizer_params,\n                    lr_scheduler_cls, lr_scheduler_params,\n                    key_mapping, convert_batch_to_npy_fn, gpu_ids, callbacks)\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,\n               lr_scheduler_cls, lr_scheduler_params, key_mapping,\n               convert_batch_to_npy_fn, gpu_ids, callbacks):\n        \"\"\"\n        Defines the Trainers Setup\n\n        Parameters\n        ----------\n        network : instance of :class: `AbstractTfNetwork`\n            the network to train\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of choice\n        optimizer_params : dict\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            the identity function\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        callbacks : list\n            initial callbacks to register\n\n        Raises\n        ------\n        RuntimeError\n            if multiple GPU ids passed\n        \"\"\"\n\n        if gpu_ids and tf.test.is_gpu_available():\n            self.use_gpu = True\n            if len(gpu_ids) > 1:\n                raise RuntimeError(\"Multiple GPUs not yet supported\")\n                # logger.warning(\n                #     \"multi-GPU training not yet tested!\")\n\n                # network = DataParallelTfEagerNetwork(network, gpu_ids)\n                #\n                # self.input_device = \"/cpu:0\"\n                # self.output_device = \"/cpu:0\"\n            else:\n                self.input_device = \"/gpu:%d\" % gpu_ids[0]\n                self.output_device = \"/gpu:%d\" % gpu_ids[0]\n        else:\n            self.use_gpu = False\n            self.input_device = \"/cpu:0\"\n            self.output_device = \"/cpu:0\"\n\n        self.optimizers = optim_fn(optimizer_cls, **optimizer_params)\n\n        super()._setup(network, lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n                       key_mapping, convert_batch_to_npy_fn,\n                       network.prepare_batch, callbacks)\n        self._prepare_batch = partial(self._prepare_batch,\n                                      input_device=self.input_device,\n                                      output_device=self.output_device)\n\n        # Load latest epoch file if available\n        if os.path.isdir(self.save_path):\n            # check all files in directory starting with \"checkpoint\" and\n            # not ending with \"_best.meta\"\n            latest_state_path, latest_epoch = self._search_for_prev_state(\n                self.save_path\n            )\n\n            if latest_state_path is not None:\n                logger.info(\"Attempting to load state from previous \\\n                                training from %s\" % latest_state_path)\n\n                self.update_state(latest_state_path)\n                self.start_epoch = latest_epoch\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines Behaviour at end of training: Loads best model if available\n\n        Returns\n        -------\n        :class:`AbstractTfNetwork`\n            best network\n\n        \"\"\"\n        if os.path.isfile(os.path.join(self.save_path,\n                                       'checkpoint_best.meta')):\n\n            # load best model and return it.\n            self.update_state(os.path.join(self.save_path,\n                                           'checkpoint_best')\n                              )\n\n        return super()._at_training_end(*args, **kwargs)\n\n    def _train_single_epoch(self, batchgen, epoch, verbose=False):\n        \"\"\"\n        Trains the network a single epoch\n\n        Parameters\n        ----------\n        batchgen : MultiThreadedAugmenter\n            Generator yielding the training batches\n        epoch : int\n            current epoch\n\n        \"\"\"\n        self.module.trainable = True\n\n        return super()._train_single_epoch(batchgen, epoch, verbose=verbose)\n\n    def predict_data_mgr(self, datamgr, batchsize=None, metrics=None,\n                         metric_keys=None, verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n        self.module.trainable = False\n\n        return super().predict_data_mgr(datamgr, batchsize, metrics,\n                                        metric_keys, verbose=verbose, **kwargs)\n\n    def save_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        saves the current state via :func:`delira.io.tf.save_checkpoint_eager`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        \"\"\"\n        save_checkpoint_eager(file_name, self.module, self.optimizers,\n                              *args, **kwargs)\n\n    def load_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        Loads the new state from file via\n        :func:`delira.io.tf.load_checkpoint_eager`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        Returns\n        -------\n\n        \"\"\"\n        return load_checkpoint_eager(\n            file_name, self.module, self.optimizers)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latst checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".meta\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n"
  },
  {
    "path": "delira/training/backends/tf_eager/utils.py",
    "content": "import tensorflow as tf\n\nfrom delira.training.utils import convert_to_numpy_identity, \\\n    recursively_convert_elements\n\n\ndef _single_element_tensor_conversion(element):\n    return element.numpy()\n\n\ndef convert_to_numpy(*args, **kwargs):\n    \"\"\"\n    Converts all tf tensors in args and kwargs to numpy array\n\n    Parameters\n    ----------\n    *args :\n        positional arguments of arbitrary number and type\n    **kwargs :\n        keyword arguments of arbitrary number and type\n\n    Returns\n    -------\n    list\n        converted positional arguments\n    dict\n        converted keyboard arguments\n    \"\"\"\n    args = recursively_convert_elements(args, tf.Tensor,\n                                        _single_element_tensor_conversion)\n\n    kwargs = recursively_convert_elements(kwargs, tf.Tensor,\n                                          _single_element_tensor_conversion)\n\n    return convert_to_numpy_identity(*args, **kwargs)\n\n\ndef create_optims_default(optim_cls, **optim_params):\n    \"\"\"\n    Function to create a optimizer dictionary\n    (in this case only one optimizer)\n\n    Parameters\n    ----------\n    optim_cls :\n        Class implementing an optimization algorithm\n    **optim_params :\n        Additional keyword arguments (passed to the optimizer class)\n\n    Returns\n    -------\n    dict\n        dictionary containing all created optimizers\n    \"\"\"\n    return {\"default\": optim_cls(**optim_params)}\n"
  },
  {
    "path": "delira/training/backends/tf_graph/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TF\" in _get_backends():\n    from delira.training.backends.tf_graph.experiment import TfGraphExperiment\n    from delira.training.backends.tf_graph.trainer import TfGraphNetworkTrainer\n    from delira.training.backends.tf_graph.utils import \\\n        initialize_uninitialized\n"
  },
  {
    "path": "delira/training/backends/tf_graph/experiment.py",
    "content": "import typing\nfrom functools import partial\n\nimport tensorflow as tf\n\nfrom delira.models.backends.tf_graph import AbstractTfGraphNetwork\nfrom delira.data_loading import DataManager\n\nfrom delira.utils import DeliraConfig\nfrom delira.training.backends.tf_eager.experiment import TfEagerExperiment\nfrom delira.training.backends.tf_eager.utils import create_optims_default\n\nfrom delira.training.backends.tf_graph.trainer import TfGraphNetworkTrainer\nfrom delira.training.backends.tf_graph.utils import initialize_uninitialized\n\n\nclass TfGraphExperiment(TfEagerExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractTfGraphNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=create_optims_default,\n                 checkpoint_freq=1,\n                 trainer_cls=TfGraphNetworkTrainer,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractTfEagerNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"data\": \"data\"} will be used\n            here\n        val_score_key : str or None\n            key defining which metric to use for validation (determining\n            best model and scheduling lr); if None: No validation-based\n            operations will be done (model might still get validated,\n            but validation metrics can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers.\n            defaults to :func:`create_optims_default_tf`\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`TfEagerNetworkTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`TfEagerNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if key_mapping is None:\n            key_mapping = {\"data\": \"data\"}\n\n        super().__init__(\n            config=config,\n            model_cls=model_cls,\n            n_epochs=n_epochs,\n            name=name,\n            save_path=save_path,\n            key_mapping=key_mapping,\n            val_score_key=val_score_key,\n            optim_builder=optim_builder,\n            checkpoint_freq=checkpoint_freq,\n            trainer_cls=trainer_cls,\n            **kwargs)\n\n    def test(self, network, test_data: DataManager,\n             metrics: dict, metric_keys=None,\n             verbose=False, prepare_batch=lambda x: x,\n             convert_fn=None, **kwargs):\n        \"\"\"\n        Setup and run testing on a given network\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the (trained) network to test\n        test_data : :class:`DataManager`\n            the data to use for testing\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label``\n             will be used for metric calculation\n        verbose : bool\n            verbosity of the test process\n        prepare_batch : function\n            function to convert a batch-dict to a format accepted by the\n            model. This conversion typically includes dtype-conversion,\n            reshaping, wrapping to backend-specific tensors and\n            pushing to correct devices. If not further specified uses the\n            ``network``'s ``prepare_batch`` with CPU devices\n        convert_fn : function\n            function to convert a batch of tensors to numpy\n            if not specified defaults to\n            :func:`convert_torch_tensor_to_npy`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions obtained by feeding the ``test_data`` through\n            the ``network``\n        dict\n            all metrics calculated upon the ``test_data`` and the obtained\n            predictions\n\n        \"\"\"\n\n        initialize_uninitialized(network._sess)\n\n        if prepare_batch is None:\n            prepare_batch = partial(network.prepare_batch,\n                                    input_device=None,\n                                    output_device=None)\n\n        return super().test(network=network, test_data=test_data,\n                            metrics=metrics, metric_keys=metric_keys,\n                            verbose=verbose, prepare_batch=prepare_batch,\n                            convert_fn=convert_fn, **kwargs)\n"
  },
  {
    "path": "delira/training/backends/tf_graph/trainer.py",
    "content": "from delira.training.backends.tf_graph.utils import initialize_uninitialized\nfrom delira.training.backends.tf_eager.utils import create_optims_default\nfrom delira.training.utils import convert_to_numpy_identity \\\n    as convert_to_numpy\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\nfrom delira.io.tf import load_checkpoint, save_checkpoint\nfrom delira.models.backends.tf_graph import AbstractTfGraphNetwork\nfrom delira.data_loading import DataManager\nimport os\nimport logging\n\nfrom tensorflow import executing_eagerly\n\nfrom batchgenerators.dataloading import MultiThreadedAugmenter\n\nlogger = logging.getLogger(__name__)\n\n\nclass TfGraphNetworkTrainer(BaseNetworkTrainer):\n    \"\"\"\n    Train and Validate a Network\n\n    See Also\n    --------\n    :class:`AbstractNetwork`\n\n    \"\"\"\n\n    def __init__(self,\n                 network: AbstractTfGraphNetwork,\n                 save_path: str,\n                 key_mapping: dict,\n                 losses: dict,\n                 optimizer_cls,\n                 optimizer_params=None,\n                 metrics=None,\n                 lr_scheduler_cls=None,\n                 lr_scheduler_params=None,\n                 gpu_ids=None,\n                 save_freq=1,\n                 optim_fn=create_optims_default,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 val_freq=1,\n                 **kwargs\n                 ):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractTfGraphNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric as\n            target; default: None, which will result in key \"label\" for all\n            metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            the identity function\n        val_freq : int\n            validation frequency specifying how often to validate the trained\n            model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            Additional keyword arguments\n\n        \"\"\"\n        assert not executing_eagerly()\n\n        if optimizer_params is None:\n            optimizer_params = {}\n        if metrics is None:\n            metrics = {}\n        if lr_scheduler_params is None:\n            lr_scheduler_params = {}\n        if gpu_ids is None:\n            gpu_ids = []\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if callbacks is None:\n            callbacks = []\n\n        super().__init__(network=network,\n                         save_path=save_path,\n                         losses=losses,\n                         optimizer_cls=optimizer_cls,\n                         optimizer_params=optimizer_params,\n                         metrics=metrics,\n                         lr_scheduler_cls=lr_scheduler_cls,\n                         lr_scheduler_params=lr_scheduler_params,\n                         gpu_ids=gpu_ids,\n                         save_freq=save_freq,\n                         optim_fn=optim_fn,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         **kwargs\n                         )\n\n        self._setup(network, optim_fn, optimizer_cls, optimizer_params,\n                    lr_scheduler_cls, lr_scheduler_params,\n                    key_mapping, convert_batch_to_npy_fn, gpu_ids, callbacks)\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,\n               lr_scheduler_cls, lr_scheduler_params, key_mapping,\n               convert_batch_to_npy_fn, gpu_ids, callbacks):\n        \"\"\"\n        Defines the Trainers Setup\n\n        Parameters\n        ----------\n        network : instance of :class: `AbstractTfNetwork`\n            the network to train\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of choice\n        optimizer_params : dict\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            the identity function\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        callbacks : list\n            initial callbacks to register\n\n        Raises\n        ------\n        RuntimeError\n            if multiple GPU ids passed\n        \"\"\"\n\n        # TODO: implement multi-GPU and single GPU training with help of\n        #  keras multi-gpu model\n        #  note: might be bugged in combination with sess.run\n        #  https://github.com/tensorflow/tensorflow/issues/21788\n\n        # if gpu_ids and tf.test.is_gpu_available():\n        #     assert len(gpu_ids) <= len(get_available_gpus()), \"more GPUs\n        #     specified than available\"\n        #     self.use_gpu = True\n        #     if len(gpu_ids) > 1:\n        #         logger.warning(\n        #             \"multi-GPU training not yet tested!\")\n        #\n        #         network.model = tf.keras.utils.multi_gpu_model(\n        #                                 network.model,\n        #                                 len(gpu_ids),\n        #                                 cpu_merge=True,\n        #                                 cpu_relocation=False)\n        #     else:\n        #         network.models = tf.keras.models.clone_model(network.model)\n        # else:\n        #     self.use_gpu = False\n        #\n        if len(gpu_ids) > 1:\n            raise RuntimeError(\"Multiple GPUs not yet supported\")\n\n        self.optimizers = optim_fn(optimizer_cls, **optimizer_params)\n\n        super()._setup(network, lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n                       key_mapping, convert_batch_to_npy_fn, lambda x: x,\n                       callbacks)\n\n        self.use_gpu = True\n\n        self.module._add_losses(self.losses)\n        self.module._add_optims(self.optimizers)\n        # check for unitialized variables\n        initialize_uninitialized(self.module._sess)\n\n        # Load latest epoch file if available\n        if os.path.isdir(self.save_path):\n            latest_state_path, latest_epoch = self._search_for_prev_state(\n                self.save_path)\n\n            if latest_state_path is not None:\n\n                # if pth file does not exist, load pt file instead\n                if not os.path.isfile(latest_state_path):\n                    latest_state_path = latest_state_path[:-1]\n\n                logger.info(\"Attempting to load state from previous \\\n                            training from %s\" % latest_state_path)\n\n                self.update_state(latest_state_path)\n                self.start_epoch = latest_epoch\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines Behaviour at end of training: Loads best model if available\n\n        Returns\n        -------\n        :class:`AbstractTfNetwork`\n            best network\n\n        \"\"\"\n\n        if os.path.isfile(os.path.join(self.save_path,\n                                       'checkpoint_best.meta')):\n\n            # load best model and return it.\n\n            self.update_state(os.path.join(self.save_path,\n                                           'checkpoint_best')\n                              )\n\n        return super()._at_training_end(*args, **kwargs)\n\n    def _train_single_epoch(self, dmgr_train: DataManager, epoch,\n                            verbose=False):\n        \"\"\"\n        Trains the network a single epoch\n\n        Parameters\n        ----------\n        dmgr_train : :class:`DataManager`\n            Datamanager to create the data generator\n        epoch : int\n            current epoch\n\n        \"\"\"\n        self.module.training = True\n\n        return super()._train_single_epoch(dmgr_train, epoch, verbose=verbose)\n\n    def predict_data_mgr(self, datamgr, batch_size=None, metrics=None,\n                         metric_keys=None, verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batch_size : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n\n        self.module.training = False\n\n        return super().predict_data_mgr(datamgr, batch_size, metrics,\n                                        metric_keys, verbose=verbose)\n\n    def save_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        saves the current state via :func:`delira.io.tf.save_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        \"\"\"\n        save_checkpoint(file_name, self.module)\n\n    def load_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        Loads the new state from file via :func:`delira.io.tf.load_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        Returns\n        -------\n\n        \"\"\"\n        return load_checkpoint(file_name, self.module)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latest checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".meta\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n"
  },
  {
    "path": "delira/training/backends/tf_graph/utils.py",
    "content": "import tensorflow as tf\n\n\ndef initialize_uninitialized(sess):\n    \"\"\"\n    Function to initialize only uninitialized variables in a session graph\n\n    Parameters\n    ----------\n    sess : tf.Session()\n\n    \"\"\"\n\n    global_vars = tf.global_variables()\n    is_not_initialized = sess.run(\n        [tf.is_variable_initialized(var) for var in global_vars])\n\n    not_initialized_vars = [v for (v, f) in zip(\n        global_vars, is_not_initialized) if not f]\n\n    if not_initialized_vars:\n        sess.run(tf.variables_initializer(not_initialized_vars))\n"
  },
  {
    "path": "delira/training/backends/torch/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TORCH\" in _get_backends():\n    from delira.training.backends.torch.trainer import PyTorchNetworkTrainer\n    from delira.training.backends.torch.experiment import PyTorchExperiment\n    from delira.training.backends.torch.utils import create_optims_default \\\n        as create_pytorch_optims_default\n    from delira.training.backends.torch.utils import convert_to_numpy \\\n        as convert_torch_to_numpy\n"
  },
  {
    "path": "delira/training/backends/torch/experiment.py",
    "content": "from functools import partial\nimport typing\n\nimport torch\n\nfrom delira.models.backends.torch import AbstractPyTorchNetwork\nfrom delira.data_loading import DataManager\n\nfrom delira.training.base_experiment import BaseExperiment\nfrom delira.utils import DeliraConfig\n\nfrom delira.training.backends.torch.trainer import PyTorchNetworkTrainer\nfrom delira.training.backends.torch.utils import create_optims_default\nfrom delira.training.backends.torch.utils import convert_to_numpy\n\n\nclass PyTorchExperiment(BaseExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractPyTorchNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=create_optims_default,\n                 checkpoint_freq=1,\n                 trainer_cls=PyTorchNetworkTrainer,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractPyTorchNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"x\": \"data\"} will be used here\n        val_score_key : str or None\n            key defining which metric to use for validation (determining\n            best model and scheduling lr); if None: No validation-based\n            operations will be done (model might still get validated,\n            but validation metrics can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers.\n            defaults to :func:`create_optims_default_pytorch`\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`PyTorchNetworkTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`PyTorchNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if key_mapping is None:\n            key_mapping = {\"x\": \"data\"}\n        super().__init__(config=config, model_cls=model_cls,\n                         n_epochs=n_epochs, name=name, save_path=save_path,\n                         key_mapping=key_mapping,\n                         val_score_key=val_score_key,\n                         optim_builder=optim_builder,\n                         checkpoint_freq=checkpoint_freq,\n                         trainer_cls=trainer_cls,\n                         **kwargs)\n\n    def kfold(self, data: DataManager, metrics: dict, num_epochs=None,\n              num_splits=None, shuffle=False, random_seed=None,\n              split_type=\"random\", val_split=0.2, label_key=\"label\",\n              train_kwargs: dict = None, test_kwargs: dict = None,\n              metric_keys: dict = None, config=None, verbose=False,\n              **kwargs):\n        \"\"\"\n        Performs a k-Fold cross-validation\n\n        Parameters\n        ----------\n        data : :class:`DataManager`\n            the data to use for training(, validation) and testing. Will be\n            split based on ``split_type`` and ``val_split``\n        metrics : dict\n            dictionary containing the metrics to evaluate during k-fold\n        num_epochs : int or None\n            number of epochs to train (if not given, will either be\n            extracted from ``config``, ``self.config`` or ``self.n_epochs``)\n        num_splits : int or None\n            the number of splits to extract from ``data``.\n            If None: uses a default of 10\n        shuffle : bool\n            whether to shuffle the data before splitting or not\n            (implemented by index-shuffling rather than actual\n            data-shuffling to retain potentially lazy-behavior of datasets)\n        random_seed : None\n            seed to seed numpy, the splitting functions and the used\n            backend-framework\n        split_type : str\n            must be one of ['random', 'stratified']\n            if 'random': uses random data splitting\n            if 'stratified': uses stratified data splitting. Stratification\n            will be based on ``label_key``\n        val_split : float or None\n            the fraction of the train data to use as validation set.\n            If None: No validation will be done during training; only\n            testing for each fold after the training is complete\n        label_key : str\n            the label to use for stratification. Will be ignored unless\n            ``split_type`` is 'stratified'. Default: 'label'\n        train_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the train data. If None: empty dict will be passed\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label`` will be used for metric calculation\n        test_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the test and validation data.\n            If None: empty dict will be passed\n        config : :class:`Parameters`or None\n            the training and model parameters\n            (will be merged with ``self.config``)\n        verbose : bool\n            verbosity\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions from all folds\n        dict\n            all metric values from all folds\n\n        Raises\n        ------\n        ValueError\n            if ``split_type`` is neither 'random', nor 'stratified'\n\n        See Also\n        --------\n\n        * :class:`sklearn.model_selection.KFold`\n        and :class:`sklearn.model_selection.ShuffleSplit`\n        for random data-splitting\n\n        * :class:`sklearn.model_selection.StratifiedKFold`\n        and :class:`sklearn.model_selection.StratifiedShuffleSplit`\n        for stratified data-splitting\n\n        * :meth:`DataManager.update_from_state_dict` for updating the\n        data managers by kwargs\n\n        * :meth:`BaseExperiment.run` for the training\n\n        * :meth:`BaseExperiment.test` for the testing\n\n        Notes\n        -----\n        using stratified splits may be slow during split-calculation, since\n        each item must be loaded once to obtain the labels necessary for\n        stratification.\n\n        \"\"\"\n\n        # seed torch backend\n        if random_seed is not None:\n            torch.manual_seed(random_seed)\n\n        return super().kfold(\n            data=data,\n            metrics=metrics,\n            num_epochs=num_epochs,\n            num_splits=num_splits,\n            shuffle=shuffle,\n            random_seed=random_seed,\n            split_type=split_type,\n            val_split=val_split,\n            label_key=label_key,\n            train_kwargs=train_kwargs,\n            test_kwargs=test_kwargs,\n            metric_keys=metric_keys,\n            config=config,\n            verbose=verbose,\n            **kwargs)\n\n    def test(self, network, test_data: DataManager,\n             metrics: dict, metric_keys=None,\n             verbose=False, prepare_batch=None,\n             convert_fn=None, **kwargs):\n        \"\"\"\n        Setup and run testing on a given network\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the (trained) network to test\n        test_data : :class:`DataManager`\n            the data to use for testing\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label``\n             will be used for metric calculation\n        verbose : bool\n            verbosity of the test process\n        prepare_batch : function\n            function to convert a batch-dict to a format accepted by the\n            model. This conversion typically includes dtype-conversion,\n            reshaping, wrapping to backend-specific tensors and\n            pushing to correct devices. If not further specified uses the\n            ``network``'s ``prepare_batch`` with CPU devices\n        convert_fn : function\n            function to convert a batch of tensors to numpy\n            if not specified defaults to\n            :func:`convert_torch_tensor_to_npy`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions obtained by feeding the ``test_data`` through\n            the ``network``\n        dict\n            all metrics calculated upon the ``test_data`` and the obtained\n            predictions\n\n        \"\"\"\n\n        # use backend-specific and model-specific prepare_batch fn\n        # (runs on same device as passed network per default)\n\n        device = next(network.parameters()).device\n        if prepare_batch is None:\n            prepare_batch = partial(network.prepare_batch,\n                                    input_device=device,\n                                    output_device=device)\n\n        # switch to backend-specific convert function\n        if convert_fn is None:\n            convert_fn = convert_to_numpy\n\n        return super().test(network=network, test_data=test_data,\n                            metrics=metrics, metric_keys=metric_keys,\n                            verbose=verbose, prepare_batch=prepare_batch,\n                            convert_fn=convert_fn, **kwargs)\n"
  },
  {
    "path": "delira/training/backends/torch/trainer.py",
    "content": "import logging\nimport os\nfrom functools import partial\nimport warnings\n\nimport torch\nfrom batchgenerators.dataloading import MultiThreadedAugmenter\n\nfrom delira.io.torch import load_checkpoint_torch, save_checkpoint_torch\nfrom delira.models.backends.torch import AbstractPyTorchNetwork, \\\n    DataParallelPyTorchNetwork\n\nfrom delira.training.base_trainer import BaseNetworkTrainer\n\nfrom delira.training.backends.torch.utils import create_optims_default\nfrom delira.training.backends.torch.utils import convert_to_numpy\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass PyTorchNetworkTrainer(BaseNetworkTrainer):\n    \"\"\"\n    Train and Validate a Network\n\n    See Also\n    --------\n    :class:`AbstractNetwork`\n\n    \"\"\"\n\n    def __init__(self,\n                 network: AbstractPyTorchNetwork,\n                 save_path: str,\n                 key_mapping,\n                 losses=None,\n                 optimizer_cls=None,\n                 optimizer_params=None,\n                 metrics=None,\n                 lr_scheduler_cls=None,\n                 lr_scheduler_params=None,\n                 gpu_ids=None,\n                 save_freq=1,\n                 optim_fn=create_optims_default,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 mixed_precision=False,\n\n                 mixed_precision_kwargs={\"opt_level\": \"O1\",\n                                         \"cast_model_type\": None,\n                                         \"patch_torch_functions\": None,\n                                         \"master_weights\": None,\n                                         \"loss_scale\": None,\n                                         \"cast_model_outputs\": None,\n                                         \"num_losses\": 1,\n                                         \"verbosity\": 1},\n                 val_freq=1,\n                 ** kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractPyTorchNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of\n            choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric\n            as target; default: None, which will result in key \"label\"\n            for all metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this\n            is a function, which detaches the tensor, moves it to cpu and\n            then calls ``.numpy()`` on it\n        mixed_precision : bool\n            whether to use mixed precision or not (False per default)\n        mixed_precision_kwargs : dict\n            additional keyword arguments for mixed precision\n            from apex.amp.frontend:\n                opt_level : str\n                    Pure or mixed precision optimization level. Accepted\n                    values are \"O0\", \"O1\", \"O2\", and \"O3\":\n                        O0:  Pure FP32 training.\n                        O1:  Insert automatic casts around Pytorch\n                            functions and Tensor methods.\n                        O2:  FP16 training with FP32 batchnorm and FP32\n                            master weights\n                        O3:  Pure FP16 training.\n\n                cast_model_type : :class:`torch.dtype`\n                    Optional property override for model dtype;\n                    default: None\n                patch_torch_functions : bool\n                    Optional property override.\n                keep_batchnorm_fp32 : bool or str\n                    Optional property override.  If passed as a string,\n                    must be the string \"True\" or \"False\".\n                master_weights : bool\n                    Optional property override; whether to create master\n                    weights or not\n                loss_scale : float or str\n                    Optional property override.  If passed as a string,\n                    must be a string representing a number, e.g., \"128.0\",\n                    or the string \"dynamic\".\n                cast_model_outputs : :class:`torch.dtype`\n                    Option to ensure that the outputs of your model(s)\n                    are always cast to a particular type regardless\n                    of ``opt_level``.\n                num_losses : int\n                    Option to tell Amp in advance how many losses/backward\n                    passes you plan to use.  When used in conjunction with\n                    the ``loss_id`` argument to ``amp.scale_loss``, enables\n                    Amp to use a different loss scale per loss/backward\n                    pass, which can improve stability. See\n                    \"Multiple models/optimizers/losses\" under\n                    \"Advanced Amp Usage\" for examples.  If ``num_losses``\n                    is left to 1, Amp will still support multiple\n                    losses/backward passes, but use a single global\n                    loss scale for all of them; default: 1\n                verbosity : int\n                    Set to 0 to suppress Amp-related output; default: 1\n        val_freq : int\n            validation frequency specifying how often to validate the\n            trained model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if callbacks is None:\n            callbacks = []\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if gpu_ids is None:\n            gpu_ids = []\n        if lr_scheduler_params is None:\n            lr_scheduler_params = {}\n        if metrics is None:\n            metrics = {}\n        if optimizer_params is None:\n            optimizer_params = {}\n\n        super().__init__(network=network,\n                         save_path=save_path,\n                         losses=losses,\n                         optimizer_cls=optimizer_cls,\n                         optimizer_params=optimizer_params,\n                         metrics=metrics,\n                         lr_scheduler_cls=lr_scheduler_cls,\n                         lr_scheduler_params=lr_scheduler_params,\n                         gpu_ids=gpu_ids,\n                         save_freq=save_freq,\n                         optim_fn=optim_fn,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         **kwargs\n                         )\n\n        self._setup(network, optim_fn, optimizer_cls, optimizer_params,\n                    lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n                    key_mapping, convert_batch_to_npy_fn,\n                    mixed_precision, mixed_precision_kwargs, callbacks)\n\n        for key, val in kwargs.items():\n            setattr(self, key, val)\n\n    def _setup(self, network, optim_fn, optimizer_cls, optimizer_params,\n               lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n               key_mapping, convert_batch_to_npy_fn, mixed_precision,\n               mixed_precision_kwargs, callbacks):\n        \"\"\"\n        Defines the Trainers Setup\n\n        Parameters\n        ----------\n        network : :class:`AbstractPyTorchNetwork`\n            the network to train\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        optimizer_cls : subclass of torch.optim.Optimizer\n            optimizer class implementing the optimization algorithm of\n            choice\n        optimizer_params : dict\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        convert_batch_to_npy_fn : type\n            function converting a batch-tensor to numpy\n        mixed_precision : bool\n            whether to use mixed precision or not (False per default)\n        mixed_precision_kwargs : dict\n            additional keyword arguments for mixed precision\n        callbacks : list\n            initial callbacks to register\n\n        \"\"\"\n\n        self.optimizers = optim_fn(network, optimizer_cls,\n                                   **optimizer_params)\n\n        super()._setup(network, lr_scheduler_cls, lr_scheduler_params,\n                       gpu_ids, key_mapping, convert_batch_to_npy_fn,\n                       network.prepare_batch, callbacks)\n\n        # Load latest epoch file if available\n        if os.path.isdir(self.save_path):\n            latest_state_path, latest_epoch = self._search_for_prev_state(\n                self.save_path)\n\n            if latest_state_path is not None:\n\n                # if pth file does not exist, load pt file instead\n                if not os.path.isfile(latest_state_path):\n                    latest_state_path = latest_state_path[:-1]\n\n                logger.info(\"Attempting to load state from previous \\\n                            training from %s\" % latest_state_path)\n                try:\n                    self.update_state(latest_state_path)\n                except KeyError:\n                    logger.warning(\"Previous State could not be loaded, \\\n                                although it exists.Training will be \\\n                                restarted\")\n\n                self.start_epoch = latest_epoch\n\n        if gpu_ids and torch.cuda.is_available():\n            self.use_gpu = True\n            if (len(gpu_ids) > 1) and (torch.cuda.device_count() > 1):\n                # use GPU 0 as default input GPU\n                self.input_device = torch.device(\"cuda:%d\" % gpu_ids[0])\n\n                # Train on multiple GPUs and use GPU 0 as output device\n                self.module = DataParallelPyTorchNetwork(self.module.to(\n                    self.input_device),\n                    device_ids=gpu_ids,\n                    output_device=gpu_ids[1])\n\n                # use GPU 1 as default output GPU for balanced GPU usage\n                self.output_device = torch.device(\"cuda:%d\" % gpu_ids[1])\n            else:\n                # use the only available GPU as input device\n                self.input_device = torch.device(\"cuda:%d\" % gpu_ids[0])\n                self.module = self.module.to(self.input_device)\n\n                # use GPU 0 as output device as output device\n                self.output_device = torch.device(\"cuda:%d\" % gpu_ids[0])\n        else:\n            self.use_gpu = False\n            self.input_device = torch.device(\"cpu\")\n            self.output_device = torch.device(\"cpu\")\n            self.module = self.module.to(self.input_device)\n\n        self._prepare_batch = partial(\n            self._prepare_batch, input_device=self.input_device,\n            output_device=self.output_device)\n\n        try:\n            # use apex for mixed precision if installed\n            from apex import amp\n\n            # extract optimizers and corresponding keys\n            # (in case dict is not ordered)\n            _optim_keys = list(self.optimizers.keys())\n            _optims = list(self.optimizers[k] for k in _optim_keys)\n\n            # wrap model and register optimizers for mixed precision\n            self.module, _optims = amp.initialize(self.module, _optims,\n                                                  mixed_precision,\n                                                  **mixed_precision_kwargs)\n            for k, v in zip(_optim_keys, _optims):\n                self.optimizers[k] = v\n\n        except (ImportError, RuntimeError) as e:\n            warnings.warn(\n                \"Either APEX can't be imported correctly or a value \"\n                \"missmatch occured. Switching to default FP32 \"\n                \"training insted. The following Exception occured:\"\n                \"\\n%s\" %\n                str(e))\n\n    def _at_training_begin(self, *args, **kwargs):\n        \"\"\"\n        Defines the behaviour at beginnig of the training\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n        \"\"\"\n        for cbck in self._callbacks:\n            self._update_state(cbck.at_training_begin(self, *args, **kwargs))\n\n        self.save_state(os.path.join(self.save_path, \"checkpoint_epoch_%d\"\n                                     % self.start_epoch), self.start_epoch)\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines Behaviour at end of training: Loads best model if\n        available\n\n        Returns\n        -------\n        :class:`AbstractPyTorchNetwork`\n            best network\n\n        \"\"\"\n        if os.path.isfile(os.path.join(self.save_path,\n                                       'checkpoint_best.pt')):\n\n            # load best model and return it\n            self.update_state(os.path.join(self.save_path,\n                                           'checkpoint_best.pt'))\n\n        return super()._at_training_end(*args, **kwargs)\n\n    def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,\n                      **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of each epoch:\n        Executes all callbacks's `at_epoch_end` method and saves current\n        state if necessary\n\n        Parameters\n        ----------\n        metrics_val : dict\n            validation metrics\n        val_score_key : str\n            validation score key\n        epoch : int\n            current epoch\n        num_epochs : int\n            total number of epochs\n        is_best : bool\n            whether current model is best one so far\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n\n        for cb in self._callbacks:\n            self._update_state(\n                cb.at_epoch_end(\n                    self,\n                    val_metrics=metrics_val,\n                    val_score_key=val_score_key,\n                    curr_epoch=epoch))\n\n        if epoch % self.save_freq == 0:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_epoch_%d.pt\" % epoch),\n                            epoch)\n\n        if is_best:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_best.pt\"),\n                            epoch)\n\n    def _train_single_epoch(self, batchgen: MultiThreadedAugmenter, epoch,\n                            verbose=False):\n        \"\"\"\n        Trains the network a single epoch\n\n        Parameters\n        ----------\n        batchgen : MultiThreadedAugmenter\n            Generator yielding the training batches\n        epoch : int\n            current epoch\n\n        \"\"\"\n\n        self.module.train()\n\n        return super()._train_single_epoch(batchgen, epoch,\n                                           verbose=verbose)\n\n    def predict_data_mgr(self, datamgr, batchsize=None, metrics=None,\n                         metric_keys=None, verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            predictions\n        dict\n            calculated metrics\n\n        \"\"\"\n        self.module.eval()\n\n        if metrics is None:\n            metrics = {}\n\n        return super().predict_data_mgr(datamgr, batchsize, metrics,\n                                        metric_keys, verbose, **kwargs)\n\n    def save_state(self, file_name, epoch, **kwargs):\n        \"\"\"\n        saves the current state via :func:`delira.io.torch.save_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        epoch : int\n            current epoch (will be saved for mapping back)\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        if not (file_name.endswith(\".pth\") or file_name.endswith(\".pt\")):\n            file_name = file_name + \".pt\"\n        save_checkpoint_torch(file_name, self.module, self.optimizers, epoch,\n                              **kwargs)\n\n    @staticmethod\n    def load_state(file_name, **kwargs):\n        \"\"\"\n        Loads the new state from file via\n        :func:`delira.io.torch.load_checkpoint`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        **kwargs : keyword arguments\n\n        Returns\n        -------\n        dict\n            new state\n\n        \"\"\"\n\n        if not (file_name.endswith(\".pth\") or file_name.endswith(\".pt\")):\n            file_name = file_name + \".pt\"\n\n        return load_checkpoint_torch(file_name, **kwargs)\n\n    def _update_state(self, new_state):\n        \"\"\"\n        Update the state from a given new state\n\n        Parameters\n        ----------\n        new_state : dict\n            new state to update internal state from\n\n        Returns\n        -------\n        :class:`PyTorchNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n\n        if \"model\" in new_state:\n            self.module.load_state_dict(new_state.pop(\"model\"))\n\n        if \"optimizer\" in new_state and new_state[\"optimizer\"]:\n            optim_state = new_state.pop(\"optimizer\")\n            for key in self.optimizers.keys():\n                self.optimizers[key].load_state_dict(\n                    optim_state[key])\n\n        if \"epoch\" in new_state:\n            self.start_epoch = new_state.pop(\"epoch\")\n\n        return super()._update_state(new_state)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latst checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".pt\", \".pth\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n"
  },
  {
    "path": "delira/training/backends/torch/utils.py",
    "content": "import torch\n\nfrom delira.utils.decorators import dtype_func\nfrom delira.training.utils import convert_to_numpy_identity\nfrom delira.training.utils import recursively_convert_elements\n\n\n@dtype_func(torch.nn.Module)\ndef create_optims_default(model, optim_cls, **optim_params):\n    \"\"\"\n    Function to create a optimizer dictionary\n    (in this case only one optimizer for the whole network)\n\n    Parameters\n    ----------\n    model : :class:`AbstractPyTorchNetwork`\n        model whose parameters should be updated by the optimizer\n    optim_cls :\n        Class implementing an optimization algorithm\n    **optim_params :\n        Additional keyword arguments (passed to the optimizer class\n\n    Returns\n    -------\n    dict\n        dictionary containing all created optimizers\n    \"\"\"\n    return {\"default\": optim_cls(model.parameters(), **optim_params)}\n\n\ndef _single_element_tensor_conversion(element):\n    return element.cpu().detach().numpy()\n\n\ndef convert_to_numpy(*args, **kwargs):\n    \"\"\"\n    Converts all :class:`torch.Tensor` in args and kwargs to numpy array\n\n    Parameters\n    ----------\n    *args :\n        positional arguments of arbitrary number and type\n    **kwargs :\n        keyword arguments of arbitrary number and type\n\n    Returns\n    -------\n    list\n        converted positional arguments\n    dict\n        converted keyboard arguments\n\n    \"\"\"\n    args = recursively_convert_elements(args, torch.Tensor,\n                                        _single_element_tensor_conversion)\n\n    kwargs = recursively_convert_elements(kwargs, torch.Tensor,\n                                          _single_element_tensor_conversion)\n\n    return convert_to_numpy_identity(*args, **kwargs)\n"
  },
  {
    "path": "delira/training/backends/torchscript/__init__.py",
    "content": "from delira import get_backends as _get_backends\n\nif \"TORCH\" in _get_backends():\n    from delira.training.backends.torchscript.experiment import \\\n        TorchScriptExperiment\n    from delira.training.backends.torchscript.trainer import \\\n        TorchScriptNetworkTrainer\n"
  },
  {
    "path": "delira/training/backends/torchscript/experiment.py",
    "content": "import typing\n\nfrom delira.models.backends.torchscript import AbstractTorchScriptNetwork\n\nfrom delira.utils import DeliraConfig\nfrom delira.training.backends.torch.experiment import PyTorchExperiment\nfrom delira.training.backends.torch.utils import create_optims_default\n\nfrom delira.training.backends.torchscript.trainer import \\\n    TorchScriptNetworkTrainer\n\n\nclass TorchScriptExperiment(PyTorchExperiment):\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractTorchScriptNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=create_optims_default,\n                 checkpoint_freq=1,\n                 trainer_cls=TorchScriptNetworkTrainer,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training config, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractTorchScriptNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API), if no keymapping is\n            given, a default key_mapping of {\"x\": \"data\"} will be used here\n        val_score_key : str or None\n            key defining which metric to use for validation (determining\n            best model and scheduling lr); if None: No validation-based\n            operations will be done (model might still get validated,\n            but validation metrics can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers.\n            defaults to :func:`create_optims_default_pytorch`\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`TorchScriptNetworkTrainer`\n            the trainer class to use for training the model, defaults to\n            :class:`TorchScriptNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        super().__init__(config=config, model_cls=model_cls,\n                         n_epochs=n_epochs, name=name, save_path=save_path,\n                         key_mapping=key_mapping,\n                         val_score_key=val_score_key,\n                         optim_builder=optim_builder,\n                         checkpoint_freq=checkpoint_freq,\n                         trainer_cls=trainer_cls,\n                         **kwargs)\n"
  },
  {
    "path": "delira/training/backends/torchscript/trainer.py",
    "content": "import logging\n\nfrom delira.io.torch import load_checkpoint_torchscript, \\\n    save_checkpoint_torchscript\nfrom delira.models.backends.torchscript import AbstractTorchScriptNetwork\n\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.training.backends.torch.trainer import PyTorchNetworkTrainer\n\nfrom delira.training.backends.torch.utils import convert_to_numpy\nfrom delira.training.backends.torch.utils import create_optims_default\n\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\n\n\nlogger = logging.getLogger(__name__)\n\n\nclass TorchScriptNetworkTrainer(PyTorchNetworkTrainer):\n    def __init__(self,\n                 network: AbstractTorchScriptNetwork,\n                 save_path: str,\n                 key_mapping,\n                 losses=None,\n                 optimizer_cls=None,\n                 optimizer_params=None,\n                 metrics=None,\n                 lr_scheduler_cls=None,\n                 lr_scheduler_params=None,\n                 gpu_ids=None,\n                 save_freq=1,\n                 optim_fn=create_optims_default,\n                 logging_type=\"tensorboardx\",\n                 logging_kwargs=None,\n                 fold=0,\n                 callbacks=None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=convert_to_numpy,\n                 criterions=None,\n                 val_freq=1,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractPyTorchJITNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of\n            choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n            Currently ``torch.jit`` only supports single GPU-Training,\n            thus only the first GPU will be used if multiple GPUs are\n            passed\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n            specifies how often to log for each key.\n            If int: integer will be applied to all valid keys\n            if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n            None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            dict specifying which batch_dict entry to use for which metric\n            as target; default: None, which will result in key \"label\" for\n            all metrics\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this\n            is a function, which detaches the tensor, moves it to cpu and\n            then calls ``.numpy()`` on it\n        mixed_precision : bool\n            whether to use mixed precision or not (False per default)\n        mixed_precision_kwargs : dict\n            additional keyword arguments for mixed precision\n        val_freq : int\n            validation frequency specifying how often to validate the\n            trained\n            model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        if callbacks is None:\n            callbacks = []\n        if logging_kwargs is None:\n            logging_kwargs = {}\n        if gpu_ids is None:\n            gpu_ids = []\n        if lr_scheduler_params is None:\n            lr_scheduler_params = {}\n        if metrics is None:\n            metrics = {}\n        if optimizer_params is None:\n            optimizer_params = {}\n\n        if len(gpu_ids) > 1:\n            # only use first GPU due to\n            # https://github.com/pytorch/pytorch/issues/15421\n            gpu_ids = [gpu_ids[0]]\n            logging.warning(\"Multiple GPUs specified. Torch JIT currently \"\n                            \"supports only single-GPU training. \"\n                            \"Switching to use only the first GPU \"\n                            \"for now...\")\n\n        super().__init__(network=network,\n                         save_path=save_path,\n                         losses=losses,\n                         optimizer_cls=optimizer_cls,\n                         optimizer_params=optimizer_params,\n                         metrics=metrics,\n                         lr_scheduler_cls=lr_scheduler_cls,\n                         lr_scheduler_params=lr_scheduler_params,\n                         gpu_ids=gpu_ids,\n                         save_freq=save_freq,\n                         optim_fn=optim_fn,\n                         key_mapping=key_mapping,\n                         logging_type=logging_type,\n                         logging_kwargs=logging_kwargs,\n                         logging_callback_cls=logging_callback_cls,\n                         logging_frequencies=logging_frequencies,\n                         logging_reduce_types=logging_reduce_types,\n                         fold=fold,\n                         callbacks=callbacks,\n                         start_epoch=start_epoch,\n                         metric_keys=metric_keys,\n                         convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                         val_freq=val_freq,\n                         mixed_precision=False,\n                         mixed_precision_kwargs={},\n                         **kwargs\n                         )\n\n    def save_state(self, file_name, epoch, **kwargs):\n        \"\"\"\n        saves the current state via\n        :func:`delira.io.torch.save_checkpoint_jit`\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        epoch : int\n            current epoch (will be saved for mapping back)\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        if file_name.endswith(\".ptj\"):\n            file_name = file_name.rsplit(\".\", 1)[0]\n\n        save_checkpoint_torchscript(file_name, self.module,\n                                    self.optimizers, **kwargs)\n\n    @staticmethod\n    def load_state(file_name, **kwargs):\n        \"\"\"\n        Loads the new state from file via\n        :func:`delira.io.torch.load_checkpoint:jit`\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        **kwargs : keyword arguments\n\n        Returns\n        -------\n        dict\n            new state\n\n        \"\"\"\n        return load_checkpoint_torchscript(file_name, **kwargs)\n\n    def _update_state(self, new_state):\n        \"\"\"\n        Update the state from a given new state\n\n        Parameters\n        ----------\n        new_state : dict\n            new state to update internal state from\n\n        Returns\n        -------\n        :class:`TorchScriptNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n        if \"model\" in new_state:\n            self.module = new_state.pop(\"model\").to(self.input_device)\n\n        return super()._update_state(new_state)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latst checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = [\".model.ptj\"]\n        return BaseNetworkTrainer._search_for_prev_state(path, extensions)\n"
  },
  {
    "path": "delira/training/base_experiment.py",
    "content": "import typing\nimport logging\nimport pickle\nimport os\nfrom datetime import datetime\nimport warnings\n\nimport copy\n\nimport numpy as np\nfrom sklearn.model_selection import KFold, StratifiedKFold, \\\n    StratifiedShuffleSplit, ShuffleSplit\n\nfrom delira import get_backends\n\nfrom delira.data_loading import DataManager\nfrom delira.models import AbstractNetwork\n\nfrom delira.utils import DeliraConfig\nfrom delira.training.base_trainer import BaseNetworkTrainer\nfrom delira.training.predictor import Predictor\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseExperiment(object):\n    \"\"\"\n    Baseclass for Experiments.\n\n    Implements:\n\n    * Setup-Behavior for Models, Trainers and Predictors (depending on train\n        and test case)\n\n    * The K-Fold logic (including stratified and random splitting)\n\n    * Argument Handling\n\n    \"\"\"\n\n    def __init__(self,\n                 config: typing.Union[str, DeliraConfig],\n                 model_cls: AbstractNetwork,\n                 n_epochs=None,\n                 name=None,\n                 save_path=None,\n                 key_mapping=None,\n                 val_score_key=None,\n                 optim_builder=None,\n                 checkpoint_freq=1,\n                 trainer_cls=BaseNetworkTrainer,\n                 predictor_cls=Predictor,\n                 **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or str\n            the training parameters, if string is passed,\n            it is treated as a path to a file, where the\n            config is loaded from\n        model_cls : Subclass of :class:`AbstractNetwork`\n            the class implementing the model to train\n        n_epochs : int or None\n            the number of epochs to train, if None: can be specified later\n            during actual training\n        name : str or None\n            the Experiment's name\n        save_path : str or None\n            the path to save the results and checkpoints to.\n            if None: Current working directory will be used\n        key_mapping : dict\n            mapping between data_dict and model inputs (necessary for\n            prediction with :class:`Predictor`-API)\n        val_score_key : str or None\n            key defining which metric to use for validation (determining best\n            model and scheduling lr); if None: No validation-based operations\n            will be done (model might still get validated, but validation\n            metrics\n            can only be logged and not used further)\n        optim_builder : function\n            Function returning a dict of backend-specific optimizers\n        checkpoint_freq : int\n            frequency of saving checkpoints (1 denotes saving every epoch,\n            2 denotes saving every second epoch etc.); default: 1\n        trainer_cls : subclass of :class:`BaseNetworkTrainer`\n            the trainer class to use for training the model\n        predictor_cls : subclass of :class:`Predictor`\n            the predictor class to use for testing the model\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n\n        # config could also be a file containing config information\n        if isinstance(config, str):\n            config = config.create_from_file(config)\n\n        if n_epochs is None:\n            n_epochs = config.nested_get(\"n_epochs\",\n                                         config.nested_get(\"num_epochs\"))\n\n        self.n_epochs = n_epochs\n\n        if name is None:\n            name = \"UnnamedExperiment\"\n        self.name = name\n\n        if save_path is None:\n            save_path = os.path.abspath(\".\")\n\n        self.save_path = os.path.join(save_path, name,\n                                      str(datetime.now().strftime(\n                                          \"%y-%m-%d_%H-%M-%S\")))\n\n        if os.path.isdir(self.save_path):\n            logger.warning(\"Save Path %s already exists\")\n\n        os.makedirs(self.save_path, exist_ok=True)\n\n        self.trainer_cls = trainer_cls\n        self.predictor_cls = predictor_cls\n\n        if val_score_key is None:\n            warnings.warn(\"No 'val_score_key' is given. This disables the \"\n                          \"automatic selection of the best model\",\n                          UserWarning)\n\n        self.val_score_key = val_score_key\n\n        assert key_mapping is not None\n        self.key_mapping = key_mapping\n\n        self.config = config\n        self.model_cls = model_cls\n\n        self._optim_builder = optim_builder\n        self.checkpoint_freq = checkpoint_freq\n\n        self._run = 0\n\n        self.kwargs = kwargs\n\n    def setup(self, config, training=True, **kwargs):\n        \"\"\"\n        Defines the setup behavior (model, trainer etc.) for training and\n        testing case\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig`\n            the config to use for setup\n        training : bool\n            whether to setup for training case or for testing case\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the created trainer (if ``training=True``)\n        :class:`Predictor`\n            the created predictor (if ``training=False``)\n\n        See Also\n        --------\n\n        * :meth:`BaseExperiment._setup_training` for training setup\n\n        * :meth:`BaseExperiment._setup_test` for test setup\n\n        \"\"\"\n        if training:\n            return self._setup_training(config, **kwargs)\n\n        return self._setup_test(config, **kwargs)\n\n    def _setup_training(self, config, **kwargs):\n        \"\"\"\n        Handles the setup for training case\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig`\n            the config containing the model and training kwargs\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the created trainer\n\n        \"\"\"\n        model_kwargs = config.model_params\n        model_kwargs = {**model_kwargs[\"variable\"], **model_kwargs[\"fixed\"]}\n\n        model = self.model_cls(**model_kwargs)\n\n        training_params = config.training_params\n\n        losses = training_params.nested_get(\"losses\")\n        optimizer_cls = training_params.nested_get(\"optimizer_cls\")\n        optimizer_params = training_params.nested_get(\"optimizer_params\")\n        train_metrics = training_params.nested_get(\"train_metrics\", {})\n        lr_scheduler_cls = training_params.nested_get(\"lr_sched_cls\", None)\n        lr_scheduler_params = training_params.nested_get(\"lr_sched_params\",\n                                                         {})\n\n        metrics = training_params.nested_get(\"metrics\", {})\n\n        # ToDo: remove after next release\n        val_metrics = config.nested_get(\"val_metrics\", {})\n        train_metrics = config.nested_get(\"train_metrics\", {})\n\n        if val_metrics or train_metrics:\n            warnings.warn(\"'val_metrics' and 'train_metrics' are deprecated. \"\n                          \"Please use the combined 'metrics' instead!\",\n                          DeprecationWarning)\n            metrics.update(val_metrics)\n            metrics.update(train_metrics)\n\n        # necessary for resuming training from a given path\n        save_path = kwargs.pop(\"save_path\", os.path.join(\n            self.save_path,\n            \"checkpoints\",\n            \"run_%02d\" % self._run))\n\n        return self.trainer_cls(\n            network=model,\n            save_path=save_path,\n            losses=losses,\n            key_mapping=self.key_mapping,\n            optimizer_cls=optimizer_cls,\n            optimizer_params=optimizer_params,\n            train_metrics=train_metrics,\n            metrics=metrics,\n            lr_scheduler_cls=lr_scheduler_cls,\n            lr_scheduler_params=lr_scheduler_params,\n            optim_fn=self._optim_builder,\n            save_freq=self.checkpoint_freq,\n            **kwargs\n        )\n\n    def _setup_test(self, config, model, convert_batch_to_npy_fn,\n                    prepare_batch_fn, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig`\n            the parameters containing the model and training kwargs\n            (ignored here, just passed for subclassing and unified API)\n        model : :class:`AbstractNetwork`\n            the model to test\n        convert_batch_to_npy_fn : function\n            function to convert a batch of tensors to numpy\n        prepare_batch_fn : function\n            function to convert a batch-dict to a format accepted by the model.\n            This conversion typically includes dtype-conversion, reshaping,\n            wrapping to backend-specific tensors and pushing to correct devices\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`Predictor`\n            the created predictor\n\n        \"\"\"\n        predictor = self.predictor_cls(\n            model=model, key_mapping=self.key_mapping,\n            convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n            prepare_batch_fn=prepare_batch_fn, **kwargs)\n        return predictor\n\n    def run(self, train_data: DataManager,\n            val_data: DataManager = None,\n            config: DeliraConfig = None, **kwargs):\n        \"\"\"\n        Setup and run training\n\n        Parameters\n        ----------\n        train_data : :class:`DataManager`\n            the data to use for training\n        val_data : :class:`DataManager` or None\n            the data to use for validation (no validation is done\n            if passing None); default: None\n        config : :class:`DeliraConfig` or None\n            the config to use for training and model instantiation\n            (will be merged with ``self.config``)\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`AbstractNetwork`\n            The trained network returned by the trainer (usually best network)\n\n        See Also\n        --------\n        :class:`BaseNetworkTrainer` for training itself\n\n        \"\"\"\n\n        config = self._resolve_params(config)\n        kwargs = self._resolve_kwargs(kwargs)\n\n        training_params = config.training_params\n\n        trainer = self.setup(config, training=True, **kwargs)\n\n        self._run += 1\n\n        num_epochs = kwargs.get(\"num_epochs\", training_params.nested_get(\n            \"num_epochs\", self.n_epochs))\n\n        if num_epochs is None:\n            num_epochs = self.n_epochs\n\n        return trainer.train(num_epochs, train_data, val_data,\n                             self.val_score_key, kwargs.get(\"val_score_mode\",\n                                                            \"lowest\"))\n\n    def resume(self, save_path: str, train_data: DataManager,\n               val_data: DataManager = None,\n               config: DeliraConfig = None, **kwargs):\n        \"\"\"\n        Resumes a previous training by passing an explicit ``save_path``\n        instead of generating a new one\n\n        Parameters\n        ----------\n        save_path : str\n            path to previous training\n        train_data : :class:`DataManager`\n            the data to use for training\n        val_data : :class:`DataManager` or None\n            the data to use for validation (no validation is done\n            if passing None); default: None\n        config : :class:`DeliraConfig` or None\n            the config to use for training and model instantiation\n            (will be merged with ``self.config``)\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`AbstractNetwork`\n            The trained network returned by the trainer (usually best network)\n\n        See Also\n        --------\n        :class:`BaseNetworkTrainer` for training itself\n\n        \"\"\"\n        return self.run(\n            train_data=train_data,\n            val_data=val_data,\n            config=config,\n            save_path=save_path,\n            **kwargs)\n\n    def test(self, network, test_data: DataManager,\n             metrics: dict, metric_keys=None,\n             verbose=False, prepare_batch=None,\n             convert_fn=lambda *x, **y: (x, y), **kwargs):\n        \"\"\"\n        Setup and run testing on a given network\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the (trained) network to test\n        test_data : :class:`DataManager`\n            the data to use for testing\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label`` will be used for metric calculation\n        verbose : bool\n            verbosity of the test process\n        prepare_batch : function\n            function to convert a batch-dict to a format accepted by the model.\n            This conversion typically includes dtype-conversion, reshaping,\n            wrapping to backend-specific tensors and pushing to correct devices\n        convert_fn : function\n            function to convert a batch of tensors to numpy\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions obtained by feeding the ``test_data`` through the\n            ``network``\n        dict\n            all metrics calculated upon the ``test_data`` and the obtained\n            predictions\n\n        \"\"\"\n\n        kwargs = self._resolve_kwargs(kwargs)\n\n        predictor = self.setup(None, training=False, model=network,\n                               convert_batch_to_npy_fn=convert_fn,\n                               prepare_batch_fn=prepare_batch, **kwargs)\n\n        # return first item of generator\n        return next(predictor.predict_data_mgr_cache_all(test_data, 1, metrics,\n                                                         metric_keys, verbose))\n\n    def kfold(self, data: DataManager, metrics: dict, num_epochs=None,\n              num_splits=None, shuffle=False, random_seed=None,\n              split_type=\"random\", val_split=0.2, label_key=\"label\",\n              train_kwargs: dict = None, metric_keys: dict = None,\n              test_kwargs: dict = None, config=None, verbose=False, **kwargs):\n        \"\"\"\n        Performs a k-Fold cross-validation\n\n        Parameters\n        ----------\n        data : :class:`DataManager`\n            the data to use for training(, validation) and testing. Will be\n            split based on ``split_type`` and ``val_split``\n        metrics : dict\n            dictionary containing the metrics to evaluate during k-fold\n        num_epochs : int or None\n            number of epochs to train (if not given, will either be extracted\n            from ``config``, ``self.config`` or ``self.n_epochs``)\n        num_splits : int or None\n            the number of splits to extract from ``data``.\n            If None: uses a default of 10\n        shuffle : bool\n            whether to shuffle the data before splitting or not (implemented by\n            index-shuffling rather than actual data-shuffling to retain\n            potentially lazy-behavior of datasets)\n        random_seed : None\n            seed to seed numpy, the splitting functions and the used\n            backend-framework\n        split_type : str\n            must be one of ['random', 'stratified']\n            if 'random': uses random data splitting\n            if 'stratified': uses stratified data splitting. Stratification\n            will be based on ``label_key``\n        val_split : float or None\n            the fraction of the train data to use as validation set. If None:\n            No validation will be done during training; only testing for each\n            fold after the training is complete\n        label_key : str\n            the label to use for stratification. Will be ignored unless\n            ``split_type`` is 'stratified'. Default: 'label'\n        train_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the train data. If None: empty dict will be passed\n        metric_keys : dict of tuples\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label`` will be used for metric calculation\n        test_kwargs : dict or None\n            kwargs to update the behavior of the :class:`DataManager`\n            containing the test and validation data.\n            If None: empty dict will be passed\n        config : :class:`DeliraConfig`or None\n            the training and model parameters\n            (will be merged with ``self.config``)\n        verbose : bool\n            verbosity\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            all predictions from all folds\n        dict\n            all metric values from all folds\n\n        Raises\n        ------\n        ValueError\n            if ``split_type`` is neither 'random', nor 'stratified'\n\n        See Also\n        --------\n\n        * :class:`sklearn.model_selection.KFold`\n        and :class:`sklearn.model_selection.ShuffleSplit`\n        for random data-splitting\n\n        * :class:`sklearn.model_selection.StratifiedKFold`\n        and :class:`sklearn.model_selection.StratifiedShuffleSplit`\n        for stratified data-splitting\n\n        * :meth:`DataManager.update_from_state_dict` for updating the\n        data managers by kwargs\n\n        * :meth:`BaseExperiment.run` for the training\n\n        * :meth:`BaseExperiment.test` for the testing\n\n        Notes\n        -----\n        using stratified splits may be slow during split-calculation, since\n        each item must be loaded once to obtain the labels necessary for\n        stratification.\n\n        \"\"\"\n\n        # set number of splits if not specified\n        if num_splits is None:\n            num_splits = 10\n            logger.warning(\"num_splits not defined, using default value of \\\n                                    10 splits instead \")\n\n        metrics_test, outputs = {}, {}\n        split_idxs = list(range(len(data.dataset)))\n\n        if train_kwargs is None:\n            train_kwargs = {}\n        if test_kwargs is None:\n            test_kwargs = {}\n\n        # switch between differnt kfold types\n        if split_type == \"random\":\n            split_cls = KFold\n            val_split_cls = ShuffleSplit\n            # split_labels are ignored for random splitting, set them to\n            # split_idxs just ensures same length\n            split_labels = split_idxs\n        elif split_type == \"stratified\":\n            split_cls = StratifiedKFold\n            val_split_cls = StratifiedShuffleSplit\n            # iterate over dataset to get labels for stratified splitting\n            split_labels = [data.dataset[_idx][label_key]\n                            for _idx in split_idxs]\n        else:\n            raise ValueError(\"split_type must be one of \"\n                             \"['random', 'stratified'], but got: %s\"\n                             % str(split_type))\n\n        fold = split_cls(n_splits=num_splits, shuffle=shuffle,\n                         random_state=random_seed)\n\n        if random_seed is not None:\n            np.random.seed(random_seed)\n\n        # iterate over folds\n        for idx, (train_idxs, test_idxs) in enumerate(\n                fold.split(split_idxs, split_labels)):\n\n            # extract data from single manager\n            train_data = data.get_subset(train_idxs)\n            test_data = data.get_subset(test_idxs)\n\n            train_data.update_state_from_dict(copy.deepcopy(train_kwargs))\n            test_data.update_state_from_dict(copy.deepcopy(test_kwargs))\n\n            val_data = None\n            if val_split is not None:\n                if split_type == \"random\":\n                    # split_labels are ignored for random splitting, set them\n                    # to split_idxs just ensures same length\n                    train_labels = train_idxs\n                elif split_type == \"stratified\":\n                    # iterate over dataset to get labels for stratified\n                    # splitting\n                    train_labels = [train_data.dataset[_idx][label_key]\n                                    for _idx in train_idxs]\n                else:\n                    raise ValueError(\"split_type must be one of \"\n                                     \"['random', 'stratified'], but got: %s\"\n                                     % str(split_type))\n\n                _val_split = val_split_cls(n_splits=1, test_size=val_split,\n                                           random_state=random_seed)\n\n                for _train_idxs, _val_idxs in _val_split.split(train_idxs,\n                                                               train_labels):\n                    val_data = train_data.get_subset(_val_idxs)\n                    val_data.update_state_from_dict(copy.deepcopy(test_kwargs))\n\n                    train_data = train_data.get_subset(_train_idxs)\n\n            model = self.run(train_data=train_data, val_data=val_data,\n                             config=config, num_epochs=num_epochs, fold=idx,\n                             **kwargs)\n\n            _outputs, _metrics_test = self.test(model, test_data,\n                                                metrics=metrics,\n                                                metric_keys=metric_keys,\n                                                verbose=verbose)\n\n            outputs[str(idx)] = _outputs\n            metrics_test[str(idx)] = _metrics_test\n\n        return outputs, metrics_test\n\n    def __str__(self):\n        \"\"\"\n        Converts :class:`BaseExperiment` to string representation\n\n        Returns\n        -------\n        str\n            representation of class\n\n        \"\"\"\n        s = \"Experiment:\\n\"\n        for k, v in vars(self).items():\n            s += \"\\t{} = {}\\n\".format(k, v)\n        return s\n\n    def __call__(self, *args, **kwargs):\n        \"\"\"\n        Call :meth:`BaseExperiment.run`\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            trainer of trained network\n\n        \"\"\"\n        return self.run(*args, **kwargs)\n\n    def save(self):\n        \"\"\"\n        Saves the Whole experiments\n\n        \"\"\"\n        with open(os.path.join(self.save_path, \"experiment.delira.pkl\"),\n                  \"wb\") as f:\n            pickle.dump(self, f)\n\n        self.config.dump(os.path.join(self.save_path, \"parameters\"))\n\n    @staticmethod\n    def load(file_name):\n        \"\"\"\n        Loads whole experiment\n\n        Parameters\n        ----------\n        file_name : str\n            file_name to load the experiment from\n\n        \"\"\"\n        with open(file_name, \"rb\") as f:\n            return pickle.load(f)\n\n    def _resolve_params(self, config: typing.Union[DeliraConfig, None]):\n        \"\"\"\n        Merges the given config with ``self.config``.\n        If the same argument is given in both configs,\n        the one from the currently given config is used here\n\n        Parameters\n        ----------\n        config : :class:`DeliraConfig` or None\n            the parameters to merge with ``self.config``\n\n\n        Returns\n        -------\n        :class:`Parameters`\n            the merged parameter instance\n\n        \"\"\"\n        if config is None:\n            config = DeliraConfig()\n\n        if hasattr(self, \"config\") and isinstance(self.config, DeliraConfig):\n            _config = copy.deepcopy(config)\n            config = self.config\n            config.update(_config, overwrite=True)\n\n        return config\n\n    def _resolve_kwargs(self, kwargs: typing.Union[dict, None]):\n        \"\"\"\n        Merges given kwargs with ``self.kwargs``\n        If same argument is present in both kwargs, the one from the given\n        kwargs will be used here\n\n        Parameters\n        ----------\n        kwargs : dict\n            the given kwargs to merge with self.kwargs\n\n        Returns\n        -------\n        dict\n            merged kwargs\n\n        \"\"\"\n\n        if kwargs is None:\n            kwargs = {}\n\n        if hasattr(self, \"kwargs\") and isinstance(self.kwargs, dict):\n            _kwargs = kwargs\n            kwargs = self.kwargs\n            kwargs.update(_kwargs)\n\n        return kwargs\n\n    def __getstate__(self):\n        return vars(self)\n\n    def __setstate__(self, state):\n        vars(self).update(state)\n"
  },
  {
    "path": "delira/training/base_trainer.py",
    "content": "import logging\nimport os\nimport pickle\nimport typing\nimport warnings\n\nfrom delira.utils.config import LookupConfig\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom .callbacks import AbstractCallback, DefaultLoggingCallback\nfrom .predictor import Predictor\nfrom ..data_loading import Augmenter, DataManager\nfrom ..models import AbstractNetwork\nfrom ..logging import register_logger, make_logger\n\nlogger = logging.getLogger(__name__)\n\n\nclass BaseNetworkTrainer(Predictor):\n    \"\"\"\n    Defines a Base API and basic functions for Network Trainers\n\n    See Also\n    --------\n    :class:`PyTorchNetworkTrainer`\n    :class:`TfNetworkTrainer`\n\n    \"\"\"\n\n    __KEYS_TO_GUARD = [\"use_gpu\",\n                       \"input_device\",\n                       \"output_device\",\n                       \"_callbacks\"]\n\n    def __init__(self,\n                 network: AbstractNetwork,\n                 save_path: str,\n                 losses: dict,\n                 optimizer_cls: type,\n                 optimizer_params: dict,\n                 metrics: dict,\n                 lr_scheduler_cls: type,\n                 lr_scheduler_params: dict,\n                 gpu_ids: typing.List[int],\n                 save_freq: int,\n                 optim_fn,\n                 key_mapping: dict,\n                 logging_type: str,\n                 logging_kwargs: dict,\n                 logging_callback_cls=DefaultLoggingCallback,\n                 logging_frequencies=None,\n                 logging_reduce_types=None,\n                 fold: int = 0,\n                 callbacks: typing.List[AbstractCallback] = None,\n                 start_epoch=1,\n                 metric_keys=None,\n                 convert_batch_to_npy_fn=lambda x: x,\n                 val_freq=1,\n                 **kwargs\n                 ):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractTfNetwork`\n            the network to train\n        save_path : str\n            path to save networks to\n        losses : dict\n            dictionary containing the training losses\n        optimizer_cls : subclass of tf.train.Optimizer\n            optimizer class implementing the optimization algorithm of choice\n        optimizer_params : dict\n            keyword arguments passed to optimizer during construction\n        metrics : dict, optional\n            metrics, which will be evaluated during train and validation phase\n            (should work on numpy arrays)\n        lr_scheduler_cls : Any\n            learning rate schedule class: must implement step() method\n        lr_scheduler_params : dict\n            keyword arguments passed to lr scheduler during construction\n        gpu_ids : list\n            list containing ids of GPUs to use; if empty: use cpu instead\n        save_freq : int\n            integer specifying how often to save the current model's state.\n            State is saved every state_freq epochs\n        optim_fn : function\n            creates a dictionary containing all necessary optimizers\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler backend class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n                specifies how often to log for each key.\n                If int: integer will be applied to all valid keys\n                if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n                None is equal to empty dict here.\n        logging_reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'median' | 'max' | 'min'.\n        fold : int\n            current cross validation fold (0 per default)\n        callbacks : list\n            initial callbacks to register\n        start_epoch : int\n            epoch to start training at\n        metric_keys : dict\n            the batch_dict keys to use for each metric to calculate.\n            Should contain a value for each key in ``metrics``.\n            If no values are given for a key, per default ``pred`` and\n            ``label`` will be used for metric calculation\n        convert_batch_to_npy_fn : type, optional\n            function converting a batch-tensor to numpy, per default this is\n            the identity function\n        val_freq : int\n            validation frequency specifying how often to validate the trained\n            model (a value of 1 denotes validating every epoch,\n            a value of 2 denotes validating every second epoch etc.);\n            defaults to 1\n        **kwargs :\n            Additional keyword arguments\n\n        \"\"\"\n\n        # explicity not call self._setup here to reuse the __init__ of\n        # abstract class. self._setup has to be called in subclass\n        if callbacks is None:\n            callbacks = []\n\n        # check argument types\n        for instance, cls_type in zip([\n            network, save_path, losses, optimizer_params, metrics,\n            lr_scheduler_params, gpu_ids], [AbstractNetwork, str, dict, dict,\n                                            dict, dict, list]):\n            if not isinstance(instance, cls_type):\n                raise TypeError(\"%s should be of type %s, but is of type %s\"\n                                % (instance.__name__, cls_type.__name__,\n                                   type(instance).__name__))\n\n        if os.path.isdir(save_path):\n            logger.warning(\n                \"Save Path already exists. Saved Models may be overwritten\")\n        else:\n            os.makedirs(save_path)\n\n        self._fold = fold\n        self.start_epoch = start_epoch\n        self.save_path = save_path\n        self.losses = losses\n        self.metrics = metrics\n        self.stop_training = False\n        self.save_freq = save_freq\n        self.metric_keys = metric_keys\n\n        self._tqdm_desc = \"Validate\"\n        self.val_freq = val_freq\n        self._global_iter_num = 1\n        self._logging_setup_kwargs = {\n            \"logging_type\": logging_type,\n            \"logging_kwargs\": logging_kwargs,\n            \"logging_callback_cls\": logging_callback_cls,\n            \"logging_frequencies\": logging_frequencies,\n            \"reduce_types\": logging_reduce_types}\n\n    def _setup(self, network, lr_scheduler_cls, lr_scheduler_params, gpu_ids,\n               key_mapping, convert_batch_to_npy_fn, prepare_batch_fn,\n               callbacks):\n\n        super()._setup(network, key_mapping, convert_batch_to_npy_fn,\n                       prepare_batch_fn, callbacks)\n\n        self._reinitialize_logging(**self._logging_setup_kwargs)\n\n        self.closure_fn = network.closure\n\n        # optimizers must exist before calling _setup()\n        if lr_scheduler_cls is not None:\n            for key, optim in self.optimizers.items():\n                if not issubclass(lr_scheduler_cls, AbstractCallback):\n                    logger.warning(\"lr_scheduler_cls is not a callback.\")\n                self.register_callback(lr_scheduler_cls(optim,\n                                                        **lr_scheduler_params))\n\n        if gpu_ids:\n            self.use_gpu = True\n        else:\n            self.use_gpu = False\n\n    def _at_training_begin(self, *args, **kwargs):\n        \"\"\"\n        Defines the behaviour at beginnig of the training\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        for cbck in self._callbacks:\n            self._update_state(cbck.at_training_begin(self, *args, **kwargs))\n\n        self.save_state(os.path.join(self.save_path, \"checkpoint_epoch_%d\"\n                                     % self.start_epoch))\n\n    def _at_training_end(self, *args, **kwargs):\n        \"\"\"\n        Defines the behaviour at the end of the training\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        Returns\n        -------\n        :class:`AbstractNetwork`\n            the network with the loaded state\n\n        \"\"\"\n        for cbck in self._callbacks:\n            self._update_state(cbck.at_training_end(self, *args, **kwargs))\n\n        return self.module\n\n    def _at_epoch_begin(self, val_score_key, epoch, num_epochs,\n                        **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of each epoch: Executes all callbacks's\n        `at_epoch_begin` method\n\n        Parameters\n        ----------\n        val_score_key : str\n            validation score key\n        epoch : int\n            current epoch\n        num_epochs : int\n            total number of epochs\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        # execute all callbacks\n        for cb in self._callbacks:\n            self._update_state(cb.at_epoch_begin(self, val_metrics={},\n                                                 val_score_key=val_score_key,\n                                                 curr_epoch=epoch))\n\n    def _at_epoch_end(self, metrics_val, val_score_key, epoch, is_best,\n                      **kwargs):\n        \"\"\"\n        Defines behaviour at beginning of each epoch: Executes all callbacks's\n        `at_epoch_end` method and saves current state if necessary\n\n        Parameters\n        ----------\n        metrics_val : dict\n            validation metrics\n        val_score_key : str\n            validation score key\n        epoch : int\n            current epoch\n        num_epochs : int\n            total number of epochs\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n\n        for cb in self._callbacks:\n            self._update_state(cb.at_epoch_end(self, val_metrics=metrics_val,\n                                               val_score_key=val_score_key,\n                                               curr_epoch=epoch))\n\n        if epoch % self.save_freq == 0:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_epoch_%d\" % epoch))\n\n        if is_best:\n            self.save_state(os.path.join(self.save_path,\n                                         \"checkpoint_best\"))\n\n    def _at_iter_begin(self, iter_num, epoch=0, **kwargs):\n        \"\"\"\n        Defines the behavior executed at an iteration's begin\n\n        Parameters\n        ----------\n        iter_num : int\n            number of current iter\n        epoch : int\n            number of current epoch\n        **kwargs :\n            additional keyword arguments (forwarded to callback calls)\n\n        \"\"\"\n        for cb in self._callbacks:\n            self._update_state(cb.at_iter_begin(\n                self, iter_num=iter_num,\n                curr_epoch=epoch,\n                global_iter_num=self._global_iter_num,\n                train=True,\n                **kwargs,\n            ))\n\n    def _at_iter_end(self, iter_num, data_dict, metrics, epoch=0, **kwargs):\n        \"\"\"\n        Defines the behavior executed at an iteration's end\n\n        Parameters\n        ----------\n        iter_num : int\n            number of current iter\n        data_dict : dict\n            dictionary holding input data and predictions\n        metrics: dict\n            calculated metrics\n        epoch : int\n            number of current epoch\n        **kwargs :\n            additional keyword arguments (forwarded to callback calls)\n\n        \"\"\"\n\n        for cb in self._callbacks:\n            self._update_state(cb.at_iter_end(\n                self, iter_num=iter_num,\n                data_dict=data_dict,\n                metrics=metrics,\n                curr_epoch=epoch,\n                global_iter_num=self._global_iter_num,\n                train=True,\n                **kwargs,\n            ))\n\n        self._global_iter_num += 1\n\n    def _train_single_epoch(self, dmgr_train: DataManager, epoch,\n                            verbose=False):\n        \"\"\"\n        Trains the network a single epoch\n\n        Parameters\n        ----------\n        dmgr_train : :class:`DataManager`\n            Datamanager to create the data generator\n        epoch : int\n            current epoch\n\n        \"\"\"\n\n        metrics, losses = [], []\n\n        batchgen = dmgr_train.get_batchgen(seed=epoch)\n\n        n_batches = dmgr_train.n_batches\n        if verbose:\n            iterable = tqdm(\n                enumerate(batchgen),\n                unit=' batch',\n                total=n_batches,\n                desc='Epoch %d' %\n                     epoch)\n        else:\n            iterable = enumerate(batchgen)\n\n        for iter_num, batch in iterable:\n            self._at_iter_begin(epoch=epoch, iter_num=iter_num)\n\n            data_dict = self._prepare_batch(batch)\n\n            _losses, _preds = self.closure_fn(self.module, data_dict,\n                                              optimizers=self.optimizers,\n                                              losses=self.losses,\n                                              fold=self.fold,\n                                              iter_num=iter_num)\n\n            data_dict = self._convert_to_npy_fn(**data_dict)[1]\n            _preds = self._convert_to_npy_fn(**_preds)[1]\n\n            _metrics = self.calc_metrics(\n                LookupConfig(**data_dict, **_preds),\n                self.metrics,\n                self.metric_keys)\n\n            metrics.append(_metrics)\n            losses.append(_losses)\n\n            self._at_iter_end(epoch=epoch, iter_num=iter_num,\n                              data_dict={**batch, **_preds},\n                              metrics={**_metrics, **_losses},\n                              )\n\n        total_losses, total_metrics = {}, {}\n\n        for _metrics in metrics:\n            for key, val in _metrics.items():\n                if key in total_metrics:\n                    total_metrics[key].append(val)\n                else:\n                    total_metrics[key] = [val]\n\n        for _losses in losses:\n            for key, val in _losses.items():\n                if key in total_losses:\n                    total_losses[key].append(val)\n                else:\n                    total_losses[key] = [val]\n\n        return total_metrics, total_losses\n\n    def train(self, num_epochs, datamgr_train, datamgr_valid=None,\n              val_score_key=None, val_score_mode='highest', reduce_mode='mean',\n              verbose=True):\n        \"\"\"\n        Defines a routine to train a specified number of epochs\n\n        Parameters\n        ----------\n        num_epochs : int\n            number of epochs to train\n        datamgr_train : DataManager\n            the datamanager holding the train data\n        datamgr_valid : DataManager\n            the datamanager holding the validation data (default: None)\n        val_score_key : str\n            the key specifying which metric to use for validation\n            (default: None)\n        val_score_mode : str\n            key specifying what kind of validation score is best\n        reduce_mode : str\n            'mean','sum','first_only'\n        verbose : bool\n            whether to show progress bars or not\n\n        \"\"\"\n        self._at_training_begin()\n\n        if val_score_mode == 'highest':\n            best_val_score = 0\n        elif val_score_mode == 'lowest':\n            best_val_score = float('inf')\n        else:\n            best_val_score = None\n\n        is_best = False\n        new_val_score = best_val_score\n\n        if reduce_mode == 'mean':\n            def reduce_fn(batch):\n                return np.mean(batch)\n        elif reduce_mode == 'sum':\n            def reduce_fn(batch):\n                return np.sum(batch)\n        elif reduce_mode == 'first_only':\n            def reduce_fn(batch):\n                return batch[0]\n        elif reduce_mode == 'last_only':\n            def reduce_fn(batch):\n                return batch[-1]\n        else:\n            raise ValueError(\"No valid reduce mode given\")\n\n        for epoch in range(self.start_epoch, num_epochs + 1):\n\n            self._at_epoch_begin(val_score_key, epoch,\n                                 num_epochs)\n\n            # train single network epoch\n            train_metrics, train_losses = self._train_single_epoch(\n                datamgr_train, epoch, verbose=verbose)\n\n            total_metrics = {\n                **train_metrics,\n                **train_losses}\n\n            # validate network\n            if datamgr_valid is not None and (epoch % self.val_freq == 0):\n                # next must be called here because self.predict_data_mgr\n                # returns a generator (of size 1) and we want to get the\n                # first (and only) item\n                val_metrics = next(\n                    self.predict_data_mgr_cache_metrics_only(\n                        datamgr_valid, datamgr_valid.batch_size,\n                        metrics=self.metrics,\n                        metric_keys=self.metric_keys,\n                        verbose=verbose))\n\n                val_metrics = {\"val_\" + k: v\n                               for k, v in val_metrics.items()}\n\n                total_metrics.update(val_metrics)\n            _, total_metrics = self._convert_to_npy_fn(**total_metrics)\n\n            for k, v in total_metrics.items():\n                total_metrics[k] = reduce_fn(v)\n\n            # check if metric became better\n            if val_score_key is not None:\n                if val_score_key not in total_metrics:\n                    if \"val_\" + val_score_key not in total_metrics:\n                        warnings.warn(\"val_score_key '%s' not a valid key \"\n                                      \"for validation metrics\" %\n                                      str(val_score_key), UserWarning)\n\n                        new_val_score = best_val_score\n\n                    else:\n                        new_val_score = \\\n                            total_metrics[\"val_\" + val_score_key]\n                        val_score_key = \"val_\" + val_score_key\n                else:\n                    new_val_score = total_metrics.get(val_score_key)\n\n            if new_val_score != best_val_score:\n                is_best = self._is_better_val_scores(\n                    best_val_score, new_val_score, val_score_mode)\n\n                # set best_val_score to new_val_score if is_best\n                if is_best:\n                    best_val_score = new_val_score\n\n                if is_best and verbose:\n                    logging.info(\"New Best Value at Epoch %03d : %03.3f\" %\n                                 (epoch, best_val_score))\n\n            self._at_epoch_end(total_metrics, val_score_key, epoch,\n                               is_best)\n\n            is_best = False\n\n            # stop training (might be caused by early stopping)\n            if self.stop_training:\n                break\n\n        return self._at_training_end()\n\n    @property\n    def fold(self):\n        \"\"\"\n        Get current fold\n\n        Returns\n        -------\n        int\n            current fold\n\n        \"\"\"\n        return self._fold\n\n    @fold.setter\n    def fold(self, fold):\n        \"\"\"\n        Set the current fold\n\n        Parameters\n        ----------\n        fold : int\n            new fold\n\n        Raises\n        ------\n        ValueError\n            if `fold` is not covertable to :obj:`int`\n\n        \"\"\"\n        try:\n            self._fold = int(fold)\n\n        except ValueError as e:\n            logger.error(e)\n            raise e\n\n    def register_callback(self, callback: AbstractCallback):\n        \"\"\"\n        Register Callback to Trainer\n\n        Parameters\n        ----------\n        callback : :class:`AbstractCallback`\n            the callback to register\n\n        Raises\n        ------\n        AssertionError\n            `callback` is not an instance of :class:`AbstractCallback` and has\n            not both methods ['at_epoch_begin', 'at_epoch_end']\n\n        \"\"\"\n        assertion_str = \"Given callback is not valid; Must be instance of \" \\\n                        \"AbstractCallback or provide functions \" \\\n                        \"'at_training_begin' and 'at_training_end'\"\n\n        instance_check = isinstance(callback, AbstractCallback)\n        attr_check_begin_train = hasattr(callback, \"at_training_begin\")\n        attr_check_end_train = hasattr(callback, \"at_training_end\")\n        attr_check_both_train = attr_check_begin_train and attr_check_end_train\n\n        assert instance_check or attr_check_both_train, assertion_str\n\n        super().register_callback(callback)\n\n    def save_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        saves the current state\n\n        Parameters\n        ----------\n        file_name : str\n            filename to save the state to\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        with open(file_name, \"wb\") as f:\n            pickle.dump(vars(self), f, *args, **kwargs)\n\n    @staticmethod\n    def load_state(file_name, *args, **kwargs):\n        \"\"\"\n        Loads the new state from file\n\n        Parameters\n        ----------\n        file_name : str\n            the file to load the state from\n        *args :\n            positional arguments\n        **kwargs : keyword arguments\n\n        Returns\n        -------\n        dict\n            new state\n\n        \"\"\"\n        with open(file_name, \"rb\") as f:\n            new_state = pickle.load(f, *args, **kwargs)\n\n        return new_state\n\n    def _update_state(self, new_state):\n        \"\"\"\n        Update the state from a given new state\n\n        Parameters\n        ----------\n        new_state : dict\n            new state to update internal state from\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n        for key, val in new_state.items():\n            if key.startswith(\"__\") and key.endswith(\"__\"):\n                continue\n\n            try:\n                setattr(self, key, val)\n\n            except PermissionError:\n                logger.error(\"Trying to overwrite attribute %s of \"\n                             \"NetworkTrainer, which is not allowed!\" % key)\n\n        return self\n\n    def update_state(self, file_name, *args, **kwargs):\n        \"\"\"\n        Update internal state from a loaded state\n\n        Parameters\n        ----------\n        file_name : str\n            file containing the new state to load\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        Returns\n        -------\n        :class:`BaseNetworkTrainer`\n            the trainer with a modified state\n\n        \"\"\"\n        self._update_state(self.load_state(file_name, *args, **kwargs))\n\n    @staticmethod\n    def _is_better_val_scores(old_val_score, new_val_score,\n                              mode='highest'):\n        \"\"\"\n        Check whether the new val score is better than the old one\n        with respect to the optimization goal\n\n        Parameters\n        ----------\n        old_val_score :\n            old validation score\n        new_val_score :\n            new validation score\n        mode: str\n            String to specify whether a higher or lower validation score is\n            optimal; must be in ['highest', 'lowest']\n\n        Returns\n        -------\n        bool\n            True if new score is better, False otherwise\n        \"\"\"\n\n        assert mode in ['highest', 'lowest'], \"Invalid Comparison Mode\"\n\n        if mode == 'highest':\n            return new_val_score > old_val_score\n        elif mode == 'lowest':\n            return new_val_score < old_val_score\n\n    @property\n    def name(self):\n        return os.path.basename(os.path.dirname(os.path.dirname(\n            os.path.dirname(self.save_path))))\n\n    def _reinitialize_logging(self, logging_type, logging_kwargs: dict,\n                              logging_callback_cls, logging_frequencies,\n                              reduce_types):\n        \"\"\"\n\n        Parameters\n        ----------\n        logging_type : str or callable\n            the type of logging. If string: it must be one of\n            [\"visdom\", \"tensorboardx\"]\n            If callable: it must be a logging handler backend class\n        logging_kwargs : dict\n            dictionary containing all logging keyword arguments\n        logging_callback_cls : class\n            the callback class to create and register for logging\n        logging_frequencies : int or dict\n                specifies how often to log for each key.\n                If int: integer will be applied to all valid keys\n                if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n                None is equal to empty dict here.\n        reduce_types : str of FunctionType or dict\n            Values are logged in each iteration. This argument specifies,\n            how to reduce them to a single value if a logging_frequency\n            besides 1 is passed\n\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'max' | 'min'.\n\n        \"\"\"\n\n        from delira.logging import TensorboardBackend, VisdomBackend, \\\n            BaseBackend\n\n        if isinstance(logging_type, str):\n            if logging_type.lower() == \"visdom\":\n                backend_cls = VisdomBackend\n\n            elif logging_type.lower() == \"tensorboardx\":\n                backend_cls = TensorboardBackend\n\n            else:\n                raise ValueError(\"Invalid Logging Type\")\n\n        elif issubclass(logging_type, BaseBackend):\n            backend_cls = logging_type\n\n        else:\n            raise ValueError(\"Invalid logging_type passed\")\n\n        _logging_kwargs = {}\n\n        if backend_cls == VisdomBackend:\n            _logging_kwargs.update({\"exp_name\": \"main\",\n                                    \"level\": 0})\n        elif backend_cls == TensorboardBackend:\n            _logging_kwargs.update(\n                {\n                    \"logdir\":\n                        os.path.join(os.path.dirname(\n                            os.path.dirname(self.save_path)),\n                            \"logs\", \"run_%02d\" % self.fold),\n                    \"level\": 0})\n\n        _logging_kwargs.update(logging_kwargs)\n\n        if \"exp_name\" in _logging_kwargs.keys():\n            _logging_kwargs[\"exp_name\"] = _logging_kwargs[\"exp_name\"] + \\\n                \"_%02d\" % self.fold\n\n        # remove prior Trixihandlers and reinitialize it with given logging\n        # type\n        # This facilitates visualization of multiple splits/fold inside one\n        # tensorboard-instance by means of\n        # different tf.Summary.FileWriters()\n\n        level = _logging_kwargs.pop(\"level\")\n\n        logger = backend_cls(_logging_kwargs)\n\n        self.register_callback(\n            logging_callback_cls(\n                logger, level=level,\n                logging_frequencies=logging_frequencies,\n                reduce_types=reduce_types))\n\n        register_logger(self._callbacks[-1]._logger, self.name)\n\n    @staticmethod\n    def _search_for_prev_state(path, extensions=None):\n        \"\"\"\n        Helper function to search in a given path for previous epoch states\n        (indicated by extensions)\n\n        Parameters\n        ----------\n        path : str\n            the path to search in\n        extensions : list\n            list of strings containing valid file extensions for checkpoint\n            files\n\n        Returns\n        -------\n        str\n            the file containing the latest checkpoint (if available)\n        None\n            if no latst checkpoint was found\n        int\n            the latest epoch (1 if no checkpoint was found)\n\n        \"\"\"\n        if extensions is None:\n            extensions = []\n        files = []\n        for file in os.listdir(path):\n            for ext in extensions:\n                if not ext.startswith(\".\"):\n                    ext = \".\" + ext\n\n                if not file.endswith(ext):\n                    continue\n\n                if not file.startswith(\"checkpoint\"):\n                    continue\n\n                if file.endswith(\"_best\" + ext):\n                    continue\n\n                files.append(file)\n                break\n\n        if files:\n            latest_epoch = max([\n                int(x.rsplit(\"_\", 1)[-1].split(\".\", 1)[0])\n                for x in files])\n\n            latest_state_filename = [x for x in files\n                                     if x.startswith(\"checkpoint_epoch_%d\"\n                                                     % latest_epoch)][0]\n            latest_state_path = os.path.join(path, latest_state_filename)\n            return latest_state_path, latest_epoch\n\n        return None, 1\n\n    def register_callback(self, callback: AbstractCallback):\n        \"\"\"\n        Registers the passed callback to the trainer,\n        after checking it is really a valid callback\n\n        Parameters\n        ----------\n        callback : AbstractCallback\n            the potential callback to register\n\n        Raises\n        ------\n        AssertionError\n            :param:`callback` is not an instance of :class:`AbstractCallback`\n            and does not provide the methods `at_iter_begin`, `at_iter_end`,\n            `at_epoch_begin` and `at_epoch_end`\n\n        \"\"\"\n        has_all_attrs = True\n        for attr in (\"epoch\",):\n            has_all_attrs = has_all_attrs and hasattr(callback,\n                                                      \"at_%s_begin\" % attr)\n            has_all_attrs = has_all_attrs and hasattr(callback,\n                                                      \"at_%s_end\" % attr)\n\n        assert has_all_attrs, \"Given callback is not valid; Must be \" \\\n                              \"instance of AbstractCallback or provide \" \\\n                              \"functions 'at_epoch_begin' and 'at_epoch_end'\"\n        super().register_callback(callback)\n"
  },
  {
    "path": "delira/training/callbacks/__init__.py",
    "content": "from delira import get_backends\n\nfrom delira.training.callbacks.logging_callback import DefaultLoggingCallback\nfrom delira.training.callbacks.abstract_callback import AbstractCallback\nfrom delira.training.callbacks.early_stopping import EarlyStopping\n\nif \"TORCH\" in get_backends():\n    from delira.training.callbacks.pytorch_schedulers import \\\n        DefaultPyTorchSchedulerCallback\n    from delira.training.callbacks.pytorch_schedulers import \\\n        CosineAnnealingLRCallback as CosineAnnealingLRCallbackPyTorch\n    from delira.training.callbacks.pytorch_schedulers import \\\n        ExponentialLRCallback as ExponentialLRCallbackPyTorch\n\n    from delira.training.callbacks.pytorch_schedulers import \\\n        LambdaLRCallback as LambdaLRCallbackPyTorch\n    from delira.training.callbacks.pytorch_schedulers import \\\n        MultiStepLRCallback as MultiStepLRCallbackPyTorch\n    from delira.training.callbacks.pytorch_schedulers import \\\n        ReduceLROnPlateauCallback as ReduceLROnPlateauCallbackPyTorch\n    from delira.training.callbacks.pytorch_schedulers import StepLRCallback \\\n        as StepLRCallbackPyTorch\n    from delira.training.callbacks.pytorch_schedulers import \\\n        OneCycleLRCallback as OneCycleLRCallbackPyTorch\n"
  },
  {
    "path": "delira/training/callbacks/abstract_callback.py",
    "content": "class AbstractCallback(object):\n    \"\"\"\n    Implements abstract callback interface.\n    All callbacks should be derived from this class\n\n    See Also\n    --------\n    :class:`AbstractNetworkTrainer`\n\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        *args :\n            positional arguments\n        **kwargs :\n            keyword arguments\n\n        \"\"\"\n        super().__init__(*args, **kwargs)\n\n    def at_epoch_begin(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at begin of each epoch\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n\n        Notes\n        -----\n        The basetrainer calls the callbacks with the following additional\n        arguments: `val_metrics`(dict), `val_score_key`(str), `curr_epoch`(int)\n        \"\"\"\n        return {}\n\n    def at_epoch_end(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at end of each epoch\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n\n        Notes\n        -----\n        The basetrainer calls the callbacks with the following additional\n        arguments: `val_metrics`(dict), `val_score_key`(str), `curr_epoch`(int)\n        \"\"\"\n        return {}\n\n    def at_training_begin(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at begin of training\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n        \"\"\"\n        return {}\n\n    def at_training_end(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at end of training\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n\n        \"\"\"\n        return {}\n\n    def at_iter_begin(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at begin of each iteration\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n\n        Notes\n        -----\n        The predictor calls the callbacks with the following additional\n        arguments: `iter_num`(int), `train`(bool)\n\n        The basetrainer adds following arguments (wrt the predictor):\n        `curr_epoch`(int), `global_iter_num`(int)\n\n        \"\"\"\n        return {}\n\n    def at_iter_end(self, trainer, *args, **kwargs):\n        \"\"\"\n        Function which will be executed at end of each iteration\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            modified trainer attributes, where the name must correspond to the\n            trainer's attribute name\n\n        Notes\n        -----\n        The predictor calls the callbacks with the following additional\n        arguments: `iter_num`(int), `metrics`(dict),\n        `data_dict`(dict, contains prediction and input data),\n        `train`(bool)\n\n        The basetrainer adds following arguments (wrt the predictor):\n        `curr_epoch`(int), `global_iter_num`(int)\n\n        \"\"\"\n        return {}\n"
  },
  {
    "path": "delira/training/callbacks/early_stopping.py",
    "content": "from delira.training.callbacks.abstract_callback import AbstractCallback\n\n\nclass EarlyStopping(AbstractCallback):\n    \"\"\"\n    Implements Early Stopping as callback\n\n    See Also\n    --------\n    :class:`AbstractCallback`\n\n    \"\"\"\n\n    def __init__(self, monitor_key,\n                 min_delta=0,\n                 patience=0,\n                 mode='min'):\n        \"\"\"\n\n        Parameters\n        ----------\n        monitor_key : str\n            the validation key to monitor\n        min_delta : float or int\n            the minimum difference between the best metric value so far and\n            the current one\n        patience : int\n            number of epochs to wait before stopping training\n        mode : str (default: 'min')\n            defines the optimum for the monitored value\n\n        \"\"\"\n\n        super().__init__()\n\n        self.monitor_key = monitor_key,\n        self.min_delta = min_delta\n        self.patience = patience\n        self.mode = mode\n\n        if 'min' == mode:\n            self.best_metric = float('inf')\n        elif 'max' == mode:\n            self.best_metric = - float('inf')\n\n        else:\n            raise ValueError(\"Unknown compare mode: Got %s, but expected one \"\n                             \"of ['min', 'max']\" % mode)\n        self.epochs_waited = 0\n\n    def _is_better(self, metric):\n        \"\"\"\n        Helper function to decide whether the current metric is better than\n        the best metric so far\n\n        Parameters\n        ----------\n        metric :\n            current metric value\n\n        Returns\n        -------\n        bool\n            whether this metric is the new best metric or not\n\n        \"\"\"\n        if 'min' == self.mode:\n            return metric < (self.best_metric - self.min_delta)\n        else:\n            return metric > (self.best_metric + self.min_delta)\n\n    def at_epoch_end(self, trainer, **kwargs):\n        \"\"\"\n        Actual early stopping: Checks at end of each epoch if monitored metric\n        is new best and if it hasn't improved over `self.patience` epochs, the\n        training will be stopped\n\n        Parameters\n        ----------\n        trainer : :class:`AbstractNetworkTrainer`\n            the trainer whose arguments can be modified\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        :class:`AbstractNetworkTrainer`\n            trainer with modified attributes\n\n        \"\"\"\n        metric = kwargs.get(\"val_metrics\", {})[self.monitor_key]\n\n        self.epochs_waited += 1 - int(self._is_better(metric))\n\n        if self.epochs_waited >= self.patience:\n            stop_training = True\n        else:\n            stop_training = False\n\n        return {\"stop_training\": stop_training}\n"
  },
  {
    "path": "delira/training/callbacks/logging_callback.py",
    "content": "from delira.training.callbacks.abstract_callback import AbstractCallback\nfrom delira.logging import make_logger, BaseBackend\nimport logging\n\n\nclass DefaultLoggingCallback(AbstractCallback):\n    \"\"\"\n    A default Logging backend which logs only the metrics; Should be\n    subclassed for additional logging\n    \"\"\"\n\n    def __init__(self, backend: BaseBackend, max_queue_size: int = None,\n                 logging_frequencies=None, reduce_types=None,\n                 level=logging.NOTSET):\n        \"\"\"\n\n        Parameters\n        ----------\n        backend : :class:`delira.logging.base_backend.BaseBackend`\n            the logging backend\n        max_queue_size : int\n            the maximum queue size\n        logging_frequencies : int or dict\n                specifies how often to log for each key.\n                If int: integer will be applied to all valid keys\n                if dict: should contain a frequency per valid key. Missing keys\n                will be filled with a frequency of 1 (log every time)\n                None is equal to empty dict here.\n        reduce_types : str of FunctionType or dict\n            if str:\n                specifies the reduction type to use. Valid types are\n                'last' | 'first' | 'mean' | 'max' | 'min'.\n                The given type will be mapped to all valid keys.\n            if FunctionType:\n                specifies the actual reduction function. Will be applied\n                for all keys.\n            if dict: should contain pairs of valid logging keys and either\n                str or FunctionType. Specifies the logging value per key.\n                Missing keys will be filles with a default value of 'last'.\n                Valid types for strings are\n                'last' | 'first' | 'mean' | 'max' | 'min'.\n        level : int\n            the logging level for python's internal logging module\n\n        \"\"\"\n        super().__init__()\n\n        self._logger = make_logger(backend=backend,\n                                   max_queue_size=max_queue_size,\n                                   logging_frequencies=logging_frequencies,\n                                   reduce_types=reduce_types, level=level)\n\n    def at_iter_end(self, trainer, iter_num=None, data_dict=None, train=False,\n                    **kwargs):\n        \"\"\"\n        Function logging the metrics at the end of each iteration\n\n        Parameters\n        ----------\n        trainer : :class:`BaseNetworkTrainer`\n            the current trainer object (unused in this callback)\n        iter_num : int\n            number of the current iteration inside the current epoch\n            (unused in this callback)\n        data_dict : dict\n            the current data dict (including predictions)\n        train: bool\n            signals if callback is called by trainer or predictor\n        **kwargs :\n            additional keyword arguments\n\n        Returns\n        -------\n        dict\n            empty dict, because no state should be updated\n        \"\"\"\n\n        metrics = kwargs.get(\"metrics\", {})\n\n        for k, v in metrics.items():\n            self._logger.log({\"scalar\": {\"tag\": self.create_tag(k, train),\n                                         \"scalar_value\": v}})\n\n        return {}\n\n    @staticmethod\n    def create_tag(tag: str, train: bool):\n        if train:\n            tag = tag + \"_val\"\n        return tag\n"
  },
  {
    "path": "delira/training/callbacks/pytorch_schedulers.py",
    "content": "from delira import get_backends\nfrom delira.training.callbacks.abstract_callback import AbstractCallback\n\nif 'TORCH' in get_backends():\n    from torch.optim.lr_scheduler import ReduceLROnPlateau, \\\n        CosineAnnealingLR, ExponentialLR, LambdaLR, MultiStepLR, StepLR, \\\n        OneCycleLR\n\n    class DefaultPyTorchSchedulerCallback(AbstractCallback):\n        \"\"\"\n        Implements a Callback, which `at_epoch_end` function is suitable for\n        most schedulers\n\n        \"\"\"\n\n        def __init__(self, *args, **kwargs):\n            \"\"\"\n\n            Parameters\n            ----------\n            *args :\n                Arbitrary Positional Arguments\n            **kwargs :\n                Arbitrary Keyword Arguments\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = None\n\n        def at_epoch_end(self, trainer, **kwargs):\n            \"\"\"\n            Executes a single scheduling step\n\n            Parameters\n            ----------\n            trainer : :class:`PyTorchNetworkTrainer`\n                the trainer class, which can be changed\n            **kwargs :\n                additional keyword arguments\n\n            Returns\n            -------\n            :class:`PyTorchNetworkTrainer`\n                modified trainer\n\n            \"\"\"\n            self.scheduler.step(epoch=kwargs.get(\"curr_epoch\", None))\n            return {}\n\n    class OneCycleLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `OneCycleLR` Scheduler as Callback\n\n        \"\"\"\n\n        def __init__(\n                self,\n                optimizer,\n                max_lr,\n                total_steps=None,\n                epochs=None,\n                steps_per_epoch=None,\n                pct_start=0.3,\n                anneal_strategy='cos',\n                cycle_momentum=True,\n                base_momentum=0.85,\n                max_momentum=0.95,\n                div_factor=25.0,\n                final_div_factor=10000.0,\n                last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer (Optimizer): Wrapped optimizer.\n            max_lr (float or list): Upper learning rate boundaries in the cycle\n                for each parameter group.\n            total_steps (int): The total number of steps in the cycle. Note\n                that if a value is provided here, then it must be inferred by\n                providing a value for epochs and steps_per_epoch.\n                Default: None\n            epochs (int): The number of epochs to train for. This is used along\n                with steps_per_epoch in order to infer the total number of\n                steps in the cycle if a value for total_steps is not provided.\n                Default: None\n            steps_per_epoch (int): The number of steps per epoch to train for.\n                This is used along with epochs in order to infer the total\n                number of steps in the cycle if a value for total_steps is\n                not provided.\n                Default: None\n            pct_start (float): The percentage of the cycle (in number of steps)\n                spent increasing the learning rate.\n                Default: 0.3\n            anneal_strategy (str): {'cos', 'linear'}\n                Specifies the annealing strategy.\n                Default: 'cos'\n            cycle_momentum (bool): If ``True``, momentum is cycled inversely\n                to learning rate between 'base_momentum' and 'max_momentum'.\n                Default: True\n            base_momentum (float or list): Lower momentum boundaries in the\n                cycle for each parameter group. Note that momentum is cycled\n                inversely to learning rate; at the peak of a cycle, momentum is\n                'base_momentum' and learning rate is 'max_lr'.\n                Default: 0.85\n            max_momentum (float or list): Upper momentum boundaries in the\n                cycle for each parameter group. Functionally,\n                it defines the cycle amplitude (max_momentum - base_momentum).\n                Note that momentum is cycled inversely\n                to learning rate; at the start of a cycle, momentum is\n                'max_momentum' and learning rate is 'base_lr'\n                Default: 0.95\n            div_factor (float): Determines the initial learning rate via\n                initial_lr = max_lr/div_factor\n                Default: 25\n            final_div_factor (float): Determines the minimum learning rate via\n                min_lr = initial_lr/final_div_factor\n                Default: 1e4\n            last_epoch (int): The index of the last batch. This parameter is\n                used when resuming a training job. Since `step()` should be\n                invoked after each batch instead of after each epoch, this\n                number represents the total number of *batches* computed,\n                not the total number of epochs computed.\n                When last_epoch=-1, the schedule is started from the\n                beginning.\n                Default: -1\n            \"\"\"\n            super().__init__()\n            self.scheduler = OneCycleLR(\n                optimizer,\n                max_lr,\n                total_steps,\n                epochs,\n                steps_per_epoch,\n                pct_start,\n                anneal_strategy,\n                cycle_momentum,\n                base_momentum,\n                max_momentum,\n                div_factor,\n                final_div_factor,\n                last_epoch)\n\n        def at_iter_begin(self, trainer, train,\n                          **kwargs):\n            \"\"\"\n            Executes a single scheduling step\n\n            Parameters\n            ----------\n            trainer : :class:`PyTorchNetworkTrainer`\n                the trainer class, which can be changed\n            kwargs :\n                additional keyword arguments\n\n            Returns\n            -------\n            :class:`PyTorchNetworkTrainer`\n                modified trainer\n\n            \"\"\"\n            if train:\n                self.scheduler.step()\n\n            return {}\n\n        def at_epoch_end(self, trainer, **kwargs):\n            return {}\n\n    class ReduceLROnPlateauCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `ReduceLROnPlateau` Scheduler as Callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, mode='min', factor=0.1, patience=10,\n                     verbose=False, threshold=1e-4, threshold_mode='rel',\n                     cooldown=0, min_lr=0, eps=1e-8):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : Optimizer\n                Wrapped optimizer.\n            mode : str\n                One of `min`, `max`. In `min` mode, lr will\n                be reduced when the quantity monitored has stopped\n                decreasing; in `max` mode it will be reduced when the\n                quantity monitored has stopped increasing. Default: 'min'.\n            factor : float\n                Factor by which the learning rate will be\n                reduced. new_lr = lr * factor. Default: 0.1.\n            patience : int\n                Number of epochs with no improvement after\n                which learning rate will be reduced. For example, if\n                `patience = 2`, then we will ignore the first 2 epochs\n                with no improvement, and will only decrease the LR after the\n                3rd epoch if the loss still hasn't improved then.\n                Default: 10.\n            verbose : bool\n                If ``True``, prints a message to stdout for\n                each update. Default: ``False``.\n            threshold : float\n                Threshold for measuring the new optimum,\n                to only focus on significant changes. Default: 1e-4.\n            threshold_mode : string\n                One of `rel`, `abs`. In `rel` mode,\n                dynamic_threshold = best * ( 1 + threshold ) in 'max'\n                mode or best * ( 1 - threshold ) in `min` mode.\n                In `abs` mode, dynamic_threshold = best + threshold in\n                `max` mode or best - threshold in `min` mode. Default: 'rel'.\n            cooldown : int\n                Number of epochs to wait before resuming\n                normal operation after lr has been reduced. Default: 0.\n            min_lr : float or list\n                A scalar or a list of scalars. A\n                lower bound on the learning rate of all param groups\n                or each group respectively. Default: 0.\n            eps : float\n                Minimal decay applied to lr. If the difference\n                between new and old lr is smaller than eps, the update is\n                ignored. Default: 1e-8\n\n            \"\"\"\n            super().__init__()\n            self.scheduler = ReduceLROnPlateau(\n                optimizer,\n                mode,\n                factor,\n                patience,\n                verbose,\n                threshold,\n                threshold_mode,\n                cooldown,\n                min_lr,\n                eps)\n\n        def at_epoch_end(self, trainer,\n                         **kwargs):\n            \"\"\"\n            Executes a single scheduling step\n\n            Parameters\n            ----------\n            trainer : :class:`PyTorchNetworkTrainer`\n                the trainer class, which can be changed\n            kwargs :\n                additional keyword arguments\n\n            Returns\n            -------\n            :class:`PyTorchNetworkTrainer`\n                modified trainer\n\n            \"\"\"\n            val_metrics = kwargs.get(\"val_metrics\", {})\n\n            val_score_key = kwargs.get(\"val_score_key\", None)\n\n            metrics = val_metrics.get(val_score_key)\n\n            self.scheduler.step(metrics=metrics)\n\n            return {}\n\n    class CosineAnnealingLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `CosineAnnealingLR` Scheduler as callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : optimizer\n                Wrapped optimizer.\n            T_max : int\n                Maximum number of iterations.\n            eta_min : float\n                Minimum learning rate. Default: 0.\n            last_epoch : int\n                The index of last epoch. Default: -1.\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = CosineAnnealingLR(optimizer, T_max, eta_min,\n                                               last_epoch)\n\n    class ExponentialLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `ExponentialLR` Scheduler as callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, gamma, last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : Optimizer\n                Wrapped optimizer.\n            gamma : float\n                Multiplicative factor of learning rate decay.\n            last_epoch : int\n                The index of last epoch. Default: -1.\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = ExponentialLR(optimizer, gamma, last_epoch)\n\n    class LambdaLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `LambdaLR` Scheduler as callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, lr_lambda, last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : Optimizer\n                Wrapped optimizer.\n            lr_lambda : function or list\n                A function which computes a multiplicative\n                factor given an integer parameter epoch, or a list of such\n                functions, one for each group in optimizer.param_groups.\n            last_epoch : int\n                The index of last epoch. Default: -1.\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = LambdaLR(optimizer, lr_lambda, last_epoch)\n\n    class MultiStepLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `MultiStepLR` Scheduler as callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : Optimizer\n                Wrapped optimizer.\n            milestones : list\n                List of epoch indices. Must be increasing.\n            gamma : float\n                Multiplicative factor of learning rate decay.\n                Default: 0.1.\n            last_epoch : int\n                The index of last epoch. Default: -1.\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = MultiStepLR(\n                optimizer, milestones, gamma, last_epoch)\n\n    class StepLRCallback(DefaultPyTorchSchedulerCallback):\n        \"\"\"\n        Wraps PyTorch's `StepLR` Scheduler as callback\n\n        \"\"\"\n\n        def __init__(self, optimizer, step_size, gamma=0.1, last_epoch=-1):\n            \"\"\"\n\n            Parameters\n            ----------\n            optimizer : Optimizer\n                Wrapped optimizer.\n            step_size : int\n                Period of learning rate decay.\n            gamma :float\n                Multiplicative factor of learning rate decay.\n                Default: 0.1.\n            last_epoch : int\n                The index of last epoch. Default: -1\n\n            \"\"\"\n            super().__init__()\n\n            self.scheduler = StepLR(optimizer, step_size, gamma, last_epoch)\n"
  },
  {
    "path": "delira/training/losses.py",
    "content": "from delira import get_backends\n\nif \"TORCH\" in get_backends():\n    import torch\n    import torch.nn.functional as F\n\n    class BCEFocalLossPyTorch(torch.nn.Module):\n        \"\"\"\n        Focal loss for binary case without(!) logit\n\n        \"\"\"\n\n        def __init__(self, alpha=None, gamma=2, reduction='elementwise_mean'):\n            \"\"\"\n            Implements Focal Loss for binary class case\n\n            Parameters\n            ----------\n            alpha : float\n                alpha has to be in range [0,1], assigns class weight\n            gamma : float\n                focusing parameter\n            reduction : str\n                Specifies the reduction to apply to the output: ‘none’ |\n                ‘elementwise_mean’ | ‘sum’.\n                    ‘none’: no reduction will be\n                        applied,\n                    ‘elementwise_mean’: the sum of the output will be divided\n                        by the number of elements in the output,\n                    ‘sum’: the output will be summed\n            (further information about parameters above can be found in pytorch\n            documentation)\n\n            Returns\n            -------\n            torch.Tensor\n                loss value\n\n            \"\"\"\n            super().__init__()\n            self.alpha = alpha\n            self.gamma = gamma\n            self.reduction = reduction\n\n        def forward(self, p, t):\n            bce_loss = F.binary_cross_entropy(p, t, reduction='none')\n\n            if self.alpha is not None:\n                # create weights for alpha\n                alpha_weight = torch.ones(t.shape, device=p.device) * \\\n                    self.alpha\n                alpha_weight = torch.where(torch.eq(t, 1.),\n                                           alpha_weight, 1 - alpha_weight)\n            else:\n                alpha_weight = torch.Tensor([1]).to(p.device)\n\n            # create weights for focal loss\n            focal_weight = 1 - torch.where(torch.eq(t, 1.), p, 1 - p)\n            focal_weight.pow_(self.gamma)\n            focal_weight.to(p.device)\n\n            # compute loss\n            focal_loss = focal_weight * alpha_weight * bce_loss\n\n            if self.reduction == 'elementwise_mean':\n                return torch.mean(focal_loss)\n            if self.reduction == 'none':\n                return focal_loss\n            if self.reduction == 'sum':\n                return torch.sum(focal_loss)\n            raise AttributeError('Reduction parameter unknown.')\n\n    class BCEFocalLossLogitPyTorch(torch.nn.Module):\n        \"\"\"\n        Focal loss for binary case WITH logit\n\n        \"\"\"\n\n        def __init__(self, alpha=None, gamma=2, reduction='elementwise_mean'):\n            \"\"\"\n            Implements Focal Loss for binary class case\n\n            Parameters\n            ----------\n            alpha : float\n                alpha has to be in range [0,1], assigns class weight\n            gamma : float\n                focusing parameter\n            reduction : str\n                Specifies the reduction to apply to the output: ‘none’ |\n                ‘elementwise_mean’ | ‘sum’.\n                    ‘none’: no reduction will be applied,\n                    ‘elementwise_mean’: the sum of the output will be divided\n                        by the number of elements in the output,\n                    ‘sum’: the output will be summed\n            (further information about parameters above can be found in pytorch\n            documentation)\n\n            Returns\n            -------\n            torch.Tensor\n                loss value\n\n            \"\"\"\n            super().__init__()\n            self.alpha = alpha\n            self.gamma = gamma\n            self.reduction = reduction\n\n        def forward(self, p, t):\n            bce_loss = F.binary_cross_entropy_with_logits(\n                p, t, reduction='none')\n\n            p = torch.sigmoid(p)\n\n            if self.alpha is not None:\n                # create weights for alpha\n                alpha_weight = torch.ones(t.shape, device=p.device) * \\\n                    self.alpha\n                alpha_weight = torch.where(torch.eq(t, 1.),\n                                           alpha_weight, 1 - alpha_weight)\n            else:\n                alpha_weight = torch.Tensor([1]).to(p.device)\n\n            # create weights for focal loss\n            focal_weight = 1 - torch.where(torch.eq(t, 1.), p, 1 - p)\n            focal_weight.pow_(self.gamma)\n            focal_weight.to(p.device)\n\n            # compute loss\n            focal_loss = focal_weight * alpha_weight * bce_loss\n\n            if self.reduction == 'elementwise_mean':\n                return torch.mean(focal_loss)\n            if self.reduction == 'none':\n                return focal_loss\n            if self.reduction == 'sum':\n                return torch.sum(focal_loss)\n            raise AttributeError('Reduction parameter unknown.')\n"
  },
  {
    "path": "delira/training/metrics.py",
    "content": "\nfrom sklearn.metrics import accuracy_score, balanced_accuracy_score, \\\n    f1_score, fbeta_score, hamming_loss, jaccard_similarity_score, log_loss, \\\n    matthews_corrcoef, precision_score, recall_score, zero_one_loss, \\\n    roc_auc_score\nfrom sklearn.preprocessing import label_binarize\n\nimport numpy as np\n\n\nclass SklearnClassificationMetric(object):\n    def __init__(self, score_fn, gt_logits=False, pred_logits=True, **kwargs):\n        \"\"\"\n        Wraps an score function as a metric\n\n        Parameters\n        ----------\n        score_fn : function\n            function which should be wrapped\n        gt_logits : bool\n            whether given ``y_true`` are logits or not\n        pred_logits : bool\n            whether given ``y_pred`` are logits or not\n        **kwargs:\n            variable number of keyword arguments passed to score_fn function\n        \"\"\"\n        self._score_fn = score_fn\n        self._gt_logits = gt_logits\n        self._pred_logits = pred_logits\n        self.kwargs = kwargs\n\n    def __call__(self, y_true, y_pred, **kwargs):\n        \"\"\"\n        Compute metric with score_fn\n\n        Parameters\n        ----------\n        y_true: np.ndarray\n            ground truth data\n        y_pred: np.ndarray\n            predictions of network\n        kwargs:\n            variable number of keyword arguments passed to score_fn\n\n        Returns\n        -------\n        float\n            result from score function\n\n        \"\"\"\n\n        if self._gt_logits:\n            y_true = np.argmax(y_true, axis=-1)\n\n        if self._pred_logits:\n            y_pred = np.argmax(y_pred, axis=-1)\n\n        return self._score_fn(y_true=y_true, y_pred=y_pred,\n                              **kwargs, **self.kwargs)\n\n\nclass SklearnAccuracyScore(SklearnClassificationMetric):\n    \"\"\"\n    Accuracy Metric\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(accuracy_score, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnBalancedAccuracyScore(SklearnClassificationMetric):\n    \"\"\"\n    Balanced Accuracy Metric\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(balanced_accuracy_score, gt_logits, pred_logits,\n                         **kwargs)\n\n\nclass SklearnF1Score(SklearnClassificationMetric):\n    \"\"\"\n    F1 Score\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(f1_score, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnFBetaScore(SklearnClassificationMetric):\n    \"\"\"\n    F-Beta Score (Generalized F1)\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(fbeta_score, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnHammingLoss(SklearnClassificationMetric):\n    \"\"\"\n    Hamming Loss\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(hamming_loss, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnJaccardSimilarityScore(SklearnClassificationMetric):\n    \"\"\"\n    Jaccard Similarity Score\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(jaccard_similarity_score, gt_logits, pred_logits,\n                         **kwargs)\n\n\nclass SklearnLogLoss(SklearnClassificationMetric):\n    \"\"\"\n    Log Loss (NLL)\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(log_loss, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnMatthewsCorrCoeff(SklearnClassificationMetric):\n    \"\"\"\n    Matthews Correlation Coefficient\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(matthews_corrcoef, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnPrecisionScore(SklearnClassificationMetric):\n    \"\"\"\n    Precision Score\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(precision_score, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnRecallScore(SklearnClassificationMetric):\n    \"\"\"\n    Recall Score\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(recall_score, gt_logits, pred_logits, **kwargs)\n\n\nclass SklearnZeroOneLoss(SklearnClassificationMetric):\n    \"\"\"\n    Zero One Loss\n    \"\"\"\n\n    def __init__(self, gt_logits=False, pred_logits=True, **kwargs):\n        super().__init__(zero_one_loss, gt_logits, pred_logits, **kwargs)\n\n\nclass AurocMetric(object):\n    def __init__(self, classes=(0, 1), **kwargs):\n        \"\"\"\n        Implements the auroc metric for binary and multi class classification\n\n        Parameters\n        ----------\n        classes: array-like\n            uniquely holds the label for each class.\n        kwargs:\n            variable number of keyword arguments passed to roc_auc_score\n\n        Raises\n        ------\n        ValueError\n            if not at least two classes are provided\n        \"\"\"\n        self.classes = classes\n        self.kwargs = kwargs\n        if len(self.classes) < 2:\n            raise ValueError(\"At least classes 2 must exist for \"\n                             \"classification. Only classes {} were passed to \"\n                             \"AurocMetric.\".format(classes))\n\n    def __call__(self, y_true, y_pred, **kwargs):\n        \"\"\"\n        Compute auroc\n\n        Parameters\n        ----------\n        y_true: np.ndarray\n            ground truth data with shape (N)\n        y_pred: np.ndarray\n            predictions of network in numpy format with shape (N, nclasses)\n        kwargs:\n            variable number of keyword arguments passed to roc_auc_score\n\n        Returns\n        -------\n        float\n            computes auc score\n\n        Raises\n        ------\n        ValueError\n            if two classes are given and the predictions contain more than two\n            classes\n        \"\"\"\n        # binary classification\n        if len(self.classes) == 2:\n            # single output unit (e.g. sigmoid)\n            if len(y_pred.shape) == 1 or y_pred.shape[2] == 1:\n                return roc_auc_score(y_true, y_pred, **kwargs)\n            # output of two units (e.g. softmax)\n            elif y_pred.shape[2] == 2:\n                return roc_auc_score(y_true, y_pred[:, 1], **kwargs)\n            else:\n                raise ValueError(\"Can not compute auroc metric for binary \"\n                                 \"classes with {} predicted \"\n                                 \"classes.\".format(y_pred.shape[2]))\n\n        # classification with multiple classes\n        if len(self.classes) > 2:\n            y_true_bin = label_binarize(y_true, self.classes)\n            return roc_auc_score(y_true_bin, y_pred, **kwargs, **self.kwargs)\n"
  },
  {
    "path": "delira/training/predictor.py",
    "content": "import logging\nimport gc\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom delira.data_loading import DataManager\nfrom delira.training.utils import convert_to_numpy_identity\nfrom ..utils.config import LookupConfig\n\nfrom delira.training.callbacks import AbstractCallback\n\nlogger = logging.getLogger(__name__)\n\n\nclass Predictor(object):\n    \"\"\"\n    Defines an API for Predictions from a Network\n\n    See Also\n    --------\n    :class:`PyTorchNetworkTrainer`\n\n    \"\"\"\n\n    # static variable to prevent certain attributes from overwriting\n    __KEYS_TO_GUARD = []\n\n    def __init__(\n            self, model, key_mapping: dict,\n            convert_batch_to_npy_fn=convert_to_numpy_identity,\n            prepare_batch_fn=lambda x: x,\n            callbacks=None, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        model : :class:`AbstractNetwork`\n            the model to predict from\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        convert_batch_args_kwargs_to_npy_fn : type, optional\n            a callable function to convert tensors in positional and keyword\n            arguments to numpy; default: identity function\n        prepare_batch_fn : type, optional\n            function converting a batch-tensor to the framework specific\n            tensor-type and pushing it to correct device, default: identity\n            function\n        callbacks : list\n            initial callbacks to register\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        if callbacks is None:\n            callbacks = []\n\n        self._setup(model, key_mapping, convert_batch_to_npy_fn,\n                    prepare_batch_fn, callbacks, **kwargs)\n\n        self._tqdm_desc = \"Test\"\n\n    def _setup(self, network, key_mapping, convert_batch_args_kwargs_to_npy_fn,\n               prepare_batch_fn, callbacks, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        network : :class:`AbstractNetwork`\n            the network to predict from\n        key_mapping : dict\n            a dictionary containing the mapping from the ``data_dict`` to\n            the actual model's inputs.\n            E.g. if a model accepts one input named 'x' and the data_dict\n            contains one entry named 'data' this argument would have to\n            be ``{'x': 'data'}``\n        convert_batch_to_npy_fn : type\n            a callable function to convert tensors in positional and keyword\n            arguments to numpy\n        prepare_batch_fn : (dict, str, str) -> dict\n            function converting a batch-tensor to the framework specific\n            tensor-type and pushing it to correct device, default: identity\n            function\n        callbacks : list\n            initial callbacks to register\n\n        \"\"\"\n\n        self.module = network\n        self.key_mapping = key_mapping\n        self._convert_to_npy_fn = convert_batch_args_kwargs_to_npy_fn\n        self._prepare_batch = prepare_batch_fn\n        self._callbacks = []\n\n        for cb in callbacks:\n            self.register_callback(cb)\n\n    def __call__(self, data: dict, **kwargs):\n        \"\"\"\n        Method to call the class.\n        Returns the predictions corresponding to the given data\n        obtained by the model\n\n        Parameters\n        ----------\n        data : dict\n            batch dictionary\n\n        Returns\n        -------\n        dict\n            predicted data\n        \"\"\"\n        return self.predict(data, **kwargs)\n\n    def predict(self, data: dict, already_prepared=False, **kwargs):\n        \"\"\"\n        Predict single batch\n        Returns the predictions corresponding to the given data\n        obtained by the model\n\n        Parameters\n        ----------\n        data : dict\n            batch dictionary\n        already_prepared : bool\n            if True, the `prepare_batch` function won't be called on the data\n            anymore\n        **kwargs :\n            keyword arguments(directly passed to ``prepare_batch``)\n\n        Returns\n        -------\n        dict\n            predicted data\n\n        \"\"\"\n        if not already_prepared:\n            data = self._prepare_batch(data, **kwargs)\n\n        mapped_data = {\n            k: data[v] for k, v in self.key_mapping.items()}\n\n        pred = self.module(\n            **mapped_data\n        )\n\n        # converts positional arguments and keyword arguments,\n        # but returns only keyword arguments, since positional\n        # arguments are not given.\n        return self._convert_to_npy_fn(\n            **pred\n        )[1]\n\n    def _at_iter_begin(self, iter_num, **kwargs):\n        \"\"\"\n        Function defining the behavior executed at beginning of each iteration\n\n        Parameters\n        ----------\n        iter_num : int\n            the number of the current iteration\n        **kwargs :\n            additional keyword arguments (forwarded to callbacks call)\n\n        Returns\n        -------\n        dict\n            combined dicts returned by the callbacks\n\n        \"\"\"\n        return_dict = {}\n        for cb in self._callbacks:\n            return_dict.update(cb.at_iter_begin(self,\n                                                iter_num=iter_num,\n                                                train=False,\n                                                **kwargs))\n\n        return return_dict\n\n    def _at_iter_end(self, iter_num, data_dict, metrics, **kwargs):\n        \"\"\"\n        Function defining the behavior executed at beginning of each iteration\n\n        Parameters\n        ----------\n        iter_num : int\n            the number of the current iteration\n        data_dict : dict\n            dictionary holding input data and predictions\n        metrics: dict\n            calculated metrics\n        **kwargs :\n            additional keyword arguments (forwarded to callbacks call)\n\n        Returns\n        -------\n        dict\n            combined dicts returned by the callbacks\n\n        \"\"\"\n        return_dict = {}\n        for cb in self._callbacks:\n            return_dict.update(cb.at_iter_end(self,\n                                              iter_num=iter_num,\n                                              data_dict=data_dict,\n                                              metrics=metrics,\n                                              train=False,\n                                              **kwargs))\n\n        return return_dict\n\n    def predict_data_mgr(\n            self,\n            datamgr: DataManager,\n            batchsize=None,\n            metrics=None,\n            metric_keys=None,\n            verbose=False,\n            **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator\n        without explicitly caching anything\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        kwargs :\n            keyword arguments passed to :func:`prepare_batch_fn`\n\n        Yields\n        ------\n        dict\n            a dictionary containing all predictions of the current batch\n        dict\n            a dictionary containing all metrics of the current batch\n\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n        orig_num_aug_processes = datamgr.n_process_augmentation\n        orig_batch_size = datamgr.batch_size\n\n        if batchsize is None:\n            batchsize = orig_batch_size\n\n        datamgr.batch_size = 1\n\n        batchgen = datamgr.get_batchgen()\n\n        n_batches = datamgr.n_batches\n\n        if verbose:\n            iterable = tqdm(enumerate(batchgen), unit=' sample',\n                            total=n_batches, desc=self._tqdm_desc)\n\n        else:\n            iterable = enumerate(batchgen)\n\n        batch_list = []\n\n        for i, batch in iterable:\n            Predictor._at_iter_begin(self, iter_num=i)\n\n            if not batch_list and (n_batches - i) < batchsize:\n                batchsize = n_batches - i\n                logger.debug(\"Set Batchsize down to %d to avoid cutting \"\n                             \"of the last batches\" % batchsize)\n\n            batch_list.append(batch)\n\n            # if queue is full process queue:\n            if batchsize is None or len(batch_list) >= batchsize:\n\n                batch_dict = {}\n                for _batch in batch_list:\n                    for key, val in _batch.items():\n                        if key in batch_dict.keys():\n                            batch_dict[key].append(val)\n                        else:\n                            batch_dict[key] = [val]\n\n                for key, val_list in batch_dict.items():\n                    batch_dict[key] = np.concatenate(val_list)\n\n                batch_dict = self._prepare_batch(batch_dict)\n                preds = self.predict(batch_dict, already_prepared=True,\n                                     **kwargs)\n\n                # convert batchdict back to numpy (self.predict may convert it\n                # to backend-specific tensor type) - no-op if already numpy\n                batch_dict = self._convert_to_npy_fn(**batch_dict)[1]\n\n                preds_batch = LookupConfig()\n                # explicitly free memory of old lookup config\n                gc.collect()\n                preds_batch.update(batch_dict)\n                preds_batch.update(preds)\n\n                # calculate metrics for predicted batch\n                _metric_vals = self.calc_metrics(preds_batch,\n                                                 metrics=metrics,\n                                                 metric_keys=metric_keys)\n\n                self._at_iter_end(data_dict={**batch_dict, **preds_batch},\n                                  metrics={\"val_\" + k: v\n                                           for k, v in _metric_vals.items()},\n                                  iter_num=i)\n\n                yield preds, _metric_vals\n\n                batch_list = []\n\n        datamgr.batch_size = orig_batch_size\n        datamgr.n_process_augmentation = orig_num_aug_processes\n\n        return\n\n    def predict_data_mgr_cache_metrics_only(self, datamgr, batchsize=None,\n                                            metrics=None, metric_keys=None,\n                                            verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator and\n        caches the metrics\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        kwargs :\n            keyword arguments passed to :func:`prepare_batch_fn`\n\n        Yields\n        ------\n        dict\n            a dictionary containing all validation metrics (maybe empty)\n\n        Notes\n        -----\n        This function stores each prediction temporarily for metric\n        calculation; This results in a (typically) way lower memory\n        consumption than :meth:`Predictor.predict_data_mgr_cache_all`,\n        but still caches the metrics. If this is not desired, it is recommended\n        to use :meth:`Predictor.predict_data_mgr` and iterate over the\n        generator as this only produces per-batch metrics and predictions and\n        does not cache anything by default\n\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n        yield from self.predict_data_mgr_cache(datamgr=datamgr,\n                                               batchsize=batchsize,\n                                               metrics=metrics,\n                                               metric_keys=metric_keys,\n                                               verbose=verbose,\n                                               cache_preds=False, **kwargs)\n\n        return\n\n    def predict_data_mgr_cache_all(self, datamgr, batchsize=None, metrics=None,\n                                   metric_keys=None, verbose=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator and\n        caches all predictions and metrics (yields them in dicts)\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        kwargs :\n            keyword arguments passed to :func:`prepare_batch_fn`\n\n        Yields\n        ------\n        dict\n            a dictionary containing all predictions;\n        dict\n            a dictionary containing all validation metrics (maybe empty)\n\n        Warnings\n        --------\n        Since this function caches all predictions and metrics, this may result\n        in huge memory consumption. If you are running out of memory, please\n        have a look at :meth:`Predictor.predict_data_mgr_cache_metrics_only`\n        or :meth:`Predictor.predict_data_mgr`\n\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n        yield from self.predict_data_mgr_cache(datamgr=datamgr,\n                                               batchsize=batchsize,\n                                               metrics=metrics,\n                                               metric_keys=metric_keys,\n                                               verbose=verbose,\n                                               cache_preds=True, **kwargs)\n\n        return\n\n    def predict_data_mgr_cache(self, datamgr, batchsize=None, metrics=None,\n                               metric_keys=None, verbose=False,\n                               cache_preds=False, **kwargs):\n        \"\"\"\n        Defines a routine to predict data obtained from a batchgenerator and\n        caches all predictions and metrics (yields them in dicts)\n\n        Parameters\n        ----------\n        datamgr : :class:`DataManager`\n            Manager producing a generator holding the batches\n        batchsize : int\n            Artificial batchsize (sampling will be done with batchsize\n            1 and sampled data will be stacked to match the artificial\n            batchsize)(default: None)\n        metrics : dict\n            the metrics to calculate\n        metric_keys : dict\n            the ``batch_dict`` items to use for metric calculation\n        verbose : bool\n            whether to show a progress-bar or not, default: False\n        cache_preds : bool\n            whether to also cache predictions\n        kwargs :\n            keyword arguments passed to :func:`prepare_batch_fn`\n\n        Yields\n        ------\n        dict\n            a dictionary containing all validation metrics (maybe empty)\n        dict\n            a dictionary containing all predictions; If ``cache_preds=True``\n\n        Warnings\n        --------\n        Since this function caches all metrics and may additionally cache all\n        predictions (based on the argument ``cache_preds``), this may result\n        in huge memory consumption. If you are running out of memory, please\n        have a look at :meth:`Predictor.predict_data_mgr_cache_metrics_only`\n        or :meth:`Predictor.predict_data_mgr` or consider setting\n        ``cache_preds`` to ``False`` (if not done already)\n\n        \"\"\"\n\n        if metrics is None:\n            metrics = {}\n\n        predictions_all, metric_vals = [], {k: [] for k in metrics.keys()}\n\n        for preds, _metric_vals in self.predict_data_mgr(\n                datamgr=datamgr,\n                batchsize=batchsize,\n                metrics=metrics,\n                metric_keys=metric_keys,\n                verbose=verbose,\n                **kwargs):\n\n            if cache_preds:\n                predictions_all.append(preds)\n            for k, v in _metric_vals.items():\n                metric_vals[k].append(v)\n\n        if cache_preds:\n            # convert predictions from list of dicts to dict of lists\n            new_predictions_all = {}\n\n            # recursively convert all nested dicts\n            for preds in predictions_all:\n                new_predictions_all = self.__convert_dict(preds,\n                                                          new_predictions_all)\n\n            # concatenate lists to single arrays\n            preds_all = self.__concatenate_dict_items(new_predictions_all)\n        else:\n            preds_all = {}\n\n        for k, v in metric_vals.items():\n            metric_vals[k] = np.array(v)\n\n        if cache_preds:\n            yield preds_all, metric_vals\n        else:\n            yield metric_vals\n\n        return\n\n    @staticmethod\n    def __convert_dict(old_dict, new_dict):\n        \"\"\"\n        Function to recursively convert dicts\n\n        Parameters\n        ----------\n        old_dict : dict\n            the old nested dict\n        new_dict : dict\n            the new nested dict\n\n        Returns\n        -------\n        dict\n            the updated new nested dict\n        \"\"\"\n        for k, v in old_dict.items():\n\n            # apply same function again on item if item is dict\n            if isinstance(v, dict):\n                if k not in new_dict:\n                    new_dict[k] = {}\n\n                new_dict[k] = Predictor.__convert_dict(v, new_dict[k])\n\n            else:\n\n                # check if v is scalar and convert to npy-array if\n                # necessary.\n                # Otherwise concatenation might fail\n                if np.isscalar(v):\n                    v = np.array(v)\n\n                # check for zero-sized arrays and reshape if necessary.\n                # Otherwise concatenation might fail\n                if v.shape == ():\n                    v = v.reshape(1)\n                if k in new_dict:\n                    new_dict[k].append(v)\n                else:\n                    new_dict[k] = [v]\n\n        return new_dict\n\n    @staticmethod\n    def __concatenate_dict_items(dict_like: dict):\n        \"\"\"\n        Function to recursively concatenate dict-items\n\n        Parameters\n        ----------\n        dict_like : dict\n            the (nested) dict, whoose items should be concatenated\n\n        Returns\n        -------\n\n        \"\"\"\n        for k, v in dict_like.items():\n            if isinstance(v, dict):\n                v = Predictor.__concatenate_dict_items(v)\n            else:\n                v = np.concatenate(v)\n\n            dict_like[k] = v\n\n            return dict_like\n\n    def __setattr__(self, key, value):\n        \"\"\"\n        Set attributes and guard specific attributes after they have been set\n        once\n\n        Parameters\n        ----------\n        key : str\n            the attributes name\n        value : Any\n            the value to set\n\n        Raises\n        ------\n        PermissionError\n            If attribute which should be set is guarded\n\n        \"\"\"\n\n        # check if key has been set once\n        if key in self.__KEYS_TO_GUARD and hasattr(self, key):\n            raise PermissionError(\"%s should not be overwritten after \"\n                                  \"it has been set once\" % key)\n        else:\n            super().__setattr__(key, value)\n\n    @staticmethod\n    def calc_metrics(batch: LookupConfig, metrics=None, metric_keys=None):\n        \"\"\"\n        Compute metrics\n\n        Parameters\n        ----------\n        batch: LookupConfig\n            dictionary containing the whole batch\n            (including predictions)\n        metrics: dict\n            dict with metrics\n        metric_keys : dict\n            dict of tuples which contains hashables for specifying the items\n            to use for calculating the respective metric.\n            If not specified for a metric, the keys \"pred\" and \"label\"\n            are used per default\n\n        Returns\n        -------\n        dict\n            dict with metric results\n        \"\"\"\n        if metrics is None:\n            metrics = {}\n        if metric_keys is None:\n            metric_keys = {k: (\"label\", \"pred\") for k in metrics.keys()}\n\n        return {key: metric_fn(*[batch.nested_get(k)\n                                 for k in metric_keys[key]])\n                for key, metric_fn in metrics.items()}\n\n    def register_callback(self, callback: AbstractCallback):\n        \"\"\"\n        Register Callback to Trainer\n\n        Parameters\n        ----------\n        callback : :class:`AbstractCallback`\n            the callback to register\n\n        Raises\n        ------\n        AssertionError\n            `callback` is not an instance of :class:`AbstractCallback` and has\n            not both methods ['at_iter_begin', 'at_iter_end']\n\n        \"\"\"\n        assertion_str = \"Given callback is not valid; Must be instance of \" \\\n                        \"AbstractCallback or provide functions \" \\\n                        \"'at_iter_begin' and 'at_iter_end'\"\n        instance_check = isinstance(callback, AbstractCallback)\n        attr_check_begin = hasattr(callback, \"at_iter_begin\")\n        attr_check_end = hasattr(callback, \"at_iter_end\")\n        attr_check_both = attr_check_begin and attr_check_end\n\n        assert instance_check or attr_check_both, assertion_str\n\n        self._callbacks.append(callback)\n"
  },
  {
    "path": "delira/training/utils.py",
    "content": "import collections\nimport numpy as np\n\n\ndef recursively_convert_elements(element, check_type, conversion_fn):\n    \"\"\"\n    Function to recursively convert all elements\n\n    Parameters\n    ----------\n    element : Any\n        the element to convert\n    check_type : Any\n        if ``element`` is of type ``check_type``, the conversion function will\n        be applied to it\n    conversion_fn : Any\n        the function to apply to ``element`` if it is of type ``check_type``\n\n    Returns\n    -------\n    Any\n        the converted element\n\n    \"\"\"\n\n    # convert element with conversion_fn\n    if isinstance(element, check_type):\n        return conversion_fn(element)\n\n    # return string and arrays as is\n    elif isinstance(element, (str, np.ndarray)):\n        return element\n\n    # recursively convert all keys and values of mapping and convert result\n    # back to original mapping type\n    # must be checked before iterable since most mappings are also a iterable\n    elif isinstance(element, collections.Mapping):\n        element = type(element)({\n            recursively_convert_elements(k, check_type, conversion_fn):\n                recursively_convert_elements(v, check_type, conversion_fn)\n            for k, v in element.items()\n        })\n        return element\n\n    # recursively convert all items of iterable and convert result back to\n    # original iterable type\n    elif isinstance(element, collections.Iterable):\n        element = type(element)([recursively_convert_elements(x,\n                                                              check_type,\n                                                              conversion_fn)\n                                 for x in element])\n        return element\n\n    # none of the previous cases is suitable for the element -> return as is\n    return element\n\n\ndef _correct_zero_shape(arg):\n    \"\"\"\n    Corrects the shape of numpy array to be at least 1d and returns the\n    argument as is otherwise\n\n    Parameters\n    ----------\n    arg : Any\n        the argument which must be corrected in its shape if it's\n        zero-dimensional\n\n    Returns\n    -------\n    Any\n        argument (shape corrected if necessary)\n    \"\"\"\n    if arg.shape == ():\n        arg = arg.reshape(1)\n\n    return arg\n\n\ndef convert_to_numpy_identity(*args, **kwargs):\n    \"\"\"\n    Corrects the shape of all zero-sized numpy arrays to be at least 1d\n\n    Parameters\n    ----------\n    *args :\n        positional arguments of potential arrays to be corrected\n    **kwargs :\n        keyword arguments of potential arrays to be corrected\n\n    Returns\n    -------\n\n    \"\"\"\n    args = recursively_convert_elements(args, np.ndarray, _correct_zero_shape)\n\n    kwargs = recursively_convert_elements(kwargs, np.ndarray,\n                                          _correct_zero_shape)\n\n    return args, kwargs\n"
  },
  {
    "path": "delira/utils/__init__.py",
    "content": "from delira.utils.config import DeliraConfig, Config\nfrom delira.utils.path import subdirs\nfrom delira.utils.time import now\n"
  },
  {
    "path": "delira/utils/codecs.py",
    "content": "import importlib\nimport types\nimport collections\nimport inspect\nimport numpy as np\nimport logging\nimport typing\nfrom functools import partial\nimport typing\n\n\nclass Encoder:\n    \"\"\"\n    Encode arbitrary objects. The encoded object consists of dicts,\n    lists, ints, floats and strings.\n    \"\"\"\n\n    def __call__(self, obj) -> typing.Any:\n        \"\"\"\n        Encode arbitrary objects as dicts, str, int, float, list\n\n        Parameters\n        ----------\n        obj : Any\n            object to be encoded\n\n        Returns\n        -------\n        Any\n            encoded object\n        \"\"\"\n        return self.encode(obj)\n\n    def encode(self, obj) -> typing.Any:\n        \"\"\"\n        Encode arbitrary objects as dicts, str, int, float, list\n\n        Parameters\n        ----------\n        obj : Any\n            object to be encoded\n\n        Returns\n        -------\n        Any\n            encoded object\n        \"\"\"\n        # use type() to check for dict and list because type() does not\n        # consider subtypes which is the desired behaviour in this case\n        if isinstance(obj, (str, int, float)):\n            # end recursion\n            return obj\n        elif obj is None:\n            return obj\n        elif type(obj) == dict:\n            # end recursion\n            return self._encode_dict(obj)\n        elif type(obj) == list:\n            # end recursion\n            return self._encode_list(obj)\n        elif isinstance(obj, np.ndarray):\n            return self._encode_array(obj)\n        elif isinstance(obj, collections.Mapping):\n            return self._encode_mapping(obj)\n        elif isinstance(obj, collections.Iterable):\n            return self._encode_iterable(obj)\n        elif isinstance(obj, types.ModuleType):\n            return self._encode_module(obj)\n        elif inspect.isclass(obj) or isinstance(obj, type):\n            # use both ways to determine functions here\n            # (the second uglier one serves as fallback here in case inspect\n            # does not cover all cases)\n            return self._encode_type(obj)\n        elif isinstance(obj, (types.BuiltinFunctionType, types.FunctionType)):\n            return self._encode_function(obj)\n        else:\n            return self._encode_class(obj)\n\n    def _encode_list(self, obj) -> list:\n        \"\"\"\n        Encode list\n\n        Parameters\n        ----------\n        obj : list\n            list to be encoded\n\n        Returns\n        -------\n        list\n            list with encoded internal items\n        \"\"\"\n        return [self.encode(i) for i in obj]\n\n    def _encode_dict(self, obj) -> dict:\n        \"\"\"\n        Encode dict\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be encoded\n\n        Returns\n        -------\n        dict\n            dict with encoded internal items\n        \"\"\"\n        return {self.encode(_key):\n                self.encode(_item) for _key, _item in obj.items()}\n\n    def _encode_array(self, obj) -> dict:\n        \"\"\"\n        Encode array\n\n        Parameters\n        ----------\n        obj : :class:`np.ndarray`\n            object to be encoded\n\n        Returns\n        -------\n        dict\n            array encoded as a list inside a dict\n        \"\"\"\n        # # if numpy array: add explicit array specifier\n        # use tolist instead of tostring here (even though this requires\n        # additional encoding steps and increases memory usage), since tolist\n        # retains the shape and tostring doesn't\n        return {\"__array__\": self.encode(obj.tolist())}\n\n    def _encode_mapping(self, obj) -> dict:\n        \"\"\"\n        Encode mapping\n\n        Parameters\n        ----------\n        obj : collections.Mapping\n            object to be encoded\n\n        Returns\n        -------\n        dict\n            mapping encoded as a dict with original data and type\n        \"\"\"\n        # encode via encoding the type and the mapping converted to dict\n        # separately and add a conversion specifier\n        convert_repr = {\n            \"type\": self.encode(type(obj)),\n            \"repr\": self.encode(dict(obj)),\n        }\n        return {\"__convert__\": convert_repr}\n\n    def _encode_iterable(self, obj) -> dict:\n        \"\"\"\n        Encode iterable\n\n        Parameters\n        ----------\n        obj : collections.Iterable\n            object to be encoded\n\n        Returns\n        -------\n        dict\n            iterable encoded as a dict with original data and type\n        \"\"\"\n        # encode via converting the type and the mapping converted to list\n        # separately and add conversion specifier\n        convert_repr = {\n            \"type\": self.encode(type(obj)),\n            \"repr\": self.encode(list(obj)),\n        }\n        return {\"__convert__\": convert_repr}\n\n    def _encode_module(self, obj) -> dict:\n        \"\"\"\n        Encode module\n\n        Parameters\n        ----------\n        obj : types.ModuleType\n            module to be encoded\n\n        Returns\n        -------\n        dict\n            module encoded as a dict\n        \"\"\"\n        # encode via name and module specifier\n        return {\"__module__\": obj.__module__}\n\n    def _encode_type(self, obj) -> dict:\n        \"\"\"\n        Encode class or type\n\n        Parameters\n        ----------\n        obj :\n            class/type to be encoded\n\n        Returns\n        -------\n        dict\n            class/type encoded as a dict\n        \"\"\"\n        type_repr = {\n            \"module\": self.encode(obj.__module__),\n            \"name\": self.encode(obj.__name__),\n        }\n        return {\"__type__\": type_repr}\n\n    def _encode_function(self, obj) -> dict:\n        \"\"\"\n        Encode function\n\n        Parameters\n        ----------\n        obj :\n            function to be encoded\n\n        Returns\n        -------\n        dict\n            function encoded as a dict\n        \"\"\"\n        function_repr = {\n            \"module\": self.encode(obj.__module__),\n            \"name\": self.encode(obj.__name__),\n        }\n        return {\"__function__\": function_repr}\n\n    def _encode_class(self, obj) -> dict:\n        \"\"\"\n        Encode arbitrary object\n\n        Parameters\n        ----------\n        obj :\n             arbitrary object to be encoded\n\n        Returns\n        -------\n        dict\n             arbitrary object encoded as a dict\n        \"\"\"\n        try:\n            class_repr = {\n                \"type\": self.encode(type(obj)),\n                \"dict\": self.encode(obj.__dict__)\n            }\n            return {\"__class__\": class_repr}\n        except Exception as e:\n            logging.error(e)\n\n\nclass Decoder:\n    \"\"\"\n    Deocode arbitrary objects which were encoded by :class:`Encoder`.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self._decode_mapping = {\n            \"__array__\": self._decode_array,\n            \"__convert__\": self._decode_convert,\n            \"__module__\": self._decode_module,\n            \"__type__\": self._decode_type,\n            \"__function__\": self._decode_function,\n            \"__class__\": self._decode_class,\n            \"__classargs__\": self._decode_classargs,\n            \"__functionargs__\": self._decode_functionargs\n        }\n\n    def __call__(self, obj) -> typing.Any:\n        \"\"\"\n        Decode object\n\n        Parameters\n        ----------\n        obj : Any\n            object to be decoded\n\n        Returns\n        -------\n        Any\n            decoded object\n        \"\"\"\n        return self.decode(obj)\n\n    def decode(self, obj) -> typing.Any:\n        \"\"\"\n        Decode object\n\n        Parameters\n        ----------\n        obj : Any\n            object to be decoded\n\n        Returns\n        -------\n        Any\n            decoded object\n        \"\"\"\n        if isinstance(obj, (str, int, float)):\n            return obj\n        elif isinstance(obj, dict):\n            return self._decode_dict(obj)\n        elif isinstance(obj, list):\n            return self._decode_list(obj)\n        else:\n            return obj\n\n    def _decode_dict(self, obj) -> dict:\n        \"\"\"\n        Decode dict with respect to unique identifier keys.\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        dict\n            decoded dict\n        \"\"\"\n        for key in obj.keys():\n            if key in self._decode_mapping:\n                return self._decode_mapping[key](obj[key])\n            else:\n                obj[key] = self.decode(obj[key])\n        return obj\n\n    def _decode_list(self, obj) -> list:\n        \"\"\"\n        Decode list\n\n        Parameters\n        ----------\n        obj : list\n            list to be decoded\n\n        Returns\n        -------\n        Any\n            decoded list\n        \"\"\"\n        return [self.decode(_i) for _i in obj]\n\n    def _decode_array(self, obj) -> np.ndarray:\n        \"\"\"\n        Decode np.ndarray\n\n        Parameters\n        ----------\n        obj : :class:`np.ndarray`\n            array to be decoded\n\n        Returns\n        -------\n        :class:`np.ndarray`\n            decoded array\n        \"\"\"\n        return np.array(self.decode(obj))\n\n    def _decode_convert(self, obj: dict) -> typing.Union[\n            typing.Iterable, typing.Mapping]:\n        \"\"\"\n        Decode mappings and iterables\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        typing.Union[typing.Iterable, typing.Mapping]\n            decoded object\n        \"\"\"\n        # decode items in dict representation\n        convert_repr = self.decode(obj)\n        # create new object\n        return convert_repr[\"type\"](convert_repr[\"repr\"])\n\n    def _decode_module(self, obj: dict) -> types.ModuleType:\n        \"\"\"\n        Decode module\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        ModuleType\n            decoded module\n        \"\"\"\n        return importlib.import_module(self.decode(obj))\n\n    def _decode_type(self, obj) -> typing.Any:\n        \"\"\"\n        Decode type\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        Any\n            decoded type\n        \"\"\"\n        # decode items in dict representation\n        type_repr = self.decode(obj)\n        return getattr(importlib.import_module(type_repr[\"module\"]),\n                       type_repr[\"name\"])\n\n    def _decode_function(self, obj: dict) -> typing.Union[\n            types.FunctionType, types.BuiltinFunctionType]:\n        \"\"\"\n        Decode function\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        typing.Union[types.FunctionType, types.BuiltinFunctionType]\n            decoded function\n        \"\"\"\n        # decode items in dict representation\n        function_repr = self.decode(obj)\n        return getattr(importlib.import_module(function_repr[\"module\"]),\n                       function_repr[\"name\"])\n\n    def _decode_class(self, obj: dict) -> typing.Any:\n        \"\"\"\n        Decode arbitrary object\n\n        Parameters\n        ----------\n        obj : dict\n            dict to be decoded\n\n        Returns\n        -------\n        Any\n            decoded object\n        \"\"\"\n        class_repr = self.decode(obj)\n        cls_type = class_repr[\"type\"]\n        cls_dict = class_repr[\"dict\"]\n\n        # need to create a temporary type here (which is basically a raw\n        # object, since using object directly raises\n        # \"TypeError: __class__ assignment only supported for heap types\n        # or ModuleType subclasses\"\n        # After a bit of research this kind of class re-creation only\n        # seems to be possible, if the intermediate class was created in\n        # python (which is not True for the object type since this is part\n        # of Python's C Core)\n        tmp_cls = type(\"__tmp\", (), {})\n        # create instance of temporary class\n        tmp_instance = tmp_cls()\n        # change class type\n        tmp_instance.__class__ = self.decode(cls_type)\n        # update attributes of class\n        tmp_instance.__dict__.update(self.decode(cls_dict))\n        return tmp_instance\n\n    def _decode_classargs(self, obj: dict) -> typing.Any:\n        \"\"\"\n        Create an object from specified class and arguments\n\n        Parameters\n        ----------\n        obj : dict\n            dictionary which representes the object. Must include `module` and\n            `name`. Can optionally include `args` and `kwargs`.\n\n        Returns\n        -------\n        Any\n            decoded object\n\n        Raises\n        ------\n        TypeError\n            arguments and name must be encoded as a dict\n        \"\"\"\n        classargs = self.decode(obj)\n\n        if not isinstance(classargs, dict):\n            raise TypeError(\"Arguments for classargs must be defined as dict.\")\n\n        obj_cls = getattr(importlib.import_module(classargs[\"module\"]),\n                          classargs[\"name\"])\n        args = classargs.get(\"args\", [])\n        kwargs = classargs.get(\"kwargs\", {})\n        return obj_cls(*args, **kwargs)\n\n    def _decode_functionargs(self, obj: dict) -> typing.Any:\n        \"\"\"\n        Create an function from specified function and arguments\n\n        Parameters\n        ----------\n        obj : dict\n            dictionary which representes the function. Must include `module`\n            and `name`. Can optionally include `args` and `kwargs` which are\n            passed via `functool.partial`.\n\n        Returns\n        -------\n        Any\n            decoded function\n\n        Raises\n        ------\n        TypeError\n            arguments and name must be encoded as a dict\n        \"\"\"\n        functionargs = self.decode(obj)\n\n        if not isinstance(functionargs, dict):\n            raise TypeError(\"Arguments for classargs must be defined as dict.\")\n\n        fn = getattr(importlib.import_module(functionargs[\"module\"]),\n                     functionargs[\"name\"])\n        args = functionargs.get(\"args\", [])\n        kwargs = functionargs.get(\"kwargs\", {})\n        return partial(fn, args, kwargs)\n"
  },
  {
    "path": "delira/utils/config.py",
    "content": "import copy\nfrom delira._version import get_versions\nfrom delira.utils.time import now\nfrom nested_lookup import nested_lookup\nimport warnings\nfrom .codecs import Encoder, Decoder\n\nimport yaml\nimport argparse\nimport sys\nimport collections\nimport inspect\n\n\ndef non_string_warning(func):\n    def warning_wrapper(config, key, *args, **kwargs):\n        \"\"\"\n        Emit warning if non string keys are used\n\n        Parameters\n        ----------\n        config: :class:`Config`\n            decorated function receive :param:`self` as first argument\n        key : immutable type\n            key which is checked\n\n        Returns\n        -------\n        callable\n            original function with arguments\n        \"\"\"\n        if not isinstance(key, str):\n            warnings.warn(\"The key {} is not a string, but a {}. \"\n                          \"This may lead to unwanted behavior!\".format(\n                              key, type(key)), RuntimeWarning)\n\n        return func(config, key, *args, **kwargs)\n\n    return warning_wrapper\n\n\nclass Config(dict):\n    \"\"\"\n    Baseclass to create a config which hold arbitrary data\n    \"\"\"\n\n    def __init__(self, dict_like=None, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        dict_like : dict, optional\n            dict like object to initialize config, by default None\n        kwargs:\n            additional arguments added to the config\n\n        Warnings\n        --------\n        It is recommended to only use strings as keys inside the config.\n        Because of the shortened access to nested keys the types of the\n        keys are lost.\n\n        Examples\n        --------\n        Create simple configuration with nested keys\n        >>> from delira.utils import Config\n        >>> cf = Config()\n        >>> # automatically generates new nested dictionaries\n        >>> cf['first_level.second_level.third_level'] = 1\n        >>> # form access\n        >>> print(cf['first_level.second_level.third_level'])\n        >>> # traditional access\n        >>> print(cf['first_level']['second_level']['third_level'])\n        >>> # entries can also be accessed with dot operator\n        >>> print(cf.first_level.second_level.thirs_level)\n        \"\"\"\n\n        super().__init__()\n        self.__dict__ = self\n        if dict_like is not None:\n            self.update(dict_like)\n        self.update(kwargs)\n\n    @non_string_warning\n    def __setattr__(self, key, value):\n        \"\"\"\n        Set attribute in config\n\n        Parameters\n        ----------\n        key : str\n            attribute name\n        value : any\n            attribute value\n\n        \"\"\"\n        super().__setattr__(key, self._to_config(value))\n\n    @non_string_warning\n    def __setitem__(self, key, value):\n        \"\"\"\n        Set items inside dict. Supports setting of nested entries by\n        seperating the individual keys with a '.'.\n\n        Parameters\n        ----------\n        key : str\n            key for new value\n        value : any\n            new value\n        \"\"\"\n        if not isinstance(key, str) or '.' not in key:\n            super().__setitem__(key, value)\n        else:\n            current_level = self\n            keys = key.split(\".\")\n            final_key = keys.pop(-1)\n            final_dict = self._traverse_keys(keys, create=True)\n            final_dict._set_internal_item(final_key, value)\n\n    def _traverse_keys(self, keys, create=False):\n        \"\"\"\n        Internal helper to traverse through nested dicts\n        (iterative implementation to avoid problems with python stack)\n\n        Parameters\n        ----------\n        keys : iterable of list\n            iterable with keys which should be traversed\n        create : bool, optional\n            creates new empty configs for non existant keys, by default False\n\n        Returns\n        -------\n        Any\n            value defined by the traversed keys\n        \"\"\"\n        current_level = self\n        for k in keys:\n            if k not in current_level:\n                if create:\n                    current_level[k] = self._create_internal_dict()\n                else:\n                    raise KeyError(\n                        \"{} was not found in internal dict.\".format(k))\n            # traverse to needed dict\n            current_level = current_level[k]\n        return current_level\n\n    def _set_internal_item(self, key, item, deepcopy=False):\n        \"\"\"\n        Set internal item\n\n        Parameters\n        ----------\n        key : str\n            key where new item should be assigned\n        item : Any\n            item which should be assigned\n        deepcopy : bool, optional\n            if enabled the item is copied to the config, by default False\n        \"\"\"\n        config_item = self._to_config(item)\n        if deepcopy:\n            self[key] = copy.deepcopy(config_item)\n        else:\n            self[key] = config_item\n\n    @classmethod\n    def _to_config(cls, item):\n        \"\"\"\n        Convert items to config if they are a dict like object\n        but not already a config\n\n        Parameters\n        ----------\n        item : Any\n            item which is converted\n\n        Returns\n        -------\n        Any\n            return a config is item is dict like, otherwise the item is\n            returned\n        \"\"\"\n        if isinstance(item, dict) and not isinstance(item, cls):\n            # convert dict to config for additional functionality\n            return cls._create_internal_dict(item)\n        else:\n            return item\n\n    @staticmethod\n    def _create_internal_dict(*args, **kwargs):\n        \"\"\"\n        Defines how internal dicts should be created. Can be used to easily\n        overwrite subclasses\n\n        Returns\n        -------\n        :class:`Config`\n            new config\n        \"\"\"\n        return Config(*args, **kwargs)\n\n    @non_string_warning\n    def __getitem__(self, key):\n        \"\"\"\n        Get single item\n\n        Parameters\n        ----------\n        key : str\n            key to desired item\n\n        Returns\n        -------\n        Any\n            value inside dict\n        \"\"\"\n        if not isinstance(key, str) or '.' not in key:\n            try:\n                return super().__getitem__(int(key))\n            except (KeyError, ValueError):\n                return super().__getitem__(key)\n        else:\n            return self._traverse_keys(key.split(\".\"), create=False)\n\n    @non_string_warning\n    def __contains__(self, key):\n        \"\"\"\n        Check if key is in config\n        (also works for nested dicts with short form)\n\n        Parameters\n        ----------\n        key : str\n            key for desired value\n\n        Returns\n        -------\n        bool\n            true if key is in config\n        \"\"\"\n        contain = True\n        try:\n            self[key]\n        except KeyError:\n            contain = False\n        return contain\n\n    def update(self, update_dict, deepcopy=False, overwrite=False):\n        \"\"\"\n        Update internal dicts with dict like object\n\n        Parameters\n        ----------\n        update_dict : dictlike\n            values which should be added to config\n        deepcopy : bool, optional\n            copies values from :param:`update_dict`, by default False\n        overwrite : bool, optional\n            overwrite existing values inside config, by default False\n\n        Raises\n        ------\n        ValueError\n            if overwrite is not enabled and `update_dict` contains same values\n            as config\n        \"\"\"\n        for key, item in update_dict.items():\n            # update items individually\n            self._update(key, item, deepcopy=deepcopy, overwrite=overwrite)\n\n    def _update(self, key, item, deepcopy=False, overwrite=False):\n        \"\"\"\n        Helper function for update\n\n        Parameters\n        ----------\n        key : str\n            key where new item should be assigned\n        item : Any\n            item which should be assigned\n        deepcopy : bool, optional\n            copies :param:`item`, by default False\n        overwrite : bool, optional\n            overwrite existing values inside config, by default False\n        \"\"\"\n        if isinstance(item, dict):\n            # update nested dicts\n            if key not in self:\n                self[key] = self._create_internal_dict({})\n            self[key].update(item, deepcopy=deepcopy, overwrite=overwrite)\n        else:\n            # check for overwrite\n            self._raise_overwrite(key, overwrite=overwrite)\n            # set item\n            self._set_internal_item(key, item, deepcopy=deepcopy)\n\n    def _raise_overwrite(self, key, overwrite):\n        \"\"\"\n        Checks if a ValueError should be raised\n\n        Parameters\n        ----------\n        key : str\n            key which needs to be checked\n        overwrite : bool\n            if overwrite is enabled no ValueError is raised even if the key\n            already exists\n\n        Raises\n        ------\n        ValueError\n            raised if overwrite is not enabled and key already exists\n        \"\"\"\n        if key in self and not overwrite:\n            raise ValueError(\"{} already in config. Can \"\n                             \"not overwrite value.\".format(key))\n\n    def dump(self, path, formatter=yaml.dump, encoder_cls=Encoder, **kwargs):\n        \"\"\"\n        Save config to a file and add time stamp to config\n\n        Parameters\n        ----------\n        path : str\n            path where config is saved\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        encoder_cls : :class:`Encoder`, optional\n            transforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n        \"\"\"\n        self._timestamp = now()\n        encoded_self = encoder_cls().encode(self)\n        with open(path, \"w\") as f:\n            formatter(encoded_self, f, **kwargs)\n\n    def dumps(self, formatter=yaml.dump, encoder_cls=Encoder, **kwargs):\n        \"\"\"\n        Create a loadable string representation from the config and\n        add time stamp to config\n\n        Parameters\n        ----------\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        encoder_cls : :class:`Encoder`, optional\n            transforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n        \"\"\"\n        self._timestamp = now()\n        encoded_self = encoder_cls().encode(self)\n        return formatter(encoded_self, **kwargs)\n\n    def load(self, path, formatter=yaml.load, decoder_cls=Decoder, **kwargs):\n        \"\"\"\n        Update config from a file\n\n        Parameters\n        ----------\n        path : str\n            path to file\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        decoder_cls : :class:`Encoder`, optional\n            transforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n        \"\"\"\n        with open(path, \"r\") as f:\n            decoded_format = formatter(f, **kwargs)\n        decoded_format = decoder_cls().decode(decoded_format)\n        self.update(decoded_format, overwrite=True)\n\n    def loads(self, data, formatter=yaml.load, decoder_cls=Decoder, **kwargs):\n        \"\"\"\n        Update config from a string\n\n        Parameters\n        ----------\n        data: str\n            string representation of config\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        decoder_cls : :class:`Encoder`, optional\n            transforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n        \"\"\"\n        decoded_format = formatter(data, **kwargs)\n        decoded_format = decoder_cls().decode(decoded_format)\n        self.update(decoded_format, overwrite=True)\n\n    @classmethod\n    def create_from_dict(cls, value, deepcopy=False):\n        \"\"\"\n        Create config from dict like object\n\n        Parameters\n        ----------\n        value : dict like\n            dict like object used to create new config\n        deepcopy : bool, optional\n            if enabled, copies values from origin, by default False\n\n        Returns\n        -------\n        :class:`Config`\n            new config\n\n        Raises\n        ------\n        TypeError\n            raised if :param:`value` is not a dict (or a subclass of dict)\n        \"\"\"\n        if not isinstance(value, dict):\n            raise TypeError(\"Value must be an instance of dict but type {} \"\n                            \"was found.\".format(type(value)))\n        config = cls()\n        config.update(value, deepcopy=deepcopy)\n        return config\n\n    @classmethod\n    def create_from_argparse(cls, value, deepcopy=False, **kwargs):\n        \"\"\"\n        Create config from argument parser\n\n        Parameters\n        ----------\n        value : argument parser or namespace\n            if value is an argument parser, the arguments are first parsed\n            and than a new config with the values is created\n            if value is a Namespace the new config is created immediatly\n        deepcopy : bool, optional\n            if enabled, copies values from origin, by default False\n\n        Returns\n        -------\n        :class:`Config`\n            new config\n\n        Raises\n        ------\n        TypeError\n            if value is not an instance of :class:`ArgumentParser`\n            or :class:`Namespace`\n        \"\"\"\n        if isinstance(value, argparse.ArgumentParser):\n            args_parsed = value.parse_args(**kwargs)\n            return cls.create_from_argparse(args_parsed, deepcopy=deepcopy)\n        elif isinstance(value, argparse.Namespace):\n            return cls.create_from_dict(vars(value), deepcopy=deepcopy)\n        else:\n            raise TypeError(\"Type of args not supported.\")\n\n    @classmethod\n    def create_from_file(cls, path, formatter=yaml.load, decoder_cls=Decoder,\n                         **kwargs):\n        \"\"\"\n        Create config from a file\n\n        Parameters\n        ----------\n        path : str\n            path to file\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        decoder_cls : :class:`Encoder`, optional\n            trasforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n\n        Returns\n        -------\n        :class:`Config`\n            new config\n        \"\"\"\n        config = cls()\n        config.load(path, formatter=formatter, decoder_cls=decoder_cls,\n                    **kwargs)\n        return config\n\n    @classmethod\n    def create_from_str(cls, data, formatter=yaml.load, decoder_cls=Decoder,\n                        **kwargs):\n        \"\"\"\n        Create config from a string\n\n        Parameters\n        ----------\n        data: str\n            string representation of config\n        formatter : callable, optional\n            defines the format how the config is saved, by default yaml.dump\n        decoder_cls : :class:`Encoder`, optional\n            trasforms config to a format which can be formatted by the\n            :param:`formatter`, by default Encoder\n        kwargs:\n            additional keyword arguments passed to :param:`formatter`\n\n        Returns\n        -------\n        :class:`Config`\n            new config\n        \"\"\"\n        config = cls()\n        config.loads(data, formatter=formatter, decoder_cls=decoder_cls,\n                     **kwargs)\n        return config\n\n    def create_argparser(self):\n        '''\n        Creates an argparser for all values in the config\n        Following the pattern: `--training.learning_rate 1234`\n\n        Returns\n        -------\n        argparse.ArgumentParser\n            parser for all variables in the config\n        '''\n        parser = argparse.ArgumentParser(allow_abbrev=False)\n\n        def add_val(dict_like, prefix=''):\n            for key, val in dict_like.items():\n                name = \"--{}\".format(prefix + key)\n                if val is None:\n                    parser.add_argument(name)\n                else:\n                    if isinstance(val, int):\n                        parser.add_argument(name, type=type(val))\n                    elif isinstance(val, collections.Mapping):\n                        add_val(val, prefix=key + '.')\n                    elif isinstance(val, collections.Iterable):\n                        if len(val) > 0 and type(val[0]) != type:\n                            parser.add_argument(name, type=type(val[0]))\n                        else:\n                            parser.add_argument(name)\n                    elif issubclass(val, type) or inspect.isclass(val):\n                        parser.add_argument(name, type=val)\n                    else:\n                        parser.add_argument(name, type=type(val))\n\n        add_val(self)\n        return parser\n\n    @staticmethod\n    def _add_unknown_args(unknown_args):\n        '''\n        Can add unknown args as parsed by argparsers method\n        `parse_unknown_args`.\n\n        Parameters\n        ------\n        unknown_args : list\n            list of unknown args\n        Returns\n        ------\n        Config\n            a config of the parsed args\n        '''\n        # first element in the list must be a key\n        if not isinstance(unknown_args[0], str):\n            unknown_args = [str(arg) for arg in unknown_args]\n        if not unknown_args[0].startswith('--'):\n            raise ValueError\n\n        args = Config()\n        # take first key\n        key = unknown_args[0][2:]\n        idx, done, val = 1, False, []\n        while not done:\n            try:\n                item = unknown_args[idx]\n            except IndexError:\n                done = True\n            if item.startswith('--') or done:\n                # save key with its value\n                if len(val) == 0:\n                    # key is used as flag\n                    args[key] = True\n                elif len(val) == 1:\n                    args[key] = val[0]\n                else:\n                    args[key] = val\n                # new key and flush data\n                key = item[2:]\n                val = []\n            else:\n                val.append(item)\n            idx += 1\n        return args\n\n    def update_from_argparse(self, parser=None, add_unknown_items=False):\n        '''\n        Updates the config with all values from the command line.\n        Following the pattern: `--training.learning_rate 1234`\n\n        Raises\n        ------\n        TypeError\n            raised if another datatype than currently in the config is parsed\n        Returns\n        -------\n        dict\n            dictionary containing only updated arguments\n        '''\n\n        if len(sys.argv) > 1:\n            if not parser:\n                parser = self.create_argparser()\n\n            params, unknown = parser.parse_known_args()\n            params = vars(params)\n            if unknown and not add_unknown_items:\n                warnings.warn(\n                    \"Called with unknown arguments: {} \"\n                    \"They will not be stored if you do not set \"\n                    \"`add_unknown_items` to true.\".format(unknown),\n                    RuntimeWarning)\n\n            new_params = Config()\n            for key, val in params.items():\n                if val is None:\n                    continue\n                new_params[key] = val\n\n            # update dict\n            self.update(new_params, overwrite=True)\n            if add_unknown_items:\n                additional_params = self._add_unknown_args(unknown)\n                self.update(additional_params)\n                new_params.update(additional_params)\n            return new_params\n\n\nclass LookupConfig(Config):\n    \"\"\"\n    Helper class to have nested lookups in all subdicts of Config\n    \"\"\"\n\n    @staticmethod\n    def _create_internal_dict(*args, **kwargs):\n        \"\"\"\n        Defines how internal dicts should be created. Can be used to easily\n        overwrite subclasses\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            new config\n        \"\"\"\n        return LookupConfig(*args, **kwargs)\n\n    @non_string_warning\n    def __contains__(self, key):\n        \"\"\"\n        Check if key is in config\n        (also works for nested dicts with short form)\n\n        Parameters\n        ----------\n        key : str\n            key for desired value\n\n        Returns\n        -------\n        bool\n            true if key is in config\n        \"\"\"\n        contain = True\n        try:\n            self.nested_get(key, allow_multiple=True)\n        except KeyError:\n            contain = False\n        return contain\n\n    def nested_get(self, key, *args, allow_multiple=False, **kwargs):\n        \"\"\"\n        Returns all occurances of :param:`key` in :param:`self` and subdicts\n\n        Parameters\n        ----------\n        key : str\n            the key to search for\n        *args :\n            positional arguments to provide default value\n        allow_multiple: bool\n            allow multiple results\n        **kwargs :\n            keyword arguments to provide default value\n\n        Raises\n        ------\n        KeyError\n            Multiple Values are found for key and :param:`allow_multiple` is\n            False (unclear which value should be returned)\n            OR\n            No Value was found for key and no default value was given\n\n        Returns\n        -------\n        Any\n            value corresponding to key (or default if value was not found)\n\n        \"\"\"\n\n        if \".\" in key:\n            return self[key]\n        results = nested_lookup(key, self)\n        if len(results) > 1:\n            if allow_multiple:\n                return results\n            else:\n                raise KeyError(\"Multiple Values found for key %s\" % key)\n        elif len(results) == 0:\n            if \"default\" in kwargs:\n                return kwargs[\"default\"]\n            elif args:\n                return args[0]\n            else:\n                raise KeyError(\"No Value found for key %s\" % key)\n        else:\n            return results[0]\n\n\nclass DeliraConfig(LookupConfig):\n    \"\"\"\n    Configure experiment for delira. Contains variables for model and training\n    which can be either fixed or variables (for hyperparameter search)\n    \"\"\"\n\n    def __init__(self, dict_like=None, fixed_model=None, fixed_training=None,\n                 variable_model=None, variable_training=None, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        dict_like : dict, optional\n            dict like object containing values for config, by default None.\n        fixed_model : dict, optional\n            fixed parameters for model, by default None.\n        fixed_training : dict, optional\n            fixed parameters for training, by default None.\n        variable_model : dict, optional\n            variable parameters for model, by default None.\n        variable_training : dict, optional\n            variable parameters for training, by default None.\n        kwargs:\n            additional arguments added to the config\n        \"\"\"\n        super().__init__(dict_like=dict_like, **kwargs)\n        self._update(\"fixed_model\", self.generate_dict(fixed_model),\n                     overwrite=True)\n        self._update(\"fixed_training\", self.generate_dict(fixed_training),\n                     overwrite=True)\n        self._update(\"variable_model\", self.generate_dict(variable_model),\n                     overwrite=True)\n        self._update(\n            \"variable_training\",\n            self.generate_dict(variable_training),\n            overwrite=True)\n        self._version = get_versions()[\"version\"]\n\n    @staticmethod\n    def generate_dict(value):\n        \"\"\"\n        If value is none an emty dict will be created\n\n        Parameters\n        ----------\n        value : Any\n            checked value\n\n        Returns\n        -------\n        Any\n            dict if value is none otherwise value is returned\n        \"\"\"\n        if value is None:\n            return {}\n        else:\n            return dict(value)\n\n    @property\n    def params(self):\n        \"\"\"\n        Returns a :class:`LookupConfig` with all model and training parameters\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            config with model and training parameters\n        \"\"\"\n        return LookupConfig(fixed_model=self.fixed_model,\n                            fixed_training=self.fixed_training,\n                            variable_model=self.variable_model,\n                            variable_training=self.variable_training)\n\n    @property\n    def variable_params(self):\n        \"\"\"\n        Returns a :class:`LookupConfig` with all variable parameters\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            config with variable parameters\n        \"\"\"\n        return LookupConfig(model=self.variable_model,\n                            training=self.variable_training)\n\n    @variable_params.setter\n    def variable_params(self, new_params: dict):\n        \"\"\"\n        Update variable parameters from dict like object\n\n        Raises\n        ------\n        TypeError\n            raised if :param:`new_params` is not a dict (or a subclass of dict)\n        \"\"\"\n        if not isinstance(new_params, dict):\n            raise TypeError(\"new_params must be an instance of dict but \"\n                            \"type {} was found.\".format(type(new_params)))\n\n        # create empty dict\n        if \"model\" not in new_params:\n            new_params[\"model\"] = {}\n\n        # create empty dict\n        if \"training\" not in new_params:\n            new_params[\"training\"] = {}\n\n        self.variable_model = new_params[\"model\"]\n        self.variable_training = new_params[\"training\"]\n\n    @property\n    def fixed_params(self):\n        \"\"\"\n        Returns a :class:`LookupConfig` with all fixed parameters\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            config with fixed parameters\n        \"\"\"\n        return LookupConfig(model=self.fixed_model,\n                            training=self.fixed_training)\n\n    @fixed_params.setter\n    def fixed_params(self, new_params: dict):\n        \"\"\"\n        Update fixed parameters from dict like object\n\n        Raises\n        ------\n        TypeError\n            raised if :param:`new_params` is not a dict (or a subclass of dict)\n        \"\"\"\n        if not isinstance(new_params, dict):\n            raise TypeError(\"new_params must be an instance of dict but \"\n                            \"type {} was found.\".format(type(new_params)))\n        # create empty dict\n        if \"model\" not in new_params:\n            new_params[\"model\"] = {}\n\n        # create empty dict\n        if \"training\" not in new_params:\n            new_params[\"training\"] = {}\n\n        self.fixed_model = new_params[\"model\"]\n        self.fixed_training = new_params[\"training\"]\n\n    @property\n    def model_params(self):\n        \"\"\"\n        Returns a :class:`LookupConfig` with all model parameters\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            config with model parameters\n        \"\"\"\n        return LookupConfig(variable=self.variable_model,\n                            fixed=self.fixed_model)\n\n    @model_params.setter\n    def model_params(self, new_params: dict):\n        \"\"\"\n        Update model parameters from dict like object\n\n        Raises\n        ------\n        TypeError\n            raised if :param:`new_params` is not a dict (or a subclass of dict)\n        \"\"\"\n        if not isinstance(new_params, dict):\n            raise TypeError(\"new_params must be an instance of dict but \"\n                            \"type {} was found.\".format(type(new_params)))\n        # create empty dict\n        if \"fixed\" not in new_params:\n            new_params[\"fixed\"] = {}\n\n        # create empty dict\n        if \"variable\" not in new_params:\n            new_params[\"variable\"] = {}\n\n        self.fixed_model = new_params[\"fixed\"]\n        self.variable_model = new_params[\"variable\"]\n\n    @property\n    def training_params(self):\n        \"\"\"\n        Returns a :class:`LookupConfig` with all training parameters\n\n        Returns\n        -------\n        :class:`LookupConfig`\n            config with training parameters\n        \"\"\"\n        return LookupConfig(variable=self.variable_training,\n                            fixed=self.fixed_training)\n\n    @training_params.setter\n    def training_params(self, new_params: dict):\n        \"\"\"\n        Update training parameters from dict like object\n\n        Raises\n        ------\n        TypeError\n            raised if :param:`new_params` is not a dict (or a subclass of dict)\n        \"\"\"\n        if not isinstance(new_params, dict):\n            raise TypeError(\"new_params must be an instance of dict but \"\n                            \"type {} was found.\".format(type(new_params)))\n        # create empty dict\n        if \"fixed\" not in new_params:\n            new_params[\"fixed\"] = {}\n\n        # create empty dict\n        if \"variable\" not in new_params:\n            new_params[\"variable\"] = {}\n\n        self.fixed_training = new_params[\"fixed\"]\n        self.variable_training = new_params[\"variable\"]\n\n    def log_as_string(self, full_config=False, **kwargs):\n        \"\"\"\n        Log current config as a string\n\n        Parameters\n        ----------\n        full_config : bool, optional\n            if enabled the complete Config is logged, by default False.\n            Otherwise only model and training parameters will be logged.\n        kwargs:\n            keyword arguments passed to `self.dumps` method to create string\n            representation\n\n        Returns\n        -------\n        str\n            string representation used for logging\n        \"\"\"\n        from delira.logging import log\n\n        if full_config:\n            str_repr = self.dumps(**kwargs)\n        else:\n            str_repr = self.params.dumps(**kwargs)\n        log({'text': {\"text_string\": str_repr, \"tag\": \"DeliraConfig\"}})\n        return str_repr\n"
  },
  {
    "path": "delira/utils/context_managers.py",
    "content": "from delira import get_current_debug_mode, set_debug_mode\n\n\nclass DebugMode(object):\n    \"\"\"\n    Context Manager to set a specific debug mode for the code inside the\n    defined context (and reverting to previous mode afterwards)\n\n    \"\"\"\n\n    def __init__(self, mode):\n        \"\"\"\n\n        Parameters\n        ----------\n        mode : bool\n            the debug mode; if ``True`` disables all multiprocessing\n        \"\"\"\n        self._mode = mode\n\n    def _switch_to_new_mode(self):\n        \"\"\"\n        helper function to switch to the new debug mode\n        (and saving the previous one in ``self._mode``)\n\n        \"\"\"\n        prev_mode = get_current_debug_mode()\n        set_debug_mode(self._mode)\n        self._mode = prev_mode\n\n    def __enter__(self):\n        \"\"\"\n        Sets the specified debug mode on entering the context\n        \"\"\"\n        self._switch_to_new_mode()\n\n    def __exit__(self, *args, **kwargs):\n        \"\"\"\n        Resets the previous debug mode on exiting the context\n\n        Parameters\n        ----------\n        *args :\n            arbitrary positional arguments\n            (ignored here, just needed for compatibility with other context\n            managers)\n        **kwargs :\n            arbitrary keyword arguments\n            (ignored here, just needed for compatibility with other context\n            managers)\n\n        \"\"\"\n        self._switch_to_new_mode()\n\n\nclass DebugEnabled(DebugMode):\n    \"\"\"\n    Context Manager to enable the debug mode for the wrapped context\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(True)\n\n\nclass DebugDisabled(DebugMode):\n    \"\"\"\n    Context Manager to disable the debug mode for the wrapped context\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(False)\n"
  },
  {
    "path": "delira/utils/decorators.py",
    "content": "import warnings\nfrom functools import wraps\n\nimport numpy as np\n\nfrom delira import get_backends\n\n\ndef dtype_func(class_object):\n    \"\"\"\n    Decorator to Check whether the first argument of the decorated function is\n    of a certain type\n\n    Parameters\n    ----------\n    class_object : Any\n        type the first function argument should have\n\n    Returns\n    -------\n    Wrapped Function\n\n    Raises\n    ------\n    AssertionError\n        First argument of decorated function is not of given type\n\n    \"\"\"\n    def instance_checker(func):\n        @wraps(func)\n        def func_wrapper(checked_object, *args, **kwargs):\n            assertion_str = \"Argument 1 is not of type %s but of type %s\" % \\\n                            (class_object.__name__,\n                             checked_object.__class__.__name__)\n\n            assert isinstance(checked_object, class_object), assertion_str\n            return func(checked_object, *args, **kwargs)\n        return func_wrapper\n    return instance_checker\n\n\ndef classtype_func(class_object):\n    \"\"\"\n    Decorator to Check whether the first argument of the decorated function is\n    a subclass of a certain type\n\n    Parameters\n    ----------\n    class_object : Any\n        type the first function argument should be subclassed from\n\n    Returns\n    -------\n    Wrapped Function\n\n    Raises\n    ------\n    AssertionError\n        First argument of decorated function is not a subclass of given type\n\n    \"\"\"\n    def subclass_checker(func):\n        @wraps(func)\n        def func_wrapper(checked_object, *args, **kwargs):\n            assertion_str = \"Argument 1 is not subclass of %s but of type %s\" \\\n                            % (class_object.__name__, checked_object.__name__)\n\n            assert issubclass(checked_object, class_object), assertion_str\n            return func(checked_object, *args, **kwargs)\n        return func_wrapper\n    return subclass_checker\n\n\ndef make_deprecated(new_func):\n    \"\"\"\n    Decorator which raises a DeprecationWarning for the decorated object\n\n    Parameters\n    ----------\n    new_func : Any\n        new function which should be used instead of the decorated one\n\n    Returns\n    -------\n    Wrapped Function\n\n    Raises\n    ------\n    Deprecation Warning\n\n    \"\"\"\n    def deprecation(func):\n        @wraps(func)\n        def func_wrapper(*args, **kwargs):\n            if not isinstance(new_func, str):\n                new_func_name = new_func.__name__\n            else:\n                new_func_name = new_func\n\n            if func.__name__ == '__init__':\n                old_func_name = func.__class__.__name__\n            else:\n                old_func_name = func.__name__\n            warnings.warn(DeprecationWarning(\"%s is deprecated in favor of %s\"\n                                             \" and will be removed at next \"\n                                             \"release\" % (old_func_name,\n                                                          new_func_name)))\n            return func(*args, **kwargs)\n\n        return func_wrapper\n    return deprecation\n\n\nnumpy_array_func = dtype_func(np.ndarray)\n\n\nif \"TORCH\" in get_backends():\n    import torch\n    torch_tensor_func = dtype_func(torch.Tensor)\n    torch_module_func = dtype_func(torch.nn.Module)\n"
  },
  {
    "path": "delira/utils/dict_reductions.py",
    "content": "from collections import MutableMapping\nfrom typing import Union, Dict, Callable\nimport numpy as np\n\n\n# Reduction Functions\ndef reduce_last(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the last element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return items[-1]\n\n\ndef reduce_first(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the first element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return items[0]\n\n\ndef reduce_mean(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the mean element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return np.mean(items)\n\n\ndef reduce_median(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the median element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return np.median(items)\n\n\ndef reduce_max(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the max element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return np.max(items)\n\n\ndef reduce_min(items: list) -> Union[float, int, np.ndarray]:\n    \"\"\"\n    Reduction Function returning the min element\n\n    Parameters\n    ----------\n    items : list\n        the items to reduce\n\n    Returns\n    -------\n    float, int or :class:`numpy.ndarray`\n        reduced items\n\n    \"\"\"\n    return np.min(items)\n\n\ndef flatten_dict(d: dict, parent_key: str = '', sep: str = '.') -> dict:\n    \"\"\"\n    Flattens a dictionary by concatenating all keys for subdicts with the\n    current key separated by :param`sep`\n\n    Parameters\n    ----------\n    d : dict\n        the dictionary to flatten\n    parent_key : str\n        the key of the parent dict (ususally empty when called by user)\n    sep : str\n        the separator to separate the key from the subdict's key\n\n    Returns\n    -------\n    dict\n        the flattened dictionary\n\n    \"\"\"\n    items = []\n    for k, v in d.items():\n        new_key = parent_key + sep + k if parent_key else k\n        if isinstance(v, MutableMapping):\n            items.extend(flatten_dict(v, new_key, sep=sep).items())\n        else:\n            items.append((new_key, v))\n\n    return type(d)(items)\n\n\ndef unflatten_dict(dictionary: dict, sep: str = \".\") -> dict:\n    \"\"\"\n    Unflattens a dict, where keys and the keys from their subdirs are\n    separated by :param:`sep`\n\n    Parameters\n    ----------\n    dictionary : dict\n        the dictionary to unflatten\n    sep : str\n        the separation string\n\n    Returns\n    -------\n\n    \"\"\"\n    return_dict = {}\n    for key, value in dictionary.items():\n        parts = key.split(sep)\n        d = return_dict\n        for part in parts[:-1]:\n            if part not in d:\n                d[part] = dict()\n            d = d[part]\n        d[parts[-1]] = value\n    return return_dict\n\n\ndef reduce_dict(items: list, reduce_fn) -> dict:\n    \"\"\"\n    A function to reduce all entries inside a dict\n\n    Parameters\n    ----------\n    items : list\n        a list of dicts to reduce\n    reduce_fn : FunctionType\n        a function to apply to all non-equal iterables\n\n    Returns\n    -------\n    dict\n        the reduced dict\n\n    \"\"\"\n\n    result_dict = {}\n    # assuming the type of all items is same for all queued logging dicts and\n    # all dicts have the same keys\n\n    flattened_dicts = [flatten_dict(_tmp, sep=\".\") for _tmp in items]\n\n    # from list of dicts to dict of lists:\n    for d in flattened_dicts:\n        for k, v in d.items():\n            try:\n                result_dict[k].append(v)\n            except KeyError:\n                result_dict[k] = [v]\n\n    for k, v in result_dict.items():\n        # check if all items are equal\n        equals = [_v == v[0] for _v in v[1:]]\n        for idx, equality in enumerate(equals):\n            if isinstance(equality, np.ndarray):\n                equals[idx] = equality.all()\n        if all(equals):\n            # use first item since they are equal\n            result_dict[k] = v[0]\n        else:\n            # apply reduce function\n            result_dict[k] = reduce_fn(v)\n\n    # unflatten reduced dict\n    return unflatten_dict(result_dict, sep=\".\")\n\n\n# string mapping for reduction functions\n_REDUCTION_FUNCTIONS = {\n    \"last\": reduce_last,\n    \"first\": reduce_first,\n    \"mean\": reduce_mean,\n    \"median\": reduce_median,\n    \"max\": reduce_max,\n    \"min\": reduce_min\n}\n\n\ndef possible_reductions() -> tuple:\n    \"\"\"\n    Function returning a tuple containing all valid reduction strings\n\n    Returns\n    -------\n    tuple\n        a tuple containing all valid reduction strings\n    \"\"\"\n    return tuple(_REDUCTION_FUNCTIONS.keys())\n\n\ndef get_reduction(reduce_type: str) -> Callable:\n    \"\"\"\n    A getter function to get a specified reduction function by it's\n    specifier string\n\n    Parameters\n    ----------\n    reduce_type : str\n        the reduction type\n\n    Returns\n    -------\n    Callable\n        the actual reduction function\n\n    \"\"\"\n    return _REDUCTION_FUNCTIONS[reduce_type]\n"
  },
  {
    "path": "delira/utils/messenger.py",
    "content": "import logging\nimport warnings\nfrom abc import ABC, abstractmethod\n\nfrom delira.training import BaseExperiment\nfrom delira.training.callbacks import AbstractCallback\n\n\nclass BaseMessenger(ABC):\n    \"\"\"\n    Wrap arbitrary experiments and connect its functions to a\n    notification service.\n    \"\"\"\n\n    def __init__(self, experiment: BaseExperiment, notify_epochs: int = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        experiment : :class:`BaseExperiment`\n            instance of current experiment\n        notify_epochs : int\n            Activates notifications about finished epochs with frequency\n            `notify_epochs`.\n        \"\"\"\n        super().__init__()\n        self._experiment = experiment\n        self._notify_epochs = notify_epochs\n\n    @abstractmethod\n    def emit_message(self, msg: str) -> dict:\n        \"\"\"\n        Emit message.\n        Implement this method in base class to create new notification\n        services.\n\n        Parameters\n        ----------\n        msg : str\n            message which should be emitted\n\n        Returns\n        -------\n        dict\n            dict with additional information from message\n        \"\"\"\n        raise NotImplementedError\n\n    def __getattr__(self, attr):\n        \"\"\"\n        If wrapper does not implement attribute, return attribute of wrapped\n        object\n\n        Parameters\n        ----------\n        attr : str\n            name of attribute\n\n        Returns\n        -------\n        Any\n            attribute\n        \"\"\"\n        # NOTE do note use hasattr, it goes into infinite recursion\n        if attr in self.__dict__:\n            # this object has it\n            return getattr(self, attr)\n        return getattr(self._experiment, attr)\n\n    def run(self, *args, **kwargs):\n        \"\"\"\n        Wrapper for run function. Notifies experiment start, fail, complete.\n\n        Parameters\n        ----------\n        args :\n            positional arguments passed to experiment.\n        kwargs :\n            additional keyword arguments passed to experiment.\n\n        Returns\n        -------\n        Any\n            result of experiment\n        \"\"\"\n        if self._notify_epochs is not None:\n            callbacks = list(kwargs.pop(\"callbacks\", []))\n            callbacks.append(MessengerEpochCallback(self._notify_epochs,\n                                                    self))\n            kwargs[\"callbacks\"] = callbacks\n\n        msg = str(self._experiment.name) + \" : Training started.\"\n        self.emit_message(msg)\n\n        try:\n            out = self._experiment.run(*args, **kwargs)\n        except Exception as e:\n            msg = \\\n                str(self._experiment.name) + \" : Training failed. \\n\" + str(e)\n            self.emit_message(msg)\n            raise\n\n        msg = str(self._experiment.name) + \" : Training completed.\"\n        self.emit_message(msg)\n        return out\n\n    def resume(self, *args, **kwargs):\n        \"\"\"\n        Wrapper for resume function. Notifies experiment start, fail, complete.\n\n        Parameters\n        ----------\n        args :\n            positional arguments passed to experiment.\n        kwargs :\n            additional keyword arguments passed to experiment.\n\n        Returns\n        -------\n        Any\n            result of experiment\n        \"\"\"\n        if self._notify_epochs is not None:\n            callbacks = kwargs.pop(\"callbacks\", [])\n            callbacks.append(MessengerEpochCallback(self._notify_epochs,\n                                                    self))\n            kwargs[\"callbacks\"] = callbacks\n\n        msg = str(self._experiment.name) + \" : Resume started.\"\n        self.emit_message(msg)\n\n        try:\n            out = self._experiment.resume(*args, **kwargs)\n        except Exception as e:\n            msg = str(self._experiment.name) + \" : Resume failed. \\n\" + str(e)\n            self.emit_message(msg)\n            raise e\n\n        msg = str(self._experiment.name) + \" : Resume ended.\"\n        self.emit_message(msg)\n        return out\n\n    def test(self, *args, **kwargs):\n        \"\"\"\n        Wrapper for test function. Notifies experiment start, fail, complete.\n\n        Parameters\n        ----------\n        args :\n            positional arguments passed to experiment.\n        kwargs :\n            additional keyword arguments passed to experiment.\n\n        Returns\n        -------\n        Any\n            result of experiment\n        \"\"\"\n        msg = str(self._experiment.name) + \" : Test started.\"\n        self.emit_message(msg)\n\n        try:\n            out = self._experiment.test(*args, **kwargs)\n        except Exception as e:\n            msg = str(self._experiment.name) + \" : Test failed. \\n\" + str(e)\n            self.emit_message(msg)\n            raise e\n\n        msg = str(self._experiment.name) + \" : Test completed.\"\n        self.emit_message(msg)\n        return out\n\n    def kfold(self, *args, **kwargs):\n        \"\"\"\n        Wrapper for kfold function. Notifies experiment start, fail, complete,\n        end of fold.\n\n        Parameters\n        ----------\n        args :\n            positional arguments passed to experiment.\n        kwargs :\n            additional keyword arguments passed to experiment.\n\n        Returns\n        -------\n        Any\n            result of experiment\n        \"\"\"\n        # append own callback for fold messages\n        callbacks = kwargs.pop(\"callbacks\", [])\n        callbacks.append(MessengerFoldCallback(self))\n\n        # append own callback for epoch messages\n        if self._notify_epochs is not None:\n            callbacks.append(MessengerEpochCallback(self._notify_epochs,\n                                                    self))\n\n        kwargs[\"callbacks\"] = callbacks\n\n        msg = str(self._experiment.name) + \" : Kfold started.\"\n        self.emit_message(msg)\n\n        # execute k-fold\n        try:\n            out = self._experiment.kfold(*args, **kwargs)\n        except Exception as e:\n            msg = str(self._experiment.name) + \" : Kfold failed. \\n\" + str(e)\n            self.emit_message(msg)\n            raise e\n\n        msg = str(self._experiment.name) + \" : Kfold completed.\"\n        self.emit_message(msg)\n\n        return out\n\n\nclass MessengerEpochCallback(AbstractCallback):\n    \"\"\"\n    Callback for \"Epoch X trained\" message\n\n    See Also\n    --------\n    :class:`BaseMessenger`\n    \"\"\"\n\n    def __init__(self, n_epochs: int, messenger: BaseMessenger):\n        \"\"\"\n\n        Parameters\n        ----------\n        n_epochs : int\n            notification frequency\n        messenger : :class:`BaseMessenger`\n            instance of a experiment with messanger to emit message\n        \"\"\"\n        super().__init__()\n        self._n_epochs = n_epochs\n        self._messenger = messenger\n\n    def at_epoch_end(self, trainer, **kwargs) -> dict:\n        \"\"\"\n        Call at end of epoch\n\n        Parameters\n        ----------\n        trainer : :class:`BaseTrainer`\n            instance of trainer\n        kwargs :\n            additional keyword arguments. Must contain ``curr_epoch``.\n\n        Returns\n        -------\n        dict\n            empty dict\n        \"\"\"\n        curr_epoch = kwargs.pop(\"curr_epoch\")\n        trained_epochs = curr_epoch - trainer.start_epoch\n        if trained_epochs % self._n_epochs == 0:\n            msg = \"Epoch \" + str(curr_epoch) + \" trained.\"\n            self._messenger.emit_message(msg)\n        return {}\n\n\nclass MessengerFoldCallback(AbstractCallback):\n    \"\"\"\n    Callback for \"Fold X completed\" in slack\n\n    See Also\n    --------\n    :class:`BaseMessenger`\n    \"\"\"\n\n    def __init__(self, messenger: BaseMessenger):\n        \"\"\"\n\n        Parameters\n        ----------\n        messenger : :class:`BaseMessenger`\n            instance of a experiment with messanger to emit message\n        \"\"\"\n        super().__init__()\n        self._messenger = messenger\n\n    def at_training_begin(self, trainer, **kwargs) -> dict:\n        \"\"\"\n        End of training callback\n\n        Parameters\n        ----------\n        trainer : :class:`BaseTrainer`\n            instance of trainer\n        kwargs :\n            additional keyword arguments (not used)\n\n        Returns\n        -------\n        dict\n            empty dict\n        \"\"\"\n        msg = \"Fold \" + str(trainer.fold) + \" started.\"\n        self._messenger.emit_message(msg)\n        return {}\n\n    def at_training_end(self, trainer, **kwargs) -> dict:\n        \"\"\"\n        End of training callback\n\n        Parameters\n        ----------\n        trainer : :class:`BaseTrainer`\n            instance of trainer\n        kwargs :\n            additional keyword arguments (not used)\n\n        Returns\n        -------\n        dict\n            empty dict\n        \"\"\"\n        msg = \"Fold \" + str(trainer.fold) + \" completed.\"\n        self._messenger.emit_message(msg)\n        return {}\n\n\nclass SlackMessenger(BaseMessenger):\n    \"\"\"\n    Wrap arbitrary experiments and connect its functions to slack\n    notification\n\n    .. note:: `token`can be either your personal user token or a token\n              from an artificial bot. To create your own bot you can\n              visit https://api.slack.com/ and click 'Your Apps' at the\n              top-right corner (you may need to create an own workspace\n              where you can install your bot).\n\n    .. warning:: Slack messenger has `slackclient` as a dependency which\n                 is not included in the requirements!\n    \"\"\"\n\n    def __init__(self, experiment: BaseExperiment, token: str,\n                 channel: str, notify_epochs: int = None, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        experiment : :class:`BaseExperiment`\n            instance of current experiment\n        token : str\n            User or Bot token from slack\n        channel : str\n            channel id (destination of messages)\n        notify_epochs : int\n            Activates notifications about finished epochs with frequency\n            `notify_epochs`.\n        kwargs :\n            additional keyword arguments passed to :class:`SlackClient`\n\n        Raises\n        ------\n        ImportError\n            if `slackclient` is not installed\n\n        See Also\n        --------\n        :class:`BaseMessenger`\n        \"\"\"\n        super().__init__(experiment, notify_epochs=notify_epochs)\n\n        # switch between different versions (with changed imports)\n        try:\n            from slackclient import SlackClient\n            self._client = SlackClient(token, **kwargs)\n            self._version = 1\n        except ImportError as e:\n            try:\n                from slack import WebClient\n                self._client = WebClient(token=token, **kwargs)\n                self._version = 2\n            except ImportError as e:\n                warnings.warn(\n                    \"Could not import `slackclient`. This package is not\"\n                    \"included in the default dependencies of delira!\")\n                raise e\n        assert self._version in [1, 2], \"Only version 1 and 2 supported\"\n\n        self._channel = channel\n        self._ts = None  # Set to None for initial message\n\n        # initial slack message\n        msg = \"Created new experiment: \" + str(self._experiment.name)\n        resp = self.emit_message(msg)\n\n        if self._version == 1:\n            # old api\n            self._ts = resp['ts'] if 'ts' in resp else None\n        elif self._version == 2:\n            # new api\n            self._ts = resp.data['ts'] if hasattr(resp, 'data') else None\n\n    def emit_message(self, msg, **kwargs):\n        \"\"\"\n        Emit message (is possible the current thread is used)\n\n        Parameters\n        ----------\n        msg : str\n            message which should be emitted\n        kwargs:\n            additional keyword arguments passed to slack api calls\n\n        Returns\n        -------\n        dict\n            dict with additional information from message\n\n        Raises\n        ------\n        ValueError\n            unknown `self._version`\n        \"\"\"\n        # use thread of current post if possible\n        if self._ts is not None and 'thread_ts' not in kwargs:\n            kwargs['thread_ts'] = self._ts\n\n        if self._version == 1:\n            resp = self._emit_message_v1(msg, **kwargs)\n        elif self._version == 2:\n            resp = self._emit_message_v2(msg, **kwargs)\n        else:\n            raise ValueError(\"Unknown version detected!\")\n        return resp\n\n    def _emit_message_v1(self, msg, **kwargs) -> dict:\n        \"\"\"\n        Emit message with old slack api\n\n        Parameters\n        ----------\n        msg : str\n            message which should be emitted\n        kwargs:\n            additional keyword arguments passed to slack api calls\n\n        Returns\n        -------\n        dict\n            representation dict of message\n        \"\"\"\n        resp = self._client.api_call(\n            \"chat.postMessage\",\n            channel=self._channel,\n            text=msg,\n            **kwargs,\n        )\n\n        if not resp[\"ok\"]:\n            logging.error(\"Slack message was not emitted correctly!\"\n                          \" \\n {}\".format(msg))\n        return resp\n\n    def _emit_message_v2(self, msg, **kwargs):\n        \"\"\"\n        Emit message with new slack api\n\n        Parameters\n        ----------\n        msg : str\n            message which should be emitted\n        kwargs:\n            additional keyword arguments passed to slack api calls\n\n        Returns\n        -------\n        :class:`slack.web.slack_response.SlackResponse`\n            slack api response\n        \"\"\"\n        resp = self._client.chat_postMessage(channel=self._channel,\n                                             text=msg,\n                                             **kwargs,\n                                             )\n        if not resp.data[\"ok\"]:\n            logging.error(\"Slack message was not emitted correctly!\"\n                          \" \\n {}\".format(msg))\n        return resp\n"
  },
  {
    "path": "delira/utils/path.py",
    "content": "import os\n\n\ndef subdirs(d):\n    \"\"\"For a given directory, return a list of all subdirectories (full paths)\n\n    Parameters\n    ----------\n    d : string\n        given root directory\n\n    Returns\n    -------\n    list\n        list of strings of all subdirectories\n    \"\"\"\n    return sorted([os.path.join(d, name) for name in os.listdir(d)\n                   if os.path.isdir(os.path.join(d, name))])\n"
  },
  {
    "path": "delira/utils/time.py",
    "content": "import datetime\n\n\ndef now():\n    \"\"\"Return current time as YYYY-MM-DD_HH-MM-SS\n\n    Returns\n    -------\n    string\n        current time\n    \"\"\"\n\n    return datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')\n"
  },
  {
    "path": "docker/Dockerfile",
    "content": "FROM nvidia/cuda:9.2-base-ubuntu18.04\n\nRUN apt-get update && apt-get install -y \\\n    curl \\\n    ca-certificates \\\n    sudo \\\n    git \\\n    bzip2 \\\n    libx11-6 \\\n    build-essential \\\n    fonts-roboto \\\n    && rm -rf /var/lib/apt/lists/*\n\nRUN useradd --create-home --shell /bin/bash containeruser\nUSER containeruser\nWORKDIR /home/containeruser\n\nRUN curl -o ~/miniconda.sh -O  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \\\n    chmod +x ~/miniconda.sh && \\\n    ~/miniconda.sh -b -p /home/containeruser/conda && \\\n    rm ~/miniconda.sh && \\\n    /home/containeruser/conda/bin/conda clean -ya\nENV PATH /home/containeruser/conda/bin:$PATH\nRUN conda install python=3.7\nRUN pip install --upgrade pip\nRUN git clone https://github.com/justusschock/delira.git && \\\n    pip install pip wheel && \\\n    pip install -r delira/requirements.txt && \\\n    pip install -r delira/requirements_extra_torch.txt && \\\n    pip install delira/\nENV PYTHONPATH /home/containeruser/delira:$PYTHONPATH\nCMD [\"/bin/bash\"]\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSPHINXPROJ    = delira\nSOURCEDIR     = .\nBUILDDIR      = _build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/_api/_build/delira/backend_resolution.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira._backends\n\nBackend Resolution\n==================\n\nThese functions are used to determine the installed backends and update the\ncreated config file. They also need to be used, to guard backend specific code,\n when writing code with several backends in one file like this:\n\n``if \"YOUR_BACKEND\" in delira.get_backends():``\n\n:hidden:`get_backends`\n~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: get_backends\n"
  },
  {
    "path": "docs/_api/_build/delira/class_hierarchy.rst",
    "content": "Class Hierarchy Diagrams\n========================\n\n.. contents::\n\n* `Coarse <../../../_static/class_hierarchy/delira_coarse.png>`_\n\n* `Fine <../../../_static/class_hierarchy/delira_fine.png>`_\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/arbitrary_data.rst",
    "content": "Arbitrary Data\n--------------\n\nThe following classes are implemented to work with every kind of data. You can\nuse every framework you want to load your data, but the returned samples\nshould be a :obj:`dict` of\n``numpy ndarrays``\n\n.. toctree::\n    :maxdepth: 5\n\n    Dataset <dataset>\n    Dataloader <dataloader>\n    Datamanager <datamanager>\n    Utils <utils>"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/data_loading.rst",
    "content": "Data Loading\n============\n\nThis module provides Utilities to load the Data\n\n.. toctree::\n\n    Arbitrary Data <arbitrary_data>\n    Nii <nii>\n    Sampler <sampler>\n\n\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/dataloader.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.data_loading\n\nDataloader\n**********\n\nThe Dataloader wraps the dataset and combines them with a sampler\n(see below) to combine single samples to whole batches.\n\nToDo: add flow chart diagramm\n\n:hidden:`DataLoader`\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataLoader\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/datamanager.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.data_loading.data_manager\n\n\nDatamanager\n***********\n\nThe datamanager wraps a dataloader and combines it with augmentations\nand multiprocessing.\n\n:hidden:`DataManager`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataManager\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`Augmenter`\n~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: Augmenter\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/dataset.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.data_loading\n\nDatasets\n********\n\nThe Dataset the most basic class and implements the loading of your dataset\nelements.\nYou can either load your data in a lazy way e.g. loading them just at the moment\nthey are needed or you could preload them and cache them.\n\nDatasets can be indexed by integers and return single samples.\n\nTo implement custom datasets you should derive the :class:`AbstractDataset`\n\n\n:hidden:`AbstractDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`BaseLazyDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseLazyDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`BaseCacheDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseCacheDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`BaseExtendCacheDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseExtendCacheDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ConcatDataset`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ConcatDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`BlankDataset`\n~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BlankDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`Nii3DLazyDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: Nii3DLazyDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`Nii3DCacheDataset`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: Nii3DCacheDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`TorchvisionClassificationDataset`:\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TorchvisionClassificationDataset\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/nii.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.data_loading.nii\n\nNii-Data\n--------\n\nSince ``delira`` aims to provide dataloading tools for medical data (which is\noften stored in Nii-Files), the following classes and functions provide a\nbasic way to load data from nii-files:\n\n.. currentmodule:: delira.data_loading.nii\n\n:hidden:`load_nii`\n~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: load_nii\n\n:hidden:`BaseLabelGenerator`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseLabelGenerator\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`load_sample_nii`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: load_sample_nii\n\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/sampler.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n\nSampler\n-------\n\nSampler define the way of iterating over the dataset and returning samples.\n\n.. currentmodule:: delira.data_loading.sampler\n\n:hidden:`AbstractSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`LambdaSampler`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: LambdaSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`RandomSampler`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: RandomSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`PrevalenceRandomSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: PrevalenceRandomSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`StoppingPrevalenceRandomSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: StoppingPrevalenceRandomSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SequentialSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SequentialSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`PrevalenceSequentialSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: PrevalenceSequentialSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`StoppingPrevalenceSequentialSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: StoppingPrevalenceSequentialSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`WeightedRandomSampler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: WeightedRandomSampler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/data_loading/utils.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.data_loading.load_utils\n\nUtils\n*****\n\n:hidden:`norm_range`\n~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: norm_range\n\n:hidden:`norm_zero_mean_unit_std`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: norm_zero_mean_unit_std\n\n:hidden:`is_valid_image_file`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: is_valid_image_file\n\n:hidden:`default_load_fn_2d`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: default_load_fn_2d\n\n:hidden:`LoadSample`\n~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: LoadSample\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/debug_mode.rst",
    "content": "def get_current_debug_mode():\n    \"\"\"\n    Getter function for the current debug mode\n    Returns\n    -------\n    bool\n        current debug mode\n    \"\"\"\n    return __DEBUG_MODE\n\n\ndef switch_debug_mode():\n    \"\"\"\n    Alternates the current debug mode\n    \"\"\"\n    set_debug_mode(not get_current_debug_mode())\n\n\ndef set_debug_mode(mode: bool):\n    \"\"\"\n    Sets a new debug mode\n    Parameters\n    ----------\n    mode : bool\n        the new debug mode\n    \"\"\"\n    global __DEBUG_MODE\n    __DEBUG_MODE = mode\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira._debug_mode\n\nDebug Mode\n==========\n\nDelira now contains a fully-fledged `Debug` mode, which disables all kinds of multiprocessing.\n\n:hidden:`get_current_debug_mode`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: get_current_debug_mode\n\n:hidden:`switch_debug_mode`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: switch_debug_mode\n\n:hidden:`set_debug_mode`\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: set_debug_mode\n"
  },
  {
    "path": "docs/_api/_build/delira/delira.io.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\nIO\n==\n\n.. currentmodule:: delira.io\n\nif \"CHAINER\" in get_backends():\n    from delira.io.chainer import save_checkpoint as chainer_save_checkpoint\n    from delira.io.chainer import load_checkpoint as chainer_load_checkpoint\n\nif \"SKLEARN\" in get_backends():\n    from delira.io.sklearn import load_checkpoint as sklearn_load_checkpoint\n    from delira.io.sklearn import save_checkpoint as sklearn_save_checkpoint\n\n\n:hidden:`torch_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: torch_load_checkpoint\n\n:hidden:`torch_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: torch_save_checkpoint\n\n:hidden:`torchscript_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: torchscript_load_checkpoint\n\n:hidden:`torchscript_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: torchscript_save_checkpoint\n\n:hidden:`tf_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: tf_load_checkpoint\n\n:hidden:`tf_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: tf_save_checkpoint\n\n:hidden:`tf_eager_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: tf_eager_load_checkpoint\n\n:hidden:`tf_eager_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: tf_eager_save_checkpoint\n\n:hidden:`chainer_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: chainer_load_checkpoint\n\n:hidden:`chainer_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: chainer_save_checkpoint\n\n:hidden:`sklearn_load_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: sklearn_load_checkpoint\n\n:hidden:`sklearn_save_checkpoint`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: sklearn_save_checkpoint\n"
  },
  {
    "path": "docs/_api/_build/delira/delira.rst",
    "content": "Delira\n======\n\n.. toctree::\n    :maxdepth: 10\n    :glob:\n\n    Data Loading <data_loading/data_loading>\n    IO <delira.io>\n    Logging <logging/logging>\n    Models <models/models>\n    Training <training/training>\n    Utilities <delira.utils>\n    Backend Resolution <backend_resolution>\n    Debug Mode <debug_mode>\n\n    Class Hierarchy Diagrams <class_hierarchy>"
  },
  {
    "path": "docs/_api/_build/delira/delira.utils.rst",
    "content": "Utils\n=====\n\nThis package provides utility functions as image operations, various decorators,\npath operations and time operations.\n\n.. automodule:: delira.utils.context_managers\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.decorators\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.imageops\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.path\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.time\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.config\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. automodule:: delira.utils.messenger\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/backends.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.logging.base_backend\n\n:hidden:`BaseBackend`\n~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseBackend\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. currentmodule:: delira.logging.writer_backend\n\n:hidden:`WriterBackend`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: WriterBackend\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. currentmodule:: delira.logging.tensorboard_backend\n\n:hidden:`TensorboardBackend`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TensorboardBackend\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. currentmodule:: delira.logging.visdom_backend\n\n:hidden:`VisdomBackend`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: VisdomBackend\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/base_logger.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.logging.base_logger\n\n:hidden:`Logger`\n~~~~~~~~~~~~~~~~\n\n.. autoclass:: Logger\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n\n:hidden:`SingleThreadedLogger`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SingleThreadedLogger\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`make_logger`\n~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: make_logger\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/handlers.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.logging\n\n:hidden:`MultiStreamHandler`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: MultiStreamHandler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`TrixiHandler`\n~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TrixiHandler\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/logging.rst",
    "content": "Logging\n=======\n\nThe logging module provides the utilities for logging arbitrary values to\ndifferent backends and a logger registry.\n\n.. toctree::\n\nLogger <base_logger>\nLogging Backends <backends>\nLogging Context <logging_context>\nRegistry <registry>\n\n\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/logging_context.py",
    "content": "from delira.logging.registry import logger_exists, register_logger, \\\n    unregister_logger, log as _log\nfrom delira.logging.base_logger import make_logger\n\nlog = _log\n\n\nclass LoggingContext(object):\n    \"\"\"\n    Contextmanager to set a new logging context\n    \"\"\"\n\n    def __init__(\n            self,\n            name,\n            initialize_if_missing=False,\n            destroy_on_exit=None,\n            **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        name : str\n            the name of the logger to use\n        initialize_if_missing : bool\n            whether to create a logger if it does not yet exist\n        destroy_on_exit : bool\n            whether to destroy the logger on exit; If None, the logger will\n            only be destroyed, if it was created here\n        **kwargs:\n            additional keyword arguments to create a logger if necessary\n\n        Raises\n        ------\n        ValueError\n            if the logger does not exist already and shall not be created\n        \"\"\"\n\n        # Logger does exist already\n        if logger_exists(name):\n            self._name = name\n            if destroy_on_exit is None:\n                destroy_on_exit = False\n\n        # logger will be created\n        elif initialize_if_missing:\n            register_logger(make_logger(**kwargs), name)\n            if destroy_on_exit is None:\n                destroy_on_exit = True\n            self._name = name\n\n        # logger does not exist and shall not be created\n        else:\n            raise ValueError(\"No valid logger for name %s and \"\n                             \"'initialize_if_missing' is False\" % name)\n\n        self._destroy_on_exit = destroy_on_exit\n\n    def __enter__(self):\n        \"\"\"\n        Function to be executed during entrance;\n        Resets the logging context\n\n        Returns\n        -------\n        :class:`LoggingContext`\n            self\n        \"\"\"\n        global log\n        log = self.log\n        return self\n\n    def __exit__(self, *args):\n        \"\"\"\n        Function to be called during exiting the context manager;\n        Destroys the logger if necessary and resets the old logging context\n\n        Parameters\n        ----------\n        *args\n            Postional arguments to be compatible with other context managers\n\n        Returns\n        -------\n\n        \"\"\"\n        if self._destroy_on_exit:\n            _logger = unregister_logger(self._name)\n            del _logger\n\n        global log\n        log = _log\n\n    def log(self, msg: dict):\n        \"\"\"\n        Main Logging Function, Decides whether to log with the assigned\n        backend or python's internal module\n\n        Parameters\n        ----------\n        msg : dict\n            the message to log; Should be a dict, where the keys indicate the\n            logging function to execute, and the corresponding value holds\n            the arguments necessary to execute this function\n        \"\"\"\n\n        _log(msg, self._name)\n\n    def __call__(self, log_message: dict):\n        \"\"\"\n        Makes the class callable and forwards the call to\n        :meth:`delira.logging.base_logger.Logger.log`\n\n        Parameters\n        ----------\n        log_message : dict\n            the logging message to log\n\n        Returns\n        -------\n        Any\n            the return value obtained by\n            :meth:`LoggingContext.log`\n\n        \"\"\"\n        return self.log(log_message)\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/logging_context.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.logging.logging_context\n\n:hidden:`LoggingContext`\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: LoggingContext\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/registry.py",
    "content": "from delira.logging.base_logger import Logger\nfrom collections import OrderedDict\n\n# Registry dict containing all registered available Loggers\n# Use Ordered Dict here to use first logger for logging if no name was given\n_AVAILABLE_LOGGERS = OrderedDict()\n\n\ndef log(msg: dict, name=None):\n    \"\"\"\n    Global logging function\n\n    Parameters\n    ----------\n    msg : dict\n        the message to log; Should be a dict, where the keys indicate the\n        logging function to execute, and the corresponding value holds\n        the arguments necessary to execute this function\n    name : str\n        the name of the logger to use;\n        if None: the last logger will be used\n\n    Raises\n    ------\n    AssertionError\n        if the logger with the specified name does not exist\n    AssertionError\n        if the returned object is not a logger\n\n    Returns\n    -------\n    Any\n        the value obtained by the loggers ``log`` function\n\n    \"\"\"\n\n    # use last name if no name is present\n    if name is None:\n        name = get_available_loggers()[-1]\n\n    assert logger_exists(name)\n    _logger = get_logger(name)\n\n    assert isinstance(_logger, Logger)\n\n    return _logger.log(msg)\n\n\ndef logger_exists(name: str):\n    \"\"\"\n    Check if logger exists\n\n    Parameters\n    ----------\n    name : str\n        the name to check the existence for\n\n    Returns\n    -------\n    bool\n        whether a logger with the given name exists\n\n    \"\"\"\n    return name in _AVAILABLE_LOGGERS\n\n\ndef register_logger(logger: Logger, name: str, overwrite=False):\n    \"\"\"\n    Register a new logger to the Registry\n\n    Parameters\n    ----------\n    logger : :class:`delira.logging.base_logger.Logger`\n        the logger to register\n    name : str\n        the corresponding name, to register the logger at\n    overwrite : bool\n        whether or not to overwrite existing loggers if necessary\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the registered logger object\n\n    \"\"\"\n\n    if not logger_exists(name) or overwrite:\n        _AVAILABLE_LOGGERS[name] = logger\n\n    return get_logger(name)\n\n\ndef unregister_logger(name: str):\n    \"\"\"\n    Unregisters a logger from the registry\n\n    Parameters\n    ----------\n    name : str\n        the name of the logger to unregister\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the registered logger object\n    \"\"\"\n    return _AVAILABLE_LOGGERS.pop(name)\n\n\ndef get_logger(name):\n    \"\"\"\n    Returns a logger from the registry\n\n    Parameters\n    ----------\n    name : str\n        the name indicating the logger to return\n\n    Returns\n    -------\n    :class:`delira.logging.base_logger.Logger`\n        the specified logger object\n\n    \"\"\"\n    return _AVAILABLE_LOGGERS[name]\n\n\ndef get_available_loggers():\n    \"\"\"\n    Gets names for all registered loggers\n\n    Returns\n    -------\n    tuple\n        a tuple of strings specifying the names of all registered loggers\n\n    \"\"\"\n    return tuple(_AVAILABLE_LOGGERS.keys())\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/registry.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. :currentmodule:: delira.logging.registry\n\n:hidden:`log`\n~~~~~~~~~~~~~\n\n.. autofunction:: log\n\n:hidden:`logger_exists`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: logger_exists\n\n:hidden:`register_logger`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: register_logger\n\n:hidden:`unregister_logger`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: unregister_logger\n\n:hidden:`get_logger`\n~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: get_logger\n\n:hidden:`get_available_loggers`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: get_available_loggers\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/tensorboard_backend.py",
    "content": "import tensorboardX\nfrom threading import Event\nfrom queue import Queue\n\nfrom delira.logging.writer_backend import WriterLoggingBackend\n\n\nclass TensorboardBackend(WriterLoggingBackend):\n    \"\"\"\n    A Tensorboard logging backend\n    \"\"\"\n\n    def __init__(self, writer_kwargs=None,\n                 abort_event: Event = None, queue: Queue = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        writer_kwargs : dict\n            arguments to initialize a writer\n        abort_event : :class:`threading.Event`\n            the abortion event\n        queue : :class:`queue.Queue`\n            the queue holding all logging tasks\n        \"\"\"\n\n        if writer_kwargs is None:\n            writer_kwargs = {}\n\n        super().__init__(tensorboardX.SummaryWriter, writer_kwargs,\n                         abort_event, queue)\n\n    def _call_exec_fn(self, exec_fn, args):\n        \"\"\"\n        Helper Function calling the actual mapped function and flushing\n        results to the writer afterwards\n\n        Parameters\n        ----------\n        exec_fn : function\n            the function which will execute the actual logging\n        args : iterable (listlike) or mapping (dictlike)\n            the arguments passed to the ``exec_fn``\n\n        Returns\n        -------\n        Any\n            the return value obtained by the ``exec_fn``\n\n        \"\"\"\n        ret_val = super()._call_exec_fn(exec_fn, args)\n\n        self._writer.file_writer.flush()\n\n        return ret_val\n\n    def __del__(self):\n        \"\"\"\n        Function to be executed at deletion;\n        Flushes all unsaved changes\n\n        \"\"\"\n        self._writer.file_writer.flush()\n\n    def _graph_pytorch(self, model, input_to_model=None, verbose=False,\n                       **kwargs):\n        \"\"\"\n        Function to log a PyTorch graph\n\n        Parameters\n        ----------\n        model : :class:`AbstractPyTorchNetwork`\n            the model, whose graph shall be logged\n        input_to_model : :class:`torch.Tensor`\n            the input to the model; necessary for graph traversal\n        verbose : bool\n            verbosity option\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            model=model, input_to_model=input_to_model,\n            verbose=verbose, **kwargs)\n\n        self._writer.add_graph(*converted_args, **converted_kwargs)\n\n    def _graph_tf(self, graph, run_metadata=None):\n        \"\"\"\n        Function to log a TensorFlow Graph\n\n        Parameters\n        ----------\n        graph : :class:`tensorflow.Graph` or :class:`tensorflow.GraphDef`\n        run_metadata :\n            the run metadata\n\n        Raises\n        ------\n        TypeError\n            if given graph cannot be converted to graphdef\n\n        \"\"\"\n        import tensorflow as tf\n        from tensorboardX.proto.event_pb2 import Event, TaggedRunMetadata\n\n        # convert to graphdef\n        if isinstance(graph, tf.Graph):\n            graphdef = graph.as_graph_def()\n        elif isinstance(graph, tf.GraphDef):\n            graphdef = graph\n        elif hasattr(graph, \"SerializeToString\"):\n            graphdef = graph\n        else:\n            raise TypeError(\"Invalid type given for graph: %s\" %\n                            graph.__class__.__name__)\n\n        if run_metadata:\n            run_metadata = TaggedRunMetadata(\n                tag='step1', run_metadata=run_metadata.SerializeToString())\n\n        self._writer._get_file_writer().add_event(\n            Event(\n                graph_def=graphdef.SerializeToString(),\n                tagged_run_metadata=run_metadata))\n\n    def _graph_onnx(self, prototxt):\n        \"\"\"\n        Function to log a ONNX graph to file\n\n        Parameters\n        ----------\n        prototxt : str\n            filepath to a given prototxt file containing an ONNX graph\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            prototxt=prototxt)\n        self._writer.add_onnx_graph(*converted_args, **converted_kwargs)\n\n    def _embedding(self, mat, metadata=None, label_img=None, global_step=None,\n                   tag='default', metadata_header=None):\n        \"\"\"\n        Function to create an embedding of given data\n\n        Parameters\n        ----------\n        mat : array-like\n            an arraylike object, which can be converted to a numpy array;\n            holds the actual embedding value\n        metadata :\n            the embeddings metadata\n        label_img : array-like\n            an arraylike object, which can be converted to a numpy array;\n            holds the label image\n        global_step : int\n            the global step\n        tag : str\n            the tag to store the embedding at\n        metadata_header :\n            the metadata header\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            mat=mat, metadata=metadata, label_img=label_img,\n            global_step=global_step\n        )\n        self._writer.add_embedding(*converted_args, **converted_kwargs)\n\n    def _scalars(self, main_tag: str, tag_scalar_dict: dict, global_step=None,\n                 walltime=None, sep=\"/\"):\n        \"\"\"\n        Function to log multiple scalars at once. Opposing to the base\n        function, this is done sequentially rather then parallel to avoid\n        creating new event files\n\n        Parameters\n        ----------\n        main_tag : str\n            the main tag, will be combined with the subtags inside the\n            ``tag_scalar_dict``\n        tag_scalar_dict : dict\n            dictionary of (key, scalar) pairs\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        sep : str\n            the character separating maintag and subtag in the final tag\n\n        \"\"\"\n\n        # log scalars sequentially\n        for key, val in tag_scalar_dict.items():\n            # combine tags\n            new_tag = main_tag + sep + key\n            self._scalar(new_tag, val, global_step=global_step,\n                         walltime=walltime)\n\n    @property\n    def name(self):\n        return \"TensorFlow Backend\"\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/visdom_backend.py",
    "content": "import tensorboardX\nfrom threading import Event\nfrom queue import Queue\n\nfrom delira.logging.writer_backend import WriterLoggingBackend\n\n\nclass VisdomBackend(WriterLoggingBackend):\n    \"\"\"\n    A Visdom Logging backend\n    \"\"\"\n\n    def __init__(self, writer_kwargs: dict = None,\n                 abort_event: Event = None, queue: Queue = None):\n        \"\"\"\n\n        Parameters\n        ----------\n        writer_kwargs : dict\n            arguments to initialize a writer\n        abort_event : :class:`threading.Event`\n            the abortion event\n        queue : :class:`queue.Queue`\n            the queue holding all logging tasks\n        \"\"\"\n\n        if writer_kwargs is None:\n            writer_kwargs = {}\n\n        super().__init__(\n            tensorboardX.visdom_writer.VisdomWriter,\n            writer_kwargs,\n            abort_event,\n            queue)\n\n    @property\n    def name(self):\n        return \"VisdomBackend\"\n"
  },
  {
    "path": "docs/_api/_build/delira/logging/writer_backend.py",
    "content": "\nfrom delira.logging.base_backend import BaseBackend\nfrom queue import Queue\nfrom threading import Event\n\n\nclass WriterLoggingBackend(BaseBackend):\n    \"\"\"\n    A Basic Writer Backend for a unspecified writer class\n    \"\"\"\n\n    def __init__(self, writer_cls, writer_kwargs: dict,\n                 abort_event: Event = None, queue: Queue = None):\n        super().__init__(abort_event, queue)\n\n        self._writer = writer_cls(**writer_kwargs)\n\n    @staticmethod\n    def convert_to_npy(*args, **kwargs):\n        \"\"\"\n        Function to convert all positional args and keyword args to numpy\n        (returns identity per default, but can be overwritten in subclass to\n        log more complex types)\n\n        Parameters\n        ----------\n        *args :\n            positional arguments of arbitrary number and type\n        **kwargs :\n            keyword arguments of arbitrary number and type\n        Returns\n        -------\n        tuple\n            converted positional arguments\n        dict\n            converted keyword arguments\n        \"\"\"\n        return args, kwargs\n\n    def _image(self, tag, img_tensor, global_step=None, walltime=None,\n               dataformats='CHW'):\n        \"\"\"\n        Function to log a single image\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, global_step=global_step,\n            walltime=walltime, dataformats=dataformats)\n\n        self._writer.add_image(*converted_args, **converted_kwargs)\n\n    def _images(self, tag, img_tensor, global_step=None, walltime=None,\n                dataformats='NCHW'):\n        \"\"\"\n        Function to log multiple values\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, global_step=global_step,\n            walltime=walltime, dataformats=dataformats)\n\n        self._writer.add_images(*converted_args, **converted_kwargs)\n\n    def _image_with_boxes(self, tag, img_tensor, box_tensor, global_step=None,\n                          walltime=None, dataformats='CHW', **kwargs):\n        \"\"\"\n        Function to log a single image with bounding boxes\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        img_tensor : array-like\n            an array-like object containing the actual image; Must be\n            convertible to numpy\n        box_tensor : array-like\n            an array-like object containing the actual bounding boxes in xyxy\n            format; must be convertible to numpy\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n        dataformats : str\n            string specifying the image format\n        **kwargs :\n            additional keyword arguments\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, img_tensor=img_tensor, box_tensor=box_tensor,\n            global_step=global_step, walltime=walltime,\n            dataformats=dataformats, **kwargs)\n\n        self._writer.add_image_with_boxes(*converted_args, **converted_kwargs)\n\n    def _scalar(self, tag, scalar_value, global_step=None, walltime=None):\n        \"\"\"\n        Function to log a single scalar value\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        scalar_value : int or float\n            the scalar value to log\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, scalar_value=scalar_value, global_step=global_step,\n            walltime=walltime)\n        self._writer.add_scalar(*converted_args, **converted_kwargs)\n\n    def _scalars(self, main_tag, tag_scalar_dict, global_step=None,\n                 walltime=None):\n        \"\"\"\n        Function to log multiple scalars\n\n        Parameters\n        ----------\n        main_tag : str\n            the main tag to store the scalars at\n        tag_scalar_dict : dict\n            a dictionary containing tags as keys and the corresponding scalar\n            values\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            main_tag=main_tag, tag_scalar_dict=tag_scalar_dict,\n            global_step=global_step, walltime=walltime)\n\n        self._writer.add_scalars(*converted_args, **converted_kwargs)\n\n    def _histogram(self, tag, values, global_step=None, bins='tensorflow',\n                   walltime=None):\n        \"\"\"\n        Function to create and log a histogram out of given values\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the histogram at\n        values : arraylike\n            an arraylike object containing the raw data to create a histogram\n            from; Must be convertible to numpy\n        global_step : int\n            global step\n        bins : str\n            string indicating the bins format\n        walltime :\n            the overall time\n\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, values=values, global_step=global_step, bins=bins)\n        self._writer.add_histogram(*converted_args, **converted_kwargs)\n\n    def _figure(self, tag, figure, global_step=None, close=True,\n                walltime=None):\n        \"\"\"\n        Function to log a ``matplotlib.pyplot`` figure\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the figure at\n        figure : :class:`matplotlib.pyplot.Figure``\n            the figure to log\n        global_step : int\n            the global step\n        close : bool\n            whether to close the figure after pushing it\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, figure=figure, global_step=global_step, close=close,\n            walltime=walltime)\n        self._writer.add_figure(*converted_args, **converted_kwargs)\n\n    def _audio(self, tag, snd_tensor, global_step=None, sample_rate=44100,\n               walltime=None):\n        \"\"\"\n        Function to log a single audio signal\n        Parameters\n        ----------\n        tag : str\n            the tag to store the sound signal at\n        snd_tensor : arraylike\n            arraylike object containing the sound signal;\n            must be convertible to numpy\n        global_step : int\n            the global step\n        sample_rate : int\n            the sampling rate for the sound signal\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, snd_tensor=snd_tensor, global_step=global_step,\n            sample_rate=sample_rate, walltime=walltime\n        )\n        self._writer.add_audio(*converted_args, **converted_kwargs)\n\n    def _text(self, tag, text_string, global_step=None, walltime=None):\n        \"\"\"\n        Function to log a single string as text\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the text at\n        text_string : str\n            the text string to log\n        global_step : int\n            the global step\n        walltime :\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, text_string=text_string, global_step=global_step,\n            walltime=walltime)\n        self._writer.add_text(*converted_args, **converted_kwargs)\n\n    def _pr_curve(self, tag, labels, predictions, global_step=None,\n                  num_thresholds=127, weights=None, walltime=None):\n        \"\"\"\n        Function to create and log a PR curve out of given predictions and +\n        labels\n\n        Parameters\n        ----------\n        tag : str\n            function to store the curve at\n        labels : arraylike\n            arraylike object containing the groundtruth data; must be\n            convertible to numpy\n        predictions : arraylike\n            arraylike object containing the predictions; must be convertible\n            to numpy\n        global_step : int\n            the global step\n        num_thresholds : int\n            number of thresholds to apply for PR calculation\n        weights : arraylike\n            arraylike object containing sample weights, must be covertible to\n            numpy\n        walltime :\n            overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, labels=labels, predictions=predictions,\n            global_step=global_step, num_thresholds=num_thresholds,\n            weights=weights, walltime=walltime)\n        self._writer.add_pr_curve(*converted_args, **converted_kwargs)\n\n    def _video(self, tag, vid_tensor, global_step=None, fps=4, walltime=None):\n        \"\"\"\n        Function to log a single video\n\n        Parameters\n        ----------\n        tag : str\n            the tag to store the image at\n        vid_tensor : arraylike\n            arraylike object containing the video frames; must be convertible\n            to numpy\n        global_step : int\n            the global step\n        fps : int\n            frames per second to display\n        walltime : int\n            the overall time\n\n        \"\"\"\n        converted_args, converted_kwargs = self.convert_to_npy(\n            tag=tag, vid_tensor=vid_tensor, global_step=global_step, fps=fps,\n            walltime=walltime)\n        self._writer.add_video(*converted_args, **converted_kwargs)\n\n    @property\n    def name(self):\n        return \"WriterBackend\"\n"
  },
  {
    "path": "docs/_api/_build/delira/models/chainer.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.chainer\n\nChainer\n.......\n\n\n:hidden:`AbstractChainerNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractChainerNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`DataParallelChainerNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataParallelChainerNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`DataParallelChainerOptimizer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataParallelChainerOptimizer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ParallelOptimizerUpdateModelParameters`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataParallelOptimizerUpdateModelParameters\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ParallelOptimizerCumulateGradientsHook`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ParallelOptimizerCumulateGradientsHook\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/_api/_build/delira/models/models.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\nModels\n======\n\n``delira`` comes with it's own model-structure tree - with\n:class:`AbstractNetwork` at it's root - and integrates\nseveral backends deeply into it's structure.\n\n.. currentmodule:: delira.models\n\n:hidden:`AbstractNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractNetwork(type)\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\nBackends\n--------\n\n.. toctree::\n\n    Chainer <chainer>\n    SciKit-Learn <sklearn>\n    TensorFLow Eager Execution <tfeager>\n    TensorFlow Graph Execution <tfgraph>\n    PyTorch <torch>\n    TorchScript <torchscript>\n"
  },
  {
    "path": "docs/_api/_build/delira/models/sklearn.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.sklearn\n\nSciKit-Learn\n............\n\n:hidden:`SklearnEstimator`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnEstimator\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/models/tfeager.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.tf_eager\n\nTensorFlow Eager Execution\n..........................\n\n:hidden:`AbstractTfEagerNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractTfEagerNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`DataParallelTfEagerNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataParallelTfEagerNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/models/tfgraph.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.tf_graph\n\nTensorFlow Graph Execution\n..........................\n\n:hidden:`AbstractTfGraphNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractTfGraphNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/models/torch.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.torch\n\nPyTorch\n.......\n\n:hidden:`AbstractPyTorchNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractPyTorchNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`DataParallelPyTorchNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DataParallelPyTorchNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`scale_loss`\n~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: scale_loss\n"
  },
  {
    "path": "docs/_api/_build/delira/models/torchscript.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.models.backends.torchscript\n\nTorchScript\n...........\n\n:hidden:`AbstractTorchScriptNetwork`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractTorchScriptNetwork\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/backends.rst",
    "content": "Backends\n========\n\nThe following section contains all backends which are implemented,\ndeveloped and maintained for usage with ``delira``.\n\nA single backend usually contains at least a trainer, an experiment and some models (which are capsuled in the\n`models<../../models/models>`_ section.\n\n.. toctree::\n\n    Chainer <chainer>\n    SciKit-Learn <sklearn>\n    TensorFlow Eager Execution <tfeager>\n    Tensorflow Graph Execution <tfgraph>\n    PyTorch <torch>\n    TorchScript <torchscript>"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/chainer.rst",
    "content": "Chainer\n.......\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.chainer\n\n:hidden:`ChainerNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ChainerNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ChainerExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ChainerExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`convert_chainer_to_numpy`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: convert_chainer_to_numpy\n\n:hidden:`create_chainer_optims_default`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: create_chainer_optims_default"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/sklearn.rst",
    "content": "SciKit-Learn\n............\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.sklearn\n\n:hidden:`SklearnEstimatorTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnEstimatorTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`create_sklearn_optims_default`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: create_sklearn_optims_default\n"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/tfeager.rst",
    "content": "TensorFlow Eager Execution\n..........................\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.tf_eager\n\n:hidden:`TfEagerNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TfEagerNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`TfEagerExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TfEagerExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`create_tfeager_optims_default`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: create_tfeager_optims_default\n\n:hidden:`convert_tfeager_to_numpy`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: convert_tfeager_to_numpy"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/tfgraph.rst",
    "content": "TensorFlow Graph Execution\n..........................\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.tf_graph\n\n:hidden:`TfGraphNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TfGraphNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`TfGraphExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TfGraphExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`initialize_uninitialized`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: initialize_uninitialized\n"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/torch.rst",
    "content": "PyTorch\n.......\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.torch\n\n:hidden:`PyTorchNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: PyTorchNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`PyTorchExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: PyTorchExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`create_pytorch_optims_default`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: create_pytorch_optims_default\n\n:hidden:`convert_torch_to_numpy`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: convert_torch_to_numpy\n"
  },
  {
    "path": "docs/_api/_build/delira/training/backends/torchscript.rst",
    "content": "TorchScript\n...........\n\n.. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.backends.torchscript\n\n:hidden:`TorchScriptNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TorchScriptNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`TorchScriptExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: TorchScriptExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/training/callbacks.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.callbacks\n\nCallbacks\n=========\n\nCallbacks are essential to provide a uniform API for tasks like early stopping\netc.\nThe PyTorch learning rate schedulers are also implemented as callbacks.\nEvery callback should ber derived from :class:`AbstractCallback` and must\nprovide the methods ``at_epoch_begin``\nand ``at_epoch_end``.\n\n:hidden:`AbstractCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AbstractCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`EarlyStopping`\n~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: EarlyStopping\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`DefaultPyTorchSchedulerCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: DefaultPyTorchSchedulerCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. currentmodule:: delira.training.callbacks.pytorch_schedulers\n\n:hidden:`CosineAnnealingLRCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: CosineAnnealingLRCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ExponentialLRCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ExponentialLRCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`LambdaLRCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: LambdaLRCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`MultiStepLRCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: MultiStepLRCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ReduceLROnPlateauCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ReduceLROnPlateauCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`StepLRCallback`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: StepLRCallback\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n.. currentmodule:: delira.training.callbacks\n\n:hidden:`CosineAnnealingLRCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: CosineAnnealingLRCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ExponentialLRCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ExponentialLRCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`LambdaLRCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: LambdaLRCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`MultiStepLRCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: MultiStepLRCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`ReduceLROnPlateauCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: ReduceLROnPlateauCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`StepLRCallbackPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: StepLRCallbackPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n"
  },
  {
    "path": "docs/_api/_build/delira/training/experiment.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training\n\nExperiments\n===========\n\nExperiments are the outermost class to control your training, it wraps your\nNetworkTrainer and provides utilities for\ncross-validation. More Experiments can be found in the sections for the specific backends\n\n:hidden:`BaseExperiment`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseExperiment\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n"
  },
  {
    "path": "docs/_api/_build/delira/training/losses.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.losses\n\nCustom Loss Functions\n=====================\n\n:hidden:`BCEFocalLossPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BCEFocalLossPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`BCEFocalLossLogitPyTorch`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BCEFocalLossLogitPyTorch\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/training/metrics.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.metrics\n\nMetrics\n=======\n\n:hidden:`SklearnClassificationMetric`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnClassificationMetric\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnAccuracyScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnAccuracyScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnBalancedAccuracyScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnBalancedAccuracyScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnF1Score`\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnF1Score\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnFBetaScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnFBetaScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnHammingLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnHammingLoss\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnJaccardSimilarityScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnJaccardSimilarityScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnLogLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnLogLoss\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnMatthewsCorrCoeff`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnMatthewsCorrCoeff\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnPrecisionScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnPrecisionScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnRecallScore`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnRecallScore\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`SklearnZeroOneLoss`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: SklearnZeroOneLoss\n    :members:\n    :undoc-members:\n    :show-inheritance:\n\n:hidden:`AurocMetric`\n~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: AurocMetric\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/training/parameters.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training\n\nParameters\n===============\n\n:hidden:`Parameters`\n~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: Parameters\n    :members:\n    :undoc-members:\n    :show-inheritance:"
  },
  {
    "path": "docs/_api/_build/delira/training/predictor.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training\n\n\nPredictor\n=========\n\nThe predictor implements the basic prediction and metric calculation routines\n and can be subclassed for special routines.\nIt is also the baseclass of all the trainers which extend it's functionality\nby training routines\n\n:hidden:`Predictor`\n~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: Predictor\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/training/trainer.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training\n\n\nNetworkTrainer\n==============\nThe network trainer implements the actual training routine and can be subclassed\n for special routines. More specific trainers can be found in the backend-specific sections\n\n:hidden:`BaseNetworkTrainer`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autoclass:: BaseNetworkTrainer\n    :members:\n    :undoc-members:\n    :show-inheritance:\n"
  },
  {
    "path": "docs/_api/_build/delira/training/training.rst",
    "content": "Training\n========\nThe training subpackage implements Callbacks, a class for Hyperparameters,\ntraining routines and wrapping experiments.\n\n.. toctree::\n\n    Parameters <parameters>\n    Network Trainer <trainer>\n    Predictor <predictor>\n    Experiment <experiment>\n    Backends <backends/backends>\n    Callbacks <callbacks>\n    Losses <losses>\n    Metrics <metrics>\n    Utilities <utils>\n"
  },
  {
    "path": "docs/_api/_build/delira/training/utils.rst",
    "content": ".. role:: hidden\n    :class: hidden-section\n\n.. currentmodule:: delira.training.utils\n\n\ndef recursively_convert_elements(element, check_type, conversion_fn):\n\n\ndef convert_to_numpy_identity(*args, **kwargs):\n\nUtilities\n.........\n\n:hidden:`recursively_convert_elements`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: recursively_convert_elements\n\n:hidden:`convert_to_numpy_identity`\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n.. autofunction:: convert_to_numpy_identity\n"
  },
  {
    "path": "docs/_api/_build/modules.rst",
    "content": "API Documentation\n=================\n\n.. toctree::\n   :maxdepth: 10\n\n   delira/delira\n"
  },
  {
    "path": "docs/classification_pytorch.rst",
    "content": "\nClassification with Delira - A very short introduction\n======================================================\n\n*Author: Justus Schock*\n\n*Date: 04.12.2018*\n\nThis Example shows how to set up a basic classification PyTorch\nexperiment and Visdom Logging Environment.\n\nLet's first setup the essential hyperparameters. We will use\n``delira``'s ``Parameters``-class for this:\n\n.. code:: ipython3\n\n    logger = None\n    import torch\n    from delira.training import Parameters\n    params = Parameters(fixed_params={\n        \"model\": {\n            \"in_channels\": 1, \n            \"n_outputs\": 10\n        },\n        \"training\": {\n            \"batch_size\": 64, # batchsize to use\n            \"num_epochs\": 10, # number of epochs to train\n            \"optimizer_cls\": torch.optim.Adam, # optimization algorithm to use\n            \"optimizer_params\": {'lr': 1e-3}, # initialization parameters for this algorithm\n            \"losses\": {\"CE\": torch.nn.CrossEntropyLoss()}, # the loss function\n            \"lr_sched_cls\": None,  # the learning rate scheduling algorithm to use\n            \"lr_sched_params\": {}, # the corresponding initialization parameters\n            \"metrics\": {} # and some evaluation metrics\n        }\n    }) \n\nSince we did not specify any metric, only the ``CrossEntropyLoss`` will\nbe calculated for each batch. Since we have a classification task, this\nshould be sufficient. We will train our network with a batchsize of 64\nby using ``Adam`` as optimizer of choice.\n\nLogging and Visualization\n-------------------------\n\nTo get a visualization of our results, we should monitor them somehow.\nFor logging we will use ``Visdom``. To start a visdom server you need to\nexecute the following command inside an environment which has visdom\ninstalled:\n\n.. code:: shell\n\n    visdom -port=9999\n\nThis will start a visdom server on port 9999 of your machine and now we\ncan start to configure our logging environment. To view your results you\ncan open http://localhost:9999 in your browser.\n\n.. code:: ipython3\n\n    from trixi.logger import PytorchVisdomLogger\n    from delira.logging import TrixiHandler\n    import logging\n    \n    logger_kwargs = {\n        'name': 'ClassificationExampleLogger', # name of our logging environment\n        'port': 9999 # port on which our visdom server is alive\n    }\n    \n    logger_cls = PytorchVisdomLogger\n    \n    # configure logging module (and root logger)\n    logging.basicConfig(level=logging.INFO,\n                        handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\n    \n    \n    # derive logger from root logger\n    # (don't do `logger = logging.Logger(\"...\")` since this will create a new\n    # logger which is unrelated to the root logger\n    logger = logging.getLogger(\"Test Logger\")\n    \n\nSince a single visdom server can run multiple environments, we need to\nspecify a (unique) name for our environment and need to tell the logger,\non which port it can find the visdom server.\n\nData Preparation\n----------------\n\nLoading\n~~~~~~~\n\nNext we will create a small train and validation set (based on\n``torchvision`` MNIST):\n\n.. code:: ipython3\n\n    from delira.data_loading import TorchvisionClassificationDataset\n    \n    dataset_train = TorchvisionClassificationDataset(\"mnist\", # which dataset to use\n                                                     train=True, # use trainset\n                                                     img_shape=(224, 224) # resample to 224 x 224 pixels\n                                                    )\n    dataset_val = TorchvisionClassificationDataset(\"mnist\", \n                                                   train=False,\n                                                   img_shape=(224, 224)\n                                                  )\n\nAugmentation\n~~~~~~~~~~~~\n\nFor Data-Augmentation we will apply a few transformations:\n\n.. code:: ipython3\n\n    from batchgenerators.transforms import RandomCropTransform, \\\n                                            ContrastAugmentationTransform, Compose\n    from batchgenerators.transforms.spatial_transforms import ResizeTransform\n    from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\n    \n    transforms = Compose([\n        RandomCropTransform(200), # Perform Random Crops of Size 200 x 200 pixels\n        ResizeTransform(224), # Resample these crops back to 224 x 224 pixels\n        ContrastAugmentationTransform(), # randomly adjust contrast\n        MeanStdNormalizationTransform(mean=[0.5], std=[0.5])]) \n    \n    \n\nWith these transformations we can now wrap our datasets into\ndatamanagers:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataManager, SequentialSampler, RandomSampler\n    \n    manager_train = DataManager(dataset_train, params.nested_get(\"batch_size\"),\n                                    transforms=transforms,\n                                    sampler_cls=RandomSampler,\n                                    n_process_augmentation=4)\n    \n    manager_val = DataManager(dataset_val, params.nested_get(\"batch_size\"),\n                                  transforms=transforms,\n                                  sampler_cls=SequentialSampler,\n                                  n_process_augmentation=4)\n    \n\nTraining\n--------\n\nAfter we have done that, we can finally specify our experiment and run\nit. We will therfore use the already implemented\n``ClassificationNetworkBasePyTorch`` which is basically a ResNet18:\n\n.. code:: ipython3\n\n    import warnings\n    warnings.simplefilter(\"ignore\", UserWarning) # ignore UserWarnings raised by dependency code\n    warnings.simplefilter(\"ignore\", FutureWarning) # ignore FutureWarnings raised by dependency code\n    \n    \n    from delira.training import PyTorchExperiment\n    from delira.training.train_utils import create_optims_default_pytorch\n    from delira.models.classification import ClassificationNetworkBasePyTorch\n    \n    if logger is not None:\n        logger.info(\"Init Experiment\")\n    experiment = PyTorchExperiment(params, ClassificationNetworkBasePyTorch,\n                                   name=\"ClassificationExample\",\n                                   save_path=\"./tmp/delira_Experiments\",\n                                   optim_builder=create_optims_default_pytorch,\n                                   gpu_ids=[0])\n    experiment.save()\n    \n    model = experiment.run(manager_train, manager_val)\n\nCongratulations, you have now trained your first Classification Model\nusing ``delira``, we will now predict a few samples from the testset to\nshow, that the networks predictions are valid:\n\n.. code:: ipython3\n\n    import numpy as np\n    from tqdm.auto import tqdm # utility for progress bars\n    \n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # set device (use GPU if available)\n    model = model.to(device) # push model to device\n    preds, labels = [], []\n    \n    with torch.no_grad():\n        for i in tqdm(range(len(dataset_val))):\n            img = dataset_val[i][\"data\"] # get image from current batch\n            img_tensor = torch.from_numpy(img).unsqueeze(0).to(device).to(torch.float) # create a tensor from image, push it to device and add batch dimension\n            pred_tensor = model(img_tensor) # feed it through the network\n            pred = pred_tensor.argmax(1).item() # get index with maximum class confidence\n            label = np.asscalar(dataset_val[i][\"label\"]) # get label from batch\n            if i % 1000 == 0:\n                print(\"Prediction: %d \\t label: %d\" % (pred, label)) # print result\n            preds.append(pred)\n            labels.append(label)\n            \n    # calculate accuracy\n    accuracy = (np.asarray(preds) == np.asarray(labels)).sum() / len(preds)\n    print(\"Accuracy: %.3f\" % accuracy)\n\nSee Also\n--------\n\nFor a more detailed explanation have a look at \\* `the introduction\ntutorial <tutorial_delira.ipynb,>`__ \\* `the 2d segmentation\nexample <segmentation_2d_pytorch.ipynb,>`__ \\* `the 3d segmentation\nexample <segmentation_3d_pytorch.ipynb,>`__ \\* `the generative\nadversarial example <gan_pytorch.ipynb,>`__\n"
  },
  {
    "path": "docs/conda.yml",
    "content": "name: delira-docs\ndependencies:\n  - python=3.7\n  - pip:\n    - sphinx==1.8.4\n    - sphinx-rtd-theme\n"
  },
  {
    "path": "docs/conf.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\n\nfrom delira._version import get_versions\nimport os\nimport sys\nimport re\n\n# source code directory, relative to this file, for sphinx-build\nsys.path.insert(0, os.path.join(os.path.dirname(__file__), os.path.pardir))\n\n# -- Project information -----------------------------------------------------\n\nproject = 'delira'\ncopyright = '2019, Justus Schock, Michael Baumgartner, Oliver Rippel, Christoph Haarburger'\nauthor = 'Justus Schock, Michael Baumgartner, Oliver Rippel, Christoph Haarburger'\n\n\ndef read_file(file):\n    with open(file) as f:\n        content = f.read()\n    return content\n\n\nwhole_version = get_versions()[\"version\"]\n# The short X.Y version\nversion = whole_version.split(\"+\", 1)[0]\n# The full version, including alpha/beta/rc tags\nrelease = whole_version  # delira.__version__\n\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx.ext.autosummary',\n    'sphinx.ext.intersphinx',\n    'sphinx.ext.ifconfig',\n    'sphinx.ext.viewcode',\n    'sphinx.ext.githubpages',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.inheritance_diagram',\n    'sphinx.ext.autosectionlabel',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['_templates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = '.rst'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = None\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path .\nexclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = 'sphinx'\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'sphinx_rtd_theme'\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\nhtml_theme_options = {\n    \"collapse_navigation\": False,\n    \"logo_only\": True\n}\n\nhtml_logo = \"_static/logo/delira.svg\"\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['_static']\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\nhtml_sidebars = {\n    '**': [\n        'relations.html',  # needs 'show_related': True theme option to display\n        'searchbox.html',\n        'localtoc.html',\n        'sourcelink.html',\n    ]\n}\n\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'deliradoc'\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'delira.tex', 'delira Documentation',\n     author, 'manual'),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    (master_doc, 'delira', 'delira Documentation',\n     [author], 1)\n]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'delira', 'delira Documentation',\n     author, 'delira', 'One line description of project.',\n     'Miscellaneous'),\n]\n\n\n# -- Extension configuration -------------------------------------------------\n\n# -- Options for intersphinx extension ---------------------------------------\n\n# Example configuration for intersphinx: refer to the Python standard library.\nintersphinx_mapping = {\n    'https://docs.python.org/': None,\n    'trixi': (\n        'https://trixi.readthedocs.io/en/latest/',\n        None),\n    'torch': (\n        'https://pytorch.org/docs/stable/',\n        None),\n    'tensorflow': (\n        'https://www.tensorflow.org/api_docs/python/',\n        None),\n    'chainer': (\n        'https://docs.chainer.org/en/stable/',\n        None),\n    'sklearn': (\n        'https://scikit-learn.org/stable/documentation/',\n        None),\n    'numpy': (\n        'https://docs.scipy.org/doc/numpy/reference/',\n        None),\n    'scipy': (\n        'https://docs.scipy.org/doc/scipy/reference/'\n    )\n}\n\n# -- Options for todo extension ----------------------------------------------\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = True\nautoclass_content = 'both'\nadd_module_names = False\n\nautodoc_default_flags = ['members',\n                         'undoc-members',\n                         'private-members',\n                         'inherited-members',\n                         'show-inheritance']\n\nautodoc_inherit_docstrings = True\n\nautodoc_mock_imports = [\n    \"numpy\",\n    \"torchvision\",\n    \"torch\",\n    \"skimage\",\n    \"sklearn\",\n    \"jupyter\",\n    \"flake8\"\n    \"pytest-cov\",\n    \"autopep8\",\n    \"ipython\",\n    \"joblib\",\n    \"pillow\",\n    \"SimpleITK\",\n    \"pylint\",\n    \"tqdm\",\n    \"visdom\",\n    \"pyyaml\",\n    \"trixi\",\n    \"batchgenerators\",\n    \"psutil\",\n    \"nested_lookup\",\n    \"colorlover\",\n    \"flask\",\n    \"graphviz\",\n    \"matplotlib\",\n    \"seaborn\",\n    \"scipy\",\n    \"scipy.ndimage\",\n    \"telegram\",\n    \"portalocker\",\n    \"plotly\",\n    \"PIL\",\n    \"umap\",\n    \"tensorflow\",\n    \"yaml\",\n    \"chainer\"\n]\n\n# autodoc_mock_imports = [\n#         \"torch.optim\",\n#         \"torch.optim.lr_scheduler\",\n#         \"yaml\",\n#         \"numpy\",\n#         \"torchvision\",\n#         \"torchvision.datasets\",\n#         \"torch\",\n#         \"torch.nn\",\n#         \"torch.nn.functional\",\n#         \"skimage\",\n#         \"skimage.io\",\n#         \"skimage.transform\",\n#         \"sklearn\",\n#         \"sklearn.model_selection\",\n#         \"jupyter\",\n#         \"flake8\"\n#         \"pytest-cov\",\n#         \"autopep8\",\n#         \"ipython\",\n#         \"joblib\",\n#         \"pillow\",\n#         \"SimpleITK\",\n#         \"pylint\",\n#         \"tqdm\",\n#         \"visdom\",\n#         \"pyyaml\",\n#         \"trixi\",\n#         \"trixi.experiment\",\n#         \"trixi.logger\",\n#         \"trixi.util\",\n#         \"batchgenerators\",\n#         \"batchgenerators.dataloading\",\n#         \"batchgenerators.dataloading.data_loader\",\n#         \"batchgenerators.transforms\",\n#         \"psutil\",\n#         \"nested_lookup\",\n#         \"colorlover\",\n#         \"flask\",\n#         \"graphviz\",\n#         \"matplotlib\",\n#         \"seaborn\",\n#         \"scipy\",\n#         \"scipy.ndimage\",\n#         \"telegram\",\n#         \"portalocker\",\n#         \"plotly\",\n#         \"PIL\",\n#         \"umap\",\n#         \"PIL.Image\",\n#         \"tensorflow\",\n#         \"tqdm.auto\",\n#         \"trixi.logger.tensorboard\",\n#         \"trixi.logger.tensorboard.tensorboardxlogger\",\n#         \"sklearn.metrics\",\n# ]\n"
  },
  {
    "path": "docs/custom_backend.rst",
    "content": "\nHow To: Integrate your own Computation Backend\n==============================================\n\n*Author: Justus Schock*\n\n*Date: 15.05.2019*\n\nThis howto will take you on a trip through the ``delira`` internals,\nwhile we will see, how to add a custom computation backend on the\nexamplaric case of the ``torch.jit`` or ``TorchScript`` backend\n\nModel Definitions\n-----------------\n\nIn order to implement a network, we will first have to define the\nnetwork itself. In ``delira`` there is a single backend-specific\nimplementation of an abstract network class for each of the backends.\nThese interface classes are all based on the ``AbstractNetwork``-class,\ndefining the major API.\n\nSo let's start having a look at this class to see, what we will have to\nimplement for our own backend.\n\nOf course we will have to implement an ``__init__`` defining our class.\nThe ``__init__`` of ``AbstractNetwork`` (which should be called during\nour the ``__init__`` of our baseclass) accepts a number of kwargs and\nsimply registers them to be ``init_kwargs``, so there is nothing we have\nto take care of.\n\nThe next function to inspect is the ``__call__`` function, which makes\nthe class callable and the docstrings indicate, that it should take care\nof our model's forward-pass.\n\nAfter the ``__call__`` we now have the ``closure`` function, which\ndefines a single training step (including, but not limited to,\nforward-pass, calculation of losses and train-metrics, backward-pass and\noptimization).\n\nThe last method to implement is the ``prepare_batch`` function which\nconverts the input to a suitable format and the correct data-type and\ndevice.\n\nTorchScript Limitations\n~~~~~~~~~~~~~~~~~~~~~~~\n\nSince we want to implement an abstract network class for this specific\nbackend, we should have a look on how to generally implement models in\nthis backend.\n\nAccording the the `PyTorch\ndocs <https://pytorch.org/docs/stable/jit.html>`__ this works as\nfollows:\n\n    You can write TorchScript code directly using Python syntax. You do\n    this using the ``torch.jit.script`` decorator (for functions) or\n    ``torch.jit.script_method`` decorator (for methods) on subclasses of\n    ``ScriptModule``. With this decorator the body of the annotated\n    function is directly translated into TorchScript. TorchScript itself\n    is a subset of the Python language, so not all features in Python\n    work, but we provide enough functionality to compute on tensors and\n    do control-dependent operations.\n\nSince our use-case is to implement the interface class for networks, we\nwant to use the way of subclassing ``torch.jit.ScriptModule``, implement\nit's ``forward`` and use the ``torch.jit.script_method`` decorator on\nit.\n\nThe example given in the very same docs for this case is:\n\n.. code:: ipython3\n\n    import torch\n    class MyScriptModule(torch.jit.ScriptModule):\n        def __init__(self, N, M):\n            super().__init__()\n            self.weight = torch.nn.Parameter(torch.rand(N, M))\n    \n        @torch.jit.script_method\n        def forward(self, input):\n            return self.weight.mv(input)\n        \n    my_script_module = MyScriptModule(5, 3)\n    input_tensor = torch.rand(3)\n    my_script_module(input_tensor)\n\n\n\n\n.. parsed-literal::\n\n    tensor([0.4997, 0.2955, 0.1588, 0.1873, 0.4753], grad_fn=<MvBackward>)\n\n\n\nMerging TorchScript into our Abstract Class\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\nThis little example gives us a few things, we have to do for a\nsuccessful definition of our base class:\n\n**1.)** Our class has to subclass both, the ``AbstractNetwork`` and the\n``torch.jit.ScriptModule`` classes.\n\n**2.)** We need to implement a ``forward`` method, which takes care of\nthe forward-pass (as it's name indicates).\n\n**3.)** We don't have to take care of the backward-pass (thanks to\n``PyTorch``'s and ``TorchScript``'s AutoGrad (which is a framework for\nautomatic differentiation).\n\n**4.)** Since ``torch.jit.ScriptModule`` is callable (seen in the\nexample), it already implements a ``__call__`` method and we may simply\nuse this one.\n\n**5.)** The ``closure`` is completely network-dependent and thus has to\nremain an abstract method here.\n\n**6.)** The ``prepare_batch`` function also depends on the combination\nof network, inputs and loss functions to use, but we can at least give a\nprototype of such an function, which handles the devices correctly and\nconverts everything to ``float``\n\nActual Implementation\n~~~~~~~~~~~~~~~~~~~~~\n\nNow, let's start with the actual implementation and do one function by\nanother and keep the things in mind, we just discovered.\n\nClass Signature and ``__init__``-Method\n^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\nTo subclass both networks, we cannot use the simple ``super().__init__``\napproach, because we have to init both parent classes, so we do\n\n.. code:: python\n\n\n        class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):\n\n            @abc.abstractmethod\n            def __init__(self, optimize=True, **kwargs):\n                \"\"\"\n\n                Parameters\n                ----------\n                optimize : bool\n                    whether to optimize the network graph or not; default: True\n                **kwargs :\n                    additional keyword arguments (passed to :class:`AbstractNetwork`)\n                \"\"\"\n                torch.jit.ScriptModule.__init__(self, optimize=optimize)\n                AbstractNetwork.__init__(self, **kwargs)\n                \n\ninstead. This ensures all parent classes to be initialized correctly.\n\n``__call__``-Method\n^^^^^^^^^^^^^^^^^^^\n\nAs mentioned above, the ``__call__`` method is very easy to implement,\nbecause we can simply use the implementation of our ``TorchScript`` base\nclass like this:\n\n.. code:: python\n\n\n        def __call__(self, *args, **kwargs):\n            \"\"\"\n            Calls Forward method\n\n            Parameters\n            ----------\n            *args :\n                positional arguments (passed to `forward`)\n            **kwargs :\n                keyword arguments (passed to `forward`)\n\n            Returns\n            -------\n            Any\n                result: module results of arbitrary type and number\n\n            \"\"\"\n            return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\n            \n\nThis also ensures, that we can pass an arbitrary number or positional\nand keyword arguments of arbitrary types to it (which are all passed to\nthe ``forward``-function). The advantage over directly calling the\n``forward`` method here, is that the ``ScriptModule.__call__`` already\ndoes the handling of\n`forward-pre-hooks <https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_forward_pre_hook>`__,\n`forward-hooks <https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_forward_hook>`__\nand\n`backward-hooks <https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook>`__.\n\n``closure``-Method\n^^^^^^^^^^^^^^^^^^\n\nSince this method is highly model-dependant, we just don't implement it,\nwhich forces the user to implement it (since it is marked as an\n``abstractmethod`` in ``AbstractExperiment``).\n\n``prepare_batch``-Method\n^^^^^^^^^^^^^^^^^^^^^^^^\n\nThe above mentioned prototype of pushing everything to the correct\ndevice and convert it to float looks like this:\n\n.. code:: python\n\n\n        @staticmethod\n        def prepare_batch(batch: dict, input_device, output_device):\n            \"\"\"\n            Helper Function to prepare Network Inputs and Labels (convert them to\n            correct type and shape and push them to correct devices)\n\n            Parameters\n            ----------\n            batch : dict\n                dictionary containing all the data\n            input_device : torch.device\n                device for network inputs\n            output_device : torch.device\n                device for network outputs\n\n            Returns\n            -------\n            dict\n                dictionary containing data in correct type and shape and on correct\n                device\n\n            \"\"\"\n            return_dict = {\"data\": torch.from_numpy(batch.pop(\"data\")).to(\n                input_device).to(torch.float)}\n\n            for key, vals in batch.items():\n                return_dict[key] = torch.from_numpy(vals).to(output_device).to(\n                    torch.float)\n\n            return return_dict\n\nSince we don't want to use any of the model's attributes here (and for\nconformity with the ``AbstractNetwork`` class), this method is defined\nas ``staticmethod``, meaning it is class-bound, not instance-bound. The\n``closure`` method has to be a ``staticmethod`` too.\n\n``forward``-Method\n^^^^^^^^^^^^^^^^^^\n\nThe only thing left now, is the ``forward`` method, which is internally\ncalled by ``ScriptModule.__call__``. The bad news is: We currently can't\nimplement it. Subclassing a ``ScriptModule`` to overwrite a function\ndecorated with ``torch.jit.script_method`` is not (yet) supported, but\nwill be soon, once `this\nPR <https://github.com/pytorch/pytorch/pull/20503>`__ is merged and\nreleased.\n\nFor now: you simply have to implement this method in your own network\ndespite the missing of an abstract interface-method.\n\nPutting it all together\n^^^^^^^^^^^^^^^^^^^^^^^\n\nIf we combine all the function implementations to one class, it looks\nlike this:\n\n.. code:: python\n\n\n        class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):\n\n            \"\"\"\n            Abstract Interface Class for TorchScript Networks. For more information\n            have a look at https://pytorch.org/docs/stable/jit.html#torchscript\n\n            Warnings\n            --------\n            In addition to the here defined API, a forward function must be\n            implemented and decorated with ``@torch.jit.script_method``\n\n            \"\"\"\n            @abc.abstractmethod\n            def __init__(self, optimize=True, **kwargs):\n                \"\"\"\n\n                Parameters\n                ----------\n                optimize : bool\n                    whether to optimize the network graph or not; default: True\n                **kwargs :\n                    additional keyword arguments (passed to :class:`AbstractNetwork`)\n                \"\"\"\n                torch.jit.ScriptModule.__init__(self, optimize=optimize)\n                AbstractNetwork.__init__(self, **kwargs)\n\n            def __call__(self, *args, **kwargs):\n                \"\"\"\n                Calls Forward method\n\n                Parameters\n                ----------\n                *args :\n                    positional arguments (passed to `forward`)\n                **kwargs :\n                    keyword arguments (passed to `forward`)\n\n                Returns\n                -------\n                Any\n                    result: module results of arbitrary type and number\n\n                \"\"\"\n                return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\n\n            @staticmethod\n            def prepare_batch(batch: dict, input_device, output_device):\n                \"\"\"\n                Helper Function to prepare Network Inputs and Labels (convert them to\n                correct type and shape and push them to correct devices)\n\n                Parameters\n                ----------\n                batch : dict\n                    dictionary containing all the data\n                input_device : torch.device\n                    device for network inputs\n                output_device : torch.device\n                    device for network outputs\n\n                Returns\n                -------\n                dict\n                    dictionary containing data in correct type and shape and on correct\n                    device\n\n                \"\"\"\n                return_dict = {\"data\": torch.from_numpy(batch.pop(\"data\")).to(\n                    input_device).to(torch.float)}\n\n                for key, vals in batch.items():\n                    return_dict[key] = torch.from_numpy(vals).to(output_device).to(\n                        torch.float)\n\n                return return_dict\n            \n\nSaving and loading\n------------------\n\nNow that we have the ability to implement ``delira``-suitable\nTorchScript models, we want to store them on disk and load them again,\nso that we don't have to retrain them every time we want to use them.\nThese I/O functions are usually located in ``delira.io``.\n\nSaving\n~~~~~~\n\nOur saving function utilizes multiple functions: ``torch.jit.save`` to\nsimply save the model (including it's graph) and the\n``save_checkpoint_torch`` function implemented for the ``PyTorch``\nbackend to store the trainer state, since ``TorchScript`` allows us to\nuse plain ``PyTorch`` optimizers.\n\nThe implementation of the function looks like this:\n\n.. code:: python\n\n\n        def save_checkpoint_torchscript(file: str, model=None, optimizers={},\n                                        epoch=None, **kwargs):\n            \"\"\"\n            Save current checkpoint to two different files:\n                1.) ``file + \"_model.ptj\"``: Will include the state of the model\n                    (including the graph; this is the opposite to\n                    :func:`save_checkpoint`)\n                2.) ``file + \"_trainer_state.pt\"``: Will include the states of all\n                    optimizers and the current epoch (if given)\n\n            Parameters\n            ----------\n            file : str\n                filepath the model should be saved to\n            model : AbstractPyTorchJITNetwork or None\n                the model which should be saved\n                if None: empty dict will be saved as state dict\n            optimizers : dict\n                dictionary containing all optimizers\n            epoch : int\n                current epoch (will also be pickled)\n\n            \"\"\"\n\n            # remove file extension if given\n            if any([file.endswith(ext) for ext in [\".pth\", \".pt\", \".ptj\"]]):\n                file = file.rsplit(\".\", 1)[0]\n\n            if isinstance(model, AbstractPyTorchJITNetwork):\n                torch.jit.save(model, file + \"_model.ptj\")\n\n            if optimizers or epoch is not None:\n                save_checkpoint_torch(file + \"_trainer_state.pt\", None,\n                                optimizers=optimizers, epoch=epoch, **kwargs)\n                \n\nLoading\n~~~~~~~\n\nTo load a model, which has been saved to disk by this function we have\nto revert each part of it. We do this by using ``torch.jit.load`` for\nthe model (and the graph) and ``load_checkpoint_torch`` by the\n``PyTorch`` backend. The actual implementation is given here:\n\n.. code:: python\n\n\n        def load_checkpoint_torchscript(file: str, **kwargs):\n            \"\"\"\n            Loads a saved checkpoint consisting of 2 files\n            (see :func:`save_checkpoint_jit` for details)\n\n            Parameters\n            ----------\n            file : str\n                filepath to a file containing a saved model\n            **kwargs:\n                Additional keyword arguments (passed to torch.load)\n                Especially \"map_location\" is important to change the device the\n                state_dict should be loaded to\n\n            Returns\n            -------\n            OrderedDict\n                checkpoint state_dict\n\n            \"\"\"\n            # remove file extensions\n            if any([file.endswith(ext) for ext in [\".pth\", \".pt\", \".ptj\"]]):\n                file = file.rsplit(\".\", 1)[0]\n\n            # load model\n            if os.path.isfile(file + \".ptj\"):\n                model_file = file\n            elif os.path.isfile(file + \"_model.ptj\"):\n                model_file = file + \"_model.ptj\"\n            else:\n                raise ValueError(\"No Model File found for %s\" % file)\n\n            # load trainer state (if possible)\n            trainer_file = model_file.replace(\"_model.ptj\", \"_trainer_state.pt\")\n            if os.path.isfile(trainer_file):\n                trainer_state = load_checkpoint_torch(trainer_file, **kwargs)\n\n            else:\n                trainer_state = {\"optimizer\": {},\n                                 \"epoch\": None}\n\n            trainer_state.update({\"model\": torch.jit.load(model_file)})\n\n            return trainer_state\n        \n\nA Trainer to train\n------------------\n\nNow, that we can define and save/load our models, we want to train them.\nLuckily ``delira`` has already implemented a very modular\nbackend-agnostic trainer (the ``BaseNetworkTrainer``) and build upon\nthis a ``PyTorchNetworkTrainer``. Since the training process in PyTorch\nand TorchScript is nearly the same, we can just extend the\n``PyTorchNetworkTrainer``. Usually one would have to extend the\n``BaseNetworkTrainer`` to provide some backend specific functions (like\nnecessary initializations, optimizer setup, seeding etc.). To see how\nthis is done, you could either have a look at the\n``PyTorchNetworkTrainer`` or the ``TfNetworkTrainer`` for tensorflow,\nwhich are both following this principle. Usually the only stuff to\ncompletely change is the loading/saving behavior and the ``_setup``\nfunction, which defines the backend-specific initialization. Some other\nfunctions may have to be extended (by implementing the extension and\ncalling the parent-classes function).\n\nThings to change:\n~~~~~~~~~~~~~~~~~\n\nBy Subclassing the ``PyTorchNetworkTrainer`` we have to change the\nfollowing things:\n\n-  The trainer's default arguments\n\n-  The behavior for trying to resume a previous training\n\n-  The saving, loading and updating behavior\n\nWe will access this one by one:\n\nThe Default Arguments\n^^^^^^^^^^^^^^^^^^^^^\n\nWe want to use ``AbstractTorchScriptNetwork``\\ s instead of\n``AbstractPyTorchNetwork``\\ s here and we have to change the behavior if\npassing multiple GPUs, because currently Multi-GPU training is not\nsupported by ``TorchScript``.\n\nTo do this: we implement the functions ``__init__``, apply our changes\nand forward these changes to the call of the base-classes ``__init__``\nlike this (omitted docstrings for the sake of shortness):\n\n.. code:: python\n\n\n    class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):\n            def __init__(self,\n                         network: AbstractTorchScriptNetwork,\n                         save_path: str,\n                         key_mapping,\n                         losses=None,\n                         optimizer_cls=None,\n                         optimizer_params={},\n                         train_metrics={},\n                         val_metrics={},\n                         lr_scheduler_cls=None,\n                         lr_scheduler_params={},\n                         gpu_ids=[],\n                         save_freq=1,\n                         optim_fn=create_optims_default,\n                         logging_type=\"tensorboardx\",\n                         logging_kwargs={},\n                         fold=0,\n                         callbacks=[],\n                         start_epoch=1,\n                         metric_keys=None,\n                         convert_batch_to_npy_fn=convert_torch_tensor_to_npy,\n                         criterions=None,\n                         val_freq=1,\n                         **kwargs):\n                \n                if len(gpu_ids) > 1:\n                    # only use first GPU due to\n                    # https://github.com/pytorch/pytorch/issues/15421\n                    gpu_ids = [gpu_ids[0]]\n                    logging.warning(\"Multiple GPUs specified. Torch JIT currently \"\n                                    \"supports only single-GPU training. \"\n                                    \"Switching to use only the first GPU for now...\")\n\n                super().__init__(network=network, save_path=save_path,\n                                 key_mapping=key_mapping, losses=losses,\n                                 optimizer_cls=optimizer_cls,\n                                 optimizer_params=optimizer_params,\n                                 train_metrics=train_metrics,\n                                 val_metrics=val_metrics,\n                                 lr_scheduler_cls=lr_scheduler_cls,\n                                 lr_scheduler_params=lr_scheduler_params,\n                                 gpu_ids=gpu_ids, save_freq=save_freq,\n                                 optim_fn=optim_fn, logging_type=logging_type,\n                                 logging_kwargs=logging_kwargs, fold=fold,\n                                 callbacks=callbacks,\n                                 start_epoch=start_epoch, metric_keys=metric_keys,\n                                 convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                                 mixed_precision=False, mixed_precision_kwargs={},\n                                 criterions=criterions, val_freq=val_freq, **kwargs\n                                 )\n                \n\nResuming Training\n^^^^^^^^^^^^^^^^^\n\nFor resuming the training, we have to completely change the\n``try_resume_training`` function and cannot reuse the parent's\nimplementation of it. Thus, we don't call\n``super().try_resume_training`` here, but completely reimplement it from\nscratch:\n\n.. code:: python\n\n\n        def try_resume_training(self):\n            \"\"\"\n            Load the latest state of a previous training if possible\n\n            \"\"\"\n            # Load latest epoch file if available\n            if os.path.isdir(self.save_path):\n                # check all files in directory starting with \"checkpoint\" and\n                # not ending with \"_best.pth\"\n                files = [x for x in os.listdir(self.save_path)\n                         if os.path.isfile(os.path.join(self.save_path, x))\n                         and x.startswith(\"checkpoint\")\n                         and not x.endswith(\"_best.ptj\")\n                         ]\n\n                # if list is not empty: load previous state\n                if files:\n\n                    latest_epoch = max([\n                        int(x.rsplit(\"_\", 1)[-1].rsplit(\".\", 1)[0])\n                        for x in files])\n\n                    latest_state_path = os.path.join(self.save_path,\n                                                     \"checkpoint_epoch_%d.ptj\"\n                                                     % latest_epoch)\n\n                    # if pth file does not exist, load pt file instead\n                    if not os.path.isfile(latest_state_path):\n                        latest_state_path = latest_state_path[:-1]\n\n                    logger.info(\"Attempting to load state from previous \\\n                                training from %s\" % latest_state_path)\n                    try:\n                        self.update_state(latest_state_path)\n                    except KeyError:\n                        logger.warning(\"Previous State could not be loaded, \\\n                                        although it exists.Training will be \\\n                                        restarted\")\n\nSaving and Loading\n^^^^^^^^^^^^^^^^^^\n\nNow we need to change the saving and loading behavior. As always we try\nto reuse as much code as possible to avoid code duplication.\n\nSaving\n''''''\n\nTo save the current training state, we simply call the\n``save_checkpoint_torchscript`` function:\n\n.. code:: python\n\n\n        def save_state(self, file_name, epoch, **kwargs):\n            \"\"\"\n            saves the current state via\n            :func:`delira.io.torch.save_checkpoint_jit`\n\n            Parameters\n            ----------\n            file_name : str\n                filename to save the state to\n            epoch : int\n                current epoch (will be saved for mapping back)\n            **kwargs :\n                keyword arguments\n\n            \"\"\"\n            if file_name.endswith(\".pt\") or file_name.endswith(\".pth\"):\n                file_name = file_name.rsplit(\".\", 1)[0]\n\n            save_checkpoint_torchscript(file_name, self.module, self.optimizers,\n                                        **kwargs)\n            \n\nLoading\n'''''''\n\nTo load the training state, we simply return the state loaded by\n``load_checkpoint_torchscript``. Since we don't use any arguments of the\ntrainer itself here, the function is a ``staticmethod``:\n\n.. code:: python\n\n\n        @staticmethod\n        def load_state(file_name, **kwargs):\n            \"\"\"\n            Loads the new state from file via\n            :func:`delira.io.torch.load_checkpoint:jit`\n\n            Parameters\n            ----------\n            file_name : str\n                the file to load the state from\n            **kwargs : keyword arguments\n\n            Returns\n            -------\n            dict\n                new state\n\n            \"\"\"\n            return load_checkpoint_torchscript(file_name, **kwargs)\n        \n\nUpdating\n''''''''\n\nAfter we loaded the new state, we need to update the trainer's internal\nstate by this new state.\n\nWe do this by directly assigning the model here (since the graph was\nstored/loaded too) instead of only updating the state\\_dict and calling\nthe parent-classes method afterwards:\n\n.. code:: python\n\n\n        def _update_state(self, new_state):\n            \"\"\"\n            Update the state from a given new state\n\n            Parameters\n            ----------\n            new_state : dict\n                new state to update internal state from\n\n            Returns\n            -------\n            :class:`PyTorchNetworkJITTrainer`\n                the trainer with a modified state\n\n            \"\"\"\n            if \"model\" in new_state:\n                self.module = new_state.pop(\"model\").to(self.input_device)\n\n            return super()._update_state(new_state)\n\nA Whole Trainer\n~~~~~~~~~~~~~~~\n\nAfter combining all the changes above, we finally get our new trainer\nas:\n\n.. code:: python\n\n\n        class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):\n            def __init__(self,\n                         network: AbstractTorchScriptNetwork,\n                         save_path: str,\n                         key_mapping,\n                         losses=None,\n                         optimizer_cls=None,\n                         optimizer_params={},\n                         train_metrics={},\n                         val_metrics={},\n                         lr_scheduler_cls=None,\n                         lr_scheduler_params={},\n                         gpu_ids=[],\n                         save_freq=1,\n                         optim_fn=create_optims_default,\n                         logging_type=\"tensorboardx\",\n                         logging_kwargs={},\n                         fold=0,\n                         callbacks=[],\n                         start_epoch=1,\n                         metric_keys=None,\n                         convert_batch_to_npy_fn=convert_torch_tensor_to_npy,\n                         criterions=None,\n                         val_freq=1,\n                         **kwargs):\n                \"\"\"\n\n                Parameters\n                ----------\n                network : :class:`AbstractPyTorchJITNetwork`\n                    the network to train\n                save_path : str\n                    path to save networks to\n                key_mapping : dict\n                    a dictionary containing the mapping from the ``data_dict`` to\n                    the actual model's inputs.\n                    E.g. if a model accepts one input named 'x' and the data_dict\n                    contains one entry named 'data' this argument would have to\n                    be ``{'x': 'data'}``\n                losses : dict\n                    dictionary containing the training losses\n                optimizer_cls : subclass of tf.train.Optimizer\n                    optimizer class implementing the optimization algorithm of\n                    choice\n                optimizer_params : dict\n                    keyword arguments passed to optimizer during construction\n                train_metrics : dict, optional\n                    metrics, which will be evaluated during train phase\n                    (should work on framework's tensor types)\n                val_metrics : dict, optional\n                    metrics, which will be evaluated during test phase\n                    (should work on numpy arrays)\n                lr_scheduler_cls : Any\n                    learning rate schedule class: must implement step() method\n                lr_scheduler_params : dict\n                    keyword arguments passed to lr scheduler during construction\n                gpu_ids : list\n                    list containing ids of GPUs to use; if empty: use cpu instead\n                    Currently ``torch.jit`` only supports single GPU-Training,\n                    thus only the first GPU will be used if multiple GPUs are passed\n                save_freq : int\n                    integer specifying how often to save the current model's state.\n                    State is saved every state_freq epochs\n                optim_fn : function\n                    creates a dictionary containing all necessary optimizers\n                logging_type : str or callable\n                    the type of logging. If string: it must be one of\n                    [\"visdom\", \"tensorboardx\"]\n                    If callable: it must be a logging handler class\n                logging_kwargs : dict\n                    dictionary containing all logging keyword arguments\n                fold : int\n                    current cross validation fold (0 per default)\n                callbacks : list\n                    initial callbacks to register\n                start_epoch : int\n                    epoch to start training at\n                metric_keys : dict\n                    dict specifying which batch_dict entry to use for which metric as\n                    target; default: None, which will result in key \"label\" for all\n                    metrics\n                convert_batch_to_npy_fn : type, optional\n                    function converting a batch-tensor to numpy, per default this is\n                    a function, which detaches the tensor, moves it to cpu and the\n                    calls ``.numpy()`` on it\n                mixed_precision : bool\n                    whether to use mixed precision or not (False per default)\n                mixed_precision_kwargs : dict\n                    additional keyword arguments for mixed precision\n                val_freq : int\n                    validation frequency specifying how often to validate the trained\n                    model (a value of 1 denotes validating every epoch,\n                    a value of 2 denotes validating every second epoch etc.);\n                    defaults to 1\n                **kwargs :\n                    additional keyword arguments\n\n                \"\"\"\n\n                if len(gpu_ids) > 1:\n                    # only use first GPU due to\n                    # https://github.com/pytorch/pytorch/issues/15421\n                    gpu_ids = [gpu_ids[0]]\n                    logging.warning(\"Multiple GPUs specified. Torch JIT currently \"\n                                    \"supports only single-GPU training. \"\n                                    \"Switching to use only the first GPU for now...\")\n\n                super().__init__(network=network, save_path=save_path,\n                                 key_mapping=key_mapping, losses=losses,\n                                 optimizer_cls=optimizer_cls,\n                                 optimizer_params=optimizer_params,\n                                 train_metrics=train_metrics,\n                                 val_metrics=val_metrics,\n                                 lr_scheduler_cls=lr_scheduler_cls,\n                                 lr_scheduler_params=lr_scheduler_params,\n                                 gpu_ids=gpu_ids, save_freq=save_freq,\n                                 optim_fn=optim_fn, logging_type=logging_type,\n                                 logging_kwargs=logging_kwargs, fold=fold,\n                                 callbacks=callbacks,\n                                 start_epoch=start_epoch, metric_keys=metric_keys,\n                                 convert_batch_to_npy_fn=convert_batch_to_npy_fn,\n                                 mixed_precision=False, mixed_precision_kwargs={},\n                                 criterions=criterions, val_freq=val_freq, **kwargs\n                                 )\n\n            def try_resume_training(self):\n                \"\"\"\n                Load the latest state of a previous training if possible\n\n                \"\"\"\n                # Load latest epoch file if available\n                if os.path.isdir(self.save_path):\n                    # check all files in directory starting with \"checkpoint\" and\n                    # not ending with \"_best.pth\"\n                    files = [x for x in os.listdir(self.save_path)\n                             if os.path.isfile(os.path.join(self.save_path, x))\n                             and x.startswith(\"checkpoint\")\n                             and not x.endswith(\"_best.ptj\")\n                             ]\n\n                    # if list is not empty: load previous state\n                    if files:\n\n                        latest_epoch = max([\n                            int(x.rsplit(\"_\", 1)[-1].rsplit(\".\", 1)[0])\n                            for x in files])\n\n                        latest_state_path = os.path.join(self.save_path,\n                                                         \"checkpoint_epoch_%d.ptj\"\n                                                         % latest_epoch)\n\n                        # if pth file does not exist, load pt file instead\n                        if not os.path.isfile(latest_state_path):\n                            latest_state_path = latest_state_path[:-1]\n\n                        logger.info(\"Attempting to load state from previous \\\n                                    training from %s\" % latest_state_path)\n                        try:\n                            self.update_state(latest_state_path)\n                        except KeyError:\n                            logger.warning(\"Previous State could not be loaded, \\\n                                            although it exists.Training will be \\\n                                            restarted\")\n\n            def save_state(self, file_name, epoch, **kwargs):\n                \"\"\"\n                saves the current state via\n                :func:`delira.io.torch.save_checkpoint_jit`\n\n                Parameters\n                ----------\n                file_name : str\n                    filename to save the state to\n                epoch : int\n                    current epoch (will be saved for mapping back)\n                **kwargs :\n                    keyword arguments\n\n                \"\"\"\n                if file_name.endswith(\".pt\") or file_name.endswith(\".pth\"):\n                    file_name = file_name.rsplit(\".\", 1)[0]\n\n                save_checkpoint_torchscript(file_name, self.module, self.optimizers,\n                                            **kwargs)\n\n            @staticmethod\n            def load_state(file_name, **kwargs):\n                \"\"\"\n                Loads the new state from file via\n                :func:`delira.io.torch.load_checkpoint:jit`\n\n                Parameters\n                ----------\n                file_name : str\n                    the file to load the state from\n                **kwargs : keyword arguments\n\n                Returns\n                -------\n                dict\n                    new state\n\n                \"\"\"\n                return load_checkpoint_torchscript(file_name, **kwargs)\n\n            def _update_state(self, new_state):\n                \"\"\"\n                Update the state from a given new state\n\n                Parameters\n                ----------\n                new_state : dict\n                    new state to update internal state from\n\n                Returns\n                -------\n                :class:`PyTorchNetworkJITTrainer`\n                    the trainer with a modified state\n\n                \"\"\"\n                if \"model\" in new_state:\n                    self.module = new_state.pop(\"model\").to(self.input_device)\n\n                return super()._update_state(new_state)\n            \n\nWrapping it all in an Experiment\n--------------------------------\n\nTo have access to methods like a K-Fold (and the not yet finished)\nhyperparameter tuning, we need to wrap the trainer in an Experiment. We\nwill use the same approach as we did for implementing the trainer:\nExtending an already provided class.\n\nThis time we extend the ``PyTorchExperiment`` which itself extends the\n``BaseExperiment`` by some backend-specific defaults, types and seeds.\n\nOur whole class definition just changes the default arguments of the\n``PyTorchExperiment`` and thus, we only have to implenent it's\n``__init__``:\n\n.. code:: python\n\n\n    class TorchScriptExperiment(PyTorchExperiment):\n        def __init__(self,\n                     params: typing.Union[str, Parameters],\n                     model_cls: AbstractTorchScriptNetwork, # not AbstractPyTorchNetwork anymore\n                     n_epochs=None,\n                     name=None,\n                     save_path=None,\n                     key_mapping=None,\n                     val_score_key=None,\n                     optim_builder=create_optims_default_pytorch,\n                     checkpoint_freq=1,\n                     trainer_cls=TorchScriptNetworkTrainer, # not PyTorchNetworkTrainer anymore\n                     **kwargs):\n            \"\"\"\n\n            Parameters\n            ----------\n            params : :class:`Parameters` or str\n                the training parameters, if string is passed,\n                it is treated as a path to a pickle file, where the\n                parameters are loaded from\n            model_cls : Subclass of :class:`AbstractTorchScriptNetwork`\n                the class implementing the model to train\n            n_epochs : int or None\n                the number of epochs to train, if None: can be specified later\n                during actual training\n            name : str or None\n                the Experiment's name\n            save_path : str or None\n                the path to save the results and checkpoints to.\n                if None: Current working directory will be used\n            key_mapping : dict\n                mapping between data_dict and model inputs (necessary for\n                prediction with :class:`Predictor`-API), if no keymapping is\n                given, a default key_mapping of {\"x\": \"data\"} will be used here\n            val_score_key : str or None\n                key defining which metric to use for validation (determining\n                best model and scheduling lr); if None: No validation-based\n                operations will be done (model might still get validated,\n                but validation metrics can only be logged and not used further)\n            optim_builder : function\n                Function returning a dict of backend-specific optimizers.\n                defaults to :func:`create_optims_default_pytorch`\n            checkpoint_freq : int\n                frequency of saving checkpoints (1 denotes saving every epoch,\n                2 denotes saving every second epoch etc.); default: 1\n            trainer_cls : subclass of :class:`TorchScriptNetworkTrainer`\n                the trainer class to use for training the model, defaults to\n                :class:`TorchScriptNetworkTrainer`\n            **kwargs :\n                additional keyword arguments\n\n            \"\"\"\n            super().__init__(params=params, model_cls=model_cls,\n                             n_epochs=n_epochs, name=name, save_path=save_path,\n                             key_mapping=key_mapping,\n                             val_score_key=val_score_key,\n                             optim_builder=optim_builder,\n                             checkpoint_freq=checkpoint_freq,\n                             trainer_cls=trainer_cls,\n                             **kwargs)\n            \n\nTesting it\n----------\n\nNow that we finished the implementation of the backend (which is the\noutermost wrapper; Congratulations!), we can just test it. We'll use a\nvery simple network and test it with dummy data. We also only test the\n``run`` and ``test`` functionality of our experiment, since everything\nelse is just used for setting up the internal state or a composition of\nthese two methods and already tested: Now, let's just define our\ndataset, instantiate it three times (for training, validation and\ntesting) and wrap each of them into a ``DataManager``:\n\n.. code:: ipython3\n\n    from delira.data_loading import AbstractDataset\n    from delira.data_loading import DataManager\n    \n    \n    class DummyDataset(AbstractDataset):\n        def __init__(self, length):\n            super().__init__(None, None)\n            self.length = length\n    \n        def __getitem__(self, index):\n            return {\"data\": np.random.rand(32),\n                    \"label\": np.random.randint(0, 1, 1)}\n    \n        def __len__(self):\n            return self.length\n    \n        def get_sample_from_index(self, index):\n            return self.__getitem__(index)\n        \n    dset_train = DummyDataset(500)\n    dset_val = DummyDataset(50)\n    dset_test = DummyDataset(10)\n    \n    # training, validation and testing with \n    #a batchsize of 16, 1 loading thread and no transformations.\n    dmgr_train = DataManager(dset_train, 16, 1, None)\n    dmgr_val = DataManager(dset_val, 16, 1, None)\n    dmgr_test = DataManager(dset_test, 16, 1, None)\n\nNow, that we have created three datasets, we need to define our small\ndummy network. We do this by subclassing\n``delira.models.AbstractTorchScriptNetwork`` (which is the exactly\nimplementation given above, be we need to use the internal one, because\nthere are some typechecks against this one).\n\n.. code:: ipython3\n\n    from delira.models import AbstractTorchScriptNetwork\n    import torch\n    \n    \n    class DummyNetworkTorchScript(AbstractTorchScriptNetwork):\n        __constants__ = [\"module\"]\n    \n        def __init__(self):\n            super().__init__()\n            self.module = self._build_model(32, 1)\n    \n        @torch.jit.script_method\n        def forward(self, x):\n            return {\"pred\": self.module(x)}\n    \n        @staticmethod\n        def prepare_batch(batch_dict, input_device, output_device):\n            return {\"data\": torch.from_numpy(batch_dict[\"data\"]\n                                             ).to(input_device,\n                                                  torch.float),\n                    \"label\": torch.from_numpy(batch_dict[\"label\"]\n                                              ).to(output_device,\n                                                   torch.float)}\n    \n        @staticmethod\n        def closure(model: AbstractTorchScriptNetwork, data_dict: dict,\n                    optimizers: dict, losses={}, metrics={},\n                    fold=0, **kwargs):\n            \"\"\"\n            closure method to do a single backpropagation step\n    \n    \n            Parameters\n            ----------\n            model : \n                trainable model\n            data_dict : dict\n                dictionary containing the data\n            optimizers : dict\n                dictionary of optimizers to optimize model's parameters\n            losses : dict\n                dict holding the losses to calculate errors\n                (gradients from different losses will be accumulated)\n            metrics : dict\n                dict holding the metrics to calculate\n            fold : int\n                Current Fold in Crossvalidation (default: 0)\n            **kwargs:\n                additional keyword arguments\n    \n            Returns\n            -------\n            dict\n                Metric values (with same keys as input dict metrics)\n            dict\n                Loss values (with same keys as input dict losses)\n            list\n                Arbitrary number of predictions as torch.Tensor\n    \n            Raises\n            ------\n            AssertionError\n                if optimizers or losses are empty or the optimizers are not\n                specified\n    \n            \"\"\"\n    \n            assert (optimizers and losses) or not optimizers, \\\n                \"Criterion dict cannot be emtpy, if optimizers are passed\"\n    \n            loss_vals = {}\n            metric_vals = {}\n            total_loss = 0\n    \n            # choose suitable context manager:\n            if optimizers:\n                context_man = torch.enable_grad\n    \n            else:\n                context_man = torch.no_grad\n    \n            with context_man():\n    \n                inputs = data_dict.pop(\"data\")\n                preds = model(inputs)\n    \n                if data_dict:\n    \n                    for key, crit_fn in losses.items():\n                        _loss_val = crit_fn(preds[\"pred\"], *data_dict.values())\n                        loss_vals[key] = _loss_val.item()\n                        total_loss += _loss_val\n    \n                    with torch.no_grad():\n                        for key, metric_fn in metrics.items():\n                            metric_vals[key] = metric_fn(\n                                preds[\"pred\"], *data_dict.values()).item()\n    \n            if optimizers:\n                optimizers['default'].zero_grad()\n                # perform loss scaling via apex if half precision is enabled\n                with optimizers[\"default\"].scale_loss(total_loss) as scaled_loss:\n                    scaled_loss.backward()\n                optimizers['default'].step()\n    \n            else:\n    \n                # add prefix \"val\" in validation mode\n                eval_loss_vals, eval_metrics_vals = {}, {}\n                for key in loss_vals.keys():\n                    eval_loss_vals[\"val_\" + str(key)] = loss_vals[key]\n    \n                for key in metric_vals:\n                    eval_metrics_vals[\"val_\" + str(key)] = metric_vals[key]\n    \n                loss_vals = eval_loss_vals\n                metric_vals = eval_metrics_vals\n    \n            return metric_vals, loss_vals, {k: v.detach()\n                                            for k, v in preds.items()}\n    \n        @staticmethod\n        def _build_model(in_channels, n_outputs):\n            return torch.nn.Sequential(\n                torch.nn.Linear(in_channels, 64),\n                torch.nn.ReLU(),\n                torch.nn.Linear(64, n_outputs)\n            )\n\nNow, that we defined our model, let's just test, if we really can\nforward some tensors through it. We will just use some random\n``torch.Tensors`` (created by ``torch.rand``). Since our model accepts\n1d inputs of length 32, we need to pass 2d tensors to it (the additional\ndimension is the batch-dimension).\n\n.. code:: ipython3\n\n    input_tensor_single = torch.rand(1, 32) # use a single-sample batch (batchsize=1) here\n    input_tensor_batched = torch.rand(4, 32) # use a batch with batchsize 4 here\n    \n    # create model instance\n    model = DummyNetworkTorchScript()\n    \n    outputs = {\"single\": model(input_tensor_single)[\"pred\"], \"batched\": model(input_tensor_batched)[\"pred\"]}\n    outputs\n\n\n\n\n.. parsed-literal::\n\n    {'single': tensor([[-0.1934]], grad_fn=<DifferentiableGraphBackward>),\n     'batched': tensor([[-0.0525],\n             [-0.0884],\n             [-0.1492],\n             [-0.0431]], grad_fn=<DifferentiableGraphBackward>)}\n\n\n\n.. code:: ipython3\n\n    from sklearn.metrics import mean_absolute_error\n    from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch\n    from delira.training import Parameters\n    params = Parameters(fixed_params={\n                        \"model\": {},\n                        \"training\": {\n                            \"losses\": {\"CE\": torch.nn.BCEWithLogitsLoss()},\n                            \"optimizer_cls\": torch.optim.Adam,\n                            \"optimizer_params\": {\"lr\": 1e-3},\n                            \"num_epochs\": 2,\n                            \"val_metrics\": {\"mae\": mean_absolute_error},\n                            \"lr_sched_cls\": ReduceLROnPlateauCallbackPyTorch,\n                            \"lr_sched_params\": {\"mode\": \"min\"}\n                        }\n                    }\n              )\n    \n    from delira.training import TorchScriptExperiment\n    \n    exp = TorchScriptExperiment(params, DummyNetworkTorchScript,\n                                key_mapping={\"x\": \"data\"},\n                                val_score_key=\"mae\",\n                                val_score_mode=\"min\")\n    \n    trained_model = exp.run(dmgr_train, dmgr_val)\n    exp.test(trained_model, dmgr_test, params.nested_get(\"val_metrics\"))\n\nCongratulations. You have implemented your first fully-workable\n``delira``-Backend. Wasn't that hard, was it?\n\nBefore you start implementing backends for all the other frameworks out\nthere, let me just give you some advices:\n\n-  You should test everything you implement or extend\n\n-  Make sure, to keep your backend-specification in mind\n\n-  Always follow the API of already existing backends. If this is not\n   possible: test this extensively\n\n-  If you extend another backend (like we did here; we extended the\n   ``PyTorch``-backend for ``TorchScript``), make sure, that the\n   \"base-backend\" is always installed (best if they can only be\n   installed together)\n\n-  If you have questions regarding the implementation, don't hestiate to\n   contact us.\n"
  },
  {
    "path": "docs/gan_pytorch.rst",
    "content": "\nGenerative Adversarial Nets with Delira - A very short introduction\n===================================================================\n\n*Author: Justus Schock*\n\n*Date: 04.12.2018*\n\nThis Example shows how to set up a basic GAN PyTorch experiment and\nVisdom Logging Environment.\n\nHyperParameters\n---------------\n\nLet's first setup the essential hyperparameters. We will use\n``delira``'s ``Parameters``-class for this:\n\n.. code:: ipython3\n\n    logger = None\n    import torch\n    from delira.training import Parameters\n    params = Parameters(fixed_params={\n        \"model\": {\n            \"n_channels\": 1, \n            \"noise_length\": 10\n        },\n        \"training\": {\n            \"batch_size\": 64, # batchsize to use\n            \"num_epochs\": 10, # number of epochs to train\n            \"optimizer_cls\": torch.optim.Adam, # optimization algorithm to use\n            \"optimizer_params\": {'lr': 1e-3}, # initialization parameters for this algorithm\n            \"losses\": {\"L1\": torch.nn.L1Loss()}, # the loss function\n            \"lr_sched_cls\": None,  # the learning rate scheduling algorithm to use\n            \"lr_sched_params\": {}, # the corresponding initialization parameters\n            \"metrics\": {} # and some evaluation metrics\n        }\n    }) \n\nSince we specified ``torch.nn.L1Loss`` as criterion and\n``torch.nn.MSELoss`` as metric, they will be both calculated for each\nbatch, but only the criterion will be used for backpropagation. Since we\nhave a simple generative task, this should be sufficient. We will train\nour network with a batchsize of 64 by using ``Adam`` as optimizer of\nchoice.\n\nLogging and Visualization\n-------------------------\n\nTo get a visualization of our results, we should monitor them somehow.\nFor logging we will use ``Visdom``. To start a visdom server you need to\nexecute the following command inside an environment which has visdom\ninstalled:\n\n.. code:: shell\n\n    visdom -port=9999\n\nThis will start a visdom server on port 9999 of your machine and now we\ncan start to configure our logging environment. To view your results you\ncan open http://localhost:9999 in your browser.\n\n.. code:: ipython3\n\n    from trixi.logger import PytorchVisdomLogger\n    from delira.logging import TrixiHandler\n    import logging\n    \n    logger_kwargs = {\n        'name': 'GANExampleLogger', # name of our logging environment\n        'port': 9999 # port on which our visdom server is alive\n    }\n    \n    logger_cls = PytorchVisdomLogger\n    \n    # configure logging module (and root logger)\n    logging.basicConfig(level=logging.INFO,\n                        handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\n    \n    \n    # derive logger from root logger\n    # (don't do `logger = logging.Logger(\"...\")` since this will create a new\n    # logger which is unrelated to the root logger\n    logger = logging.getLogger(\"Test Logger\")\n    \n\nSince a single visdom server can run multiple environments, we need to\nspecify a (unique) name for our environment and need to tell the logger,\non which port it can find the visdom server.\n\nData Preparation\n----------------\n\nLoading\n~~~~~~~\n\nNext we will create a small train and validation set (based on\n``torchvision`` MNIST):\n\n.. code:: ipython3\n\n    from delira.data_loading import TorchvisionClassificationDataset\n    \n    dataset_train = TorchvisionClassificationDataset(\"mnist\", # which dataset to use\n                                                     train=True, # use trainset\n                                                     img_shape=(224, 224) # resample to 224 x 224 pixels\n                                                    )\n    dataset_val = TorchvisionClassificationDataset(\"mnist\", \n                                                   train=False,\n                                                   img_shape=(224, 224)\n                                                  )\n\nAugmentation\n~~~~~~~~~~~~\n\nFor Data-Augmentation we will apply a few transformations:\n\n.. code:: ipython3\n\n    from batchgenerators.transforms import RandomCropTransform, \\\n                                            ContrastAugmentationTransform, Compose\n    from batchgenerators.transforms.spatial_transforms import ResizeTransform\n    from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\n    \n    transforms = Compose([\n        RandomCropTransform(200), # Perform Random Crops of Size 200 x 200 pixels\n        ResizeTransform(224), # Resample these crops back to 224 x 224 pixels\n        ContrastAugmentationTransform(), # randomly adjust contrast\n        MeanStdNormalizationTransform(mean=[0.5], std=[0.5])]) \n    \n    \n\nWith these transformations we can now wrap our datasets into\ndatamanagers:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataManager, SequentialSampler, RandomSampler\n    \n    manager_train = DataManager(dataset_train, params.nested_get(\"batch_size\"),\n                                    transforms=transforms,\n                                    sampler_cls=RandomSampler,\n                                    n_process_augmentation=4)\n    \n    manager_val = DataManager(dataset_val, params.nested_get(\"batch_size\"),\n                                  transforms=transforms,\n                                  sampler_cls=SequentialSampler,\n                                  n_process_augmentation=4)\n    \n\nTraining\n--------\n\nAfter we have done that, we can finally specify our experiment and run\nit. We will therfore use the already implemented\n``GenerativeAdversarialNetworkBasePyTorch`` which is basically a vanilla\nDCGAN:\n\n.. code:: ipython3\n\n    import warnings\n    warnings.simplefilter(\"ignore\", UserWarning) # ignore UserWarnings raised by dependency code\n    warnings.simplefilter(\"ignore\", FutureWarning) # ignore FutureWarnings raised by dependency code\n    \n    \n    from delira.training import PyTorchExperiment\n    from delira.training.train_utils import create_optims_gan_default_pytorch\n    from delira.models.gan import GenerativeAdversarialNetworkBasePyTorch\n    \n    if logger is not None:\n        logger.info(\"Init Experiment\")\n    experiment = PyTorchExperiment(params, GenerativeAdversarialNetworkBasePyTorch,\n                                   name=\"GANExample\",\n                                   save_path=\"./tmp/delira_Experiments\",\n                                   optim_builder=create_optims_gan_default_pytorch,\n                                   gpu_ids=[0])\n    experiment.save()\n    \n    model = experiment.run(manager_train, manager_val)\n\nCongratulations, you have now trained your first Generative Adversarial\nModel using ``delira``.\n\nSee Also\n--------\n\nFor a more detailed explanation have a look at \\* `the introduction\ntutorial <tutorial_delira.ipynb,>`__ \\* `the 2d segmentation\nexample <segmentation_2d_pytorch.ipynb,>`__ \\* `the 3d segmentation\nexample <segmentation_3d_pytorch.ipynb,>`__ \\* `the classification\nexample <classification_pytorch.ipynb,>`__\n"
  },
  {
    "path": "docs/getting_started.rst",
    "content": "Getting started\n===============\n\nBackends\n--------\n\nBefore installing ``delira``, you have to choose a suitable backend.\n``delira`` handles backends as optional dependencies and tries to escape all uses of a not-installed backend.\n\nThe currently supported backends are:\n\n* `torch <https://pytorch.org>`_ (recommended, since it is the most tested backend): Suffix ``torch``\n\n  .. note::\n    ``delira`` supports mixed-precision training via `apex <https://github.com/NVIDIA/apex>`_, but ``apex`` must be installed separately\n   \n* `torchscript <https://pytorch.org/docs/stable/jit.html>`_ : Suffix ``torchscript``\n\n  .. note::\n    ``delira`` with ``torchscript`` backend dies currently not support Multi-GPU training.\n    \n* `tensorflow eager execution <https://tensorflow.org>`_: Suffix ``tensorflow``\n\n  .. note::\n    ``delira`` with ``tensorflow eager`` backend dies currently not support Multi-GPU training.\n\n* `tensorflow graph mode <https://tensorflow.org>`_: Suffix ``tensorflow``\n\n  .. note::\n    ``delira`` with ``tensorflow graph`` backend dies currently not support Multi-GPU training.\n\n* `chainer <https://chainer.org>`_: Suffix ``chainer``\n\n* `scikit-learn <https://scikit-learn.org/stable/>`_: No Suffix\n\n* None: No Suffix\n\n* All (installs all registered backends and their dependencies; not recommended, since this will install many large packages): Suffix ``full``\n\n.. note::\n  Depending on the backend, some functionalities may not be available for you. If you want to ensure, you can use each functionality, please use the ``full`` option, since it installs all backends\n  \n.. note:: \n  If you want to add a backend like `CNTK <https://www.microsoft.com/en-us/cognitive-toolkit/>`_, `MXNET <https://mxnet.apache.org/>`_ or something similar, please open an issue for that and we will guide you during that process (don't worry, it is not much effort at all).\n\nInstallation\n------------\n\n=================== =================================== ================================================================================================= ======================================================================================================================\nBackend             Binary Installation                 Source Installation                                                                               Notes\n=================== =================================== ================================================================================================= ======================================================================================================================\nNone                ``pip install delira``              ``pip install git+https://github.com/delira-dev/delira.git``                                      Training not possible if backend is not installed separately\n`torch`_            ``pip install delira[torch]``       ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[torch]``       ``delira`` with ``torch`` backend supports mixed-precision training via `NVIDIA/apex`_ (must be installed separately).\n`torchscript`_      ``pip install delira[torchscript]`` ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[torchscript]`` The ``torchscript`` backend currently supports only single-GPU-training\n`tensorflow eager`_ ``pip install delira[tensorflow]``  ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[tensorflow]``  the ``tensorflow`` backend is still very experimental and lacks some `features`_\n`tensorflow graph`_ ``pip install delira[tensorflow]``  ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[tensorflow]``  the ``tensorflow`` backend is still very experimental and lacks some `features`_\n`scikit-learn`_     ``pip install delira``              ``pip install git+https://github.com/delira-dev/delira.git``                                      /\n`chainer`_          ``pip install delira[chainer]``     ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[chainer]``     /\nFull                ``pip install delira[full]``        ``git clone https://github.com/delira-dev/delira.git && cd delira && pip install .[full]``        All backends will be installed\n=================== =================================== ================================================================================================= ======================================================================================================================\n\n.. _torch: https://pytorch.org\n.. _NVIDIA/apex: https://github.com/NVIDIA/apex.git\n.. _torchscript: https://pytorch.org/docs/stable/jit.html\n.. _tensorflow eager: https://www.tensorflow.org/\n.. _features: https://github.com/delira-dev/delira/issues/47\n.. _tensorflow graph: https://www.tensorflow.org/\n.. _scikit-learn: https://scikit-learn.org/stable/\n.. _chainer: https://chainer.org/"
  },
  {
    "path": "docs/index.rst",
    "content": ".. delira documentation master file, created by\n   sphinx-quickstart on Sat Dec  1 20:56:35 2018.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\n=====================================================================\ndelira - A Backend Agnostic High Level Deep Learning Library\n=====================================================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Getting Started\n\n   getting_started\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Tutorials:\n\n   tutorial_delira\n   classification_pytorch\n   gan_pytorch\n   segmentation_2d_pytorch\n   segmentation_3d_pytorch\n   custom_backend\n\n\n.. toctree::\n   :maxdepth: 10\n   :titlesonly:\n   :caption: API Documentation:\n\n   _api/_build/modules\n\n   GitHub <https://github.com/delira-dev/delira>\n\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/requirements.txt",
    "content": "sphinx==1.8.4\nsphinx-rtd-theme\n"
  },
  {
    "path": "docs/segmentation_2d_pytorch.rst",
    "content": "\nSegmentation in 2D using U-Nets with Delira - A very short introduction\n=======================================================================\n\n*Author: Justus Schock, Alexander Moriz*\n\n*Date: 17.12.2018*\n\nThis Example shows how use the U-Net implementation in Delira with\nPyTorch.\n\nLet's first setup the essential hyperparameters. We will use\n``delira``'s ``Parameters``-class for this:\n\n.. code:: ipython3\n\n    logger = None\n    import torch\n    from delira.training import Parameters\n    params = Parameters(fixed_params={\n        \"model\": {\n            \"in_channels\": 1, \n            \"num_classes\": 4\n        },\n        \"training\": {\n            \"batch_size\": 64, # batchsize to use\n            \"num_epochs\": 10, # number of epochs to train\n            \"optimizer_cls\": torch.optim.Adam, # optimization algorithm to use\n            \"optimizer_params\": {'lr': 1e-3}, # initialization parameters for this algorithm\n            \"losses\": {\"CE\": torch.nn.CrossEntropyLoss()}, # the loss function\n            \"lr_sched_cls\": None,  # the learning rate scheduling algorithm to use\n            \"lr_sched_params\": {}, # the corresponding initialization parameters\n            \"metrics\": {} # and some evaluation metrics\n        }\n    }) \n\nSince we did not specify any metric, only the ``CrossEntropyLoss`` will\nbe calculated for each batch. Since we have a classification task, this\nshould be sufficient. We will train our network with a batchsize of 64\nby using ``Adam`` as optimizer of choice.\n\nLogging and Visualization\n-------------------------\n\nTo get a visualization of our results, we should monitor them somehow.\nFor logging we will use ``Visdom``. To start a visdom server you need to\nexecute the following command inside an environment which has visdom\ninstalled:\n\n.. code:: shell\n\n    visdom -port=9999\n\nThis will start a visdom server on port 9999 of your machine and now we\ncan start to configure our logging environment. To view your results you\ncan open http://localhost:9999 in your browser.\n\n.. code:: ipython3\n\n    from trixi.logger import PytorchVisdomLogger\n    from delira.logging import TrixiHandler\n    import logging\n    \n    logger_kwargs = {\n        'name': 'ClassificationExampleLogger', # name of our logging environment\n        'port': 9999 # port on which our visdom server is alive\n    }\n    \n    logger_cls = PytorchVisdomLogger\n    \n    # configure logging module (and root logger)\n    logging.basicConfig(level=logging.INFO,\n                        handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\n    \n    \n    # derive logger from root logger\n    # (don't do `logger = logging.Logger(\"...\")` since this will create a new\n    # logger which is unrelated to the root logger\n    logger = logging.getLogger(\"Test Logger\")\n    \n\nSince a single visdom server can run multiple environments, we need to\nspecify a (unique) name for our environment and need to tell the logger,\non which port it can find the visdom server.\n\nData Praparation\n----------------\n\nLoading\n~~~~~~~\n\nNext we will create a small train and validation set (in this case they\nwill be the same to show the overfitting capability of the UNet).\n\nOur data is a brain MR-image thankfully provided by the\n`FSL <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki>`__ in their\n`introduction <http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/IntroBox3.html>`__.\n\nWe first download the data and extract the T1 image and the\ncorresponding segmentation:\n\n.. code:: ipython3\n\n    from io import BytesIO\n    from zipfile import ZipFile\n    from urllib.request import urlopen\n    \n    resp = urlopen(\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\")\n    zipfile = ZipFile(BytesIO(resp.read()))\n    #zipfile_list = zipfile.namelist()\n    #print(zipfile_list)\n    img_file = zipfile.extract(\"ExBox3/T1_brain.nii.gz\")\n    mask_file = zipfile.extract(\"ExBox3/T1_brain_seg.nii.gz\")\n\nNow, we load the image and the mask (they are both 3D), convert them to\na 32-bit floating point numpy array and ensure, they have the same shape\n(i.e. that for each voxel in the image, there is a voxel in the mask):\n\n.. code:: ipython3\n\n    import SimpleITK as sitk\n    import numpy as np\n    \n    # load image and mask\n    img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\n    img = img.astype(np.float32)\n    mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\n    mask = mask.astype(np.float32)\n    \n    assert mask.shape == img.shape\n    print(img.shape)\n\nBy querying the unique values in the mask, we get the following:\n\n.. code:: ipython3\n\n    np.unique(mask)\n\nThis means, there are 4 classes (background and 3 types of tissue) in\nour sample.\n\nSince we want to do a 2D segmentation, we extract a single slice out of\nthe image and the mask (we choose slice 100 here) and plot it:\n\n.. code:: ipython3\n\n    import matplotlib.pyplot as plt\n    \n    # load single slice\n    img_slice = img[:, :, 100]\n    mask_slice = mask[:, :, 100]\n    \n    # plot slices\n    plt.figure(1, figsize=(15,10))\n    plt.subplot(121)\n    plt.imshow(img_slice, cmap=\"gray\")\n    plt.colorbar(fraction=0.046, pad=0.04)\n    plt.subplot(122)\n    plt.imshow(mask_slice, cmap=\"gray\")\n    plt.colorbar(fraction=0.046, pad=0.04)\n    plt.show()\n    \n\nTo load the data, we have to use a ``Dataset``. The following defines a\nvery simple dataset, accepting an image slice, a mask slice and the\nnumber of samples. It always returns the same sample until\n``num_samples`` samples have been returned.\n\n.. code:: ipython3\n\n    from delira.data_loading import AbstractDataset\n    \n    class CustomDataset(AbstractDataset):\n        def __init__(self, img, mask, num_samples=1000):\n            super().__init__(None, None, None, None)\n            self.data = {\"data\": img.reshape(1, *img.shape), \"label\": mask.reshape(1, *mask.shape)}\n            self.num_samples = num_samples\n            \n        def __getitem__(self, index):\n            return self.data\n        \n        def __len__(self):\n            return self.num_samples\n\nNow, we can finally instantiate our datasets:\n\n.. code:: ipython3\n\n    dataset_train = CustomDataset(img_slice, mask_slice, num_samples=10000)\n    dataset_val = CustomDataset(img_slice, mask_slice, num_samples=1)\n\nAugmentation\n~~~~~~~~~~~~\n\nFor Data-Augmentation we will apply a few transformations:\n\n.. code:: ipython3\n\n    from batchgenerators.transforms import RandomCropTransform, \\\n                                            ContrastAugmentationTransform, Compose\n    from batchgenerators.transforms.spatial_transforms import ResizeTransform\n    from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\n    \n    transforms = Compose([\n        RandomCropTransform(150, label_key=\"label\"), # Perform Random Crops of Size 150 x 150 pixels\n        ResizeTransform(224, label_key=\"label\"), # Resample these crops back to 224 x 224 pixels\n        ContrastAugmentationTransform(), # randomly adjust contrast\n        MeanStdNormalizationTransform(mean=[img_slice.mean()], std=[img_slice.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)\n\nWith these transformations we can now wrap our datasets into\ndatamanagers:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataManager, SequentialSampler, RandomSampler\n    \n    manager_train = DataManager(dataset_train, params.nested_get(\"batch_size\"),\n                                    transforms=transforms,\n                                    sampler_cls=RandomSampler,\n                                    n_process_augmentation=4)\n    \n    manager_val = DataManager(dataset_val, params.nested_get(\"batch_size\"),\n                                  transforms=transforms,\n                                  sampler_cls=SequentialSampler,\n                                  n_process_augmentation=4)\n\nTraining\n--------\n\nAfter we have done that, we can finally specify our experiment and run\nit. We will therfore use the already implemented ``UNet2dPytorch``:\n\n.. code:: ipython3\n\n    import warnings\n    warnings.simplefilter(\"ignore\", UserWarning) # ignore UserWarnings raised by dependency code\n    warnings.simplefilter(\"ignore\", FutureWarning) # ignore FutureWarnings raised by dependency code\n    \n    \n    from delira.training import PyTorchExperiment\n    from delira.training.train_utils import create_optims_default_pytorch\n    from delira.models.segmentation import UNet2dPyTorch\n    \n    if logger is not None:\n        logger.info(\"Init Experiment\")\n    experiment = PyTorchExperiment(params, UNet2dPyTorch,\n                                   name=\"Segmentation2dExample\",\n                                   save_path=\"./tmp/delira_Experiments\",\n                                   optim_builder=create_optims_default_pytorch,\n                                   gpu_ids=[0], mixed_precision=True)\n    experiment.save()\n    \n    model = experiment.run(manager_train, manager_val)\n\n\nSee Also\n--------\n\nFor a more detailed explanation have a look at \\* `the introduction\ntutorial <tutorial_delira.ipynb,>`__ \\* `the classification\nexample <classification_pytorch.ipynb,>`__ \\* `the 3d segmentation\nexample <segmentation_3d_pytorch.ipynb,>`__ \\* `the generative\nadversarial example <gan_pytorch.ipynb,>`__\n"
  },
  {
    "path": "docs/segmentation_3d_pytorch.rst",
    "content": "\nSegmentation in 3D using U-Nets with Delira - A very short introduction\n=======================================================================\n\n*Author: Justus Schock, Alexander Moriz*\n\n*Date: 17.12.2018*\n\nThis Example shows how use the U-Net implementation in Delira with\nPyTorch.\n\nLet's first setup the essential hyperparameters. We will use\n``delira``'s ``Parameters``-class for this:\n\n.. code:: ipython3\n\n    logger = None\n    import torch\n    from delira.training import Parameters\n    params = Parameters(fixed_params={\n        \"model\": {\n            \"in_channels\": 1, \n            \"num_classes\": 4\n        },\n        \"training\": {\n            \"batch_size\": 64, # batchsize to use\n            \"num_epochs\": 10, # number of epochs to train\n            \"optimizer_cls\": torch.optim.Adam, # optimization algorithm to use\n            \"optimizer_params\": {'lr': 1e-3}, # initialization parameters for this algorithm\n            \"losses\": {\"CE\": torch.nn.CrossEntropyLoss()}, # the loss function\n            \"lr_sched_cls\": None,  # the learning rate scheduling algorithm to use\n            \"lr_sched_params\": {}, # the corresponding initialization parameters\n            \"metrics\": {} # and some evaluation metrics\n        }\n    }) \n\nSince we did not specify any metric, only the ``CrossEntropyLoss`` will\nbe calculated for each batch. Since we have a classification task, this\nshould be sufficient. We will train our network with a batchsize of 64\nby using ``Adam`` as optimizer of choice.\n\nLogging and Visualization\n-------------------------\n\nTo get a visualization of our results, we should monitor them somehow.\nFor logging we will use ``Visdom``. To start a visdom server you need to\nexecute the following command inside an environment which has visdom\ninstalled:\n\n.. code:: shell\n\n    visdom -port=9999\n\nThis will start a visdom server on port 9999 of your machine and now we\ncan start to configure our logging environment. To view your results you\ncan open http://localhost:9999 in your browser.\n\n.. code:: ipython3\n\n    from trixi.logger import PytorchVisdomLogger\n    from delira.logging import TrixiHandler\n    import logging\n    \n    logger_kwargs = {\n        'name': 'ClassificationExampleLogger', # name of our logging environment\n        'port': 9999 # port on which our visdom server is alive\n    }\n    \n    logger_cls = PytorchVisdomLogger\n    \n    # configure logging module (and root logger)\n    logging.basicConfig(level=logging.INFO,\n                        handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\n    \n    \n    # derive logger from root logger\n    # (don't do `logger = logging.Logger(\"...\")` since this will create a new\n    # logger which is unrelated to the root logger\n    logger = logging.getLogger(\"Test Logger\")\n    \n\nSince a single visdom server can run multiple environments, we need to\nspecify a (unique) name for our environment and need to tell the logger,\non which port it can find the visdom server.\n\nData Praparation\n----------------\n\nLoading\n~~~~~~~\n\nNext we will create a small train and validation set (in this case they\nwill be the same to show the overfitting capability of the UNet).\n\nOur data is a brain MR-image thankfully provided by the\n`FSL <https://fsl.fmrib.ox.ac.uk/fsl/fslwiki>`__ in their\n`introduction <http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/IntroBox3.html>`__.\n\nWe first download the data and extract the T1 image and the\ncorresponding segmentation:\n\n.. code:: ipython3\n\n    from io import BytesIO\n    from zipfile import ZipFile\n    from urllib.request import urlopen\n    \n    resp = urlopen(\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\")\n    zipfile = ZipFile(BytesIO(resp.read()))\n    #zipfile_list = zipfile.namelist()\n    #print(zipfile_list)\n    img_file = zipfile.extract(\"ExBox3/T1_brain.nii.gz\")\n    mask_file = zipfile.extract(\"ExBox3/T1_brain_seg.nii.gz\")\n\nNow, we load the image and the mask (they are both 3D), convert them to\na 32-bit floating point numpy array and ensure, they have the same shape\n(i.e. that for each voxel in the image, there is a voxel in the mask):\n\n.. code:: ipython3\n\n    import SimpleITK as sitk\n    import numpy as np\n    \n    # load image and mask\n    img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\n    img = img.astype(np.float32)\n    mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\n    mask = mask.astype(np.float32)\n    \n    assert mask.shape == img.shape\n    print(img.shape)\n\nBy querying the unique values in the mask, we get the following:\n\n.. code:: ipython3\n\n    np.unique(mask)\n\nThis means, there are 4 classes (background and 3 types of tissue) in\nour sample.\n\nTo load the data, we have to use a ``Dataset``. The following defines a\nvery simple dataset, accepting an image slice, a mask slice and the\nnumber of samples. It always returns the same sample until\n``num_samples`` samples have been returned.\n\n.. code:: ipython3\n\n    from delira.data_loading import AbstractDataset\n    \n    class CustomDataset(AbstractDataset):\n        def __init__(self, img, mask, num_samples=1000):\n            super().__init__(None, None, None, None)\n            self.data = {\"data\": img.reshape(1, *img.shape), \"label\": mask.reshape(1, *mask.shape)}\n            self.num_samples = num_samples\n            \n        def __getitem__(self, index):\n            return self.data\n        \n        def __len__(self):\n            return self.num_samples\n\nNow, we can finally instantiate our datasets:\n\n.. code:: ipython3\n\n    dataset_train = CustomDataset(img, mask, num_samples=10000)\n    dataset_val = CustomDataset(img, mask, num_samples=1)\n\nAugmentation\n~~~~~~~~~~~~\n\nFor Data-Augmentation we will apply a few transformations:\n\n.. code:: ipython3\n\n    from batchgenerators.transforms import ContrastAugmentationTransform, Compose\n    from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\n    \n    transforms = Compose([\n        ContrastAugmentationTransform(), # randomly adjust contrast\n        MeanStdNormalizationTransform(mean=[img.mean()], std=[img.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)\n\nWith these transformations we can now wrap our datasets into\ndatamanagers:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataManager, SequentialSampler, RandomSampler\n    \n    manager_train = DataManager(dataset_train, params.nested_get(\"batch_size\"),\n                                    transforms=transforms,\n                                    sampler_cls=RandomSampler,\n                                    n_process_augmentation=4)\n    \n    manager_val = DataManager(dataset_val, params.nested_get(\"batch_size\"),\n                                  transforms=transforms,\n                                  sampler_cls=SequentialSampler,\n                                  n_process_augmentation=4)\n\nTraining\n--------\n\nAfter we have done that, we can finally specify our experiment and run\nit. We will therfore use the already implemented ``UNet3dPytorch``:\n\n.. code:: ipython3\n\n    import warnings\n    warnings.simplefilter(\"ignore\", UserWarning) # ignore UserWarnings raised by dependency code\n    warnings.simplefilter(\"ignore\", FutureWarning) # ignore FutureWarnings raised by dependency code\n    \n    \n    from delira.training import PyTorchExperiment\n    from delira.training.train_utils import create_optims_default_pytorch\n    from delira.models.segmentation import UNet3dPyTorch\n    \n    if logger:\n        logger.info(\"Init Experiment\")\n    experiment = PyTorchExperiment(params, UNet3dPyTorch,\n                                   name=\"Segmentation3dExample\",\n                                   save_path=\"./tmp/delira_Experiments\",\n                                   optim_builder=create_optims_default_pytorch,\n                                   gpu_ids=[0], mixed_precision=True)\n    experiment.save()\n    \n    model = experiment.run(manager_train, manager_val)\n\nSee Also\n--------\n\nFor a more detailed explanation have a look at \\* `the introduction\ntutorial <tutorial_delira.ipynb,>`__ \\* `the classification\nexample <classification_pytorch.ipynb,>`__ \\* `the 2d segmentation\nexample <segmentation_2d_pytorch.ipynb,>`__ \\* `the generative\nadversarial example <gan_pytorch.ipynb,>`__\n"
  },
  {
    "path": "docs/tutorial_delira.rst",
    "content": "\nDelira Introduction\n===================\n\n*Last updated: 09.05.2019*\n\nAuthors: Justus Schock, Christoph Haarburger\n\nLoading Data\n------------\n\nTo train your network you first need to load your training data (and\nprobably also your validation data). This chapter will therefore deal\nwith ``delira``'s capabilities to load your data (and apply some\naugmentation).\n\nThe Dataset\n~~~~~~~~~~~\n\nThere are mainly two ways to load your data: Lazy or non-lazy. Loading\nin a lazy way means that you load the data just in time and keep the\nused memory to a bare minimum. This has, however, the disadvantage that\nyour loading function could be a bottleneck since all postponed\noperations may have to wait until the needed data samples are loaded. In\na no-lazy way, one would preload all data to RAM before starting any\nother operations. This has the advantage that there cannot be a loading\nbottleneck during latter operations. This advantage comes at cost of a\nhigher memory usage and a (possibly) huge latency at the beginning of\neach experiment. Both ways to load your data are implemented in\n``delira`` and they are named ``BaseLazyDataset``\\ and\n``BaseCacheDataset``. In the following steps you will only see the\n``BaseLazyDataset`` since exchanging them is trivial. All Datasets\n(including the ones you might want to create yourself later) must be\nderived of ``delira.data_loading.AbstractDataset`` to ensure a minimum\ncommon API.\n\nThe dataset's ``__init__`` has the following signature:\n\n.. code:: python\n\n    def __init__(self, data_path, load_fn, **load_kwargs):\n\nThis means, you have to pass the path to the directory containing your\ndata (``data_path``), a function to load a single sample of your data\n(``load_fn``). To get a single sample of your dataset after creating it,\nyou can index it like this: ``dataset[0]``. Additionally you can iterate\nover your dataset just like over any other ``python`` iterator via\n\n.. code:: python\n\n    for sample in dataset:\n        # do your stuff here\n\nor enumerate it via\n\n.. code:: python\n\n    for idx, sample in enumerate(dataset):\n        # do your stuff here\n\n.\n\nThe missing argument ``**load_kwargs`` accepts an arbitrary amount of\nadditional keyword arguments which are directly passed to your loading\nfunction.\n\nAn example of how loading your data may look like is given below:\n\n.. code:: python\n\n    from delira.data_loading import BaseLazyDataset, default_load_fn_2d\n    dataset_train = BaseLazyDataset(\"/images/datasets/external/mnist/train\",\n                                    default_load_fn_2d, img_shape=(224, 224))\n\nIn this case all data lying in ``/images/datasets/external/mnist/train``\nis loaded by ``default_load_fn_2d``. The files containing the data must\nbe PNG-files, while the groundtruth is defined in TXT-files. The\n``default_load_fn_2d`` needs the additional argument ``img_shape`` which\nis passed as keyword argument via ``**load_kwargs``.\n\n    **Note:** for reproducability we decided to use some wrapped PyTorch\n    datasets for this introduction.\n\nNow, let's just initialize our trainset:\n\n.. code:: ipython3\n\n    from delira.data_loading import TorchvisionClassificationDataset\n    dataset_train = TorchvisionClassificationDataset(\"mnist\", train=True,\n                                                     img_shape=(224, 224))\n\nGetting a single sample of your dataset with dataset\\_train[0] will\nproduce:\n\n.. code:: ipython3\n\n    dataset_train[0]\n\nwhich means, that our data is stored in a dictionary containing the keys\n``data`` and ``label``, each of them holding the corresponding numpy\narrays. The dataloading works on ``numpy`` purely and is thus backend\nagnostic. It does not matter in which format or with which library you\nload/preprocess your data, but at the end it must be converted to numpy\narrays For validation purposes another dataset could be created with the\ntest data like this:\n\n.. code:: ipython3\n\n    dataset_val = TorchvisionClassificationDataset(\"mnist\", train=False,\n                                                   img_shape=(224, 224))\n\nThe Dataloader\n~~~~~~~~~~~~~~\n\nThe Dataloader wraps your dataset to privode the ability to load whole\nbatches with an abstract interface. To create a dataloader, one would\nhave to pass the following arguments to it's ``__init__``: the\npreviously created ``dataset``.Additionally, it is possible to pass the\n``batch_size`` defining the number of samples per batch, the total\nnumber of batches (``num_batches``), which will be the number of samples\nin your dataset devided by the batchsize per default, a random\n``seed``\\ for always getting the same behaviour of random number\ngenerators and a ```sampler`` <>`__ defining your sampling strategy.\nThis would create a dataloader for your ``dataset_train``:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataLoader\n    \n    batch_size = 32\n    \n    loader_train = DataLoader(dataset_train, batch_size)\n\nSince the batch\\_size has been set to 32, the loader will load 32\nsamples as one batch.\n\nEven though it would be possible to train your network with an instance\nof ``DataLoader``, ``malira`` also offers a different approach that\ncovers multithreaded data loading and augmentation:\n\nThe Datamanager\n~~~~~~~~~~~~~~~\n\nThe data manager is implemented as\n``delira.data_loading.DataManager`` and wraps a ``DataLoader``. It\nalso encapsulates augmentations. Having a view on the\n``DataManager``'s signature, it becomes obvious that it accepts the\nsame arguments as the ```DataLoader`` <#The-Dataloader>`__. You can\neither pass a ``dataset`` or a combination of path, dataset class and\nload function. Additionally, you can pass a custom dataloder class if\nnecessary and a sampler class to choose a sampling algorithm.\n\nThe parameter ``transforms`` accepts augmentation transformations as\nimplemented in ``batchgenerators``. Augmentation is applied on the fly\nusing ``n_process_augmentation`` threads.\n\nAll in all the DataManager is the recommended way to generate batches\nfrom your dataset.\n\nThe following example shows how to create a data manager instance:\n\n.. code:: ipython3\n\n    from delira.data_loading import DataManager\n    from batchgenerators.transforms.abstract_transforms import Compose\n    from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\n    \n    batchsize = 64\n    transforms = Compose([MeanStdNormalizationTransform(mean=1*[0], std=1*[1])])\n    \n    data_manager_train = DataManager(dataset_train,  # dataset to use\n                                        batchsize,  # batchsize\n                                        n_process_augmentation=1,  # number of augmentation processes\n                                        transforms=transforms)  # augmentation transforms\n    \n\nThe approach to initialize a DataManager from a datapath takes more\narguments since, in opposite to initializaton from dataset, it needs all\nthe arguments which are necessary to internally create a dataset.\n\nSince we want to validate our model we have to create a second manager\ncontaining our ``dataset_val``:\n\n.. code:: ipython3\n\n    data_manager_val = DataManager(dataset_val,\n                                        batchsize, \n                                        n_process_augmentation=1, \n                                        transforms=transforms)\n\nThat's it - we just finished loading our data!\n\nIterating over a DataManager is possible in simple loops:\n\n.. code:: ipython3\n\n    from tqdm.auto import tqdm # utility for progress bars\n    \n    # create actual batch generator from DataManager\n    batchgen = data_manager_val.get_batchgen()\n    \n    for data in tqdm(batchgen):\n        pass # here you can access the data of the current batch\n\nSampler\n~~~~~~~\n\nIn previous section samplers have been already mentioned but not yet\nexplained. A sampler implements an algorithm how a batch should be\nassembled from single samples in a dataset. ``delira`` provides the\nfollowing sampler classes in it's subpackage\n``delira.data_loading.sampler``:\n\n-  ``AbstractSampler``\n-  ``SequentialSampler``\n-  ``PrevalenceSequentialSampler``\n-  ``RandomSampler``\n-  ``PrevalenceRandomSampler``\n-  ``WeightedRandomSampler``\n-  ``LambdaSampler``\n\nThe ``AbstractSampler`` implements no sampling algorithm but defines a\nsampling API and thus all custom samplers must inherit from this class.\nThe ``Sequential`` sampler builds batches by just iterating over the\nsamples' indices in a sequential way. Following this, the\n``RandomSampler`` builds batches by randomly drawing the samples'\nindices with replacement. If the class each sample belongs to is known\nfor each sample at the beginning, the ``PrevalenceSequentialSampler``\nand the ``PrevalenceRandomSampler`` perform a per-class sequential or\nrandom sampling and building each batch with the exactly same number of\nsamples from each class. The ``WeightedRandomSampler``\\ accepts custom\nweights to give specific samples a higher probability during random\nsampling than others.\n\nThe ``LambdaSampler`` is a wrapper for a custom sampling function, which\ncan be passed to the wrapper during it's initialization, to ensure API\nconformity.\n\nIt can be passed to the DataLoader or DataManager as class (argument\n``sampler_cls``) or as instance (argument ``sampler``).\n\nModels\n------\n\nSince the purpose of this framework is to use machine learning\nalgorithms, there has to be a way to define them. Defining models is\nstraight forward. ``delira`` provides a class\n``delira.models.AbstractNetwork``. *All models must inherit from this\nclass*.\n\nTo inherit this class four functions must be implemented in the\nsubclass:\n\n-  ``__init__``\n-  ``closure``\n-  ``prepare_batch``\n-  ``__call__``\n\n``__init__``\n~~~~~~~~~~~~\n\nThe ``__init__``\\ function is a classes constructor. In our case it\nbuilds the entire model (maybe using some helper functions). If writing\nyour own custom model, you have to override this method.\n\n    **Note:** If you want the best experience for saving your model and\n    completely recreating it during the loading process you need to take\n    care of a few things: \\* if using ``torchvision.models`` to build\n    your model, always import it with\n    ``from torchvision import models as t_models`` \\* register all\n    arguments in your custom ``__init__`` in the abstract class. A\n    init\\_prototype could look like this:\n\n.. code:: python\n\n    def __init__(self, in_channels: int, n_outputs: int, **kwargs):\n        \"\"\"\n\n        Parameters\n        ----------\n        in_channels: int\n            number of input_channels\n        n_outputs: int\n            number of outputs (usually same as number of classes)\n        \"\"\"\n        # register params by passing them as kwargs to parent class __init__\n        # only params registered like this will be saved!\n        super().__init__(in_channels=in_channels,\n                         n_outputs=n_outputs,\n                         **kwargs)\n\n``closure``\n~~~~~~~~~~~\n\nThe ``closure``\\ function defines one batch iteration to train the\nnetwork. This function is needed for the framework to provide a generic\ntrainer function which works with all kind of networks and loss\nfunctions.\n\nThe closure function must implement all steps from forwarding, over loss\ncalculation, metric calculation, logging (for which\n``delira.logging_handlers`` provides some extensions for pythons logging\nmodule), and the actual backpropagation.\n\nIt is called with an empty optimizer-dict to evaluate and should thus\nwork with optional optimizers.\n\n``prepare_batch``\n~~~~~~~~~~~~~~~~~\n\nThe ``prepare_batch``\\ function defines the transformation from loaded\ndata to match the networks input and output shape and pushes everything\nto the right device.\n\nAbstract Networks for specific Backends\n---------------------------------------\n\nPyTorch\n~~~~~~~\n\nAt the time of writing, PyTorch is the only backend which is supported,\nbut other backends are planned. In PyTorch every network should be\nimplemented as a subclass of ``torch.nn.Module``, which also provides a\n``__call__`` method.\n\nThis results in sloghtly different requirements for PyTorch networks:\ninstead of implementing a ``__call__`` method, we simply call the\n``torch.nn.Module.__call__`` and therefore have to implement the\n``forward`` method, which defines the module's behaviour and is\ninternally called by ``torch.nn.Module.__call__`` (among other stuff).\nTo give a default behaviour suiting most cases and not have to care\nabout internals, ``delira`` provides the ``AbstractPyTorchNetwork``\nwhich is a more specific case of the ``AbstractNetwork`` for PyTorch\nmodules.\n\n``forward``\n^^^^^^^^^^^\n\nThe ``forward`` function defines what has to be done to forward your\ninput through your network and must return a dictionary. Assuming your\nnetwork has three convolutional layers stored in ``self.conv1``,\n``self.conv2`` and ``self.conv3`` and a ReLU stored in ``self.relu``, a\nsimple ``forward`` function could look like this:\n\n.. code:: python\n\n    def forward(self, input_batch: torch.Tensor):\n        out_1 = self.relu(self.conv1(input_batch))\n        out_2 = self.relu(self.conv2(out_1))\n        out_3 = self.conv3(out2)\n        \n        return {\"pred\": out_3}\n\n``prepare_batch``\n^^^^^^^^^^^^^^^^^\n\nThe default ``prepare_batch`` function for PyTorch networks looks like\nthis:\n\n.. code:: python\n\n        @staticmethod\n        def prepare_batch(batch: dict, input_device, output_device):\n            \"\"\"\n            Helper Function to prepare Network Inputs and Labels (convert them to\n            correct type and shape and push them to correct devices)\n\n            Parameters\n            ----------\n            batch : dict\n                dictionary containing all the data\n            input_device : torch.device\n                device for network inputs\n            output_device : torch.device\n                device for network outputs\n\n            Returns\n            -------\n            dict\n                dictionary containing data in correct type and shape and on correct\n                device\n\n            \"\"\"\n            return_dict = {\"data\": torch.from_numpy(batch.pop(\"data\")).to(\n                input_device)}\n\n            for key, vals in batch.items():\n                return_dict[key] = torch.from_numpy(vals).to(output_device)\n\n            return return_dict\n\nand can be customized by subclassing the ``AbstractPyTorchNetwork``.\n\n``closure example``\n^^^^^^^^^^^^^^^^^^^\n\nA simple closure function for a PyTorch module could look like this:\n\n.. code:: python\n\n        @staticmethod\n        def closure(model: AbstractPyTorchNetwork, data_dict: dict,\n                    optimizers: dict, criterions={}, metrics={},\n                    fold=0, **kwargs):\n            \"\"\"\n            closure method to do a single backpropagation step\n\n            Parameters\n            ----------\n            model : :class:`ClassificationNetworkBasePyTorch`\n                trainable model\n            data_dict : dict\n                dictionary containing the data\n            optimizers : dict\n                dictionary of optimizers to optimize model's parameters\n            criterions : dict\n                dict holding the criterions to calculate errors\n                (gradients from different criterions will be accumulated)\n            metrics : dict\n                dict holding the metrics to calculate\n            fold : int\n                Current Fold in Crossvalidation (default: 0)\n            **kwargs:\n                additional keyword arguments\n\n            Returns\n            -------\n            dict\n                Metric values (with same keys as input dict metrics)\n            dict\n                Loss values (with same keys as input dict criterions)\n            list\n                Arbitrary number of predictions as torch.Tensor\n\n            Raises\n            ------\n            AssertionError\n                if optimizers or criterions are empty or the optimizers are not\n                specified\n\n            \"\"\"\n\n            assert (optimizers and criterions) or not optimizers, \\\n                \"Criterion dict cannot be emtpy, if optimizers are passed\"\n\n            loss_vals = {}\n            metric_vals = {}\n            total_loss = 0\n\n            # choose suitable context manager:\n            if optimizers:\n                context_man = torch.enable_grad\n\n            else:\n                context_man = torch.no_grad\n\n            with context_man():\n\n                inputs = data_dict.pop(\"data\")\n                # obtain outputs from network\n                preds = model(inputs)[\"pred\"]\n\n                if data_dict:\n\n                    for key, crit_fn in criterions.items():\n                        _loss_val = crit_fn(preds, *data_dict.values())\n                        loss_vals[key] = _loss_val.detach()\n                        total_loss += _loss_val\n\n                    with torch.no_grad():\n                        for key, metric_fn in metrics.items():\n                            metric_vals[key] = metric_fn(\n                                preds, *data_dict.values())\n\n            if optimizers:\n                optimizers['default'].zero_grad()\n                total_loss.backward()\n                optimizers['default'].step()\n\n            else:\n\n                # add prefix \"val\" in validation mode\n                eval_loss_vals, eval_metrics_vals = {}, {}\n                for key in loss_vals.keys():\n                    eval_loss_vals[\"val_\" + str(key)] = loss_vals[key]\n\n                for key in metric_vals:\n                    eval_metrics_vals[\"val_\" + str(key)] = metric_vals[key]\n\n                loss_vals = eval_loss_vals\n                metric_vals = eval_metrics_vals\n\n            for key, val in {**metric_vals, **loss_vals}.items():\n                logging.info({\"value\": {\"value\": val.item(), \"name\": key,\n                                        \"env_appendix\": \"_%02d\" % fold\n                                        }})\n\n            logging.info({'image_grid': {\"images\": inputs, \"name\": \"input_images\",\n                                         \"env_appendix\": \"_%02d\" % fold}})\n\n            return metric_vals, loss_vals, preds\n\n    **Note:** This closure is taken from the\n    ``delira.models.classification.ClassificationNetworkBasePyTorch``\n\nOther examples\n~~~~~~~~~~~~~~\n\nIn ``delira.models`` you can find exemplaric implementations of\ngenerative adversarial networks, classification and regression\napproaches or segmentation networks.\n\nTraining\n--------\n\nParameters\n~~~~~~~~~~\n\nTraining-parameters (often called hyperparameters) can be defined in the\n``delira.training.Parameters`` class.\n\nThe class accepts the parameters ``batch_size`` and ``num_epochs`` to\ndefine the batchsize and the number of epochs to train, the parameters\n``optimizer_cls`` and ``optimizer_params`` to create an optimizer or\ntraining, the parameter ``criterions`` to specify the training\ncriterions (whose gradients will be accumulated by defaut), the\nparameters ``lr_sched_cls`` and ``lr_sched_params`` to define the\nlearning rate scheduling and the parameter ``metrics`` to specify\nevaluation metrics.\n\nAdditionally, it is possible to pass an aritrary number of keyword\narguments to the class\n\nIt is good practice to create a ``Parameters`` object at the beginning\nand then use it for creating other objects which are needed for\ntraining, since you can use the classes attributes and changes in\nhyperparameters only have to be done once:\n\n.. code:: ipython3\n\n    import torch\n    from delira.training import Parameters\n    from delira.data_loading import RandomSampler, SequentialSampler\n    \n    params = Parameters(fixed_params={\n        \"model\": {},\n        \"training\": {\n            \"batch_size\": 64, # batchsize to use\n            \"num_epochs\": 2, # number of epochs to train\n            \"optimizer_cls\": torch.optim.Adam, # optimization algorithm to use\n            \"optimizer_params\": {'lr': 1e-3}, # initialization parameters for this algorithm\n            \"criterions\": {\"CE\": torch.nn.CrossEntropyLoss()}, # the loss function\n            \"lr_sched_cls\": None,  # the learning rate scheduling algorithm to use\n            \"lr_sched_params\": {}, # the corresponding initialization parameters\n            \"metrics\": {} # and some evaluation metrics\n        }\n    }) \n    \n    # recreating the data managers with the batchsize of the params object\n    manager_train = DataManager(dataset_train, params.nested_get(\"batch_size\"), 1,\n                                    transforms=None, sampler_cls=RandomSampler,\n                                    n_process_loading=4)\n    manager_val = DataManager(dataset_val, params.nested_get(\"batch_size\"), 3,\n                                  transforms=None, sampler_cls=SequentialSampler,\n                                  n_process_loading=4)\n    \n\nTrainer\n~~~~~~~\n\nThe ``delira.training.NetworkTrainer`` class provides functions to train\na single network by passing attributes from your parameter object, a\n``save_freq`` to specify how often your model should be saved\n(``save_freq=1`` indicates every epoch, ``save_freq=2`` every second\nepoch etc.) and ``gpu_ids``. If you don't pass any ids at all, your\nnetwork will be trained on CPU (and probably take a lot of time). If you\nspecify 1 id, the network will be trained on the GPU with the\ncorresponding index and if you pass multiple ``gpu_ids`` your network\nwill be trained on multiple GPUs in parallel.\n\n    **Note:** The GPU indices are refering to the devices listed in\n    ``CUDA_VISIBLE_DEVICES``. E.g if ``CUDA_VISIBLE_DEVICES`` lists GPUs\n    3, 4, 5 then gpu\\_id 0 will be the index for GPU 3 etc.\n\n    **Note:** training on multiple GPUs is not recommended for easy and\n    small networks, since for these networks the synchronization\n    overhead is far greater than the parallelization benefit.\n\nTraining your network might look like this:\n\n.. code:: ipython3\n\n    from delira.training import PyTorchNetworkTrainer\n    from delira.models.classification import ClassificationNetworkBasePyTorch\n    \n    # path where checkpoints should be saved\n    save_path = \"./results/checkpoints\"\n    \n    model = ClassificationNetworkBasePyTorch(in_channels=1, n_outputs=10)\n    \n    trainer = PyTorchNetworkTrainer(network=model,\n                                    save_path=save_path,\n                                    criterions=params.nested_get(\"criterions\"),\n                                    optimizer_cls=params.nested_get(\"optimizer_cls\"),\n                                    optimizer_params=params.nested_get(\"optimizer_params\"),\n                                    metrics=params.nested_get(\"metrics\"),\n                                    lr_scheduler_cls=params.nested_get(\"lr_sched_cls\"),\n                                    lr_scheduler_params=params.nested_get(\"lr_sched_params\"),\n                                    gpu_ids=[0]\n                            )\n    \n    #trainer.train(params.nested_get(\"num_epochs\"), manager_train, manager_val)\n    \n\nExperiment\n~~~~~~~~~~\n\nThe ``delira.training.AbstractExperiment`` class needs an experiment\nname, a path to save it's results to, a parameter object, a model class\nand the keyword arguments to create an instance of this class. It\nprovides methods to perform a single training and also a method for\nrunning a kfold-cross validation. In order to create it, you must choose\nthe ``PyTorchExperiment``, which is basically just a subclass of the\n``AbstractExperiment`` to provide a general setup for PyTorch modules.\nRunning an experiment could look like this:\n\n.. code:: ipython3\n\n    from delira.training import PyTorchExperiment\n    from delira.training.train_utils import create_optims_default_pytorch\n    \n    # Add model parameters to Parameter class\n    params.fixed.model = {\"in_channels\": 1, \"n_outputs\": 10}\n    \n    experiment = PyTorchExperiment(params=params, \n                                   model_cls=ClassificationNetworkBasePyTorch,\n                                   name=\"TestExperiment\", \n                                   save_path=\"./results\",\n                                   optim_builder=create_optims_default_pytorch,\n                                   gpu_ids=[0])\n    \n    experiment.run(manager_train, manager_val)\n\nAn ``Experiment`` is the most abstract (and recommended) way to define,\ntrain and validate your network.\n\nLogging\n-------\n\nPrevious class and function definitions used pythons's ``logging``\nlibrary. As extensions for this library ``delira`` provides a package\n(``delira.logging``) containing handlers to realize different logging\nmethods.\n\nTo use these handlers simply add them to your logger like this:\n\n.. code:: python\n\n    logger.addHandler(logging.StreamHandler())\n\nNowadays, delira mainly relies on\n`trixi <https://github.com/MIC-DKFZ/trixi/>`__ for logging and provides\nonly a ``MultiStreamHandler`` and a ``TrixiHandler``, which is a binding\nto ``trixi``'s loggers and integrates them into the python ``logging``\nmodule\n\n``MultiStreamHandler``\n~~~~~~~~~~~~~~~~~~~~~~\n\nThe ``MultiStreamHandler`` accepts an arbitrary number of streams during\ninitialization and writes the message to all of it's streams during\nlogging.\n\nLogging with ``Visdom`` - The ``trixi`` Loggers\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n\n```Visdom`` <https://github.com/facebookresearch/visdom>`__ is a tool\ndesigned to visualize your logs. To use this tool you need to open a\nport on the machine you want to train on via\n``visdom -port YOUR_PORTNUMBER`` Afterwards just add the handler of your\nchoice to the logger. For more detailed information and customization\nhave a look at `this <https://github.com/facebookresearch/visdom>`__\nwebsite.\n\nLogging the scalar tensors containing ``1``, ``2``, ``3``, ``4`` (at the\nbeginning; will increase to show epochwise logging) with the\ncorresponding keys ``\"one\"``, ``\"two\"``, ``\"three\"``, ``\"four\"`` and two\nrandom images with the keys ``\"prediction\"`` and ``\"groundtruth\"`` would\nlook like this:\n\n.. code:: ipython3\n\n    NUM_ITERS = 4\n    \n    # import logging handler and logging module\n    from delira.logging import TrixiHandler\n    from trixi.logger import PytorchVisdomLogger\n    import logging\n    \n    # configure logging module (and root logger)\n    logger_kwargs = {\n        'name': 'test_env', # name of loggin environment\n        'port': 9999 # visdom port to connect to\n    }\n    logger_cls = PytorchVisdomLogger\n    \n    # configure logging module (and root logger)\n    logging.basicConfig(level=logging.INFO,\n                        handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\n    # derive logger from root logger\n    # (don't do `logger = logging.Logger(\"...\")` since this will create a new\n    # logger which is unrelated to the root logger\n    logger = logging.getLogger(\"Test Logger\")\n    \n    # create dict containing the scalar numbers as torch.Tensor\n    scalars = {\"one\": torch.Tensor([1]),\n               \"two\": torch.Tensor([2]),\n               \"three\": torch.Tensor([3]),\n               \"four\": torch.Tensor([4])}\n    \n    # create dict containing the images as torch.Tensor\n    # pytorch awaits tensor dimensionality of \n    # batchsize x image channels x height x width\n    images = {\"prediction\": torch.rand(1, 3, 224, 224),\n              \"groundtruth\": torch.rand(1, 3, 224, 224)}\n    \n    # Simulate 4 Epochs\n    for i in range(4*NUM_ITERS): \n        logger.info({\"image_grid\": {\"images\": images[\"prediction\"], \"name\": \"predictions\"}})\n        \n        for key, val_tensor in scalars.items():\n            logger.info({\"value\": {\"value\": val_tensor.item(), \"name\": key}})\n            scalars[key] += 1\n\nMore Examples\n-------------\n\nMore Examples can be found in \\* `the classification\nexample <classification_pytorch.ipynb,>`__ \\* `the 2d segmentation\nexample <segmentation_2d_pytorch.ipynb,>`__ \\* `the 3d segmentation\nexample <segmentation_3d_pytorch.ipynb,>`__ \\* `the generative\nadversarial example <gan_pytorch.ipynb,>`__\n"
  },
  {
    "path": "notebooks/classification_examples/chainer.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and Chainer - A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 31.07.2019*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification model and experiment using Chainer.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\pywt\\\\_utils.py:6: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from collections import Iterable\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\google\\\\protobuf\\\\descriptor.py:47: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from google.protobuf.pyext import _message\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\util\\\\nest.py:1286: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  _pywrap_tensorflow.RegisterType(\\\"Mapping\\\", _collections.Mapping)\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\util\\\\nest.py:1287: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  _pywrap_tensorflow.RegisterType(\\\"Sequence\\\", _collections.Sequence)\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:516: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:517: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:518: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:519: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:520: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:525: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\training\\\\tracking\\\\object_identity.py:61: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  class ObjectIdentityDictionary(collections.MutableMapping):\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\training\\\\tracking\\\\object_identity.py:112: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  class ObjectIdentitySet(collections.MutableSet):\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:541: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:542: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:543: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:544: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:545: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:550: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"WARNING: Logging before flag parsing goes to stderr.\\n\",\n            \"W0731 14:01:15.852783 27416 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_eager\\\\abstract_network.py:113: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\\n\",\n            \"\\n\",\n            \"W0731 14:01:15.869738 27416 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_graph\\\\abstract_network.py:20: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\\n\",\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"import chainer\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {\\n\",\n        \"        \\\"in_channels\\\": 1, \\n\",\n        \"        \\\"n_outputs\\\": 10\\n\",\n        \"    },\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": chainer.optimizers.Adam, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {\\u0027lr\\u0027: 1e-3}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {\\\"L1\\\": chainer.functions.mean_absolute_error}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 32), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a smaller version of a [VGG-Network](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.models import AbstractChainerNetwork\\n\",\n        \"import chainer\\n\",\n        \"from functools import partial\\n\",\n        \"    \\n\",\n        \"    \\n\",\n        \"class SmallVGGChainer(AbstractChainerNetwork):\\n\",\n        \"    def __init__(self, in_channels, num_classes):\\n\",\n        \"        super().__init__()\\n\",\n        \"        \\n\",\n        \"        self.model \\u003d chainer.Sequential(\\n\",\n        \"            chainer.links.Convolution2d(in_channels, 64, 3, padding\\u003d1), # 28 x 28\\n\",\n        \"            chainer.functions.relu,\\n\",\n        \"            partial(chainer.functions.max_pooling_2d, ksize\\u003d2), # 14 x 14\\n\",\n        \"            chainer.links.Convolution2d(64, 128, 3, padding\\u003d1),\\n\",\n        \"            chainer.functions.relu,\\n\",\n        \"            partial(chainer.functions.max_pooling_2d, ksize\\u003d2), # 7 x 7\\n\",\n        \"            chainer.links.Convolution2d(128, 256, 3), # 6 x 6\\n\",\n        \"            chainer.functions.relu,\\n\",\n        \"            partial(chainer.functions.max_pooling_2d, ksize\\u003d2), # 3 x 3\\n\",\n        \"            chainer.links.Convolution2d(256, 512, 3), # 1 x 1\\n\",\n        \"            chainer.functions.flatten,\\n\",\n        \"            chainer.links.Linear(1*1*512, num_classes)\\n\",\n        \"        )\\n\",\n        \"        \\n\",\n        \"    def forward(self, x):\\n\",\n        \"        return {\\\"pred\\\": self.model(x)}\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def prepare_batch(data_dict, input_device, output_device):\\n\",\n        \"        new_batch \\u003d {k: chainer.as_variable(v.astype(np.float32))\\n\",\n        \"                     for k, v in batch.items()}\\n\",\n        \"\\n\",\n        \"        for k, v in new_batch.items():\\n\",\n        \"            if k \\u003d\\u003d \\\"data\\\":\\n\",\n        \"                device \\u003d input_device\\n\",\n        \"            else:\\n\",\n        \"                device \\u003d output_device\\n\",\n        \"\\n\",\n        \"            # makes modification inplace!\\n\",\n        \"            v.to_device(device)\\n\",\n        \"\\n\",\n        \"        return new_batch\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\\n\",\n        \"                fold\\u003d0, **kwargs):\\n\",\n        \"\\n\",\n        \"        loss_vals \\u003d {}\\n\",\n        \"        metric_vals \\u003d {}\\n\",\n        \"        total_loss \\u003d 0\\n\",\n        \"\\n\",\n        \"        inputs \\u003d data_dict[\\\"data\\\"]\\n\",\n        \"        preds \\u003d model(inputs)\\n\",\n        \"\\n\",\n        \"        with chainer.using_config(\\\"train\\\", True):\\n\",\n        \"            for key, crit_fn in losses.items():\\n\",\n        \"                _loss_val \\u003d crit_fn(preds[\\\"pred\\\"], data_dict[\\\"label\\\"])\\n\",\n        \"                loss_vals[key] \\u003d _loss_val.item()\\n\",\n        \"                total_loss +\\u003d _loss_val\\n\",\n        \"\\n\",\n        \"        model.cleargrads()\\n\",\n        \"        total_loss.backward()\\n\",\n        \"        optimizers[\\u0027default\\u0027].update()\\n\",\n        \"        \\n\",\n        \"        return loss_vals, {k: v.unchain()\\n\",\n        \"                           for k, v in preds.items()}\\n\",\n        \"\\n\",\n        \"    \\n\",\n        \"    \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"So let\\u0027s evisit, what we have just done.\\n\",\n        \"\\n\",\n        \"In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `Chainer` Backend this class is `AbstractChainerNetwork` and all Chainer Networks should be derived from it.\\n\",\n        \"\\n\",\n        \"First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `forward` method).\\n\",\n        \"\\n\",\n        \"So far this was plain `Chainer`. The `prepare_batch` function is not plain Chainer anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractChainerNetwork` and just re-implemented here for the sake of completeness.\\n\",\n        \"\\n\",\n        \"Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"from delira.training import ChainerExperiment\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d PyTorchExperiment(params, SmallVGGChainer,\\n\",\n        \"                               name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                               save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                               key_mapping\\u003d{\\\"x\\\": \\\"data\\\"}\\n\",\n        \"                               gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"\\n\",\n        \"device \\u003d \\\"@numpy\\\"\\n\",\n        \"model \\u003d model.to(device) # push model to device\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with torch.no_grad():\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d torch.from_numpy(img).unsqueeze(0).to(device).to(torch.float) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"        \\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/classification_examples/pytorch.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and PyTorch - A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 31.07.2019*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification model and experiment using PyTorch.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"import torch\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {\\n\",\n        \"        \\\"in_channels\\\": 1, \\n\",\n        \"        \\\"n_outputs\\\": 10\\n\",\n        \"    },\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {\\u0027lr\\u0027: 1e-3}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {\\\"CE\\\": torch.nn.CrossEntropyLoss()}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 32), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a smaller version of a [VGG-Network](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.models import AbstractPyTorchNetwork\\n\",\n        \"import torch\\n\",\n        \"\\n\",\n        \"class Flatten(torch.nn.Module):\\n\",\n        \"        \\n\",\n        \"    def forward(self, x):\\n\",\n        \"        return x.view(x.size(0), -1)\\n\",\n        \"\\n\",\n        \"class SmallVGGPyTorch(AbstractPyTorchNetwork):\\n\",\n        \"    def __init__(self, in_channels, num_classes):\\n\",\n        \"        super().__init__()\\n\",\n        \"        \\n\",\n        \"        self.model \\u003d torch.nn.Sequential(\\n\",\n        \"            torch.nn.Conv2d(in_channels, 64, 3, padding\\u003d1), # 32 x 32\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 16 x 16\\n\",\n        \"            torch.nn.Conv2d(64, 128, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 8 x 8\\n\",\n        \"            torch.nn.Conv2d(128, 256, 3, padding\\u003d1), # 4 x 4\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 4 x 4\\n\",\n        \"            torch.nn.Conv2d(256, 512, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(), # 2 x 2\\n\",\n        \"            torch.nn.Conv2d(512, 512, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(), # 1 x 1\\n\",\n        \"            Flatten(),\\n\",\n        \"            torch.nn.Linear(1*1*512, num_classes),\\n\",\n        \"        )\\n\",\n        \"        \\n\",\n        \"    def forward(self, x: torch.Tensor):\\n\",\n        \"        return {\\\"pred\\\": self.model(x)}\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def prepare_batch(data_dict, input_device, output_device):\\n\",\n        \"        return_dict \\u003d {\\\"data\\\": torch.from_numpy(batch[\\\"data\\\"]).to(\\n\",\n        \"            input_device).to(torch.float)}\\n\",\n        \"\\n\",\n        \"        for key, vals in batch.items():\\n\",\n        \"            if key \\u003d\\u003d \\\"data\\\": \\n\",\n        \"                continue\\n\",\n        \"            return_dict[key] \\u003d torch.from_numpy(vals).to(output_device).to(\\n\",\n        \"                torch.float)\\n\",\n        \"\\n\",\n        \"        return return_dict\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\\n\",\n        \"                fold\\u003d0, **kwargs):\\n\",\n        \"\\n\",\n        \"        loss_vals \\u003d {}\\n\",\n        \"        total_loss \\u003d 0\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"        # predict\\n\",\n        \"        inputs \\u003d data_dict.pop(\\\"data\\\")\\n\",\n        \"        preds \\u003d model(inputs)\\n\",\n        \"\\n\",\n        \"        # calculate losses\\n\",\n        \"        for key, crit_fn in losses.items():\\n\",\n        \"            _loss_val \\u003d crit_fn(preds[\\\"pred\\\"], data_dict[\\\"label\\\"])\\n\",\n        \"            loss_vals[key] \\u003d _loss_val.item()\\n\",\n        \"            total_loss +\\u003d _loss_val\\n\",\n        \"\\n\",\n        \"        optimizers[\\u0027default\\u0027].zero_grad()\\n\",\n        \"        # perform loss scaling via apex if half precision is enabled\\n\",\n        \"        with scale_loss(total_loss, optimizers[\\\"default\\\"]) as scaled_loss:\\n\",\n        \"            scaled_loss.backward()\\n\",\n        \"        optimizers[\\u0027default\\u0027].step()\\n\",\n        \"\\n\",\n        \"        return loss_vals, {k: v.detach()\\n\",\n        \"                                for k, v in preds.items()}\\n\",\n        \"    \\n\",\n        \"    \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"So let\\u0027s evisit, what we have just done.\\n\",\n        \"\\n\",\n        \"In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `PyTorch` Backend this class is `AbstractPyTorchNetwork` and all PyTorch Networks should be derived from it.\\n\",\n        \"\\n\",\n        \"First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `forward` method).\\n\",\n        \"\\n\",\n        \"So far this was plain `PyTorch`. The `prepare_batch` function is not plain PyTorch anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractPyTorchNetwork` and just re-implemented here for the sake of completeness.\\n\",\n        \"\\n\",\n        \"Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"from delira.training import PyTorchExperiment\\n\",\n        \"from delira.training.train_utils import create_optims_default_pytorch\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d PyTorchExperiment(params, SmallVGGPyTorch,\\n\",\n        \"                               name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                               save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                               optim_builder\\u003dcreate_optims_default_pytorch,\\n\",\n        \"                               key_mapping\\u003d{\\\"x\\\": \\\"data\\\"}\\n\",\n        \"                               gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"\\n\",\n        \"device \\u003d torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\") # set device (use GPU if available)\\n\",\n        \"model \\u003d model.to(device) # push model to device\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with torch.no_grad():\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d torch.from_numpy(img).unsqueeze(0).to(device).to(torch.float) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"        \\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/classification_examples/sklearn.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and SciKit-Learn - A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 31.07.2019*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification model and experiment using SciKit-Learn.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\pywt\\\\_utils.py:6: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from collections import Iterable\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\google\\\\protobuf\\\\descriptor.py:47: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from google.protobuf.pyext import _message\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\util\\\\nest.py:1286: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  _pywrap_tensorflow.RegisterType(\\\"Mapping\\\", _collections.Mapping)\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\util\\\\nest.py:1287: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  _pywrap_tensorflow.RegisterType(\\\"Sequence\\\", _collections.Sequence)\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:516: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:517: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:518: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:519: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:520: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:525: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\training\\\\tracking\\\\object_identity.py:61: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  class ObjectIdentityDictionary(collections.MutableMapping):\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\training\\\\tracking\\\\object_identity.py:112: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  class ObjectIdentitySet(collections.MutableSet):\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"import sklearn\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {},\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": None, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {\\\"mae\\\": mean_absolute_error} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 32), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a very simple MultiLayer Perceptron here. \\n\",\n        \"In opposite to other backends, we don\\u0027t need to provide a custom implementation of our model, but we can simply use it as-is. It will be automatically wrapped by `SklearnEstimator`, which can be subclassed for more advanced usage.\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"from sklearn.neural_network import MLPClassifier\\n\",\n        \"\\n\",\n        \"from delira.training import SklearnExperiment\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d PyTorchExperiment(params, MLPClassifier,\\n\",\n        \"                               name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                               save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                               key_mapping\\u003d{\\\"X\\\": \\\"X\\\"}\\n\",\n        \"                               gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with torch.no_grad():\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d img.astype(np.float) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"        \\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/classification_examples/tf_eager.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and TensorFlow Eager Execution- A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 31.07.2019*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification model and experiment using TensorFlow\\u0027s Eager Execution Mode.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:516: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:517: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:518: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:519: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:520: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:525: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:541: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:542: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:543: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:544: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:545: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:550: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\pywt\\\\_utils.py:6: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from collections import Iterable\\n\",\n            \"WARNING: Logging before flag parsing goes to stderr.\\n\",\n            \"W0731 13:38:30.713174 21496 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_eager\\\\abstract_network.py:113: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\\n\",\n            \"\\n\",\n            \"W0731 13:38:30.727135 21496 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_graph\\\\abstract_network.py:20: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\\n\",\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"import tensorflow as tf\\n\",\n        \"tf.enable_eager_execution()\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {\\n\",\n        \"        \\\"in_channels\\\": 1, \\n\",\n        \"        \\\"n_outputs\\\": 10\\n\",\n        \"    },\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": tf.train.AdamOptimizer, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {\\u0027lr\\u0027: 1e-3}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {\\\"L1\\\": tf.losses.absolute_difference}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `L1-Loss` will be calculated for each batch. Since this is just a toy example, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 2,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"ename\": \"ModuleNotFoundError\",\n          \"evalue\": \"No module named \\u0027deliravision\\u0027\",\n          \"output_type\": \"error\",\n          \"traceback\": [\n            \"\\u001b[1;31m---------------------------------------------------------------------------\\u001b[0m\",\n            \"\\u001b[1;31mModuleNotFoundError\\u001b[0m                       Traceback (most recent call last)\",\n            \"\\u001b[1;32m\\u003cipython-input-2-c638229a3dc2\\u003e\\u001b[0m in \\u001b[0;36m\\u003cmodule\\u003e\\u001b[1;34m\\u001b[0m\\n\\u001b[1;32m----\\u003e 1\\u001b[1;33m \\u001b[1;32mfrom\\u001b[0m \\u001b[0mdeliravision\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mdata\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mfakedata\\u001b[0m \\u001b[1;32mimport\\u001b[0m \\u001b[0mClassificationFakeData\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0m\\u001b[0;32m      2\\u001b[0m dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\\u001b[0;32m      3\\u001b[0m                                        \\u001b[0mimg_size\\u001b[0m\\u001b[1;33m\\u003d\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[1;36m3\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[1;36m224\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[1;36m224\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m,\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m      4\\u001b[0m                                        num_classes\\u003d10)\\n\\u001b[0;32m      5\\u001b[0m dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n            \"\\u001b[1;31mModuleNotFoundError\\u001b[0m: No module named \\u0027deliravision\\u0027\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 43), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a smaller version of a [VGG-Network](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.models import AbstractTfEagerNetwork\\n\",\n        \"import tensorflow as tf\\n\",\n        \"import numpy as np\\n\",\n        \"\\n\",\n        \"class SmallVGGTfEager(AbstractTfEagerNetwork):\\n\",\n        \"    def __init__(self, in_channels, num_classes, data_format\\u003d\\\"channels_last\\\"):\\n\",\n        \"        if data_format \\u003d\\u003d \\\"channels_last\\\":\\n\",\n        \"            input_shape \\u003d (32, 32, 3)\\n\",\n        \"        else:\\n\",\n        \"            input_shape \\u003d (3, 32, 32)\\n\",\n        \"        super().__init__(data_format\\u003ddata_format)\\n\",\n        \"        \\n\",\n        \"        self.model \\u003d tf.keras.models.Sequential(\\n\",\n        \"            tf.keras.layers.Conv2d(in_channels, 64, 3, padding\\u003d1, input_shape\\u003dinput_shape), # 32, 32\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 16 x 16\\n\",\n        \"            tf.keras.layers.Conv2d(128, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 8 x 8\\n\",\n        \"            tf.keras.layers.Conv2d(256, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 4 x 4\\n\",\n        \"            tf.keras.layers.Conv2d(512, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(), # 2 x 2\\n\",\n        \"            tf.keras.layers.Conv2d(512, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(), # 1 x 1\\n\",\n        \"            tf.keras.layers.Flatten(),\\n\",\n        \"            tf.keras.layers.Dense(num_classes),\\n\",\n        \"        )\\n\",\n        \"        \\n\",\n        \"    def call(self, x: tf.Tensor):\\n\",\n        \"        return {\\\"pred\\\": self.model(x)}\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def prepare_batch(data_dict, input_device, output_device):\\n\",\n        \"        with tf.device(input_device):\\n\",\n        \"            return_dict \\u003d {\\\"data\\\": tf.convert.to.tensor(\\n\",\n        \"                batch[\\\"data\\\"].astype(np.float32))}\\n\",\n        \"        \\n\",\n        \"        with tf.device(output_device):\\n\",\n        \"            for key, vals in batch.items():\\n\",\n        \"                if key \\u003d\\u003d \\\"data\\\": \\n\",\n        \"                    continue\\n\",\n        \"                return_dict[key] \\u003d tf.convert_to_tensor(\\n\",\n        \"                    vals.astype(np.float32))\\n\",\n        \"\\n\",\n        \"        return return_dict\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\\n\",\n        \"                fold\\u003d0, **kwargs):\\n\",\n        \"\\n\",\n        \"        loss_vals \\u003d {}\\n\",\n        \"        total_loss \\u003d 0\\n\",\n        \"\\n\",\n        \"        # calculate loss with graph created by gradient taping\\n\",\n        \"        with tf.GradientTape() as tape:\\n\",\n        \"            preds \\u003d model(data_dict[\\\"data\\\"])\\n\",\n        \"            total_loss \\u003d None\\n\",\n        \"            for k, loss_fn in losses.items():\\n\",\n        \"                _loss_val \\u003d loss_fn(preds[\\\"pred\\\"],\\n\",\n        \"                                    data_dict[\\\"label\\\"])\\n\",\n        \"                loss_vals[k] \\u003d _loss_val.numpy()\\n\",\n        \"                if total_loss is None:\\n\",\n        \"                    total_loss \\u003d _loss_val\\n\",\n        \"                else:\\n\",\n        \"                    total_loss +\\u003d _loss_val\\n\",\n        \"                    \\n\",\n        \"        return loss_vals, preds\\n\",\n        \"    \\n\",\n        \"    \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"So let\\u0027s evisit, what we have just done.\\n\",\n        \"\\n\",\n        \"In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `Tensorflow Eager` Backend this class is `AbstractTfEagerNetwork` and all TensorFlow Eager Execution Networks should be derived from it.\\n\",\n        \"\\n\",\n        \"First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `call` method).\\n\",\n        \"\\n\",\n        \"So far this was plain `TensorFlow`. The `prepare_batch` function is not plain TF anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractTfEagerNetwork` and just re-implemented here for the sake of completeness.\\n\",\n        \"\\n\",\n        \"Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"from delira.training import TfEagerExperiment\\n\",\n        \"from delira.training.train_utils import create_tf_eager_optims_default\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d PyTorchExperiment(params, SmallVGGTfEager,\\n\",\n        \"                               name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                               save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                               optim_builder\\u003dcreate_tf_eager_optims_default,\\n\",\n        \"                               key_mapping\\u003d{\\\"x\\\": \\\"data\\\"}\\n\",\n        \"                               gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"import tensorflow as tf\\n\",\n        \"\\n\",\n        \"device \\u003d \\\"/cpu:0\\\"\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with tf.device(device):\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d tf.convert_to_tensor(img[None, ...].astype(np.float)) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"\\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/classification_examples/tf_graph.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and TensorFlow Graph Execution- A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 31.07.2019*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification model and experiment using TensorFlow\\u0027s Graph Execution Mode.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 1,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"name\": \"stderr\",\n          \"output_type\": \"stream\",\n          \"text\": [\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:516: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:517: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:518: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:519: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:520: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorflow\\\\python\\\\framework\\\\dtypes.py:525: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:541: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint8 \\u003d np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:542: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint8 \\u003d np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:543: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint16 \\u003d np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:544: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_quint16 \\u003d np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:545: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  _np_qint32 \\u003d np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\tensorboard\\\\compat\\\\tensorflow_stub\\\\dtypes.py:550: FutureWarning: Passing (type, 1) or \\u00271type\\u0027 as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / \\u0027(1,)type\\u0027.\\n\",\n            \"  np_resource \\u003d np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\",\n            \"c:\\\\users\\\\jsc7rng\\\\appdata\\\\local\\\\conda\\\\conda\\\\envs\\\\delira-dev\\\\lib\\\\site-packages\\\\pywt\\\\_utils.py:6: DeprecationWarning: Using or importing the ABCs from \\u0027collections\\u0027 instead of from \\u0027collections.abc\\u0027 is deprecated, and in 3.8 it will stop working\\n\",\n            \"  from collections import Iterable\\n\",\n            \"WARNING: Logging before flag parsing goes to stderr.\\n\",\n            \"W0731 13:38:30.713174 21496 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_eager\\\\abstract_network.py:113: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\\n\",\n            \"\\n\",\n            \"W0731 13:38:30.727135 21496 deprecation_wrapper.py:119] From c:\\\\users\\\\jsc7rng\\\\downloads\\\\delira\\\\delira\\\\models\\\\backends\\\\tf_graph\\\\abstract_network.py:20: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.\\n\",\n            \"\\n\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"import tensorflow as tf\\n\",\n        \"tf.disable_eager_execution()\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {\\n\",\n        \"        \\\"in_channels\\\": 1, \\n\",\n        \"        \\\"n_outputs\\\": 10\\n\",\n        \"    },\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": tf.train.AdamOptimizer, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {\\u0027lr\\u0027: 1e-3}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {\\\"L1\\\": tf.losses.absolute_difference}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `L1-Loss` will be calculated for each batch. Since this is just a toy example, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 2,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [\n        {\n          \"ename\": \"ModuleNotFoundError\",\n          \"evalue\": \"No module named \\u0027deliravision\\u0027\",\n          \"output_type\": \"error\",\n          \"traceback\": [\n            \"\\u001b[1;31m---------------------------------------------------------------------------\\u001b[0m\",\n            \"\\u001b[1;31mModuleNotFoundError\\u001b[0m                       Traceback (most recent call last)\",\n            \"\\u001b[1;32m\\u003cipython-input-2-c638229a3dc2\\u003e\\u001b[0m in \\u001b[0;36m\\u003cmodule\\u003e\\u001b[1;34m\\u001b[0m\\n\\u001b[1;32m----\\u003e 1\\u001b[1;33m \\u001b[1;32mfrom\\u001b[0m \\u001b[0mdeliravision\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mdata\\u001b[0m\\u001b[1;33m.\\u001b[0m\\u001b[0mfakedata\\u001b[0m \\u001b[1;32mimport\\u001b[0m \\u001b[0mClassificationFakeData\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0m\\u001b[0;32m      2\\u001b[0m dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\\u001b[0;32m      3\\u001b[0m                                        \\u001b[0mimg_size\\u001b[0m\\u001b[1;33m\\u003d\\u001b[0m\\u001b[1;33m(\\u001b[0m\\u001b[1;36m3\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[1;36m224\\u001b[0m\\u001b[1;33m,\\u001b[0m \\u001b[1;36m224\\u001b[0m\\u001b[1;33m)\\u001b[0m\\u001b[1;33m,\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[1;33m\\u001b[0m\\u001b[0m\\n\\u001b[0;32m      4\\u001b[0m                                        num_classes\\u003d10)\\n\\u001b[0;32m      5\\u001b[0m dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n            \"\\u001b[1;31mModuleNotFoundError\\u001b[0m: No module named \\u0027deliravision\\u0027\"\n          ]\n        }\n      ],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 32), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a smaller version of a [VGG-Network](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 3,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.models import AbstractTfGraphNetwork\\n\",\n        \"import tensorflow as tf\\n\",\n        \"import numpy as np\\n\",\n        \"\\n\",\n        \"class SmallVGGTfEager(AbstractTfGraphNetwork):\\n\",\n        \"    def __init__(self, in_channels, num_classes, data_format\\u003d\\\"channels_last\\\"):\\n\",\n        \"        if data_format \\u003d\\u003d \\\"channels_last\\\":\\n\",\n        \"            input_shape \\u003d (32, 32, 3)\\n\",\n        \"        else:\\n\",\n        \"            input_shape \\u003d (3, 32, 32)\\n\",\n        \"        super().__init__()\\n\",\n        \"        \\n\",\n        \"        self.model \\u003d tf.keras.models.Sequential(\\n\",\n        \"            tf.keras.layers.Conv2d(in_channels, 64, 3, padding\\u003d1, input_shape\\u003dinput_shape), # 32, 32\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 16 x 16\\n\",\n        \"            tf.keras.layers.Conv2d(128, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 8 x 8\\n\",\n        \"            tf.keras.layers.Conv2d(256, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(2), # 4 x 4\\n\",\n        \"            tf.keras.layers.Conv2d(512, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(), # 2 x 2\\n\",\n        \"            tf.keras.layers.Conv2d(512, 3, padding\\u003d1),\\n\",\n        \"            tf.keras.layers.ReLU(),\\n\",\n        \"            tf.keras.layers.MaxPool2d(), # 1 x 1\\n\",\n        \"            tf.keras.layers.Flatten(),\\n\",\n        \"            tf.keras.layers.Dense(num_classes),\\n\",\n        \"        )\\n\",\n        \"        \\n\",\n        \"        # create computation graph\\n\",\n        \"        data \\u003d tf.placeholder(shape\\u003d[None, 32], dtype\\u003dtf.float32)\\n\",\n        \"        labels \\u003d tf.placeholder_with_default(\\n\",\n        \"                tf.zeros([tf.shape(data)[0], 1]), shape\\u003d[None, 1])\\n\",\n        \"\\n\",\n        \"        preds_train \\u003d self.model(data)\\n\",\n        \"        preds_eval \\u003d self.model(data)\\n\",\n        \"\\n\",\n        \"        self.inputs[\\\"data\\\"] \\u003d data\\n\",\n        \"        self.inputs[\\\"label\\\"] \\u003d labels\\n\",\n        \"        self.outputs_train[\\\"pred\\\"] \\u003d preds_train\\n\",\n        \"        self.outputs_eval[\\\"pred\\\"] \\u003d preds_eval\\n\",\n        \"        \\n\",\n        \"    @staticmethod\\n\",\n        \"    def prepare_batch(data_dict, input_device, output_device):\\n\",\n        \"        with tf.device(input_device):\\n\",\n        \"            return_dict \\u003d {\\\"data\\\": tf.convert.to.tensor(\\n\",\n        \"                batch[\\\"data\\\"].astype(np.float32))}\\n\",\n        \"        \\n\",\n        \"        with tf.device(output_device):\\n\",\n        \"            for key, vals in batch.items():\\n\",\n        \"                if key \\u003d\\u003d \\\"data\\\": \\n\",\n        \"                    continue\\n\",\n        \"                return_dict[key] \\u003d tf.convert_to_tensor(\\n\",\n        \"                    vals.astype(np.float32))\\n\",\n        \"\\n\",\n        \"        return return_dict\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\\n\",\n        \"                fold\\u003d0, **kwargs):\\n\",\n        \"\\n\",\n        \"        outputs \\u003d model.run(data\\u003dinputs, label\\u003ddata_dict[\\u0027label\\u0027])\\n\",\n        \"        preds \\u003d outputs[\\u0027pred\\u0027]\\n\",\n        \"        loss_vals \\u003d outputs[\\u0027losses\\u0027]\\n\",\n        \"        \\n\",\n        \"        return loss_vals, preds\\n\",\n        \"    \\n\",\n        \"    \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"So let\\u0027s evisit, what we have just done.\\n\",\n        \"\\n\",\n        \"In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `Tensorflow Graph` Backend this class is `AbstractTfGraphNetwork` and all TensorFlow Eager Execution Networks should be derived from it.\\n\",\n        \"\\n\",\n        \"First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `call` method).\\n\",\n        \"\\n\",\n        \"So far this was plain `TensorFlow`. The `prepare_batch` function is not plain TF anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractTfGraphNetwork` and just re-implemented here for the sake of completeness.\\n\",\n        \"\\n\",\n        \"Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"from delira.training import TfGraphExperiment\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d TfGraphExperiment(params, SmallVGGTfGraph,\\n\",\n        \"                               name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                               save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                               key_mapping\\u003d{\\\"x\\\": \\\"data\\\"}\\n\",\n        \"                               gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid (for now, this is done manually, but we also have a `Predictor` class to automate stuff like this):\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"import tensorflow as tf\\n\",\n        \"\\n\",\n        \"device \\u003d \\\"/cpu:0\\\"\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with tf.device(device):\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d tf.convert_to_tensor(img[None, ...].astype(np.float)) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"\\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/classification_examples/torchscript.ipynb",
    "content": "{\n  \"cells\": [\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"# Classification with Delira and TorchScript - A very short introduction\\n\",\n        \"*Author: Justus Schock* \\n\",\n        \"\\n\",\n        \"*Date: 04.12.2018*\\n\",\n        \"\\n\",\n        \"This Example shows how to set up a basic classification `TorchScript` model and experiment.\\n\",\n        \"`TorchScript` is basically `PyTorch` with a static computation graph. Thus, we require only minor changes compared to the `PyTorch`-example. These changes will be highlighted.\\n\",\n        \"\\n\",\n        \"Let\\u0027s first setup the essential hyperparameters. We will use `delira`\\u0027s `Parameters`-class for this:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"logger \\u003d None\\n\",\n        \"import torch\\n\",\n        \"from delira.training import Parameters\\n\",\n        \"params \\u003d Parameters(fixed_params\\u003d{\\n\",\n        \"    \\\"model\\\": {\\n\",\n        \"        \\\"in_channels\\\": 1, \\n\",\n        \"        \\\"n_outputs\\\": 10\\n\",\n        \"    },\\n\",\n        \"    \\\"training\\\": {\\n\",\n        \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n        \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n        \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n        \"        \\\"optimizer_params\\\": {\\u0027lr\\u0027: 1e-3}, # initialization parameters for this algorithm\\n\",\n        \"        \\\"losses\\\": {\\\"CE\\\": torch.nn.CrossEntropyLoss()}, # the loss function\\n\",\n        \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n        \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n        \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n        \"    }\\n\",\n        \"}) \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n        \"\\n\",\n        \"## Logging and Visualization\\n\",\n        \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Tensorboard`. Per default the logging directory will be the same as our experiment directory.\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"\\n\",\n        \"## Data Preparation\\n\",\n        \"### Loading\\n\",\n        \"Next we will create some fake data. For this we use the `ClassificationFakeData`-Dataset, which is already implemented in `deliravision`. To avoid getting the exact same data from both datasets, we use a random offset.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from deliravision.data.fakedata import ClassificationFakeData\\n\",\n        \"dataset_train \\u003d ClassificationFakeData(num_samples\\u003d10000, \\n\",\n        \"                                       img_size\\u003d(3, 32, 32), \\n\",\n        \"                                       num_classes\\u003d10)\\n\",\n        \"dataset_val \\u003d ClassificationFakeData(num_samples\\u003d1000, \\n\",\n        \"                                     img_size\\u003d(3, 32, 32), \\n\",\n        \"                                     num_classes\\u003d10,\\n\",\n        \"                                     rng_offset\\u003d10001\\n\",\n        \"                                     )\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"### Augmentation\\n\",\n        \"For Data-Augmentation we will apply a few transformations:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n                                        ContrastAugmentationTransform, Compose\\nfrom batchgenerators.transforms.spatial_transforms import ResizeTransform\\nfrom batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\\ntransforms \\u003d Compose([\\n    RandomCropTransform(24), # Perform Random Crops of Size 24 x 24 pixels\\n    ResizeTransform(32), # Resample these crops back to 32 x 32 pixels\\n    ContrastAugmentationTransform(), # randomly adjust contrast\\n    MeanStdNormalizationTransform(mean\\u003d[0.5], std\\u003d[0.5])]) \\n\\n\"\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"With these transformations we can now wrap our datasets into datamanagers:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": false\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n        \"\\n\",\n        \"manager_train \\u003d DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                                transforms\\u003dtransforms,\\n\",\n        \"                                sampler_cls\\u003dRandomSampler,\\n\",\n        \"                                n_process_augmentation\\u003d4)\\n\",\n        \"\\n\",\n        \"manager_val \\u003d DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n        \"                              transforms\\u003dtransforms,\\n\",\n        \"                              sampler_cls\\u003dSequentialSampler,\\n\",\n        \"                              n_process_augmentation\\u003d4)\\n\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"## Model\\n\",\n        \"\\n\",\n        \"After we have done that, we can specify our model: We will use a smaller version of a [VGG11](https://arxiv.org/pdf/1409.1556.pdf) in this case. We will use more convolutions to reduce the feature dimensionality and reduce the number of units in the linear layers to save up memory (and we only have to deal with 10 classes, not the 1000 imagenet classes).\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": 2,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"from delira.models import AbstractTorchScriptNetwork\\n\",\n        \"import torch\\n\",\n        \"\\n\",\n        \"class Flatten(torch.nn.Module):\\n\",\n        \"        \\n\",\n        \"    def forward(self, x):\\n\",\n        \"        return x.view(x.size(0), -1)\\n\",\n        \"\\n\",\n        \"class VGG11TorchScript(AbstractTorchScriptNetwork):\\n\",\n        \"    def __init__(self, in_channels, num_classes):\\n\",\n        \"        super().__init__()\\n\",\n        \"        \\n\",\n        \"        self.model \\u003d torch.nn.Sequential(\\n\",\n        \"            torch.nn.Conv2d(in_channels, 64, 3, padding\\u003d1), # 32 x 32\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 16 x 16\\n\",\n        \"            torch.nn.Conv2d(64, 128, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 8 x 8\\n\",\n        \"            torch.nn.Conv2d(128, 256, 3, padding\\u003d1), # 4 x 4\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(2), # 4 x 4\\n\",\n        \"            torch.nn.Conv2d(256, 512, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(), # 2 x 2\\n\",\n        \"            torch.nn.Conv2d(512, 512, 3, padding\\u003d1),\\n\",\n        \"            torch.nn.ReLU(),\\n\",\n        \"            torch.nn.MaxPool2d(), # 1 x 1\\n\",\n        \"            Flatten(),\\n\",\n        \"            torch.nn.Linear(1*1*512, num_classes),\\n\",\n        \"        )\\n\",\n        \"        \\n\",\n        \"    @torch.jit.script_method    \\n\",\n        \"    def forward(self, x: torch.Tensor):\\n\",\n        \"        return {\\\"pred\\\": self.model(x)}\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def prepare_batch(data_dict, input_device, output_device):\\n\",\n        \"        return_dict \\u003d {\\\"data\\\": torch.from_numpy(batch[\\\"data\\\"]).to(\\n\",\n        \"            input_device).to(torch.float)}\\n\",\n        \"\\n\",\n        \"        for key, vals in batch.items():\\n\",\n        \"            if key \\u003d\\u003d \\\"data\\\": \\n\",\n        \"                continue\\n\",\n        \"            return_dict[key] \\u003d torch.from_numpy(vals).to(output_device).to(\\n\",\n        \"                torch.float)\\n\",\n        \"\\n\",\n        \"        return return_dict\\n\",\n        \"    \\n\",\n        \"    @staticmethod\\n\",\n        \"    def closure(model, data_dict: dict, optimizers: dict, losses: dict,\\n\",\n        \"                fold\\u003d0, **kwargs):\\n\",\n        \"\\n\",\n        \"        loss_vals \\u003d {}\\n\",\n        \"        total_loss \\u003d 0\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"        # predict\\n\",\n        \"        inputs \\u003d data_dict[\\\"data\\\"]\\n\",\n        \"        preds \\u003d model(inputs)\\n\",\n        \"\\n\",\n        \"        # calculate losses\\n\",\n        \"        for key, crit_fn in losses.items():\\n\",\n        \"            _loss_val \\u003d crit_fn(preds[\\\"pred\\\"], data_dict[\\\"label\\\"])\\n\",\n        \"            loss_vals[key] \\u003d _loss_val.item()\\n\",\n        \"            total_loss +\\u003d _loss_val\\n\",\n        \"\\n\",\n        \"        optimizers[\\u0027default\\u0027].zero_grad()\\n\",\n        \"        total_loss.backward()\\n\",\n        \"        optimizers[\\u0027default\\u0027].step()\\n\",\n        \"\\n\",\n        \"        return loss_vals, {k: v.detach()\\n\",\n        \"                                for k, v in preds.items()}\\n\",\n        \"    \\n\",\n        \"    \"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"So let\\u0027s evisit, what we have just done.\\n\",\n        \"\\n\",\n        \"In `delira` all networks must be derived from `delira.models.AbstractNetwork`. For each backend there is a class derived from this class, handling some backend-specific function calls and registrations. For the `TorchScript` Backend this class is `AbstractTorchScriptNetwork` and all TorchScript Networks should be derived from it.\\n\",\n        \"\\n\",\n        \"\\u003e **Note:** This is different from `PyTorch`, where the base class has to be `AbstractPyTorchNetwork`\\n\",\n        \"\\n\",\n        \"First we defined the network itself (this is the part simply concatenating the layers into a sequential model). Next, we defined the logic to apply, when we want to predict from the model (this is the `forward` method).\\n\",\n        \"\\n\",\n        \"\\u003e **Note:** In `TorchScript` all methods adding options to the computation graph must be decorated with `torch.jit.script_method`. See [here](https://pytorch.org/docs/stable/jit.html#creating-torchscript-code) for more details\\n\",\n        \"\\n\",\n        \"So far this was plain `TorchScript`. The `prepare_batch` function is not plain TorchScript anymore, but allows us to ensure the data is in the correct shape, has the correct data-type and lies on the correct device. The function above is the standard `prepare_batch` function, which is also implemented in the `AbstractTorchScriptNetwork` and just re-implemented here for the sake of completeness.\\n\",\n        \"\\n\",\n        \"Same goes for the `closure` function. This function defines the update rule for our parameters (and how to calculate the losses). These funcitons are good to go for many simple networks but can be overwritten for customization when training more complex networks.\\n\",\n        \"\\n\",\n        \"## Training\\n\",\n        \"Now that we have defined our network, we can finally specify our experiment and run it.\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {\n          \"is_executing\": true\n        }\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import warnings\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n        \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n        \"\\n\",\n        \"\\n\",\n        \"from delira.training import TorchScriptExperiment\\n\",\n        \"from delira.training.train_utils import create_optims_default_pytorch\\n\",\n        \"\\n\",\n        \"if logger is not None:\\n\",\n        \"    logger.info(\\\"Init Experiment\\\")\\n\",\n        \"experiment \\u003d TorchScriptExperiment(params, SmallTorchScript,\\n\",\n        \"                                   name\\u003d\\\"ClassificationExample\\\",\\n\",\n        \"                                   save_path\\u003d\\\"./tmp/delira_Experiments\\\",\\n\",\n        \"                                   optim_builder\\u003dcreate_optims_default_pytorch,\\n\",\n        \"                                   key_mapping\\u003d{\\\"x\\\": \\\"data\\\"}\\n\",\n        \"                                   gpu_ids\\u003d[0])\\n\",\n        \"experiment.save()\\n\",\n        \"\\n\",\n        \"model \\u003d experiment.run(manager_train, manager_val)\"\n      ]\n    },\n    {\n      \"cell_type\": \"markdown\",\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"source\": [\n        \"Congratulations, you have now trained your first Classification Model using `delira`, we will now predict a few samples from the testset to show, that the networks predictions are valid:\"\n      ]\n    },\n    {\n      \"cell_type\": \"code\",\n      \"execution_count\": null,\n      \"metadata\": {\n        \"pycharm\": {}\n      },\n      \"outputs\": [],\n      \"source\": [\n        \"import numpy as np\\n\",\n        \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n        \"\\n\",\n        \"device \\u003d torch.device(\\\"cuda\\\" if torch.cuda.is_available() else \\\"cpu\\\") # set device (use GPU if available)\\n\",\n        \"model \\u003d model.to(device) # push model to device\\n\",\n        \"preds, labels \\u003d [], []\\n\",\n        \"\\n\",\n        \"with torch.no_grad():\\n\",\n        \"    for i in tqdm(range(len(dataset_val))):\\n\",\n        \"        img \\u003d dataset_val[i][\\\"data\\\"] # get image from current batch\\n\",\n        \"        img_tensor \\u003d torch.from_numpy(img).unsqueeze(0).to(device).to(torch.float) # create a tensor from image, push it to device and add batch dimension\\n\",\n        \"        pred_tensor \\u003d model(img_tensor) # feed it through the network\\n\",\n        \"        pred \\u003d pred_tensor.argmax(1).item() # get index with maximum class confidence\\n\",\n        \"        label \\u003d np.asscalar(dataset_val[i][\\\"label\\\"]) # get label from batch\\n\",\n        \"        if i % 1000 \\u003d\\u003d 0:\\n\",\n        \"            print(\\\"Prediction: %d \\\\t label: %d\\\" % (pred, label)) # print result\\n\",\n        \"        preds.append(pred)\\n\",\n        \"        labels.append(label)\\n\",\n        \"        \\n\",\n        \"# calculate accuracy\\n\",\n        \"accuracy \\u003d (np.asarray(preds) \\u003d\\u003d np.asarray(labels)).sum() / len(preds)\\n\",\n        \"print(\\\"Accuracy: %.3f\\\" % accuracy)\"\n      ]\n    }\n  ],\n  \"metadata\": {\n    \"kernelspec\": {\n      \"display_name\": \"Python 3\",\n      \"language\": \"python\",\n      \"name\": \"python3\"\n    },\n    \"language_info\": {\n      \"codemirror_mode\": {\n        \"name\": \"ipython\",\n        \"version\": 3\n      },\n      \"file_extension\": \".py\",\n      \"mimetype\": \"text/x-python\",\n      \"name\": \"python\",\n      \"nbconvert_exporter\": \"python\",\n      \"pygments_lexer\": \"ipython3\",\n      \"version\": \"3.7.3\"\n    }\n  },\n  \"nbformat\": 4,\n  \"nbformat_minor\": 2\n}"
  },
  {
    "path": "notebooks/custom_backend.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"# How To: Integrate your own Computation Backend\\n\",\n    \"\\n\",\n    \"*Author: Justus Schock*\\n\",\n    \"\\n\",\n    \"*Date: 15.05.2019*\\n\",\n    \"\\n\",\n    \"This howto will take you on a trip through the `delira` internals, while we will see, how to add a custom computation backend on the examplaric case of the `torch.jit` or `TorchScript` backend\\n\",\n    \"\\n\",\n    \"## Model Definitions\\n\",\n    \"In order to implement a network, we will first have to define the network itself. In `delira` there is a single backend-specific implementation of an abstract network class for each of the backends. These interface classes are all based on the `AbstractNetwork`-class, defining the major API.\\n\",\n    \"\\n\",\n    \"So let's start having a look at this class to see, what we will have to implement for our own backend.\\n\",\n    \"\\n\",\n    \"Of course we will have to implement an `__init__` defining our class. The `__init__` of `AbstractNetwork` (which should be called during our the `__init__` of our baseclass) accepts a number of kwargs and simply registers them to be `init_kwargs`, so there is nothing we have to take care of.\\n\",\n    \"\\n\",\n    \"The next function to inspect is the `__call__` function, which makes the class callable and the docstrings indicate, that it should take care of our model's forward-pass.\\n\",\n    \"\\n\",\n    \"After the `__call__` we now have the `closure` function, which defines a single training step (including, but not limited to, forward-pass, calculation of losses and train-metrics, backward-pass and optimization).\\n\",\n    \"\\n\",\n    \"The last method to implement is the `prepare_batch` function which converts the input to a suitable format and the correct data-type and device.\\n\",\n    \"\\n\",\n    \"### TorchScript Limitations\\n\",\n    \"Since we want to implement an abstract network class for this specific backend, we should have a look on how to generally implement models in this backend.\\n\",\n    \"\\n\",\n    \"According the the [PyTorch docs](https://pytorch.org/docs/stable/jit.html) this works as follows:\\n\",\n    \"\\n\",\n    \"> You can write TorchScript code directly using Python syntax. You do this using the `torch.jit.script` decorator (for functions) or `torch.jit.script_method` decorator (for methods) on subclasses of `ScriptModule`. With this decorator the body of the annotated function is directly translated into TorchScript. TorchScript itself is a subset of the Python language, so not all features in Python work, but we provide enough functionality to compute on tensors and do control-dependent operations.\\n\",\n    \"\\n\",\n    \"Since our use-case is to implement the interface class for networks, we want to use the way of subclassing `torch.jit.ScriptModule`, implement it's `forward` and use the `torch.jit.script_method` decorator on it.\\n\",\n    \"\\n\",\n    \"The example given in the very same docs for this case is:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"tensor([0.4997, 0.2955, 0.1588, 0.1873, 0.4753], grad_fn=<MvBackward>)\"\n      ]\n     },\n     \"execution_count\": 1,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import torch\\n\",\n    \"class MyScriptModule(torch.jit.ScriptModule):\\n\",\n    \"    def __init__(self, N, M):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.weight = torch.nn.Parameter(torch.rand(N, M))\\n\",\n    \"\\n\",\n    \"    @torch.jit.script_method\\n\",\n    \"    def forward(self, input):\\n\",\n    \"        return self.weight.mv(input)\\n\",\n    \"    \\n\",\n    \"my_script_module = MyScriptModule(5, 3)\\n\",\n    \"input_tensor = torch.rand(3)\\n\",\n    \"my_script_module(input_tensor)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"### Merging TorchScript into our Abstract Class\\n\",\n    \"\\n\",\n    \"This little example gives us a few things, we have to do for a successful definition of our base class:\\n\",\n    \"\\n\",\n    \"**1.)** Our class has to subclass both, the `AbstractNetwork` and the `torch.jit.ScriptModule` classes.\\n\",\n    \"\\n\",\n    \"**2.)** We need to implement a `forward` method, which takes care of the forward-pass (as it's name indicates).\\n\",\n    \"\\n\",\n    \"**3.)** We don't have to take care of the backward-pass (thanks to `PyTorch`'s and `TorchScript`'s AutoGrad (which is a framework for automatic differentiation).\\n\",\n    \"\\n\",\n    \"**4.)** Since `torch.jit.ScriptModule` is callable (seen in the example), it already implements a `__call__` method and we may simply use this one.\\n\",\n    \"\\n\",\n    \"**5.)** The `closure` is completely network-dependent and thus has to remain an abstract method here.\\n\",\n    \"\\n\",\n    \"**6.)** The `prepare_batch` function also depends on the combination of network, inputs and loss functions to use, but we can at least give a prototype of such an function, which handles the devices correctly and converts everything to `float`\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### Actual Implementation\\n\",\n    \"\\n\",\n    \"Now, let's start with the actual implementation and do one function by another and keep the things in mind, we just discovered.\\n\",\n    \"\\n\",\n    \"#### Class Signature and `__init__`-Method\\n\",\n    \"To subclass both networks, we cannot use the simple `super().__init__` approach, because we have to init both parent classes, so we do \\n\",\n    \"\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):\\n\",\n    \"\\n\",\n    \"        @abc.abstractmethod\\n\",\n    \"        def __init__(self, optimize=True, **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            optimize : bool\\n\",\n    \"                whether to optimize the network graph or not; default: True\\n\",\n    \"            **kwargs :\\n\",\n    \"                additional keyword arguments (passed to :class:`AbstractNetwork`)\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            torch.jit.ScriptModule.__init__(self, optimize=optimize)\\n\",\n    \"            AbstractNetwork.__init__(self, **kwargs)\\n\",\n    \"            \\n\",\n    \"```\\n\",\n    \"instead. This ensures all parent classes to be initialized correctly.\\n\",\n    \"\\n\",\n    \"#### `__call__`-Method\\n\",\n    \"As mentioned above, the `__call__` method is very easy to implement, because we can simply use the implementation of our `TorchScript` base class like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def __call__(self, *args, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Calls Forward method\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        *args :\\n\",\n    \"            positional arguments (passed to `forward`)\\n\",\n    \"        **kwargs :\\n\",\n    \"            keyword arguments (passed to `forward`)\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        Any\\n\",\n    \"            result: module results of arbitrary type and number\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\\n\",\n    \"        \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"This also ensures, that we can pass an arbitrary number or positional and keyword arguments of arbitrary types to it (which are all passed to the `forward`-function). The advantage over directly calling the `forward` method here, is that the `ScriptModule.__call__` already does the handling of [forward-pre-hooks](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_forward_pre_hook), [forward-hooks](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_forward_hook) and [backward-hooks](https://pytorch.org/docs/stable/nn.html#torch.nn.Module.register_backward_hook).\\n\",\n    \"\\n\",\n    \"#### `closure`-Method\\n\",\n    \"Since this method is highly model-dependant, we just don't implement it, which forces the user to implement it (since it is marked as an `abstractmethod` in `AbstractExperiment`).\\n\",\n    \"\\n\",\n    \"#### `prepare_batch`-Method\\n\",\n    \"The above mentioned prototype of pushing everything to the correct device and convert it to float looks like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    @staticmethod\\n\",\n    \"    def prepare_batch(batch: dict, input_device, output_device):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Helper Function to prepare Network Inputs and Labels (convert them to\\n\",\n    \"        correct type and shape and push them to correct devices)\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        batch : dict\\n\",\n    \"            dictionary containing all the data\\n\",\n    \"        input_device : torch.device\\n\",\n    \"            device for network inputs\\n\",\n    \"        output_device : torch.device\\n\",\n    \"            device for network outputs\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        dict\\n\",\n    \"            dictionary containing data in correct type and shape and on correct\\n\",\n    \"            device\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        return_dict = {\\\"data\\\": torch.from_numpy(batch.pop(\\\"data\\\")).to(\\n\",\n    \"            input_device).to(torch.float)}\\n\",\n    \"\\n\",\n    \"        for key, vals in batch.items():\\n\",\n    \"            return_dict[key] = torch.from_numpy(vals).to(output_device).to(\\n\",\n    \"                torch.float)\\n\",\n    \"\\n\",\n    \"        return return_dict\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"Since we don't want to use any of the model's attributes here (and for conformity with the `AbstractNetwork` class), this method is defined as `staticmethod`, meaning it is class-bound, not instance-bound. The `closure` method has to be a `staticmethod` too.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"#### `forward`-Method\\n\",\n    \"The only thing left now, is the `forward` method, which is internally called by `ScriptModule.__call__`. The bad news is: We currently can't implement it. Subclassing a `ScriptModule` to overwrite a function decorated with `torch.jit.script_method` is not (yet) supported, but will be soon, once [this PR](https://github.com/pytorch/pytorch/pull/20503) is merged and released.\\n\",\n    \"\\n\",\n    \"For now: you simply have to implement this method in your own network despite the missing of an abstract interface-method.\\n\",\n    \"\\n\",\n    \"#### Putting it all together\\n\",\n    \"If we combine all the function implementations to one class, it looks like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    class AbstractTorchScriptNetwork(AbstractNetwork, torch.jit.ScriptModule):\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Abstract Interface Class for TorchScript Networks. For more information\\n\",\n    \"        have a look at https://pytorch.org/docs/stable/jit.html#torchscript\\n\",\n    \"\\n\",\n    \"        Warnings\\n\",\n    \"        --------\\n\",\n    \"        In addition to the here defined API, a forward function must be\\n\",\n    \"        implemented and decorated with ``@torch.jit.script_method``\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        @abc.abstractmethod\\n\",\n    \"        def __init__(self, optimize=True, **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            optimize : bool\\n\",\n    \"                whether to optimize the network graph or not; default: True\\n\",\n    \"            **kwargs :\\n\",\n    \"                additional keyword arguments (passed to :class:`AbstractNetwork`)\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            torch.jit.ScriptModule.__init__(self, optimize=optimize)\\n\",\n    \"            AbstractNetwork.__init__(self, **kwargs)\\n\",\n    \"\\n\",\n    \"        def __call__(self, *args, **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            Calls Forward method\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            *args :\\n\",\n    \"                positional arguments (passed to `forward`)\\n\",\n    \"            **kwargs :\\n\",\n    \"                keyword arguments (passed to `forward`)\\n\",\n    \"\\n\",\n    \"            Returns\\n\",\n    \"            -------\\n\",\n    \"            Any\\n\",\n    \"                result: module results of arbitrary type and number\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            return torch.jit.ScriptModule.__call__(self, *args, **kwargs)\\n\",\n    \"\\n\",\n    \"        @staticmethod\\n\",\n    \"        def prepare_batch(batch: dict, input_device, output_device):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            Helper Function to prepare Network Inputs and Labels (convert them to\\n\",\n    \"            correct type and shape and push them to correct devices)\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            batch : dict\\n\",\n    \"                dictionary containing all the data\\n\",\n    \"            input_device : torch.device\\n\",\n    \"                device for network inputs\\n\",\n    \"            output_device : torch.device\\n\",\n    \"                device for network outputs\\n\",\n    \"\\n\",\n    \"            Returns\\n\",\n    \"            -------\\n\",\n    \"            dict\\n\",\n    \"                dictionary containing data in correct type and shape and on correct\\n\",\n    \"                device\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            return_dict = {\\\"data\\\": torch.from_numpy(batch.pop(\\\"data\\\")).to(\\n\",\n    \"                input_device).to(torch.float)}\\n\",\n    \"\\n\",\n    \"            for key, vals in batch.items():\\n\",\n    \"                return_dict[key] = torch.from_numpy(vals).to(output_device).to(\\n\",\n    \"                    torch.float)\\n\",\n    \"\\n\",\n    \"            return return_dict\\n\",\n    \"        \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Saving and loading\\n\",\n    \"Now that we have the ability to implement `delira`-suitable TorchScript models, we want to store them on disk and load them again, so that we don't have to retrain them every time we want to use them. These I/O functions are usually located in `delira.io`. \\n\",\n    \"\\n\",\n    \"### Saving\\n\",\n    \"Our saving function utilizes multiple functions: `torch.jit.save` to simply save the model (including it's graph) and the `save_checkpoint_torch` function implemented for the `PyTorch` backend to store the trainer state, since `TorchScript` allows us to use plain `PyTorch` optimizers.\\n\",\n    \"\\n\",\n    \"The implementation of the function looks like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def save_checkpoint_torchscript(file: str, model=None, optimizers={},\\n\",\n    \"                                    epoch=None, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Save current checkpoint to two different files:\\n\",\n    \"            1.) ``file + \\\"_model.ptj\\\"``: Will include the state of the model\\n\",\n    \"                (including the graph; this is the opposite to\\n\",\n    \"                :func:`save_checkpoint`)\\n\",\n    \"            2.) ``file + \\\"_trainer_state.pt\\\"``: Will include the states of all\\n\",\n    \"                optimizers and the current epoch (if given)\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        file : str\\n\",\n    \"            filepath the model should be saved to\\n\",\n    \"        model : AbstractPyTorchJITNetwork or None\\n\",\n    \"            the model which should be saved\\n\",\n    \"            if None: empty dict will be saved as state dict\\n\",\n    \"        optimizers : dict\\n\",\n    \"            dictionary containing all optimizers\\n\",\n    \"        epoch : int\\n\",\n    \"            current epoch (will also be pickled)\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        # remove file extension if given\\n\",\n    \"        if any([file.endswith(ext) for ext in [\\\".pth\\\", \\\".pt\\\", \\\".ptj\\\"]]):\\n\",\n    \"            file = file.rsplit(\\\".\\\", 1)[0]\\n\",\n    \"\\n\",\n    \"        if isinstance(model, AbstractPyTorchJITNetwork):\\n\",\n    \"            torch.jit.save(model, file + \\\"_model.ptj\\\")\\n\",\n    \"\\n\",\n    \"        if optimizers or epoch is not None:\\n\",\n    \"            save_checkpoint_torch(file + \\\"_trainer_state.pt\\\", None,\\n\",\n    \"                            optimizers=optimizers, epoch=epoch, **kwargs)\\n\",\n    \"            \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"### Loading\\n\",\n    \"To load a model, which has been saved to disk by this function we have to revert each part of it. We do this by using `torch.jit.load` for the model (and the graph) and `load_checkpoint_torch` by the `PyTorch` backend.\\n\",\n    \"The actual implementation is given here:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def load_checkpoint_torchscript(file: str, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Loads a saved checkpoint consisting of 2 files\\n\",\n    \"        (see :func:`save_checkpoint_jit` for details)\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        file : str\\n\",\n    \"            filepath to a file containing a saved model\\n\",\n    \"        **kwargs:\\n\",\n    \"            Additional keyword arguments (passed to torch.load)\\n\",\n    \"            Especially \\\"map_location\\\" is important to change the device the\\n\",\n    \"            state_dict should be loaded to\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        OrderedDict\\n\",\n    \"            checkpoint state_dict\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        # remove file extensions\\n\",\n    \"        if any([file.endswith(ext) for ext in [\\\".pth\\\", \\\".pt\\\", \\\".ptj\\\"]]):\\n\",\n    \"            file = file.rsplit(\\\".\\\", 1)[0]\\n\",\n    \"\\n\",\n    \"        # load model\\n\",\n    \"        if os.path.isfile(file + \\\".ptj\\\"):\\n\",\n    \"            model_file = file\\n\",\n    \"        elif os.path.isfile(file + \\\"_model.ptj\\\"):\\n\",\n    \"            model_file = file + \\\"_model.ptj\\\"\\n\",\n    \"        else:\\n\",\n    \"            raise ValueError(\\\"No Model File found for %s\\\" % file)\\n\",\n    \"\\n\",\n    \"        # load trainer state (if possible)\\n\",\n    \"        trainer_file = model_file.replace(\\\"_model.ptj\\\", \\\"_trainer_state.pt\\\")\\n\",\n    \"        if os.path.isfile(trainer_file):\\n\",\n    \"            trainer_state = load_checkpoint_torch(trainer_file, **kwargs)\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"            trainer_state = {\\\"optimizer\\\": {},\\n\",\n    \"                             \\\"epoch\\\": None}\\n\",\n    \"\\n\",\n    \"        trainer_state.update({\\\"model\\\": torch.jit.load(model_file)})\\n\",\n    \"\\n\",\n    \"        return trainer_state\\n\",\n    \"    \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## A Trainer to train\\n\",\n    \"Now, that we can define and save/load our models, we want to train them. Luckily `delira` has already implemented a very modular backend-agnostic trainer (the `BaseNetworkTrainer`) and build upon this a `PyTorchNetworkTrainer`. Since the training process in PyTorch and TorchScript is nearly the same, we can just extend the `PyTorchNetworkTrainer`. Usually one would have to extend the `BaseNetworkTrainer` to provide some backend specific functions (like necessary initializations, optimizer setup, seeding etc.). To see how this is done, you could either have a look at the `PyTorchNetworkTrainer` or the `TfNetworkTrainer` for tensorflow, which are both following this principle. Usually the only stuff to completely change is the loading/saving behavior and the `_setup` function, which defines the backend-specific initialization. Some other functions may have to be extended (by implementing the extension and calling the parent-classes function).\\n\",\n    \"\\n\",\n    \"### Things to change:\\n\",\n    \"\\n\",\n    \"By Subclassing the `PyTorchNetworkTrainer` we have to change the following things:\\n\",\n    \"\\n\",\n    \"* The trainer's default arguments\\n\",\n    \"\\n\",\n    \"* The behavior for trying to resume a previous training\\n\",\n    \"\\n\",\n    \"* The saving, loading and updating behavior\\n\",\n    \"\\n\",\n    \"We will access this one by one:\\n\",\n    \"\\n\",\n    \"#### The Default Arguments\\n\",\n    \"\\n\",\n    \"We want to use `AbstractTorchScriptNetwork`s instead of `AbstractPyTorchNetwork`s here and we have to change the behavior if passing multiple GPUs, because currently Multi-GPU training is not supported by `TorchScript`.\\n\",\n    \"\\n\",\n    \"To do this: we implement the functions `__init__`, apply our changes and forward these changes to the call of the base-classes `__init__` like this (omitted docstrings for the sake of shortness):\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):\\n\",\n    \"        def __init__(self,\\n\",\n    \"                     network: AbstractTorchScriptNetwork,\\n\",\n    \"                     save_path: str,\\n\",\n    \"                     key_mapping,\\n\",\n    \"                     losses=None,\\n\",\n    \"                     optimizer_cls=None,\\n\",\n    \"                     optimizer_params={},\\n\",\n    \"                     train_metrics={},\\n\",\n    \"                     val_metrics={},\\n\",\n    \"                     lr_scheduler_cls=None,\\n\",\n    \"                     lr_scheduler_params={},\\n\",\n    \"                     gpu_ids=[],\\n\",\n    \"                     save_freq=1,\\n\",\n    \"                     optim_fn=create_optims_default,\\n\",\n    \"                     logging_type=\\\"tensorboardx\\\",\\n\",\n    \"                     logging_kwargs={},\\n\",\n    \"                     fold=0,\\n\",\n    \"                     callbacks=[],\\n\",\n    \"                     start_epoch=1,\\n\",\n    \"                     metric_keys=None,\\n\",\n    \"                     convert_batch_to_npy_fn=convert_torch_tensor_to_npy,\\n\",\n    \"                     criterions=None,\\n\",\n    \"                     val_freq=1,\\n\",\n    \"                     **kwargs):\\n\",\n    \"            \\n\",\n    \"            if len(gpu_ids) > 1:\\n\",\n    \"                # only use first GPU due to\\n\",\n    \"                # https://github.com/pytorch/pytorch/issues/15421\\n\",\n    \"                gpu_ids = [gpu_ids[0]]\\n\",\n    \"                logging.warning(\\\"Multiple GPUs specified. Torch JIT currently \\\"\\n\",\n    \"                                \\\"supports only single-GPU training. \\\"\\n\",\n    \"                                \\\"Switching to use only the first GPU for now...\\\")\\n\",\n    \"\\n\",\n    \"            super().__init__(network=network, save_path=save_path,\\n\",\n    \"                             key_mapping=key_mapping, losses=losses,\\n\",\n    \"                             optimizer_cls=optimizer_cls,\\n\",\n    \"                             optimizer_params=optimizer_params,\\n\",\n    \"                             train_metrics=train_metrics,\\n\",\n    \"                             val_metrics=val_metrics,\\n\",\n    \"                             lr_scheduler_cls=lr_scheduler_cls,\\n\",\n    \"                             lr_scheduler_params=lr_scheduler_params,\\n\",\n    \"                             gpu_ids=gpu_ids, save_freq=save_freq,\\n\",\n    \"                             optim_fn=optim_fn, logging_type=logging_type,\\n\",\n    \"                             logging_kwargs=logging_kwargs, fold=fold,\\n\",\n    \"                             callbacks=callbacks,\\n\",\n    \"                             start_epoch=start_epoch, metric_keys=metric_keys,\\n\",\n    \"                             convert_batch_to_npy_fn=convert_batch_to_npy_fn,\\n\",\n    \"                             mixed_precision=False, mixed_precision_kwargs={},\\n\",\n    \"                             criterions=criterions, val_freq=val_freq, **kwargs\\n\",\n    \"                             )\\n\",\n    \"            \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"#### Resuming Training\\n\",\n    \"\\n\",\n    \"For resuming the training, we have to completely change the `try_resume_training` function and cannot reuse the parent's implementation of it. Thus, we don't call `super().try_resume_training` here, but completely reimplement it from scratch:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def try_resume_training(self):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Load the latest state of a previous training if possible\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        # Load latest epoch file if available\\n\",\n    \"        if os.path.isdir(self.save_path):\\n\",\n    \"            # check all files in directory starting with \\\"checkpoint\\\" and\\n\",\n    \"            # not ending with \\\"_best.pth\\\"\\n\",\n    \"            files = [x for x in os.listdir(self.save_path)\\n\",\n    \"                     if os.path.isfile(os.path.join(self.save_path, x))\\n\",\n    \"                     and x.startswith(\\\"checkpoint\\\")\\n\",\n    \"                     and not x.endswith(\\\"_best.ptj\\\")\\n\",\n    \"                     ]\\n\",\n    \"\\n\",\n    \"            # if list is not empty: load previous state\\n\",\n    \"            if files:\\n\",\n    \"\\n\",\n    \"                latest_epoch = max([\\n\",\n    \"                    int(x.rsplit(\\\"_\\\", 1)[-1].rsplit(\\\".\\\", 1)[0])\\n\",\n    \"                    for x in files])\\n\",\n    \"\\n\",\n    \"                latest_state_path = os.path.join(self.save_path,\\n\",\n    \"                                                 \\\"checkpoint_epoch_%d.ptj\\\"\\n\",\n    \"                                                 % latest_epoch)\\n\",\n    \"\\n\",\n    \"                # if pth file does not exist, load pt file instead\\n\",\n    \"                if not os.path.isfile(latest_state_path):\\n\",\n    \"                    latest_state_path = latest_state_path[:-1]\\n\",\n    \"\\n\",\n    \"                logger.info(\\\"Attempting to load state from previous \\\\\\n\",\n    \"                            training from %s\\\" % latest_state_path)\\n\",\n    \"                try:\\n\",\n    \"                    self.update_state(latest_state_path)\\n\",\n    \"                except KeyError:\\n\",\n    \"                    logger.warning(\\\"Previous State could not be loaded, \\\\\\n\",\n    \"                                    although it exists.Training will be \\\\\\n\",\n    \"                                    restarted\\\")\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"#### Saving and Loading\\n\",\n    \"Now we need to change the saving and loading behavior. As always we try to reuse as much code as possible to avoid code duplication.\\n\",\n    \"\\n\",\n    \"##### Saving\\n\",\n    \"To save the current training state, we simply call the `save_checkpoint_torchscript` function:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def save_state(self, file_name, epoch, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        saves the current state via\\n\",\n    \"        :func:`delira.io.torch.save_checkpoint_jit`\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        file_name : str\\n\",\n    \"            filename to save the state to\\n\",\n    \"        epoch : int\\n\",\n    \"            current epoch (will be saved for mapping back)\\n\",\n    \"        **kwargs :\\n\",\n    \"            keyword arguments\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        if file_name.endswith(\\\".pt\\\") or file_name.endswith(\\\".pth\\\"):\\n\",\n    \"            file_name = file_name.rsplit(\\\".\\\", 1)[0]\\n\",\n    \"\\n\",\n    \"        save_checkpoint_torchscript(file_name, self.module, self.optimizers,\\n\",\n    \"                                    **kwargs)\\n\",\n    \"        \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"##### Loading\\n\",\n    \"\\n\",\n    \"To load the training state, we simply return the state loaded by `load_checkpoint_torchscript`.\\n\",\n    \"Since we don't use any arguments of the trainer itself here, the function is a `staticmethod`:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    @staticmethod\\n\",\n    \"    def load_state(file_name, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Loads the new state from file via\\n\",\n    \"        :func:`delira.io.torch.load_checkpoint:jit`\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        file_name : str\\n\",\n    \"            the file to load the state from\\n\",\n    \"        **kwargs : keyword arguments\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        dict\\n\",\n    \"            new state\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        return load_checkpoint_torchscript(file_name, **kwargs)\\n\",\n    \"    \\n\",\n    \"```\\n\",\n    \"    \\n\",\n    \"##### Updating\\n\",\n    \"\\n\",\n    \"After we loaded the new state, we need to update the trainer's internal state by this new state.\\n\",\n    \"\\n\",\n    \"We do this by directly assigning the model here (since the graph was stored/loaded too) instead of only updating the state_dict and calling the parent-classes method afterwards:\\n\",\n    \"    \\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    def _update_state(self, new_state):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Update the state from a given new state\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        new_state : dict\\n\",\n    \"            new state to update internal state from\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        :class:`PyTorchNetworkJITTrainer`\\n\",\n    \"            the trainer with a modified state\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        if \\\"model\\\" in new_state:\\n\",\n    \"            self.module = new_state.pop(\\\"model\\\").to(self.input_device)\\n\",\n    \"\\n\",\n    \"        return super()._update_state(new_state)\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \" \\n\",\n    \"### A Whole Trainer\\n\",\n    \" \\n\",\n    \"After combining all the changes above, we finally get our new trainer as:\\n\",\n    \" \\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"    class TorchScriptNetworkTrainer(PyTorchNetworkTrainer):\\n\",\n    \"        def __init__(self,\\n\",\n    \"                     network: AbstractTorchScriptNetwork,\\n\",\n    \"                     save_path: str,\\n\",\n    \"                     key_mapping,\\n\",\n    \"                     losses=None,\\n\",\n    \"                     optimizer_cls=None,\\n\",\n    \"                     optimizer_params={},\\n\",\n    \"                     train_metrics={},\\n\",\n    \"                     val_metrics={},\\n\",\n    \"                     lr_scheduler_cls=None,\\n\",\n    \"                     lr_scheduler_params={},\\n\",\n    \"                     gpu_ids=[],\\n\",\n    \"                     save_freq=1,\\n\",\n    \"                     optim_fn=create_optims_default,\\n\",\n    \"                     logging_type=\\\"tensorboardx\\\",\\n\",\n    \"                     logging_kwargs={},\\n\",\n    \"                     fold=0,\\n\",\n    \"                     callbacks=[],\\n\",\n    \"                     start_epoch=1,\\n\",\n    \"                     metric_keys=None,\\n\",\n    \"                     convert_batch_to_npy_fn=convert_torch_tensor_to_npy,\\n\",\n    \"                     criterions=None,\\n\",\n    \"                     val_freq=1,\\n\",\n    \"                     **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            network : :class:`AbstractPyTorchJITNetwork`\\n\",\n    \"                the network to train\\n\",\n    \"            save_path : str\\n\",\n    \"                path to save networks to\\n\",\n    \"            key_mapping : dict\\n\",\n    \"                a dictionary containing the mapping from the ``data_dict`` to\\n\",\n    \"                the actual model's inputs.\\n\",\n    \"                E.g. if a model accepts one input named 'x' and the data_dict\\n\",\n    \"                contains one entry named 'data' this argument would have to\\n\",\n    \"                be ``{'x': 'data'}``\\n\",\n    \"            losses : dict\\n\",\n    \"                dictionary containing the training losses\\n\",\n    \"            optimizer_cls : subclass of tf.train.Optimizer\\n\",\n    \"                optimizer class implementing the optimization algorithm of\\n\",\n    \"                choice\\n\",\n    \"            optimizer_params : dict\\n\",\n    \"                keyword arguments passed to optimizer during construction\\n\",\n    \"            train_metrics : dict, optional\\n\",\n    \"                metrics, which will be evaluated during train phase\\n\",\n    \"                (should work on framework's tensor types)\\n\",\n    \"            val_metrics : dict, optional\\n\",\n    \"                metrics, which will be evaluated during test phase\\n\",\n    \"                (should work on numpy arrays)\\n\",\n    \"            lr_scheduler_cls : Any\\n\",\n    \"                learning rate schedule class: must implement step() method\\n\",\n    \"            lr_scheduler_params : dict\\n\",\n    \"                keyword arguments passed to lr scheduler during construction\\n\",\n    \"            gpu_ids : list\\n\",\n    \"                list containing ids of GPUs to use; if empty: use cpu instead\\n\",\n    \"                Currently ``torch.jit`` only supports single GPU-Training,\\n\",\n    \"                thus only the first GPU will be used if multiple GPUs are passed\\n\",\n    \"            save_freq : int\\n\",\n    \"                integer specifying how often to save the current model's state.\\n\",\n    \"                State is saved every state_freq epochs\\n\",\n    \"            optim_fn : function\\n\",\n    \"                creates a dictionary containing all necessary optimizers\\n\",\n    \"            logging_type : str or callable\\n\",\n    \"                the type of logging. If string: it must be one of\\n\",\n    \"                [\\\"visdom\\\", \\\"tensorboardx\\\"]\\n\",\n    \"                If callable: it must be a logging handler class\\n\",\n    \"            logging_kwargs : dict\\n\",\n    \"                dictionary containing all logging keyword arguments\\n\",\n    \"            fold : int\\n\",\n    \"                current cross validation fold (0 per default)\\n\",\n    \"            callbacks : list\\n\",\n    \"                initial callbacks to register\\n\",\n    \"            start_epoch : int\\n\",\n    \"                epoch to start training at\\n\",\n    \"            metric_keys : dict\\n\",\n    \"                dict specifying which batch_dict entry to use for which metric as\\n\",\n    \"                target; default: None, which will result in key \\\"label\\\" for all\\n\",\n    \"                metrics\\n\",\n    \"            convert_batch_to_npy_fn : type, optional\\n\",\n    \"                function converting a batch-tensor to numpy, per default this is\\n\",\n    \"                a function, which detaches the tensor, moves it to cpu and the\\n\",\n    \"                calls ``.numpy()`` on it\\n\",\n    \"            mixed_precision : bool\\n\",\n    \"                whether to use mixed precision or not (False per default)\\n\",\n    \"            mixed_precision_kwargs : dict\\n\",\n    \"                additional keyword arguments for mixed precision\\n\",\n    \"            val_freq : int\\n\",\n    \"                validation frequency specifying how often to validate the trained\\n\",\n    \"                model (a value of 1 denotes validating every epoch,\\n\",\n    \"                a value of 2 denotes validating every second epoch etc.);\\n\",\n    \"                defaults to 1\\n\",\n    \"            **kwargs :\\n\",\n    \"                additional keyword arguments\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"            if len(gpu_ids) > 1:\\n\",\n    \"                # only use first GPU due to\\n\",\n    \"                # https://github.com/pytorch/pytorch/issues/15421\\n\",\n    \"                gpu_ids = [gpu_ids[0]]\\n\",\n    \"                logging.warning(\\\"Multiple GPUs specified. Torch JIT currently \\\"\\n\",\n    \"                                \\\"supports only single-GPU training. \\\"\\n\",\n    \"                                \\\"Switching to use only the first GPU for now...\\\")\\n\",\n    \"\\n\",\n    \"            super().__init__(network=network, save_path=save_path,\\n\",\n    \"                             key_mapping=key_mapping, losses=losses,\\n\",\n    \"                             optimizer_cls=optimizer_cls,\\n\",\n    \"                             optimizer_params=optimizer_params,\\n\",\n    \"                             train_metrics=train_metrics,\\n\",\n    \"                             val_metrics=val_metrics,\\n\",\n    \"                             lr_scheduler_cls=lr_scheduler_cls,\\n\",\n    \"                             lr_scheduler_params=lr_scheduler_params,\\n\",\n    \"                             gpu_ids=gpu_ids, save_freq=save_freq,\\n\",\n    \"                             optim_fn=optim_fn, logging_type=logging_type,\\n\",\n    \"                             logging_kwargs=logging_kwargs, fold=fold,\\n\",\n    \"                             callbacks=callbacks,\\n\",\n    \"                             start_epoch=start_epoch, metric_keys=metric_keys,\\n\",\n    \"                             convert_batch_to_npy_fn=convert_batch_to_npy_fn,\\n\",\n    \"                             mixed_precision=False, mixed_precision_kwargs={},\\n\",\n    \"                             criterions=criterions, val_freq=val_freq, **kwargs\\n\",\n    \"                             )\\n\",\n    \"\\n\",\n    \"        def try_resume_training(self):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            Load the latest state of a previous training if possible\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            # Load latest epoch file if available\\n\",\n    \"            if os.path.isdir(self.save_path):\\n\",\n    \"                # check all files in directory starting with \\\"checkpoint\\\" and\\n\",\n    \"                # not ending with \\\"_best.pth\\\"\\n\",\n    \"                files = [x for x in os.listdir(self.save_path)\\n\",\n    \"                         if os.path.isfile(os.path.join(self.save_path, x))\\n\",\n    \"                         and x.startswith(\\\"checkpoint\\\")\\n\",\n    \"                         and not x.endswith(\\\"_best.ptj\\\")\\n\",\n    \"                         ]\\n\",\n    \"\\n\",\n    \"                # if list is not empty: load previous state\\n\",\n    \"                if files:\\n\",\n    \"\\n\",\n    \"                    latest_epoch = max([\\n\",\n    \"                        int(x.rsplit(\\\"_\\\", 1)[-1].rsplit(\\\".\\\", 1)[0])\\n\",\n    \"                        for x in files])\\n\",\n    \"\\n\",\n    \"                    latest_state_path = os.path.join(self.save_path,\\n\",\n    \"                                                     \\\"checkpoint_epoch_%d.ptj\\\"\\n\",\n    \"                                                     % latest_epoch)\\n\",\n    \"\\n\",\n    \"                    # if pth file does not exist, load pt file instead\\n\",\n    \"                    if not os.path.isfile(latest_state_path):\\n\",\n    \"                        latest_state_path = latest_state_path[:-1]\\n\",\n    \"\\n\",\n    \"                    logger.info(\\\"Attempting to load state from previous \\\\\\n\",\n    \"                                training from %s\\\" % latest_state_path)\\n\",\n    \"                    try:\\n\",\n    \"                        self.update_state(latest_state_path)\\n\",\n    \"                    except KeyError:\\n\",\n    \"                        logger.warning(\\\"Previous State could not be loaded, \\\\\\n\",\n    \"                                        although it exists.Training will be \\\\\\n\",\n    \"                                        restarted\\\")\\n\",\n    \"\\n\",\n    \"        def save_state(self, file_name, epoch, **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            saves the current state via\\n\",\n    \"            :func:`delira.io.torch.save_checkpoint_jit`\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            file_name : str\\n\",\n    \"                filename to save the state to\\n\",\n    \"            epoch : int\\n\",\n    \"                current epoch (will be saved for mapping back)\\n\",\n    \"            **kwargs :\\n\",\n    \"                keyword arguments\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            if file_name.endswith(\\\".pt\\\") or file_name.endswith(\\\".pth\\\"):\\n\",\n    \"                file_name = file_name.rsplit(\\\".\\\", 1)[0]\\n\",\n    \"\\n\",\n    \"            save_checkpoint_torchscript(file_name, self.module, self.optimizers,\\n\",\n    \"                                        **kwargs)\\n\",\n    \"\\n\",\n    \"        @staticmethod\\n\",\n    \"        def load_state(file_name, **kwargs):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            Loads the new state from file via\\n\",\n    \"            :func:`delira.io.torch.load_checkpoint:jit`\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            file_name : str\\n\",\n    \"                the file to load the state from\\n\",\n    \"            **kwargs : keyword arguments\\n\",\n    \"\\n\",\n    \"            Returns\\n\",\n    \"            -------\\n\",\n    \"            dict\\n\",\n    \"                new state\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            return load_checkpoint_torchscript(file_name, **kwargs)\\n\",\n    \"\\n\",\n    \"        def _update_state(self, new_state):\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            Update the state from a given new state\\n\",\n    \"\\n\",\n    \"            Parameters\\n\",\n    \"            ----------\\n\",\n    \"            new_state : dict\\n\",\n    \"                new state to update internal state from\\n\",\n    \"\\n\",\n    \"            Returns\\n\",\n    \"            -------\\n\",\n    \"            :class:`PyTorchNetworkJITTrainer`\\n\",\n    \"                the trainer with a modified state\\n\",\n    \"\\n\",\n    \"            \\\"\\\"\\\"\\n\",\n    \"            if \\\"model\\\" in new_state:\\n\",\n    \"                self.module = new_state.pop(\\\"model\\\").to(self.input_device)\\n\",\n    \"\\n\",\n    \"            return super()._update_state(new_state)\\n\",\n    \"        \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Wrapping it all in an Experiment\\n\",\n    \"To have access to methods like a K-Fold (and the not yet finished) hyperparameter tuning, we need to wrap the trainer in an Experiment. We will use the same approach as we did for implementing the trainer: Extending an already provided class.\\n\",\n    \"\\n\",\n    \"This time we extend the `PyTorchExperiment` which itself extends the `BaseExperiment` by some backend-specific defaults, types and seeds.\\n\",\n    \"\\n\",\n    \"Our whole class definition just changes the default arguments of the `PyTorchExperiment` and thus, we only have to implenent it's `__init__`:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"\\n\",\n    \"class TorchScriptExperiment(PyTorchExperiment):\\n\",\n    \"    def __init__(self,\\n\",\n    \"                 params: typing.Union[str, Parameters],\\n\",\n    \"                 model_cls: AbstractTorchScriptNetwork, # not AbstractPyTorchNetwork anymore\\n\",\n    \"                 n_epochs=None,\\n\",\n    \"                 name=None,\\n\",\n    \"                 save_path=None,\\n\",\n    \"                 key_mapping=None,\\n\",\n    \"                 val_score_key=None,\\n\",\n    \"                 optim_builder=create_optims_default_pytorch,\\n\",\n    \"                 checkpoint_freq=1,\\n\",\n    \"                 trainer_cls=TorchScriptNetworkTrainer, # not PyTorchNetworkTrainer anymore\\n\",\n    \"                 **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        params : :class:`Parameters` or str\\n\",\n    \"            the training parameters, if string is passed,\\n\",\n    \"            it is treated as a path to a pickle file, where the\\n\",\n    \"            parameters are loaded from\\n\",\n    \"        model_cls : Subclass of :class:`AbstractTorchScriptNetwork`\\n\",\n    \"            the class implementing the model to train\\n\",\n    \"        n_epochs : int or None\\n\",\n    \"            the number of epochs to train, if None: can be specified later\\n\",\n    \"            during actual training\\n\",\n    \"        name : str or None\\n\",\n    \"            the Experiment's name\\n\",\n    \"        save_path : str or None\\n\",\n    \"            the path to save the results and checkpoints to.\\n\",\n    \"            if None: Current working directory will be used\\n\",\n    \"        key_mapping : dict\\n\",\n    \"            mapping between data_dict and model inputs (necessary for\\n\",\n    \"            prediction with :class:`Predictor`-API), if no keymapping is\\n\",\n    \"            given, a default key_mapping of {\\\"x\\\": \\\"data\\\"} will be used here\\n\",\n    \"        val_score_key : str or None\\n\",\n    \"            key defining which metric to use for validation (determining\\n\",\n    \"            best model and scheduling lr); if None: No validation-based\\n\",\n    \"            operations will be done (model might still get validated,\\n\",\n    \"            but validation metrics can only be logged and not used further)\\n\",\n    \"        optim_builder : function\\n\",\n    \"            Function returning a dict of backend-specific optimizers.\\n\",\n    \"            defaults to :func:`create_optims_default_pytorch`\\n\",\n    \"        checkpoint_freq : int\\n\",\n    \"            frequency of saving checkpoints (1 denotes saving every epoch,\\n\",\n    \"            2 denotes saving every second epoch etc.); default: 1\\n\",\n    \"        trainer_cls : subclass of :class:`TorchScriptNetworkTrainer`\\n\",\n    \"            the trainer class to use for training the model, defaults to\\n\",\n    \"            :class:`TorchScriptNetworkTrainer`\\n\",\n    \"        **kwargs :\\n\",\n    \"            additional keyword arguments\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        super().__init__(params=params, model_cls=model_cls,\\n\",\n    \"                         n_epochs=n_epochs, name=name, save_path=save_path,\\n\",\n    \"                         key_mapping=key_mapping,\\n\",\n    \"                         val_score_key=val_score_key,\\n\",\n    \"                         optim_builder=optim_builder,\\n\",\n    \"                         checkpoint_freq=checkpoint_freq,\\n\",\n    \"                         trainer_cls=trainer_cls,\\n\",\n    \"                         **kwargs)\\n\",\n    \"        \\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"## Testing it\\n\",\n    \"Now that we finished the implementation of the backend (which is the outermost wrapper; Congratulations!), we can just test it. We'll use a very simple network and test it with dummy data. We also only test the `run` and `test` functionality of our experiment, since everything else is just used for setting up the internal state or a composition of these two methods and already tested:\\n\",\n    \"Now, let's just define our dataset, instantiate it three times (for training, validation and testing) and wrap each of them into a `DataManager`:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import AbstractDataset\\n\",\n    \"from delira.data_loading import DataManager\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class DummyDataset(AbstractDataset):\\n\",\n    \"    def __init__(self, length):\\n\",\n    \"        super().__init__(None, None)\\n\",\n    \"        self.length = length\\n\",\n    \"\\n\",\n    \"    def __getitem__(self, index):\\n\",\n    \"        return {\\\"data\\\": np.random.rand(32),\\n\",\n    \"                \\\"label\\\": np.random.randint(0, 1, 1)}\\n\",\n    \"\\n\",\n    \"    def __len__(self):\\n\",\n    \"        return self.length\\n\",\n    \"\\n\",\n    \"    def get_sample_from_index(self, index):\\n\",\n    \"        return self.__getitem__(index)\\n\",\n    \"    \\n\",\n    \"dset_train = DummyDataset(500)\\n\",\n    \"dset_val = DummyDataset(50)\\n\",\n    \"dset_test = DummyDataset(10)\\n\",\n    \"\\n\",\n    \"# training, validation and testing with \\n\",\n    \"#a batchsize of 16, 1 loading thread and no transformations.\\n\",\n    \"dmgr_train = DataManager(dset_train, 16, 1, None)\\n\",\n    \"dmgr_val = DataManager(dset_val, 16, 1, None)\\n\",\n    \"dmgr_test = DataManager(dset_test, 16, 1, None)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now, that we have created three datasets, we need to define our small dummy network. We do this by subclassing `delira.models.AbstractTorchScriptNetwork` (which is the exactly implementation given above, be we need to use the internal one, because there are some typechecks against this one).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.models import AbstractTorchScriptNetwork\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class DummyNetworkTorchScript(AbstractTorchScriptNetwork):\\n\",\n    \"    __constants__ = [\\\"module\\\"]\\n\",\n    \"\\n\",\n    \"    def __init__(self):\\n\",\n    \"        super().__init__()\\n\",\n    \"        self.module = self._build_model(32, 1)\\n\",\n    \"\\n\",\n    \"    @torch.jit.script_method\\n\",\n    \"    def forward(self, x):\\n\",\n    \"        return {\\\"pred\\\": self.module(x)}\\n\",\n    \"\\n\",\n    \"    @staticmethod\\n\",\n    \"    def prepare_batch(batch_dict, input_device, output_device):\\n\",\n    \"        return {\\\"data\\\": torch.from_numpy(batch_dict[\\\"data\\\"]\\n\",\n    \"                                         ).to(input_device,\\n\",\n    \"                                              torch.float),\\n\",\n    \"                \\\"label\\\": torch.from_numpy(batch_dict[\\\"label\\\"]\\n\",\n    \"                                          ).to(output_device,\\n\",\n    \"                                               torch.float)}\\n\",\n    \"\\n\",\n    \"    @staticmethod\\n\",\n    \"    def closure(model: AbstractTorchScriptNetwork, data_dict: dict,\\n\",\n    \"                optimizers: dict, losses={}, metrics={},\\n\",\n    \"                fold=0, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        closure method to do a single backpropagation step\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        model : \\n\",\n    \"            trainable model\\n\",\n    \"        data_dict : dict\\n\",\n    \"            dictionary containing the data\\n\",\n    \"        optimizers : dict\\n\",\n    \"            dictionary of optimizers to optimize model's parameters\\n\",\n    \"        losses : dict\\n\",\n    \"            dict holding the losses to calculate errors\\n\",\n    \"            (gradients from different losses will be accumulated)\\n\",\n    \"        metrics : dict\\n\",\n    \"            dict holding the metrics to calculate\\n\",\n    \"        fold : int\\n\",\n    \"            Current Fold in Crossvalidation (default: 0)\\n\",\n    \"        **kwargs:\\n\",\n    \"            additional keyword arguments\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        dict\\n\",\n    \"            Metric values (with same keys as input dict metrics)\\n\",\n    \"        dict\\n\",\n    \"            Loss values (with same keys as input dict losses)\\n\",\n    \"        list\\n\",\n    \"            Arbitrary number of predictions as torch.Tensor\\n\",\n    \"\\n\",\n    \"        Raises\\n\",\n    \"        ------\\n\",\n    \"        AssertionError\\n\",\n    \"            if optimizers or losses are empty or the optimizers are not\\n\",\n    \"            specified\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        assert (optimizers and losses) or not optimizers, \\\\\\n\",\n    \"            \\\"Criterion dict cannot be emtpy, if optimizers are passed\\\"\\n\",\n    \"\\n\",\n    \"        loss_vals = {}\\n\",\n    \"        metric_vals = {}\\n\",\n    \"        total_loss = 0\\n\",\n    \"\\n\",\n    \"        # choose suitable context manager:\\n\",\n    \"        if optimizers:\\n\",\n    \"            context_man = torch.enable_grad\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"            context_man = torch.no_grad\\n\",\n    \"\\n\",\n    \"        with context_man():\\n\",\n    \"\\n\",\n    \"            inputs = data_dict.pop(\\\"data\\\")\\n\",\n    \"            preds = model(inputs)\\n\",\n    \"\\n\",\n    \"            if data_dict:\\n\",\n    \"\\n\",\n    \"                for key, crit_fn in losses.items():\\n\",\n    \"                    _loss_val = crit_fn(preds[\\\"pred\\\"], *data_dict.values())\\n\",\n    \"                    loss_vals[key] = _loss_val.item()\\n\",\n    \"                    total_loss += _loss_val\\n\",\n    \"\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    for key, metric_fn in metrics.items():\\n\",\n    \"                        metric_vals[key] = metric_fn(\\n\",\n    \"                            preds[\\\"pred\\\"], *data_dict.values()).item()\\n\",\n    \"\\n\",\n    \"        if optimizers:\\n\",\n    \"            optimizers['default'].zero_grad()\\n\",\n    \"            # perform loss scaling via apex if half precision is enabled\\n\",\n    \"            with optimizers[\\\"default\\\"].scale_loss(total_loss) as scaled_loss:\\n\",\n    \"                scaled_loss.backward()\\n\",\n    \"            optimizers['default'].step()\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"\\n\",\n    \"            # add prefix \\\"val\\\" in validation mode\\n\",\n    \"            eval_loss_vals, eval_metrics_vals = {}, {}\\n\",\n    \"            for key in loss_vals.keys():\\n\",\n    \"                eval_loss_vals[\\\"val_\\\" + str(key)] = loss_vals[key]\\n\",\n    \"\\n\",\n    \"            for key in metric_vals:\\n\",\n    \"                eval_metrics_vals[\\\"val_\\\" + str(key)] = metric_vals[key]\\n\",\n    \"\\n\",\n    \"            loss_vals = eval_loss_vals\\n\",\n    \"            metric_vals = eval_metrics_vals\\n\",\n    \"\\n\",\n    \"        return metric_vals, loss_vals, {k: v.detach()\\n\",\n    \"                                        for k, v in preds.items()}\\n\",\n    \"\\n\",\n    \"    @staticmethod\\n\",\n    \"    def _build_model(in_channels, n_outputs):\\n\",\n    \"        return torch.nn.Sequential(\\n\",\n    \"            torch.nn.Linear(in_channels, 64),\\n\",\n    \"            torch.nn.ReLU(),\\n\",\n    \"            torch.nn.Linear(64, n_outputs)\\n\",\n    \"        )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Now, that we defined our model, let's just test, if we really can forward some tensors through it. We will just use some random `torch.Tensors` (created by `torch.rand`). Since our model accepts 1d inputs of length 32, we need to pass 2d tensors to it (the additional dimension is the batch-dimension).\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"{'single': tensor([[-0.1934]], grad_fn=<DifferentiableGraphBackward>),\\n\",\n       \" 'batched': tensor([[-0.0525],\\n\",\n       \"         [-0.0884],\\n\",\n       \"         [-0.1492],\\n\",\n       \"         [-0.0431]], grad_fn=<DifferentiableGraphBackward>)}\"\n      ]\n     },\n     \"execution_count\": 4,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"input_tensor_single = torch.rand(1, 32) # use a single-sample batch (batchsize=1) here\\n\",\n    \"input_tensor_batched = torch.rand(4, 32) # use a batch with batchsize 4 here\\n\",\n    \"\\n\",\n    \"# create model instance\\n\",\n    \"model = DummyNetworkTorchScript()\\n\",\n    \"\\n\",\n    \"outputs = {\\\"single\\\": model(input_tensor_single)[\\\"pred\\\"], \\\"batched\\\": model(input_tensor_batched)[\\\"pred\\\"]}\\n\",\n    \"outputs\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from sklearn.metrics import mean_absolute_error\\n\",\n    \"from delira.training.callbacks import ReduceLROnPlateauCallbackPyTorch\\n\",\n    \"from delira.training import Parameters\\n\",\n    \"params = Parameters(fixed_params={\\n\",\n    \"                    \\\"model\\\": {},\\n\",\n    \"                    \\\"training\\\": {\\n\",\n    \"                        \\\"losses\\\": {\\\"CE\\\": torch.nn.BCEWithLogitsLoss()},\\n\",\n    \"                        \\\"optimizer_cls\\\": torch.optim.Adam,\\n\",\n    \"                        \\\"optimizer_params\\\": {\\\"lr\\\": 1e-3},\\n\",\n    \"                        \\\"num_epochs\\\": 2,\\n\",\n    \"                        \\\"val_metrics\\\": {\\\"mae\\\": mean_absolute_error},\\n\",\n    \"                        \\\"lr_sched_cls\\\": ReduceLROnPlateauCallbackPyTorch,\\n\",\n    \"                        \\\"lr_sched_params\\\": {\\\"mode\\\": \\\"min\\\"}\\n\",\n    \"                    }\\n\",\n    \"                }\\n\",\n    \"          )\\n\",\n    \"\\n\",\n    \"from delira.training import TorchScriptExperiment\\n\",\n    \"\\n\",\n    \"exp = TorchScriptExperiment(params, DummyNetworkTorchScript,\\n\",\n    \"                            key_mapping={\\\"x\\\": \\\"data\\\"},\\n\",\n    \"                            val_score_key=\\\"mae\\\",\\n\",\n    \"                            val_score_mode=\\\"min\\\")\\n\",\n    \"\\n\",\n    \"trained_model = exp.run(dmgr_train, dmgr_val)\\n\",\n    \"exp.test(trained_model, dmgr_test, params.nested_get(\\\"val_metrics\\\"))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {},\n   \"source\": [\n    \"Congratulations. You have implemented your first fully-workable `delira`-Backend. Wasn't that hard, was it?\\n\",\n    \"\\n\",\n    \"Before you start implementing backends for all the other frameworks out there, let me just give you some advices:\\n\",\n    \"\\n\",\n    \"* You should test everything you implement or extend\\n\",\n    \"\\n\",\n    \"* Make sure, to keep your backend-specification in mind\\n\",\n    \"\\n\",\n    \"* Always follow the API of already existing backends. If this is not possible: test this extensively\\n\",\n    \"\\n\",\n    \"* If you extend another backend (like we did here; we extended the `PyTorch`-backend for `TorchScript`), make sure, that the \\\"base-backend\\\" is always installed (best if they can only be installed together)\\n\",\n    \"\\n\",\n    \"* If you have questions regarding the implementation, don't hestiate to contact us.\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "notebooks/gan_pytorch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"# Generative Adversarial Nets with Delira - A very short introduction\\n\",\n    \"*Author: Justus Schock* \\n\",\n    \"\\n\",\n    \"*Date: 04.12.2018*\\n\",\n    \"\\n\",\n    \"This Example shows how to set up a basic GAN PyTorch experiment and\\n\",\n    \"Visdom Logging Environment.\\n\",\n    \"\\n\",\n    \"## HyperParameters\\n\",\n    \"Let's first setup the essential hyperparameters. We will use `delira`'s `Parameters`-class for this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"logger = None\\n\",\n    \"import torch\\n\",\n    \"from delira.training import Parameters\\n\",\n    \"params = Parameters(fixed_params={\\n\",\n    \"    \\\"model\\\": {\\n\",\n    \"        \\\"n_channels\\\": 1, \\n\",\n    \"        \\\"noise_length\\\": 10\\n\",\n    \"    },\\n\",\n    \"    \\\"training\\\": {\\n\",\n    \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n    \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n    \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n    \"        \\\"optimizer_params\\\": {'lr': 1e-3}, # initialization parameters for this algorithm\\n\",\n    \"        \\\"losses\\\": {\\\"L1\\\": torch.nn.L1Loss()}, # the loss function\\n\",\n    \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n    \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n    \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n    \"    }\\n\",\n    \"}) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since we specified `torch.nn.L1Loss` as criterion and `torch.nn.MSELoss` as metric, they will be both calculated for each batch, but only the criterion will be used for backpropagation. Since we have a simple generative task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n    \"\\n\",\n    \"## Logging and Visualization\\n\",\n    \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Visdom`. To start a visdom server you need to execute the following command inside an environment which has visdom installed: \\n\",\n    \"```shell\\n\",\n    \"visdom -port=9999\\n\",\n    \"```\\n\",\n    \"This will start a visdom server on port 9999 of your machine and now we can start to configure our logging environment. To view your results you can open [http://localhost:9999](http://localhost:9999) in your browser.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from trixi.logger import PytorchVisdomLogger\\n\",\n    \"from delira.logging import TrixiHandler\\n\",\n    \"import logging\\n\",\n    \"\\n\",\n    \"logger_kwargs = {\\n\",\n    \"    'name': 'GANExampleLogger', # name of our logging environment\\n\",\n    \"    'port': 9999 # port on which our visdom server is alive\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"logger_cls = PytorchVisdomLogger\\n\",\n    \"\\n\",\n    \"# configure logging module (and root logger)\\n\",\n    \"logging.basicConfig(level=logging.INFO,\\n\",\n    \"                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# derive logger from root logger\\n\",\n    \"# (don't do `logger = logging.Logger(\\\"...\\\")` since this will create a new\\n\",\n    \"# logger which is unrelated to the root logger\\n\",\n    \"logger = logging.getLogger(\\\"Test Logger\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since a single visdom server can run multiple environments, we need to specify a (unique) name for our environment and need to tell the logger, on which port it can find the visdom server.\\n\",\n    \"\\n\",\n    \"## Data Preparation\\n\",\n    \"### Loading\\n\",\n    \"Next we will create a small train and validation set (based on `torchvision` MNIST):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import TorchvisionClassificationDataset\\n\",\n    \"\\n\",\n    \"dataset_train = TorchvisionClassificationDataset(\\\"mnist\\\", # which dataset to use\\n\",\n    \"                                                 train=True, # use trainset\\n\",\n    \"                                                 img_shape=(224, 224) # resample to 224 x 224 pixels\\n\",\n    \"                                                )\\n\",\n    \"dataset_val = TorchvisionClassificationDataset(\\\"mnist\\\", \\n\",\n    \"                                               train=False,\\n\",\n    \"                                               img_shape=(224, 224)\\n\",\n    \"                                              )\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Augmentation\\n\",\n    \"For Data-Augmentation we will apply a few transformations:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n\",\n    \"                                        ContrastAugmentationTransform, Compose\\n\",\n    \"from batchgenerators.transforms.spatial_transforms import ResizeTransform\\n\",\n    \"from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\",\n    \"\\n\",\n    \"transforms = Compose([\\n\",\n    \"    RandomCropTransform(200), # Perform Random Crops of Size 200 x 200 pixels\\n\",\n    \"    ResizeTransform(224), # Resample these crops back to 224 x 224 pixels\\n\",\n    \"    ContrastAugmentationTransform(), # randomly adjust contrast\\n\",\n    \"    MeanStdNormalizationTransform(mean=[0.5], std=[0.5])]) \\n\",\n    \"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"With these transformations we can now wrap our datasets into datamanagers:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n    \"\\n\",\n    \"manager_train = DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                                transforms=transforms,\\n\",\n    \"                                sampler_cls=RandomSampler,\\n\",\n    \"                                n_process_augmentation=4)\\n\",\n    \"\\n\",\n    \"manager_val = DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                              transforms=transforms,\\n\",\n    \"                              sampler_cls=SequentialSampler,\\n\",\n    \"                              n_process_augmentation=4)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Training\\n\",\n    \"\\n\",\n    \"After we have done that, we can finally specify our experiment and run it. We will therfore use the already implemented `GenerativeAdversarialNetworkBasePyTorch` which is basically a vanilla DCGAN:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import warnings\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from delira.training import PyTorchExperiment\\n\",\n    \"from delira.training.train_utils import create_optims_gan_default_pytorch\\n\",\n    \"from delira.models.gan import GenerativeAdversarialNetworkBasePyTorch\\n\",\n    \"\\n\",\n    \"if logger is not None:\\n\",\n    \"    logger.info(\\\"Init Experiment\\\")\\n\",\n    \"experiment = PyTorchExperiment(params, GenerativeAdversarialNetworkBasePyTorch,\\n\",\n    \"                               name=\\\"GANExample\\\",\\n\",\n    \"                               save_path=\\\"./tmp/delira_Experiments\\\",\\n\",\n    \"                               optim_builder=create_optims_gan_default_pytorch,\\n\",\n    \"                               gpu_ids=[0])\\n\",\n    \"experiment.save()\\n\",\n    \"\\n\",\n    \"model = experiment.run(manager_train, manager_val)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Congratulations, you have now trained your first Generative Adversarial Model using `delira`.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## See Also\\n\",\n    \"For a more detailed explanation have a look at \\n\",\n    \"* [the introduction tutorial](tutorial_delira.ipynb, \\\"Introduction\\\")\\n\",\n    \"* [the 2d segmentation example](segmentation_2d_pytorch.ipynb, \\\"Segmentation 2D\\\")\\n\",\n    \"* [the 3d segmentation example](segmentation_3d_pytorch.ipynb, \\\"Segmentation 3D\\\")\\n\",\n    \"* [the classification example](classification_pytorch.ipynb, \\\"GAN\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "notebooks/segmentation_2d_pytorch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"# Segmentation in 2D using U-Nets with Delira - A very short introduction\\n\",\n    \"\\n\",\n    \"*Author: Justus Schock, Alexander Moriz* \\n\",\n    \"\\n\",\n    \"*Date: 17.12.2018*\\n\",\n    \" \\n\",\n    \"This Example shows how use the U-Net implementation in Delira with PyTorch.\\n\",\n    \"\\n\",\n    \"Let's first setup the essential hyperparameters. We will use `delira`'s `Parameters`-class for this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"logger = None\\n\",\n    \"import torch\\n\",\n    \"from delira.training import Parameters\\n\",\n    \"params = Parameters(fixed_params={\\n\",\n    \"    \\\"model\\\": {\\n\",\n    \"        \\\"in_channels\\\": 1, \\n\",\n    \"        \\\"num_classes\\\": 4\\n\",\n    \"    },\\n\",\n    \"    \\\"training\\\": {\\n\",\n    \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n    \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n    \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n    \"        \\\"optimizer_params\\\": {'lr': 1e-3}, # initialization parameters for this algorithm\\n\",\n    \"        \\\"losses\\\": {\\\"CE\\\": torch.nn.CrossEntropyLoss()}, # the loss function\\n\",\n    \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n    \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n    \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n    \"    }\\n\",\n    \"}) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n    \"\\n\",\n    \"## Logging and Visualization\\n\",\n    \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Visdom`. To start a visdom server you need to execute the following command inside an environment which has visdom installed: \\n\",\n    \"```shell\\n\",\n    \"visdom -port=9999\\n\",\n    \"```\\n\",\n    \"This will start a visdom server on port 9999 of your machine and now we can start to configure our logging environment. To view your results you can open [http://localhost:9999](http://localhost:9999) in your browser.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from trixi.logger import PytorchVisdomLogger\\n\",\n    \"from delira.logging import TrixiHandler\\n\",\n    \"import logging\\n\",\n    \"\\n\",\n    \"logger_kwargs = {\\n\",\n    \"    'name': 'ClassificationExampleLogger', # name of our logging environment\\n\",\n    \"    'port': 9999 # port on which our visdom server is alive\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"logger_cls = PytorchVisdomLogger\\n\",\n    \"\\n\",\n    \"# configure logging module (and root logger)\\n\",\n    \"logging.basicConfig(level=logging.INFO,\\n\",\n    \"                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# derive logger from root logger\\n\",\n    \"# (don't do `logger = logging.Logger(\\\"...\\\")` since this will create a new\\n\",\n    \"# logger which is unrelated to the root logger\\n\",\n    \"logger = logging.getLogger(\\\"Test Logger\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since a single visdom server can run multiple environments, we need to specify a (unique) name for our environment and need to tell the logger, on which port it can find the visdom server.\\n\",\n    \"\\n\",\n    \"## Data Praparation\\n\",\n    \"### Loading\\n\",\n    \"Next we will create a small train and validation set (in this case they will be the same to show the overfitting capability of the UNet).\\n\",\n    \"\\n\",\n    \"Our data is a brain MR-image thankfully provided by the [FSL](https://fsl.fmrib.ox.ac.uk/fsl/fslwiki) in their [introduction](http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/IntroBox3.html).\\n\",\n    \"\\n\",\n    \"We first download the data and extract the T1 image and the corresponding segmentation:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from io import BytesIO\\n\",\n    \"from zipfile import ZipFile\\n\",\n    \"from urllib.request import urlopen\\n\",\n    \"\\n\",\n    \"resp = urlopen(\\\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\\\")\\n\",\n    \"zipfile = ZipFile(BytesIO(resp.read()))\\n\",\n    \"#zipfile_list = zipfile.namelist()\\n\",\n    \"#print(zipfile_list)\\n\",\n    \"img_file = zipfile.extract(\\\"ExBox3/T1_brain.nii.gz\\\")\\n\",\n    \"mask_file = zipfile.extract(\\\"ExBox3/T1_brain_seg.nii.gz\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Now, we load the image and the mask (they are both 3D), convert them to a 32-bit floating point numpy array and ensure, they have the same shape (i.e. that for each voxel in the image, there is a voxel in the mask):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import SimpleITK as sitk\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"# load image and mask\\n\",\n    \"img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\\n\",\n    \"img = img.astype(np.float32)\\n\",\n    \"mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\\n\",\n    \"mask = mask.astype(np.float32)\\n\",\n    \"\\n\",\n    \"assert mask.shape == img.shape\\n\",\n    \"print(img.shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"By querying the unique values in the mask, we get the following:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"np.unique(mask)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"This means, there are 4 classes (background and 3 types of tissue) in our sample.\\n\",\n    \"\\n\",\n    \"Since we want to do a 2D segmentation, we extract a single slice out of the image and the mask (we choose slice 100 here) and plot it:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import matplotlib.pyplot as plt\\n\",\n    \"\\n\",\n    \"# load single slice\\n\",\n    \"img_slice = img[:, :, 100]\\n\",\n    \"mask_slice = mask[:, :, 100]\\n\",\n    \"\\n\",\n    \"# plot slices\\n\",\n    \"plt.figure(1, figsize=(15,10))\\n\",\n    \"plt.subplot(121)\\n\",\n    \"plt.imshow(img_slice, cmap=\\\"gray\\\")\\n\",\n    \"plt.colorbar(fraction=0.046, pad=0.04)\\n\",\n    \"plt.subplot(122)\\n\",\n    \"plt.imshow(mask_slice, cmap=\\\"gray\\\")\\n\",\n    \"plt.colorbar(fraction=0.046, pad=0.04)\\n\",\n    \"plt.show()\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"To load the data, we have to use a `Dataset`. The following defines a very simple dataset, accepting an image slice, a mask slice and the number of samples. It always returns the same sample until `num_samples` samples have been returned.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import AbstractDataset\\n\",\n    \"\\n\",\n    \"class CustomDataset(AbstractDataset):\\n\",\n    \"    def __init__(self, img, mask, num_samples=1000):\\n\",\n    \"        super().__init__(None, None, None, None)\\n\",\n    \"        self.data = {\\\"data\\\": img.reshape(1, *img.shape), \\\"label\\\": mask.reshape(1, *mask.shape)}\\n\",\n    \"        self.num_samples = num_samples\\n\",\n    \"        \\n\",\n    \"    def __getitem__(self, index):\\n\",\n    \"        return self.data\\n\",\n    \"    \\n\",\n    \"    def __len__(self):\\n\",\n    \"        return self.num_samples\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Now, we can finally instantiate our datasets:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"dataset_train = CustomDataset(img_slice, mask_slice, num_samples=10000)\\n\",\n    \"dataset_val = CustomDataset(img_slice, mask_slice, num_samples=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Augmentation\\n\",\n    \"For Data-Augmentation we will apply a few transformations:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from batchgenerators.transforms import RandomCropTransform, \\\\\\n\",\n    \"                                        ContrastAugmentationTransform, Compose\\n\",\n    \"from batchgenerators.transforms.spatial_transforms import ResizeTransform\\n\",\n    \"from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\",\n    \"\\n\",\n    \"transforms = Compose([\\n\",\n    \"    RandomCropTransform(150, label_key=\\\"label\\\"), # Perform Random Crops of Size 150 x 150 pixels\\n\",\n    \"    ResizeTransform(224, label_key=\\\"label\\\"), # Resample these crops back to 224 x 224 pixels\\n\",\n    \"    ContrastAugmentationTransform(), # randomly adjust contrast\\n\",\n    \"    MeanStdNormalizationTransform(mean=[img_slice.mean()], std=[img_slice.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"With these transformations we can now wrap our datasets into datamanagers:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n    \"\\n\",\n    \"manager_train = DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                                transforms=transforms,\\n\",\n    \"                                sampler_cls=RandomSampler,\\n\",\n    \"                                n_process_augmentation=4)\\n\",\n    \"\\n\",\n    \"manager_val = DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                              transforms=transforms,\\n\",\n    \"                              sampler_cls=SequentialSampler,\\n\",\n    \"                              n_process_augmentation=4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Training\\n\",\n    \"\\n\",\n    \"After we have done that, we can finally specify our experiment and run it. We will therfore use the already implemented `UNet2dPytorch`:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import warnings\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from delira.training import PyTorchExperiment\\n\",\n    \"from delira.training.train_utils import create_optims_default_pytorch\\n\",\n    \"from delira.models.segmentation import UNet2dPyTorch\\n\",\n    \"\\n\",\n    \"if logger is not None:\\n\",\n    \"    logger.info(\\\"Init Experiment\\\")\\n\",\n    \"experiment = PyTorchExperiment(params, UNet2dPyTorch,\\n\",\n    \"                               name=\\\"Segmentation2dExample\\\",\\n\",\n    \"                               save_path=\\\"./tmp/delira_Experiments\\\",\\n\",\n    \"                               optim_builder=create_optims_default_pytorch,\\n\",\n    \"                               gpu_ids=[0], mixed_precision=True)\\n\",\n    \"experiment.save()\\n\",\n    \"\\n\",\n    \"model = experiment.run(manager_train, manager_val)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": []\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## See Also\\n\",\n    \"For a more detailed explanation have a look at \\n\",\n    \"* [the introduction tutorial](tutorial_delira.ipynb, \\\"Introduction\\\")\\n\",\n    \"* [the classification example](classification_pytorch.ipynb, \\\"Classification\\\")\\n\",\n    \"* [the 3d segmentation example](segmentation_3d_pytorch.ipynb, \\\"Segmentation 3D\\\")\\n\",\n    \"* [the generative adversarial example](gan_pytorch.ipynb, \\\"GAN\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "notebooks/segmentation_3d_pytorch.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"# Segmentation in 3D using U-Nets with Delira - A very short introduction\\n\",\n    \"\\n\",\n    \"*Author: Justus Schock, Alexander Moriz* \\n\",\n    \"\\n\",\n    \"*Date: 17.12.2018*\\n\",\n    \" \\n\",\n    \"This Example shows how use the U-Net implementation in Delira with PyTorch.\\n\",\n    \"\\n\",\n    \"Let's first setup the essential hyperparameters. We will use `delira`'s `Parameters`-class for this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"logger = None\\n\",\n    \"import torch\\n\",\n    \"from delira.training import Parameters\\n\",\n    \"params = Parameters(fixed_params={\\n\",\n    \"    \\\"model\\\": {\\n\",\n    \"        \\\"in_channels\\\": 1, \\n\",\n    \"        \\\"num_classes\\\": 4\\n\",\n    \"    },\\n\",\n    \"    \\\"training\\\": {\\n\",\n    \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n    \"        \\\"num_epochs\\\": 10, # number of epochs to train\\n\",\n    \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n    \"        \\\"optimizer_params\\\": {'lr': 1e-3}, # initialization parameters for this algorithm\\n\",\n    \"        \\\"losses\\\": {\\\"CE\\\": torch.nn.CrossEntropyLoss()}, # the loss function\\n\",\n    \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n    \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n    \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n    \"    }\\n\",\n    \"}) \"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since we did not specify any metric, only the `CrossEntropyLoss` will be calculated for each batch. Since we have a classification task, this should be sufficient. We will train our network with a batchsize of 64 by using `Adam` as optimizer of choice.\\n\",\n    \"\\n\",\n    \"## Logging and Visualization\\n\",\n    \"To get a visualization of our results, we should monitor them somehow. For logging we will use `Visdom`. To start a visdom server you need to execute the following command inside an environment which has visdom installed: \\n\",\n    \"```shell\\n\",\n    \"visdom -port=9999\\n\",\n    \"```\\n\",\n    \"This will start a visdom server on port 9999 of your machine and now we can start to configure our logging environment. To view your results you can open [http://localhost:9999](http://localhost:9999) in your browser.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from trixi.logger import PytorchVisdomLogger\\n\",\n    \"from delira.logging import TrixiHandler\\n\",\n    \"import logging\\n\",\n    \"\\n\",\n    \"logger_kwargs = {\\n\",\n    \"    'name': 'ClassificationExampleLogger', # name of our logging environment\\n\",\n    \"    'port': 9999 # port on which our visdom server is alive\\n\",\n    \"}\\n\",\n    \"\\n\",\n    \"logger_cls = PytorchVisdomLogger\\n\",\n    \"\\n\",\n    \"# configure logging module (and root logger)\\n\",\n    \"logging.basicConfig(level=logging.INFO,\\n\",\n    \"                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"# derive logger from root logger\\n\",\n    \"# (don't do `logger = logging.Logger(\\\"...\\\")` since this will create a new\\n\",\n    \"# logger which is unrelated to the root logger\\n\",\n    \"logger = logging.getLogger(\\\"Test Logger\\\")\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since a single visdom server can run multiple environments, we need to specify a (unique) name for our environment and need to tell the logger, on which port it can find the visdom server.\\n\",\n    \"\\n\",\n    \"## Data Praparation\\n\",\n    \"### Loading\\n\",\n    \"Next we will create a small train and validation set (in this case they will be the same to show the overfitting capability of the UNet).\\n\",\n    \"\\n\",\n    \"Our data is a brain MR-image thankfully provided by the [FSL](https://fsl.fmrib.ox.ac.uk/fsl/fslwiki) in their [introduction](http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/IntroBox3.html).\\n\",\n    \"\\n\",\n    \"We first download the data and extract the T1 image and the corresponding segmentation:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from io import BytesIO\\n\",\n    \"from zipfile import ZipFile\\n\",\n    \"from urllib.request import urlopen\\n\",\n    \"\\n\",\n    \"resp = urlopen(\\\"http://www.fmrib.ox.ac.uk/primers/intro_primer/ExBox3/ExBox3.zip\\\")\\n\",\n    \"zipfile = ZipFile(BytesIO(resp.read()))\\n\",\n    \"#zipfile_list = zipfile.namelist()\\n\",\n    \"#print(zipfile_list)\\n\",\n    \"img_file = zipfile.extract(\\\"ExBox3/T1_brain.nii.gz\\\")\\n\",\n    \"mask_file = zipfile.extract(\\\"ExBox3/T1_brain_seg.nii.gz\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Now, we load the image and the mask (they are both 3D), convert them to a 32-bit floating point numpy array and ensure, they have the same shape (i.e. that for each voxel in the image, there is a voxel in the mask):\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import SimpleITK as sitk\\n\",\n    \"import numpy as np\\n\",\n    \"\\n\",\n    \"# load image and mask\\n\",\n    \"img = sitk.GetArrayFromImage(sitk.ReadImage(img_file))\\n\",\n    \"img = img.astype(np.float32)\\n\",\n    \"mask = mask = sitk.GetArrayFromImage(sitk.ReadImage(mask_file))\\n\",\n    \"mask = mask.astype(np.float32)\\n\",\n    \"\\n\",\n    \"assert mask.shape == img.shape\\n\",\n    \"print(img.shape)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"By querying the unique values in the mask, we get the following:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"np.unique(mask)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"This means, there are 4 classes (background and 3 types of tissue) in our sample.\\n\",\n    \"\\n\",\n    \"To load the data, we have to use a `Dataset`. The following defines a very simple dataset, accepting an image slice, a mask slice and the number of samples. It always returns the same sample until `num_samples` samples have been returned.\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import AbstractDataset\\n\",\n    \"\\n\",\n    \"class CustomDataset(AbstractDataset):\\n\",\n    \"    def __init__(self, img, mask, num_samples=1000):\\n\",\n    \"        super().__init__(None, None, None, None)\\n\",\n    \"        self.data = {\\\"data\\\": img.reshape(1, *img.shape), \\\"label\\\": mask.reshape(1, *mask.shape)}\\n\",\n    \"        self.num_samples = num_samples\\n\",\n    \"        \\n\",\n    \"    def __getitem__(self, index):\\n\",\n    \"        return self.data\\n\",\n    \"    \\n\",\n    \"    def __len__(self):\\n\",\n    \"        return self.num_samples\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Now, we can finally instantiate our datasets:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"dataset_train = CustomDataset(img, mask, num_samples=10000)\\n\",\n    \"dataset_val = CustomDataset(img, mask, num_samples=1)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Augmentation\\n\",\n    \"For Data-Augmentation we will apply a few transformations:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from batchgenerators.transforms import ContrastAugmentationTransform, Compose\\n\",\n    \"from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\",\n    \"\\n\",\n    \"transforms = Compose([\\n\",\n    \"    ContrastAugmentationTransform(), # randomly adjust contrast\\n\",\n    \"    MeanStdNormalizationTransform(mean=[img.mean()], std=[img.std()])]) # use concrete values since we only have one sample (have to estimate it over whole dataset otherwise)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"With these transformations we can now wrap our datasets into datamanagers:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import DataManager, SequentialSampler, RandomSampler\\n\",\n    \"\\n\",\n    \"manager_train = DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                                transforms=transforms,\\n\",\n    \"                                sampler_cls=RandomSampler,\\n\",\n    \"                                n_process_augmentation=4)\\n\",\n    \"\\n\",\n    \"manager_val = DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"),\\n\",\n    \"                              transforms=transforms,\\n\",\n    \"                              sampler_cls=SequentialSampler,\\n\",\n    \"                              n_process_augmentation=4)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Training\\n\",\n    \"\\n\",\n    \"After we have done that, we can finally specify our experiment and run it. We will therfore use the already implemented `UNet3dPytorch`:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import warnings\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", UserWarning) # ignore UserWarnings raised by dependency code\\n\",\n    \"warnings.simplefilter(\\\"ignore\\\", FutureWarning) # ignore FutureWarnings raised by dependency code\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"from delira.training import PyTorchExperiment\\n\",\n    \"from delira.training.train_utils import create_optims_default_pytorch\\n\",\n    \"from delira.models.segmentation import UNet3dPyTorch\\n\",\n    \"\\n\",\n    \"if logger:\\n\",\n    \"    logger.info(\\\"Init Experiment\\\")\\n\",\n    \"experiment = PyTorchExperiment(params, UNet3dPyTorch,\\n\",\n    \"                               name=\\\"Segmentation3dExample\\\",\\n\",\n    \"                               save_path=\\\"./tmp/delira_Experiments\\\",\\n\",\n    \"                               optim_builder=create_optims_default_pytorch,\\n\",\n    \"                               gpu_ids=[0], mixed_precision=True)\\n\",\n    \"experiment.save()\\n\",\n    \"\\n\",\n    \"model = experiment.run(manager_train, manager_val)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## See Also\\n\",\n    \"For a more detailed explanation have a look at \\n\",\n    \"* [the introduction tutorial](tutorial_delira.ipynb, \\\"Introduction\\\")\\n\",\n    \"* [the classification example](classification_pytorch.ipynb, \\\"Classification\\\")\\n\",\n    \"* [the 2d segmentation example](segmentation_2d_pytorch.ipynb, \\\"Segmentation 2D\\\")\\n\",\n    \"* [the generative adversarial example](gan_pytorch.ipynb, \\\"GAN\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "notebooks/tutorial_delira.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"# Delira Introduction\\n\",\n    \"\\n\",\n    \"*Last updated: 09.05.2019*\\n\",\n    \"\\n\",\n    \"Authors: Justus Schock, Christoph Haarburger\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Loading Data\\n\",\n    \"\\n\",\n    \"To train your network you first need to load your training data (and probably also your validation data). This chapter will therefore deal with `delira`'s capabilities to load your data (and apply some augmentation). \\n\",\n    \"\\n\",\n    \"### The Dataset\\n\",\n    \"There are mainly two ways to load your data: Lazy or non-lazy. Loading in a lazy way means that you load the data just in time and keep the used memory to a bare minimum. This has, however, the disadvantage that your loading function could be a bottleneck since all postponed operations may have to wait until the needed data samples are loaded. In a no-lazy way, one would preload all data to RAM before starting any other operations. This has the advantage that there cannot be a loading bottleneck during latter operations. This advantage comes at cost of a higher memory usage and a (possibly) huge latency at the beginning of each experiment. Both ways to load your data are implemented in `delira` and they are named `BaseLazyDataset`and `BaseCacheDataset`. In the following steps you will only see the `BaseLazyDataset` since exchanging them is trivial. All Datasets (including the ones you might want to create yourself later) must be derived of `delira.data_loading.AbstractDataset` to ensure a minimum common API.\\n\",\n    \"\\n\",\n    \"The dataset's `__init__` has the following signature:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"def __init__(self, data_path, load_fn, **load_kwargs):\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"This means, you have to pass the path to the directory containing your data (`data_path`), a function to load a single sample of your data (`load_fn`). To get a single sample of your dataset after creating it, you can index it like this: `dataset[0]`.\\n\",\n    \"Additionally you can iterate over your dataset just like over any other `python` iterator via\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"for sample in dataset:\\n\",\n    \"    # do your stuff here\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"or enumerate it via\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"for idx, sample in enumerate(dataset):\\n\",\n    \"    # do your stuff here\\n\",\n    \"```\\n\",\n    \".\\n\",\n    \"\\n\",\n    \"The missing argument `**load_kwargs` accepts an arbitrary amount of additional keyword arguments which are directly passed to your loading function.\\n\",\n    \"\\n\",\n    \"An example of how loading your data may look like is given below:\\n\",\n    \"```python\\n\",\n    \"from delira.data_loading import BaseLazyDataset, default_load_fn_2d\\n\",\n    \"dataset_train = BaseLazyDataset(\\\"/images/datasets/external/mnist/train\\\",\\n\",\n    \"                                default_load_fn_2d, img_shape=(224, 224))\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"In this case all data lying in `/images/datasets/external/mnist/train` is loaded by `default_load_fn_2d`. The files containing the data must be PNG-files, while the groundtruth is defined in TXT-files. The `default_load_fn_2d` needs the additional argument `img_shape` which is passed as keyword argument via `**load_kwargs`.\\n\",\n    \"\\n\",\n    \"> **Note:** for reproducability we decided to use some wrapped PyTorch datasets for this introduction. \\n\",\n    \"\\n\",\n    \"Now, let's just initialize our trainset:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import TorchvisionClassificationDataset\\n\",\n    \"dataset_train = TorchvisionClassificationDataset(\\\"mnist\\\", train=True,\\n\",\n    \"                                                 img_shape=(224, 224))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Getting a single sample of your dataset with dataset_train[0] will produce:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"dataset_train[0]\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"which means, that our data is stored in a dictionary containing the keys `data` and `label`, each of them holding the corresponding numpy arrays. The dataloading works on `numpy` purely and is thus backend agnostic. It does not matter in which format or with which library you load/preprocess your data, but at the end it must be converted to numpy arrays\\n\",\n    \"For validation purposes another dataset could be created with the test data like this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"dataset_val = TorchvisionClassificationDataset(\\\"mnist\\\", train=False,\\n\",\n    \"                                               img_shape=(224, 224))\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### The Dataloader\\n\",\n    \"The Dataloader wraps your dataset to privode the ability to load whole batches with an abstract interface. To create a dataloader, one would have to pass the following arguments to it's `__init__`: the previously created `dataset`.Additionally, it is possible to pass the `batch_size` defining the number of samples per batch, the total number of batches (`num_batches`), which will be the number of samples in your dataset devided by the batchsize per default, a random `seed`for always getting the same behaviour of random number generators and a [`sampler`]() defining your sampling strategy. This would create a dataloader for your `dataset_train`:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import DataLoader\\n\",\n    \"\\n\",\n    \"batch_size = 32\\n\",\n    \"\\n\",\n    \"loader_train = DataLoader(dataset_train, batch_size)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Since the batch_size has been set to 32, the loader will load 32 samples as one batch.\\n\",\n    \"\\n\",\n    \"Even though it would be possible to train your network with an instance of `DataLoader`, `malira` also offers a different approach that covers multithreaded data loading and augmentation:\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### The Datamanager\\n\",\n    \"\\n\",\n    \"The data manager is implemented as `delira.data_loading.DataManager` and wraps a `DataLoader`. It also encapsulates augmentations. Having a view on the `DataManager`'s signature, it becomes obvious that it accepts the same arguments as the [`DataLoader`](#The-Dataloader). You can either pass a `dataset` or a combination of path, dataset class and load function. Additionally, you can pass a custom dataloder class if necessary and a sampler class to choose a sampling algorithm. \\n\",\n    \"\\n\",\n    \"The parameter `transforms` accepts augmentation transformations as implemented in `batchgenerators`. Augmentation is applied on the fly using `n_process_augmentation` threads.\\n\",\n    \"\\n\",\n    \"All in all the DataManager is the recommended way to generate batches from your dataset.\\n\",\n    \"\\n\",\n    \"The following example shows how to create a data manager instance:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.data_loading import DataManager\\n\",\n    \"from batchgenerators.transforms.abstract_transforms import Compose\\n\",\n    \"from batchgenerators.transforms.sample_normalization_transforms import MeanStdNormalizationTransform\\n\",\n    \"\\n\",\n    \"batchsize = 64\\n\",\n    \"transforms = Compose([MeanStdNormalizationTransform(mean=1*[0], std=1*[1])])\\n\",\n    \"\\n\",\n    \"data_manager_train = DataManager(dataset_train,  # dataset to use\\n\",\n    \"                                    batchsize,  # batchsize\\n\",\n    \"                                    n_process_augmentation=1,  # number of augmentation processes\\n\",\n    \"                                    transforms=transforms)  # augmentation transforms\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"The approach to initialize a DataManager from a datapath takes more arguments since, in opposite to initializaton from dataset, it needs all the arguments which are necessary to internally create a dataset.\\n\",\n    \"\\n\",\n    \"Since we want to validate our model we have to create a second manager containing our `dataset_val`:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"data_manager_val = DataManager(dataset_val, \\n\",\n    \"                                    batchsize, \\n\",\n    \"                                    n_process_augmentation=1, \\n\",\n    \"                                    transforms=transforms)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"\\n\",\n    \"That's it - we just finished loading our data!\\n\",\n    \"\\n\",\n    \"Iterating over a DataManager is possible in simple loops:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from tqdm.auto import tqdm # utility for progress bars\\n\",\n    \"\\n\",\n    \"# create actual batch generator from DataManager\\n\",\n    \"batchgen = data_manager_val.get_batchgen()\\n\",\n    \"\\n\",\n    \"for data in tqdm(batchgen):\\n\",\n    \"    pass # here you can access the data of the current batch\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Sampler\\n\",\n    \"In previous section samplers have been already mentioned but not yet explained. A sampler implements an algorithm how a batch should be assembled from single samples in a dataset. `delira` provides the following sampler classes in it's subpackage `delira.data_loading.sampler`:\\n\",\n    \"\\n\",\n    \"* `AbstractSampler`\\n\",\n    \"* `SequentialSampler`\\n\",\n    \"* `PrevalenceSequentialSampler`\\n\",\n    \"* `RandomSampler`\\n\",\n    \"* `PrevalenceRandomSampler`\\n\",\n    \"* `WeightedRandomSampler`\\n\",\n    \"* `LambdaSampler`\\n\",\n    \"\\n\",\n    \"The `AbstractSampler` implements no sampling algorithm but defines a sampling API and thus all custom samplers must inherit from this class. The `Sequential` sampler builds batches by just iterating over the samples' indices in a sequential way. Following this, the `RandomSampler` builds batches by randomly drawing the samples' indices with replacement. \\n\",\n    \"If the class each sample belongs to is known for each sample at the beginning, the `PrevalenceSequentialSampler` and the `PrevalenceRandomSampler` perform a per-class sequential or random sampling and building each batch with the exactly same number of samples from each class. \\n\",\n    \"The `WeightedRandomSampler`accepts custom weights to give specific samples a higher probability during random sampling than others.\\n\",\n    \"\\n\",\n    \"The `LambdaSampler` is a wrapper for a custom sampling function, which can be passed to the wrapper during it's initialization, to ensure API conformity.\\n\",\n    \"\\n\",\n    \"It can be passed to the DataLoader or DataManager as class (argument `sampler_cls`) or as instance (argument `sampler`).\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Models\\n\",\n    \"\\n\",\n    \"Since the purpose of this framework is to use machine learning algorithms, there has to be a way to define them. Defining models is straight forward. `delira` provides a class `delira.models.AbstractNetwork`. *All models must inherit from this class*.\\n\",\n    \"\\n\",\n    \"To inherit this class four functions must be implemented in the subclass:\\n\",\n    \"\\n\",\n    \"* `__init__`\\n\",\n    \"* `closure`\\n\",\n    \"* `prepare_batch`\\n\",\n    \"* `__call__`\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### `__init__`\\n\",\n    \"The `__init__`function is a classes constructor. In our case it builds the entire model (maybe using some helper functions). If writing your own custom model, you have to override this method.\\n\",\n    \"\\n\",\n    \"> **Note:** If you want the best experience for saving your model and completely recreating it during the loading process you need to take care of a few things:\\n\",\n    \"> * if using `torchvision.models` to build your model, always import it with `from torchvision import models as t_models`\\n\",\n    \"> * register all arguments in your custom `__init__` in the abstract class. A init_prototype could look like this:\\n\",\n    \">\\n\",\n    \"```python\\n\",\n    \"def __init__(self, in_channels: int, n_outputs: int, **kwargs):\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"    Parameters\\n\",\n    \"    ----------\\n\",\n    \"    in_channels: int\\n\",\n    \"        number of input_channels\\n\",\n    \"    n_outputs: int\\n\",\n    \"        number of outputs (usually same as number of classes)\\n\",\n    \"    \\\"\\\"\\\"\\n\",\n    \"    # register params by passing them as kwargs to parent class __init__\\n\",\n    \"    # only params registered like this will be saved!\\n\",\n    \"    super().__init__(in_channels=in_channels,\\n\",\n    \"                     n_outputs=n_outputs,\\n\",\n    \"                     **kwargs)\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"### `closure`\\n\",\n    \"The `closure`function defines one batch iteration to train the network. This function is needed for the framework to provide a generic trainer function which works with all kind of networks and loss functions.\\n\",\n    \"\\n\",\n    \"The closure function must implement all steps from forwarding, over loss calculation, metric calculation, logging (for which `delira.logging_handlers` provides some extensions for pythons logging module), and the actual backpropagation.\\n\",\n    \"\\n\",\n    \"It is called with an empty optimizer-dict to evaluate and should thus work with optional optimizers.\\n\",\n    \"\\n\",\n    \"### `prepare_batch`\\n\",\n    \"The `prepare_batch`function defines the transformation from loaded data to match the networks input and output shape and pushes everything to the right device.\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"## Abstract Networks for specific Backends\\n\",\n    \"### PyTorch\\n\",\n    \"At the time of writing, PyTorch is the only backend which is supported, but other backends are planned.\\n\",\n    \"In PyTorch every network should be implemented as a subclass of `torch.nn.Module`, which also provides a `__call__` method.\\n\",\n    \"\\n\",\n    \"This results in sloghtly different requirements for PyTorch networks: instead of implementing a `__call__` method, we simply call the `torch.nn.Module.__call__` and therefore have to implement the `forward` method, which defines the module's behaviour and is internally called by `torch.nn.Module.__call__` (among other stuff). To give a default behaviour suiting most cases and not have to care about internals, `delira` provides the `AbstractPyTorchNetwork` which is a more specific case of the `AbstractNetwork` for PyTorch modules.\\n\",\n    \"\\n\",\n    \"#### `forward`\\n\",\n    \"The `forward` function defines what has to be done to forward your input through your network and must return a dictionary. Assuming your network has three convolutional layers stored in `self.conv1`, `self.conv2` and `self.conv3` and a ReLU stored in `self.relu`, a simple `forward` function could look like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"def forward(self, input_batch: torch.Tensor):\\n\",\n    \"    out_1 = self.relu(self.conv1(input_batch))\\n\",\n    \"    out_2 = self.relu(self.conv2(out_1))\\n\",\n    \"    out_3 = self.conv3(out2)\\n\",\n    \"    \\n\",\n    \"    return {\\\"pred\\\": out_3}\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"#### `prepare_batch`\\n\",\n    \"The default `prepare_batch` function for PyTorch networks looks like this:\\n\",\n    \"\\n\",\n    \"```python\\n\",\n    \"    @staticmethod\\n\",\n    \"    def prepare_batch(batch: dict, input_device, output_device):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        Helper Function to prepare Network Inputs and Labels (convert them to\\n\",\n    \"        correct type and shape and push them to correct devices)\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        batch : dict\\n\",\n    \"            dictionary containing all the data\\n\",\n    \"        input_device : torch.device\\n\",\n    \"            device for network inputs\\n\",\n    \"        output_device : torch.device\\n\",\n    \"            device for network outputs\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        dict\\n\",\n    \"            dictionary containing data in correct type and shape and on correct\\n\",\n    \"            device\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        return_dict = {\\\"data\\\": torch.from_numpy(batch.pop(\\\"data\\\")).to(\\n\",\n    \"            input_device)}\\n\",\n    \"\\n\",\n    \"        for key, vals in batch.items():\\n\",\n    \"            return_dict[key] = torch.from_numpy(vals).to(output_device)\\n\",\n    \"\\n\",\n    \"        return return_dict\\n\",\n    \"\\n\",\n    \"```\\n\",\n    \"and can be customized by subclassing the `AbstractPyTorchNetwork`.\\n\",\n    \"\\n\",\n    \"#### `closure example`\\n\",\n    \"A simple closure function for a PyTorch module could look like this:\\n\",\n    \"```python\\n\",\n    \"    @staticmethod\\n\",\n    \"    def closure(model: AbstractPyTorchNetwork, data_dict: dict,\\n\",\n    \"                optimizers: dict, criterions={}, metrics={},\\n\",\n    \"                fold=0, **kwargs):\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"        closure method to do a single backpropagation step\\n\",\n    \"\\n\",\n    \"        Parameters\\n\",\n    \"        ----------\\n\",\n    \"        model : :class:`ClassificationNetworkBasePyTorch`\\n\",\n    \"            trainable model\\n\",\n    \"        data_dict : dict\\n\",\n    \"            dictionary containing the data\\n\",\n    \"        optimizers : dict\\n\",\n    \"            dictionary of optimizers to optimize model's parameters\\n\",\n    \"        criterions : dict\\n\",\n    \"            dict holding the criterions to calculate errors\\n\",\n    \"            (gradients from different criterions will be accumulated)\\n\",\n    \"        metrics : dict\\n\",\n    \"            dict holding the metrics to calculate\\n\",\n    \"        fold : int\\n\",\n    \"            Current Fold in Crossvalidation (default: 0)\\n\",\n    \"        **kwargs:\\n\",\n    \"            additional keyword arguments\\n\",\n    \"\\n\",\n    \"        Returns\\n\",\n    \"        -------\\n\",\n    \"        dict\\n\",\n    \"            Metric values (with same keys as input dict metrics)\\n\",\n    \"        dict\\n\",\n    \"            Loss values (with same keys as input dict criterions)\\n\",\n    \"        list\\n\",\n    \"            Arbitrary number of predictions as torch.Tensor\\n\",\n    \"\\n\",\n    \"        Raises\\n\",\n    \"        ------\\n\",\n    \"        AssertionError\\n\",\n    \"            if optimizers or criterions are empty or the optimizers are not\\n\",\n    \"            specified\\n\",\n    \"\\n\",\n    \"        \\\"\\\"\\\"\\n\",\n    \"\\n\",\n    \"        assert (optimizers and criterions) or not optimizers, \\\\\\n\",\n    \"            \\\"Criterion dict cannot be emtpy, if optimizers are passed\\\"\\n\",\n    \"\\n\",\n    \"        loss_vals = {}\\n\",\n    \"        metric_vals = {}\\n\",\n    \"        total_loss = 0\\n\",\n    \"\\n\",\n    \"        # choose suitable context manager:\\n\",\n    \"        if optimizers:\\n\",\n    \"            context_man = torch.enable_grad\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"            context_man = torch.no_grad\\n\",\n    \"\\n\",\n    \"        with context_man():\\n\",\n    \"\\n\",\n    \"            inputs = data_dict.pop(\\\"data\\\")\\n\",\n    \"            # obtain outputs from network\\n\",\n    \"            preds = model(inputs)[\\\"pred\\\"]\\n\",\n    \"\\n\",\n    \"            if data_dict:\\n\",\n    \"\\n\",\n    \"                for key, crit_fn in criterions.items():\\n\",\n    \"                    _loss_val = crit_fn(preds, *data_dict.values())\\n\",\n    \"                    loss_vals[key] = _loss_val.detach()\\n\",\n    \"                    total_loss += _loss_val\\n\",\n    \"\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    for key, metric_fn in metrics.items():\\n\",\n    \"                        metric_vals[key] = metric_fn(\\n\",\n    \"                            preds, *data_dict.values())\\n\",\n    \"\\n\",\n    \"        if optimizers:\\n\",\n    \"            optimizers['default'].zero_grad()\\n\",\n    \"            total_loss.backward()\\n\",\n    \"            optimizers['default'].step()\\n\",\n    \"\\n\",\n    \"        else:\\n\",\n    \"\\n\",\n    \"            # add prefix \\\"val\\\" in validation mode\\n\",\n    \"            eval_loss_vals, eval_metrics_vals = {}, {}\\n\",\n    \"            for key in loss_vals.keys():\\n\",\n    \"                eval_loss_vals[\\\"val_\\\" + str(key)] = loss_vals[key]\\n\",\n    \"\\n\",\n    \"            for key in metric_vals:\\n\",\n    \"                eval_metrics_vals[\\\"val_\\\" + str(key)] = metric_vals[key]\\n\",\n    \"\\n\",\n    \"            loss_vals = eval_loss_vals\\n\",\n    \"            metric_vals = eval_metrics_vals\\n\",\n    \"\\n\",\n    \"        for key, val in {**metric_vals, **loss_vals}.items():\\n\",\n    \"            logging.info({\\\"value\\\": {\\\"value\\\": val.item(), \\\"name\\\": key,\\n\",\n    \"                                    \\\"env_appendix\\\": \\\"_%02d\\\" % fold\\n\",\n    \"                                    }})\\n\",\n    \"\\n\",\n    \"        logging.info({'image_grid': {\\\"images\\\": inputs, \\\"name\\\": \\\"input_images\\\",\\n\",\n    \"                                     \\\"env_appendix\\\": \\\"_%02d\\\" % fold}})\\n\",\n    \"\\n\",\n    \"        return metric_vals, loss_vals, preds\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"> **Note:** This closure is taken from the `delira.models.classification.ClassificationNetworkBasePyTorch`\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Other examples\\n\",\n    \"In `delira.models` you can find exemplaric implementations of generative adversarial networks, classification and regression approaches or segmentation networks.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Training\\n\",\n    \"\\n\",\n    \"### Parameters\\n\",\n    \"Training-parameters (often called hyperparameters) can be defined in the `delira.training.Parameters` class. \\n\",\n    \"\\n\",\n    \"The class accepts the parameters `batch_size` and `num_epochs` to define the batchsize and the number of epochs to train, the parameters `optimizer_cls` and `optimizer_params` to create an optimizer or training, the parameter `criterions` to specify the training criterions (whose gradients will be accumulated by defaut), the parameters `lr_sched_cls` and `lr_sched_params` to define the learning rate scheduling and the parameter `metrics` to specify evaluation metrics.\\n\",\n    \"\\n\",\n    \"Additionally, it is possible to pass an aritrary number of keyword arguments to the class\\n\",\n    \"\\n\",\n    \"It is good practice to create a `Parameters` object at the beginning and then use it for creating other objects which are needed for training, since you can use the classes attributes and changes in hyperparameters only have to be done once:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"from delira.training import Parameters\\n\",\n    \"from delira.data_loading import RandomSampler, SequentialSampler\\n\",\n    \"\\n\",\n    \"params = Parameters(fixed_params={\\n\",\n    \"    \\\"model\\\": {},\\n\",\n    \"    \\\"training\\\": {\\n\",\n    \"        \\\"batch_size\\\": 64, # batchsize to use\\n\",\n    \"        \\\"num_epochs\\\": 2, # number of epochs to train\\n\",\n    \"        \\\"optimizer_cls\\\": torch.optim.Adam, # optimization algorithm to use\\n\",\n    \"        \\\"optimizer_params\\\": {'lr': 1e-3}, # initialization parameters for this algorithm\\n\",\n    \"        \\\"criterions\\\": {\\\"CE\\\": torch.nn.CrossEntropyLoss()}, # the loss function\\n\",\n    \"        \\\"lr_sched_cls\\\": None,  # the learning rate scheduling algorithm to use\\n\",\n    \"        \\\"lr_sched_params\\\": {}, # the corresponding initialization parameters\\n\",\n    \"        \\\"metrics\\\": {} # and some evaluation metrics\\n\",\n    \"    }\\n\",\n    \"}) \\n\",\n    \"\\n\",\n    \"# recreating the data managers with the batchsize of the params object\\n\",\n    \"manager_train = DataManager(dataset_train, params.nested_get(\\\"batch_size\\\"), 1,\\n\",\n    \"                                transforms=None, sampler_cls=RandomSampler,\\n\",\n    \"                                n_process_loading=4)\\n\",\n    \"manager_val = DataManager(dataset_val, params.nested_get(\\\"batch_size\\\"), 3,\\n\",\n    \"                              transforms=None, sampler_cls=SequentialSampler,\\n\",\n    \"                              n_process_loading=4)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Trainer\\n\",\n    \"\\n\",\n    \"The `delira.training.NetworkTrainer` class provides functions to train a single network by passing attributes from your parameter object, a `save_freq` to specify how often your model should be saved (`save_freq=1` indicates every epoch, `save_freq=2` every second epoch etc.) and `gpu_ids`. If you don't pass any ids at all, your network will be trained on CPU (and probably take a lot of time). If you specify 1 id, the network will be trained on the GPU with the corresponding index and if you pass multiple `gpu_ids` your network will be trained on multiple GPUs in parallel.\\n\",\n    \"\\n\",\n    \"> **Note:** The GPU indices are refering to the devices listed in `CUDA_VISIBLE_DEVICES`. E.g if `CUDA_VISIBLE_DEVICES` lists GPUs 3, 4, 5 then gpu_id 0 will be the index for GPU 3 etc.\\n\",\n    \"\\n\",\n    \"> **Note:** training on multiple GPUs is not recommended for easy and small networks, since for these networks the synchronization overhead is far greater than the parallelization benefit.\\n\",\n    \"\\n\",\n    \"Training your network might look like this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.training import PyTorchNetworkTrainer\\n\",\n    \"from delira.models.classification import ClassificationNetworkBasePyTorch\\n\",\n    \"\\n\",\n    \"# path where checkpoints should be saved\\n\",\n    \"save_path = \\\"./results/checkpoints\\\"\\n\",\n    \"\\n\",\n    \"model = ClassificationNetworkBasePyTorch(in_channels=1, n_outputs=10)\\n\",\n    \"\\n\",\n    \"trainer = PyTorchNetworkTrainer(network=model,\\n\",\n    \"                                save_path=save_path,\\n\",\n    \"                                criterions=params.nested_get(\\\"criterions\\\"),\\n\",\n    \"                                optimizer_cls=params.nested_get(\\\"optimizer_cls\\\"),\\n\",\n    \"                                optimizer_params=params.nested_get(\\\"optimizer_params\\\"),\\n\",\n    \"                                metrics=params.nested_get(\\\"metrics\\\"),\\n\",\n    \"                                lr_scheduler_cls=params.nested_get(\\\"lr_sched_cls\\\"),\\n\",\n    \"                                lr_scheduler_params=params.nested_get(\\\"lr_sched_params\\\"),\\n\",\n    \"                                gpu_ids=[0]\\n\",\n    \"                        )\\n\",\n    \"\\n\",\n    \"#trainer.train(params.nested_get(\\\"num_epochs\\\"), manager_train, manager_val)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"### Experiment\\n\",\n    \"The `delira.training.AbstractExperiment` class needs an experiment name, a path to save it's results to, a parameter object, a model class and the keyword arguments to create an instance of this class. It provides methods to perform a single training and also a method for running a kfold-cross validation. In order to create it, you must choose the `PyTorchExperiment`, which is basically just a subclass of the `AbstractExperiment` to provide a general setup for PyTorch modules. Running an experiment could look like this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"from delira.training import PyTorchExperiment\\n\",\n    \"from delira.training.train_utils import create_optims_default_pytorch\\n\",\n    \"\\n\",\n    \"# Add model parameters to Parameter class\\n\",\n    \"params.fixed.model = {\\\"in_channels\\\": 1, \\\"n_outputs\\\": 10}\\n\",\n    \"\\n\",\n    \"experiment = PyTorchExperiment(params=params, \\n\",\n    \"                               model_cls=ClassificationNetworkBasePyTorch,\\n\",\n    \"                               name=\\\"TestExperiment\\\", \\n\",\n    \"                               save_path=\\\"./results\\\",\\n\",\n    \"                               optim_builder=create_optims_default_pytorch,\\n\",\n    \"                               gpu_ids=[0])\\n\",\n    \"\\n\",\n    \"experiment.run(manager_train, manager_val)\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"An `Experiment` is the most abstract (and recommended) way to define, train and validate your network.\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## Logging\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"Previous class and function definitions used pythons's `logging` library. As extensions for this library `delira` provides a package (`delira.logging`) containing handlers to realize different logging methods. \\n\",\n    \"\\n\",\n    \"To use these handlers simply add them to your logger like this:\\n\",\n    \"```python\\n\",\n    \"logger.addHandler(logging.StreamHandler())\\n\",\n    \"```\\n\",\n    \"\\n\",\n    \"Nowadays, delira mainly relies on [trixi](https://github.com/MIC-DKFZ/trixi/) for logging and provides only a `MultiStreamHandler` and a `TrixiHandler`, which is a binding to `trixi`'s loggers and integrates them into the python `logging` module\\n\",\n    \"\\n\",\n    \"### `MultiStreamHandler`\\n\",\n    \"The `MultiStreamHandler` accepts an arbitrary number of streams during initialization and writes the message to all of it's streams during logging.\\n\",\n    \"\\n\",\n    \"### Logging with `Visdom` - The `trixi` Loggers\\n\",\n    \"[`Visdom`](https://github.com/facebookresearch/visdom) is a tool designed to visualize your logs. To use this tool you need to open a port on the machine you want to train on via `visdom -port YOUR_PORTNUMBER` Afterwards just add the handler of your choice to the logger. For more detailed information and customization have a look at [this](https://github.com/facebookresearch/visdom) website.\\n\",\n    \"\\n\",\n    \"Logging the scalar tensors containing `1`, `2`, `3`, `4` (at the beginning; will increase to show epochwise logging) with the corresponding keys `\\\"one\\\"`, `\\\"two\\\"`, `\\\"three\\\"`, `\\\"four\\\"` and two random images with the keys `\\\"prediction\\\"` and `\\\"groundtruth\\\"` would look like this:\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"outputs\": [],\n   \"source\": [\n    \"NUM_ITERS = 4\\n\",\n    \"\\n\",\n    \"# import logging handler and logging module\\n\",\n    \"from delira.logging import TrixiHandler\\n\",\n    \"from trixi.logger import PytorchVisdomLogger\\n\",\n    \"import logging\\n\",\n    \"\\n\",\n    \"# configure logging module (and root logger)\\n\",\n    \"logger_kwargs = {\\n\",\n    \"    'name': 'test_env', # name of loggin environment\\n\",\n    \"    'port': 9999 # visdom port to connect to\\n\",\n    \"}\\n\",\n    \"logger_cls = PytorchVisdomLogger\\n\",\n    \"\\n\",\n    \"# configure logging module (and root logger)\\n\",\n    \"logging.basicConfig(level=logging.INFO,\\n\",\n    \"                    handlers=[TrixiHandler(logger_cls, **logger_kwargs)])\\n\",\n    \"# derive logger from root logger\\n\",\n    \"# (don't do `logger = logging.Logger(\\\"...\\\")` since this will create a new\\n\",\n    \"# logger which is unrelated to the root logger\\n\",\n    \"logger = logging.getLogger(\\\"Test Logger\\\")\\n\",\n    \"\\n\",\n    \"# create dict containing the scalar numbers as torch.Tensor\\n\",\n    \"scalars = {\\\"one\\\": torch.Tensor([1]),\\n\",\n    \"           \\\"two\\\": torch.Tensor([2]),\\n\",\n    \"           \\\"three\\\": torch.Tensor([3]),\\n\",\n    \"           \\\"four\\\": torch.Tensor([4])}\\n\",\n    \"\\n\",\n    \"# create dict containing the images as torch.Tensor\\n\",\n    \"# pytorch awaits tensor dimensionality of \\n\",\n    \"# batchsize x image channels x height x width\\n\",\n    \"images = {\\\"prediction\\\": torch.rand(1, 3, 224, 224),\\n\",\n    \"          \\\"groundtruth\\\": torch.rand(1, 3, 224, 224)}\\n\",\n    \"\\n\",\n    \"# Simulate 4 Epochs\\n\",\n    \"for i in range(4*NUM_ITERS): \\n\",\n    \"    logger.info({\\\"image_grid\\\": {\\\"images\\\": images[\\\"prediction\\\"], \\\"name\\\": \\\"predictions\\\"}})\\n\",\n    \"    \\n\",\n    \"    for key, val_tensor in scalars.items():\\n\",\n    \"        logger.info({\\\"value\\\": {\\\"value\\\": val_tensor.item(), \\\"name\\\": key}})\\n\",\n    \"        scalars[key] += 1\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"## More Examples\"\n   ]\n  },\n  {\n   \"cell_type\": \"markdown\",\n   \"metadata\": {\n    \"pycharm\": {}\n   },\n   \"source\": [\n    \"More Examples can be found in \\n\",\n    \"* [the classification example](classification_pytorch.ipynb, \\\"Classification\\\")\\n\",\n    \"* [the 2d segmentation example](segmentation_2d_pytorch.ipynb, \\\"Segmentation 2D\\\")\\n\",\n    \"* [the 3d segmentation example](segmentation_3d_pytorch.ipynb, \\\"Segmentation 3D\\\")\\n\",\n    \"* [the generative adversarial example](gan_pytorch.ipynb, \\\"GAN\\\")\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"anaconda-cloud\": {},\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.7.3\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "paper/paper.bib",
    "content": "@online{batchgenerators,\n  author = {MIC-DKFZ},\n  title = {batchgenerators},\n  year = 2019,\n  url = {https://github.com/MIC-DKFZ/batchgenerators},\n  urldate = {2019-05-17}\n}\n\n@inproceedings{tensorflow,\n  title={Tensorflow: A system for large-scale machine learning},\n  author={Abadi, Mart{\\'\\i}n and Barham, Paul and Chen, Jianmin and Chen, Zhifeng and Davis, Andy and Dean, Jeffrey and Devin, Matthieu and Ghemawat, Sanjay and Irving, Geoffrey and Isard, Michael and others},\n  booktitle={12th {USENIX} Symposium on Operating Systems Design and Implementation {OSDI} 16)},\n  pages={265--283},\n  year={2016}\n}\n\n@inproceedings{pytorch,\n  title={Automatic differentiation in PyTorch},\n  author={Paszke, Adam and Gross, Sam and Chintala, Soumith and Chanan, Gregory and Yang, Edward and DeVito, Zachary and Lin, Zeming and Desmaison, Alban and Antiga, Luca and Lerer, Adam},\n  booktitle={NIPS 2017 Autodiff Workshop},\n  year={2017}\n}\n\n@inproceedings{gan,\ntitle = {Generative Adversarial Nets},\nauthor = {Goodfellow, Ian and Pouget-Abadie, Jean and Mirza, Mehdi and Xu, Bing and Warde-Farley, David and Ozair, Sherjil and Courville, Aaron and Bengio, Yoshua},\nbooktitle = {Advances in Neural Information Processing Systems 27},\neditor = {Z. Ghahramani and M. Welling and C. Cortes and N. D. Lawrence and K. Q. Weinberger},\npages = {2672--2680},\nyear = {2014},\npublisher = {Curran Associates, Inc.},\nurl = {http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf}\n}\n"
  },
  {
    "path": "paper/paper.md",
    "content": "---\ntitle: 'Delira: A High-Level Framework for Deep Learning in Medical Image Analysis'\ntags:\n  - python\n  - deep learning\n  - medical image analysis\n  - pytorch\n  - tensorflow\nauthors:\n - name: Christoph Haarburger\n   affiliation: \"1\"\n - name: Justus Schock\n   affiliation: \"1\"\n - name: Michael Baumgartner\n   affiliation: \"1\"\n - name: Oliver Rippel\n   affiliation: \"1\"\n - name: Dorit Merhof\n   affiliation: \"1\"\naffiliations:\n - name: Institute of Imaging and Computer Vision, RWTH Aachen University, Germany\n   index: 1\ndate: 17 May 2019\nbibliography: paper.bib\n---\n\n# Summary\n\nMedical image analysis research using deep neural networks often involves the development of problem-specific network architectures and the evaluation of models on several datasets.\nContemporary deep learning frameworks such as PyTorch [@pytorch] and Tensorflow [@tensorflow], however, operate on a low level, such that for comparing different models on several datasets, a lot of boilerplate code is necessary.\nSo far, this boilerplate code is often copied and pasted for new projects and experiments.\nReference implementations of new methods may be implemented in either PyTorch or Tensorflow, leading to a lot of friction when comparing two methods that are implemented in different low-level frameworks.\nMoreover, data augmentation for 3D medical images such as from computed tomography or magnetic resonance images is not natively supported by many low-level frameworks.\nAs a result, stand alone data augmentation solutions are often applied [@batchgenerators].\n\nIn order to integrate high level functionalities such as logging, data structures for image datasets, data augmentation, trainer classes and model save and load functionality in a way that is agnostic with respect to the low-level framework, we developed ``Delira`` (Deep Learning in Radiology).\n\n``Delira`` sonsists of serveral subpackages and modules that are structured into ``data_loading``, ``io``, ``logging``, ``models``, ``training`` and ``utils``.\nThis modular structure enables the reuse of datasets and data loading pipelines across different models.\nMoreover, reference models for classification, segmentation and data synthesis problems using generative adversarial networks [@gan] are provided in the ``models`` subpackage.\n\nThe actual training is carried out using a ``NetworkTrainer`` class that implements the actual training routine given a dataset and model.\nAn ``Experiment`` class runs the training using ``NetworkTrainer``, e.g. in a cross validation scheme.\nA quick tutorial showing how the most important data structures interact with each other and HTML documentation is provided at https://delira.readthedocs.io/en/master/classification_pytorch.html.\n\nCurrently, PyTorch and Tensorflow backends are supported and tested.\nAdding more backends is easily possible if needed.\n\n``Delira`` is released under BSD Clause-2 license.\nThe source code can be found at https://github.com/justusschock/delira.\n\n# References\n"
  },
  {
    "path": "pytest.ini",
    "content": "[pytest]\ntestpaths = tests\naddopts = --cov=delira\npython_files = *.py\n"
  },
  {
    "path": "requirements/base.txt",
    "content": "numpy>=1.15.0\nscikit-learn>=0.20.0\njupyter>=1.0.0\nipython\njoblib\npylint\ntqdm\nvisdom>=0.1.8.5\npyyaml\nbatchgenerators>=0.18.2,!=0.19.2,<0.19.4\ntensorboardX\nnested_lookup\n"
  },
  {
    "path": "requirements/chainer.txt",
    "content": "chainer >= 6.0.0\nh5py\n"
  },
  {
    "path": "requirements/tensorflow.txt",
    "content": "tensorflow-gpu==1.14\n"
  },
  {
    "path": "requirements/torch.txt",
    "content": "torchvision>=0.2.1\ntorch>=1.0.0\n"
  },
  {
    "path": "scripts/ci/build_docs.sh",
    "content": "#!/usr/bin/env bash\n\ncd ./docs;\nmake html;\nmake html;\nmake html;\ntouch _build/html/.nojekyll;\n"
  },
  {
    "path": "scripts/ci/install_before_docs.sh",
    "content": "#!/usr/bin/env bash\npip install -r docs/requirements.txt;\n"
  },
  {
    "path": "scripts/ci/install_before_style_check.sh",
    "content": "#!/usr/bin/env bash\n\npip install pycodestyle;\npip install autopep8;\n"
  },
  {
    "path": "scripts/ci/install_before_tests.sh",
    "content": "#!/usr/bin/env bash\n\npip install -U pip wheel;\npip install -r requirements/base.txt;\n\nif [[ \"$BACKEND\" == \"TFEager\" ]]; then\n    pip install -r requirements/tensorflow.txt\n    pip uninstall -y tensorflow-gpu;\n    pip install tensorflow==1.14;\nelif [[ \"$BACKEND\" == \"TFGraph\" ]]; then\n    pip install -r requirements/tensorflow.txt\n    pip uninstall -y tensorflow-gpu;\n    pip install tensorflow==1.14;\nelif [[ \"$BACKEND\" == \"Torch\" ]]; then\n    pip install -r requirements/torch.txt\nelif [[ \"$BACKEND\" == \"TorchScript\" ]]; then\n    pip install -r requirements/torch.txt\nelif [[ \"$BACKEND\" == \"Chainer\" ]]; then\n    pip install -r requirements/chainer.txt\nelse\n    pip install slackclient==1.3.1\nfi\n\npip install coverage;\npip install codecov;\n"
  },
  {
    "path": "scripts/ci/run_style_checks.sh",
    "content": "#!/usr/bin/env bash\n\n# based onhttps://gist.github.com/MichaelCurrie/802ce28c993ff2dd632c\n\n# find pep8 errors and ignore E402 module level import not at top of file due to logging\nnum_errors_before=`find . -name \\*.py -exec pycodestyle --ignore=E402 {} + | wc -l`;\necho $num_errors_before;\n\ncd \"$TRAVIS_BUILD_DIR\";\n# try with combination of maintainer email and github token\ngit config user.name \"Travis AutoPEP8 Fixes\";\ngit checkout $TRAVIS_BRANCH;\n\n# fix pep8 erros in place if possible\nfind . -name \\*.py -exec autopep8 --recursive --aggressive --aggressive --in-place --exclude *conf.py {} +;\nnum_errors_after=`find . -name \\*.py -exec pycodestyle --ignore=E402 {} + | wc -l`;\necho $num_errors_after;\n\nif (( $num_errors_after < $num_errors_before )); then\n    git commit -a -m \"PEP-8 Auto-Fix\";\n    git config --global push.default simple; # Push only to the current branch.  \n    # Make sure to make the output quiet, or else the API token will \n    # leak!  This works because the API key can replace your password.\n    git push https://$GITHUB_TOKEN@github.com/delira-dev/delira.git;\nfi\n\ncd \"$TRAVIS_BUILD_DIR\";\n# List remaining errors, which have to be fixed manually\nfind . -name \\*.py -exec pycodestyle --ignore=E402 {} +;\n"
  },
  {
    "path": "scripts/ci/run_tests.sh",
    "content": "#!/usr/bin/env bash\n\ncoverage run -m unittest\n"
  },
  {
    "path": "setup.cfg",
    "content": "[pycodestyle]\nexclude = .eggs,*.egg,build,docs/*,.git,versioneer.py,*/conf.py\nignore = E721\n\n[versioneer]\nVCS = git\nstyle = pep440\nversionfile_source = delira/_version.py\nversionfile_build = delira/_version.py\ntag_prefix = v\nparentdir_prefix = \n"
  },
  {
    "path": "setup.py",
    "content": "import os\nfrom setuptools import find_packages, setup\nimport versioneer\n\n\ndef resolve_requirements(file):\n    if not os.path.isfile(file):\n        file = os.path.join(os.path.dirname(__file__), \"requirements\", file)\n    requirements = []\n    with open(file) as f:\n        req = f.read().splitlines()\n        for r in req:\n            if r.startswith(\"-r\"):\n                requirements += resolve_requirements(\n                    os.path.join(os.path.dirname(file), r.split(\" \")[1]))\n            else:\n                requirements.append(r)\n    return requirements\n\n\ndef read_file(file):\n    with open(file) as f:\n        content = f.read()\n    return content\n\n\ndef unify_requirements(base_requirements: list, *additional_requirement_lists):\n    for reqs in additional_requirement_lists:\n        for req in reqs:\n            if req not in base_requirements:\n                base_requirements.append(req)\n\n    return base_requirements\n\n\ndef parse_all_requirements(backend_requirement_dict: dict):\n    backend_requirements = {\"full\": []}\n\n    # parse all requirements\n    for backend_name, requirement_file in backend_requirement_dict.items():\n        _reqs = resolve_requirements(requirement_file)\n        backend_requirements[backend_name] = _reqs\n\n        # add all requirements to full if not already part of it\n        backend_requirements[\"full\"] = unify_requirements(\n            backend_requirements[\"full\"], _reqs)\n\n    # for each backend: check if requirement is already in base requirements\n    for backend_name, reqs in backend_requirements.items():\n        if backend_name == \"base\":\n            continue\n\n        for _req in reqs:\n            if _req in backend_requirements[\"base\"]:\n                reqs.pop(reqs.index(_req))\n\n        backend_requirements[backend_name] = reqs\n\n    return backend_requirements\n\n\nrequirement_files = {\n    \"base\": \"base.txt\",\n    \"sklearn\": \"base.txt\",  # no extra requirements necessary\n    \"torch\": \"torch.txt\",\n    \"torchscript\": \"torch.txt\",\n    \"tensorflow\": \"tensorflow.txt\",\n    \"tensorflow_eager\": \"tensorflow.txt\",\n    \"chainer\": \"chainer.txt\"\n}\n\n\nrequirement_dict = parse_all_requirements(requirement_files)\n\nreadme = read_file(os.path.join(os.path.dirname(__file__), \"README.md\"))\n\nsetup(\n    name='delira',\n    version=versioneer.get_version(),\n    cmdclass=versioneer.get_cmdclass(),\n    packages=find_packages(),\n    url='https://github.com/delira-dev/delira/',\n    test_suite=\"unittest\",\n    long_description=readme,\n    long_description_content_type='text/markdown',\n    maintainer=\"Justus Schock\",\n    maintainer_email=\"justus.schock@rwth-aachen.de\",\n    license='BSD-2',\n    install_requires=requirement_dict.pop(\"base\"),\n    tests_require=[\"coverage\"],\n    python_requires=\">=3.5\",\n    extras_require=requirement_dict\n)\n"
  },
  {
    "path": "tests/__init__.py",
    "content": ""
  },
  {
    "path": "tests/data_loading/__init__.py",
    "content": ""
  },
  {
    "path": "tests/data_loading/test_augmenters.py",
    "content": "from delira.data_loading import Augmenter, DataLoader, SequentialSampler, \\\n    AbstractDataset\nimport numpy as np\nfrom .utils import DummyDataset\nfrom ..utils import check_for_no_backend\n\nimport unittest\n\n\nclass TestAugmenters(unittest.TestCase):\n    def setUp(self) -> None:\n        self._dset_len = 500\n        self._batchsize = 3\n\n        if \"drop_last\" in self._testMethodName:\n            self._drop_last = True\n        else:\n            self._drop_last = False\n\n        dataset = DummyDataset(self._dset_len)\n        data_loader = DataLoader(dataset)\n        sampler = SequentialSampler.from_dataset(dataset)\n\n        if \"parallel\" in self._testMethodName:\n            self.aug = Augmenter(data_loader, self._batchsize, sampler, 2,\n                                 drop_last=self._drop_last)\n        else:\n            self.aug = Augmenter(data_loader, self._batchsize, sampler, 0,\n                                 drop_last=self._drop_last)\n\n    def _aug_test(self):\n\n        num_batches = self._dset_len // self._batchsize\n        if not self._drop_last:\n            num_batches += int(bool(self._dset_len % self._batchsize))\n\n        last_idx = 0\n\n        for batch in self.aug:\n            self.assertIsInstance(batch, dict)\n\n            for v in batch.values():\n                # check for batchsize for alll batches except last\n                # (which can be smaller)\n                if self._drop_last or last_idx < num_batches - 1:\n                    self.assertEqual(len(v), self._batchsize)\n                else:\n                    self.assertLess(len(v), self._batchsize)\n\n            last_idx += 1\n\n        self.assertEqual(last_idx, num_batches)\n\n    # multiple test functions running the same test with different\n    # configurations. Must be done in different functions, because\n    # configurations are switch based on function name\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_parallel(self):\n        self._aug_test()\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_parallel_drop_last(self):\n        self._aug_test()\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_sequential(self):\n        self._aug_test()\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_sequential_drop_last(self):\n        self._aug_test()\n\n    def _test_sampler_indices(self, parallel: bool):\n        class Dataset(AbstractDataset):\n            def __init__(self):\n                super().__init__(None, None)\n\n                self.data = []\n\n                for i in range(50):\n                    self.data.append({\"data\": i})\n\n            def __getitem__(self, item):\n                return self.data[item]\n\n            def __len__(self):\n                return 50\n\n        dataset = Dataset()\n\n        data_loader = DataLoader(dataset)\n        sampler = SequentialSampler.from_dataset(dataset)\n\n        if parallel:\n            aug = Augmenter(data_loader, 1, sampler, 2,\n                            drop_last=False)\n        else:\n            aug = Augmenter(data_loader, 1, sampler, 0,\n                            drop_last=False)\n\n        for idx, batch in enumerate(aug):\n            self.assertEquals(batch[\"data\"].item(), idx)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_sampling_order_parallel(self):\n        self._test_sampler_indices(True)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_sampling_order_sequential(self):\n        self._test_sampler_indices(False)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/test_data_loader.py",
    "content": "import unittest\nfrom delira.data_loading import DataLoader, SequentialSampler, BatchSampler\nfrom .utils import DummyDataset\nimport numpy as np\nfrom ..utils import check_for_no_backend\n\n\nclass DataLoaderTest(unittest.TestCase):\n\n    def _test_data_loader(self, data):\n        loader = DataLoader(data)\n        sampler = SequentialSampler.from_dataset(loader.dataset)\n\n        batch_sampler = BatchSampler(sampler, 16)\n        sampler_iter = iter(batch_sampler)\n\n        self.assertIsInstance(loader(next(sampler_iter)), dict)\n\n        for key, val in loader(next(sampler_iter)).items():\n            self.assertEqual(len(val), 16)\n\n        self.assertIn(\"label\", loader(next(sampler_iter)))\n        self.assertIn(\"data\", loader(next(sampler_iter)))\n\n        self.assertEquals(loader.process_id, 0)\n        loader.process_id = 456\n        self.assertEquals(loader.process_id, 456)\n        with self.assertRaises(AttributeError):\n            loader.process_id = 123\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_data_loader_dset(self):\n        dset = DummyDataset(600, [0.5, 0.3, 0.2])\n        self._test_data_loader(dset)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_data_loader_dict(self):\n        data = {\"label\": np.random.rand(600),\n                \"data\": np.random.rand(600, 1, 3, 3)}\n        self._test_data_loader(data)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_data_loader_iterable(self):\n        data = [{\"label\": np.random.rand(1), \"data\": np.random.rand(1, 3, 3)}\n                for i in range(600)]\n        self._test_data_loader(data)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/test_data_manager.py",
    "content": "import unittest\n\nimport numpy as np\n\nfrom delira.data_loading import DataManager\n\nfrom delira.data_loading.data_manager import Augmenter\nfrom ..utils import check_for_no_backend\nfrom .utils import DummyDataset\n\n\nclass DataManagerTest(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_datamanager(self):\n\n        batch_size = 16\n\n        np.random.seed(1)\n        dset = DummyDataset(600, [0.5, 0.3, 0.2])\n\n        manager = DataManager(dset, batch_size, n_process_augmentation=0,\n                              transforms=None)\n\n        self.assertIsInstance(manager.get_batchgen(), Augmenter)\n\n        # create batch manually\n        data, labels = [], []\n        for i in range(batch_size):\n            data.append(dset[i][\"data\"])\n            labels.append(dset[i][\"label\"])\n\n        batch_dict = {\"data\": np.asarray(data), \"label\": np.asarray(labels)}\n\n        augmenter = manager.get_batchgen()\n        augmenter_iter = iter(augmenter)\n        for key, val in next(augmenter_iter).items():\n            self.assertTrue((val == batch_dict[key]).all())\n\n        for key, val in next(augmenter_iter).items():\n            self.assertEqual(len(val), batch_size)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/test_dataset.py",
    "content": "import unittest\n\nimport numpy as np\n\nfrom delira.data_loading import ConcatDataset, BaseCacheDataset, \\\n    BaseExtendCacheDataset, BaseLazyDataset, LoadSample, LoadSampleLabel\nfrom delira.data_loading.load_utils import norm_zero_mean_unit_std\n\nfrom ..utils import check_for_no_backend\n\n\nclass DataSubsetConcatTest(unittest.TestCase):\n\n    @staticmethod\n    def load_dummy_sample(path, label_load_fct):\n        \"\"\"\n        Returns dummy data, independent of path or label_load_fct\n        Parameters\n        ----------\n        path\n        label_load_fct\n        Returns\n        -------\n        : dict\n            dict with data and label\n        \"\"\"\n\n        return {'data': np.random.rand(1, 256, 256),\n                'label': np.random.randint(2)}\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_data_subset_concat(self):\n\n        class DummyCacheDataset(BaseCacheDataset):\n            def __init__(self, num: int, label_load_fct, *args, **kwargs):\n                \"\"\"\n                Generates random samples with _make_dataset\n                Parameters\n                ----------\n                num : int\n                    number of random samples\n                args :\n                    passed to BaseCacheDataset\n                kwargs :\n                    passed to BaseCacheDataset\n\n                \"\"\"\n                self.label_load_fct = label_load_fct\n                super().__init__(data_path=num, *args, **kwargs)\n\n            def _make_dataset(self, path):\n                data = []\n                for i in range(path):\n                    data.append(self._load_fn(i, self.label_load_fct))\n                return data\n\n        dset_a = DummyCacheDataset(500, None, load_fn=self.load_dummy_sample,\n                                   img_extensions=[], gt_extensions=[])\n        dset_b = DummyCacheDataset(700, None, load_fn=self.load_dummy_sample,\n                                   img_extensions=[], gt_extensions=[])\n\n        # test concatenating\n        concat_dataset = ConcatDataset(dset_a, dset_b)\n\n        self.assertEqual(len(concat_dataset), len(dset_a) + len(dset_b))\n\n        self.assertTrue(concat_dataset[0])\n\n        # test slicing:\n        half_len_a = len(dset_a) // 2\n        half_len_b = len(dset_b) // 2\n\n        self.assertEqual(len(dset_a.get_subset(range(half_len_a))), half_len_a)\n        self.assertEqual(len(dset_b.get_subset(range(half_len_b))), half_len_b)\n\n        sliced_concat_set = concat_dataset.get_subset(\n            range(half_len_a + half_len_b))\n\n        self.assertEqual(len(sliced_concat_set), half_len_a + half_len_b)\n\n        # check if entries are valid\n        self.assertTrue(sliced_concat_set[0])\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_cache_dataset(self):\n        def load_mul_sample(path):\n            \"\"\"\n            Return a list of random samples\n            Parameters\n            ----------\n            path\n\n            Returns\n            -------\n            list\n                list of samples\n            \"\"\"\n            return [self.load_dummy_sample(path, None)] * 4\n\n        # test normal cache dataset\n        paths = list(range(10))\n        dataset = BaseCacheDataset(paths, self.load_dummy_sample,\n                                   label_load_fct=None)\n        assert len(dataset) == 10\n        try:\n            a = dataset[0]\n            a = dataset[5]\n            a = dataset[9]\n        except BaseException:\n            raise AssertionError('Dataset access failed.')\n\n        try:\n            j = 0\n            for i in dataset:\n                assert 'data' in i\n                assert 'label' in i\n                j += 1\n            assert j == len(dataset)\n        except BaseException:\n            raise AssertionError('Dataset iteration failed.')\n\n        # test extend cache dataset\n        dataset = BaseExtendCacheDataset(paths, load_mul_sample)\n        assert len(dataset) == 40\n        try:\n            a = dataset[0]\n            a = dataset[20]\n            a = dataset[39]\n        except BaseException:\n            raise AssertionError('Dataset access failed.')\n\n        try:\n            j = 0\n            for i in dataset:\n                assert 'data' in i\n                assert 'label' in i\n                j += 1\n            assert j == len(dataset)\n        except BaseException:\n            raise AssertionError('Dataset iteration failed.')\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_lazy_dataset(self):\n        # test lazy dataset\n        paths = list(range(10))\n        dataset = BaseLazyDataset(paths, self.load_dummy_sample,\n                                  label_load_fct=None)\n        assert len(dataset) == 10\n        try:\n            a = dataset[0]\n            a = dataset[5]\n            a = dataset[9]\n        except BaseException:\n            raise AssertionError('Dataset access failed.')\n\n        try:\n            j = 0\n            for i in dataset:\n                assert 'data' in i\n                assert 'label' in i\n                j += 1\n            assert j == len(dataset)\n        except BaseException:\n            raise AssertionError('Dataset iteration failed.')\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_load_sample(self):\n        def load_dummy_label(path):\n            return {'label': 42}\n\n        def load_dummy_data(path):\n            return np.random.rand(1, 256, 256) * np.random.randint(2, 20) + \\\n                np.random.randint(20)\n\n        # check loading of a single sample\n        sample_fn = LoadSample({'data': ['data', 'data', 'data'],\n                                'seg': ['data'],\n                                'data2': ['data', 'data', 'data']},\n                               load_dummy_data,\n                               dtype={'seg': 'uint8'},\n                               normalize=('data2',))\n        sample = sample_fn('load')\n        assert not np.isclose(np.mean(sample['data']), 0)\n        assert not np.isclose(np.mean(sample['seg']), 0)\n        assert sample['seg'].dtype == 'uint8'\n        assert np.isclose(sample['data2'].max(), 1)\n        assert np.isclose(sample['data2'].min(), -1)\n\n        # check different normalization function\n        sample_fn = LoadSample({'data': ['data', 'data', 'data']},\n                               load_dummy_data,\n                               normalize=('data',),\n                               norm_fn=norm_zero_mean_unit_std)\n        sample = sample_fn('load')\n        assert np.isclose(np.mean(sample['data']), 0)\n        assert np.isclose(np.std(sample['data']), 1)\n\n        # check label and loading of single sample\n        sample_fn = LoadSampleLabel(\n            {'data': ['data', 'data', 'data'], 'seg': ['data'],\n             'data2': ['data', 'data', 'data']}, load_dummy_data,\n            'label', load_dummy_label,\n            dtype={'seg': 'uint8'}, normalize=('data2',))\n        sample = sample_fn('load')\n        assert not np.isclose(np.mean(sample['data']), 0)\n        assert not np.isclose(np.mean(sample['seg']), 0)\n        assert sample['seg'].dtype == 'uint8'\n        assert np.isclose(sample['data2'].max(), 1)\n        assert np.isclose(sample['data2'].min(), -1)\n        assert sample['label'] == 42\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/test_numba_transforms.py",
    "content": "import unittest\n\nfrom batchgenerators.transforms import ZoomTransform, PadTransform, Compose\nimport numpy as np\nfrom ..utils import check_for_no_backend\n\ntry:\n    import numba\nexcept ImportError:\n    numba = None\n\n\nclass NumbaTest(unittest.TestCase):\n    def setUp(self) -> None:\n        from delira.data_loading.numba_transform import NumbaTransform, \\\n            NumbaCompose\n        self._basic_zoom_trafo = ZoomTransform(3)\n        self._numba_zoom_trafo = NumbaTransform(ZoomTransform, zoom_factors=3)\n        self._basic_pad_trafo = PadTransform(new_size=(30, 30))\n        self._numba_pad_trafo = NumbaTransform(PadTransform,\n                                               new_size=(30, 30))\n\n        self._basic_compose_trafo = Compose([self._basic_pad_trafo,\n                                             self._basic_zoom_trafo])\n        self._numba_compose_trafo = NumbaCompose([self._basic_pad_trafo,\n                                                  self._basic_zoom_trafo])\n\n        self._input = {\"data\": np.random.rand(10, 1, 24, 24)}\n\n    def compare_transform_outputs(self, transform, numba_transform):\n        output_normal = transform(**self._input)[\"data\"]\n        output_numba = numba_transform(**self._input)[\"data\"]\n\n        # only check for same shapes, since numba might apply slightly\n        # different interpolations\n        self.assertTupleEqual(output_normal.shape, output_numba.shape)\n\n    @unittest.skipIf(numba is None, \"Numba must be imported successfully\")\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_zoom(self):\n        self.compare_transform_outputs(self._basic_zoom_trafo,\n                                       self._numba_zoom_trafo)\n\n    @unittest.skipIf(numba is None, \"Numba must be imported successfully\")\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_pad(self):\n        self.compare_transform_outputs(self._basic_pad_trafo,\n                                       self._numba_pad_trafo)\n\n    @unittest.skipIf(numba is None, \"Numba must be imported successfully\")\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should be only executed if no \"\n                         \"backend was installed\")\n    def test_compose(self):\n        self.compare_transform_outputs(self._basic_compose_trafo,\n                                       self._numba_compose_trafo)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/test_sampler.py",
    "content": "import unittest\nimport numpy as np\nfrom delira.data_loading.sampler import RandomSamplerWithReplacement, \\\n    PrevalenceRandomSampler, SequentialSampler, \\\n    RandomSamplerNoReplacement, BatchSampler, AbstractSampler\n\nfrom ..utils import check_for_no_backend\nfrom .utils import DummyDataset\n\n\nclass SamplerTest(unittest.TestCase):\n    def setUp(self) -> None:\n        self.dset = DummyDataset(600, [0.5, 0.3, 0.2])\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\")\n    def test_batch_sampler(self):\n        for batchsize in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:\n            for truncate in [True, False]:\n\n                with self.subTest(batchsize=batchsize, truncate=truncate):\n                    sampler = BatchSampler(\n                        SequentialSampler.from_dataset(self.dset),\n                        batchsize, truncate)\n\n                    sampler_iter = iter(sampler)\n                    for i in range(len(sampler)):\n                        batch = next(sampler_iter)\n\n                        if i < len(sampler) - 1:\n                            self.assertEquals(len(batch), batchsize)\n\n                        else:\n                            if truncate:\n                                self.assertLessEqual(len(batch), batchsize)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\")\n    def test_sequential(self):\n        prev_index = None\n        sampler = SequentialSampler.from_dataset(self.dset)\n\n        for idx in sampler:\n            if prev_index is not None:\n                self.assertEquals(idx, prev_index + 1)\n\n            prev_index = idx\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\")\n    def test_random_replacement(self):\n\n        sampler = RandomSamplerWithReplacement.from_dataset(self.dset)\n        samples = []\n\n        self.assertEquals(len(sampler), len(self.dset))\n\n        for idx in sampler:\n            self.assertIn(idx, np.arange(len(self.dset)))\n            samples.append(idx)\n\n        # check if all samples are only sampled once (extremly unlikely)\n        self.assertFalse((np.bincount(samples) == 1).all())\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\")\n    def test_random_no_replacement(self):\n\n        sampler = RandomSamplerNoReplacement.from_dataset(self.dset)\n        samples = []\n\n        self.assertEquals(len(sampler), len(self.dset))\n\n        for idx in sampler:\n            self.assertIn(idx, np.arange(len(self.dset)))\n            samples.append(idx)\n\n        # check if all samples are only sampled once\n        self.assertTrue((np.bincount(samples) == 1).all())\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\")\n    def test_prevalence_sampler(self):\n\n        sampler = PrevalenceRandomSampler.from_dataset(self.dset)\n        sample_classes = []\n\n        for idx in sampler:\n            self.assertIn(idx, np.arange(len(self.dset)))\n\n            sample_classes.append(self.dset[idx][\"label\"])\n\n        num_samples_per_class = np.bincount(sample_classes)\n\n        self.assertTrue(\n            (num_samples_per_class.min() - num_samples_per_class.max()) <= 1)\n\n    @unittest.skipUnless(check_for_no_backend(),\n                         \"Test should only be executed \"\n                         \"if no backend is installed/specified\"\n                         )\n    def test_abstract_sampler_iter(self):\n        sampler = AbstractSampler.from_dataset(self.dset)\n\n        with self.assertRaises(NotImplementedError):\n            iter(sampler)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/data_loading/utils.py",
    "content": "import math\nimport numpy as np\n\nfrom delira.data_loading import AbstractDataset\n\n\nclass DummyDataset(AbstractDataset):\n    def __init__(self, length=600, class_weights=[0.5, 0.3, 0.2]):\n        super().__init__(None, None)\n\n        assert math.isclose(sum(class_weights), 1)\n\n        self._data = [np.random.rand(1, 28, 28) for i in range(length)]\n        _labels = []\n        for idx, weight in enumerate(class_weights):\n            _labels += [idx] * int(length * weight)\n\n        self._labels = _labels\n\n    def __getitem__(self, index):\n        return {\"data\": self._data[index], \"label\": self._labels[index]}\n\n    def __len__(self):\n        return len(self._data)\n"
  },
  {
    "path": "tests/io/__init__.py",
    "content": ""
  },
  {
    "path": "tests/io/test_chainer.py",
    "content": "import unittest\n\nfrom ..utils import check_for_chainer_backend\n\nif check_for_chainer_backend():\n    import chainer\n    from delira.models import AbstractChainerNetwork\n\n    # define model outside actual test to make it pickleable\n    class Model(AbstractChainerNetwork):\n        def __init__(self):\n            super().__init__()\n\n            with self.init_scope():\n                self.dense = chainer.links.Linear(1, 1)\n\n        def forward(self, x):\n            return {\n                \"pred\":\n                    chainer.functions.relu(\n                        self.dense(x))\n            }\n\n\nclass IoChainerTest(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_chainer_backend(),\n                         \"Test should be only executed if chainer backend is \"\n                         \"installed and specified\")\n    def test_load_save(self):\n\n        from delira.io.chainer import load_checkpoint, save_checkpoint\n\n        net = Model()\n\n        save_checkpoint(\"./model_chainer.chain\", model=net)\n        self.assertTrue(load_checkpoint(\"./model_chainer.chain\", model=net))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/io/test_sklearn.py",
    "content": "import unittest\n\nfrom ..utils import check_for_sklearn_backend\n\n\nclass IoSklearnTest(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_sklearn_backend(),\n                         \"Test should be only executed if sklearn backend is \"\n                         \"installed and specified\")\n    def test_load_save(self):\n\n        from delira.io.sklearn import load_checkpoint, save_checkpoint\n        from delira.models import SklearnEstimator\n        from sklearn.tree import DecisionTreeRegressor\n        import numpy as np\n\n        net = SklearnEstimator(DecisionTreeRegressor())\n        net.fit(X=np.random.rand(2, 32), y=np.random.rand(2))\n        save_checkpoint(\"./model_sklearn.pkl\", model=net)\n        self.assertTrue(load_checkpoint(\"./model_sklearn.pkl\"))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/io/test_tf.py",
    "content": "import unittest\n\nfrom ..utils import check_for_tf_eager_backend, check_for_tf_graph_backend\n\n\nclass IoTfTest(unittest.TestCase):\n\n    def setUp(self) -> None:\n        import tensorflow as tf\n        tf.reset_default_graph()\n        if \"_eager\" in self._testMethodName:\n            tf.enable_eager_execution()\n        else:\n            tf.disable_eager_execution()\n\n    @unittest.skipUnless(check_for_tf_graph_backend(),\n                         \"Test should be only executed if tensorflow backend \"\n                         \"is installed and specified\")\n    def test_load_save(self):\n        import tensorflow as tf\n        tf.disable_eager_execution()\n        from delira.io.tf import load_checkpoint, save_checkpoint\n        from delira.models import AbstractTfGraphNetwork\n        from delira.training.backends import initialize_uninitialized\n\n        import numpy as np\n\n        class DummyNetwork(AbstractTfGraphNetwork):\n            def __init__(self, in_channels, n_outputs):\n                super().__init__(in_channels=in_channels, n_outputs=n_outputs)\n                self.net = self._build_model(in_channels, n_outputs)\n\n            @staticmethod\n            def _build_model(in_channels, n_outputs):\n                return tf.keras.models.Sequential(\n                    layers=[\n                        tf.keras.layers.Dense(\n                            64,\n                            input_shape=in_channels,\n                            bias_initializer='glorot_uniform'),\n                        tf.keras.layers.ReLU(),\n                        tf.keras.layers.Dense(\n                            n_outputs,\n                            bias_initializer='glorot_uniform')])\n\n        net = DummyNetwork((32,), 1)\n        initialize_uninitialized(net._sess)\n\n        vars_1 = net._sess.run(tf.global_variables())\n\n        save_checkpoint(\"./model\", model=net)\n\n        net._sess.run(tf.initializers.global_variables())\n\n        vars_2 = net._sess.run(tf.global_variables())\n\n        load_checkpoint(\"./model\", model=net)\n\n        vars_3 = net._sess.run(tf.global_variables())\n\n        for var_1, var_2 in zip(vars_1, vars_2):\n            with self.subTest(var_1=var_1, var2=var_2):\n                self.assertTrue(np.all(var_1 != var_2))\n\n        for var_1, var_3 in zip(vars_1, vars_3):\n            with self.subTest(var_1=var_1, var_3=var_3):\n                self.assertTrue(np.all(var_1 == var_3))\n\n    @unittest.skipUnless(check_for_tf_eager_backend(),\n                         \"Test should be only executed if tensorflow backend \"\n                         \"is installed and specified\")\n    def test_load_save_eager(self):\n        import tensorflow as tf\n        tf.enable_eager_execution()\n        from delira.io.tf import load_checkpoint_eager, save_checkpoint_eager\n        from delira.models import AbstractTfEagerNetwork\n\n        import numpy as np\n\n        class DummyNetwork(AbstractTfEagerNetwork):\n            def __init__(self, in_channels, n_outputs):\n                super().__init__(in_channels=in_channels, n_outputs=n_outputs)\n                with tf.init_scope():\n                    self.net = self._build_model(in_channels, n_outputs)\n\n            @staticmethod\n            def _build_model(in_channels, n_outputs):\n                return tf.keras.models.Sequential(\n                    layers=[\n                        tf.keras.layers.Dense(\n                            64,\n                            input_shape=in_channels,\n                            bias_initializer='glorot_uniform'),\n                        tf.keras.layers.ReLU(),\n                        tf.keras.layers.Dense(\n                            n_outputs,\n                            bias_initializer='glorot_uniform')])\n\n            def call(self, inputs):\n                return self.net(inputs)\n\n        net = DummyNetwork((32,), 1)\n        input_tensor = tf.constant(np.random.rand(1, 32).astype(np.float32))\n        result_pre_save = net(input_tensor)\n        save_checkpoint_eager(\"./model_eager\", model=net)\n\n        loaded_state = load_checkpoint_eager(\"./model_eager\", model=net)\n        loaded_net = loaded_state[\"model\"]\n\n        result_post_save = loaded_net(input_tensor)\n\n        self.assertTrue(np.array_equal(result_post_save, result_pre_save))\n\n    def tearDown(self) -> None:\n        import gc\n        import sys\n\n        try:\n            del sys.modules[\"tf\"]\n        except KeyError:\n            pass\n        try:\n            del tf\n        except (UnboundLocalError, NameError):\n            pass\n        try:\n            del sys.modules[\"tensorflow\"]\n        except KeyError:\n            pass\n        try:\n            del tensorflow\n        except (UnboundLocalError, NameError):\n            pass\n\n        gc.collect()\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/io/test_torch.py",
    "content": "import unittest\n\nfrom ..utils import check_for_torch_backend, check_for_torchscript_backend\n\n\nclass IoTorchTest(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Test should be only executed if torch backend is \"\n                         \"installed and specified\")\n    def test_load_save(self):\n\n        from delira.io.torch import load_checkpoint_torch, \\\n            save_checkpoint_torch\n        from delira.models import AbstractPyTorchNetwork\n        import torch\n\n        class DummyNetwork(AbstractPyTorchNetwork):\n            def __init__(self, in_channels, n_outputs):\n                super().__init__(in_channels=in_channels, n_outputs=n_outputs)\n                self.net = self._build_model(in_channels, n_outputs)\n\n            def forward(self, x):\n                return self.module(x)\n\n            @staticmethod\n            def _build_model(in_channels, n_outputs):\n                return torch.nn.Sequential(\n                    torch.nn.Linear(in_channels, 64),\n                    torch.nn.ReLU(),\n                    torch.nn.Linear(64, n_outputs)\n                )\n\n        net = DummyNetwork(32, 1)\n        save_checkpoint_torch(\"./model_torch.pt\", model=net)\n        self.assertTrue(load_checkpoint_torch(\"./model_torch.pt\"))\n\n    @unittest.skipUnless(check_for_torchscript_backend(),\n                         \"Test should be only executed if torch backend is \"\n                         \"installed and specified\")\n    def test_torchscript_save(self):\n        from delira.io.torch import load_checkpoint_torchscript, \\\n            save_checkpoint_torchscript\n        from delira.models import AbstractTorchScriptNetwork\n        import torch\n\n        class DummyNetwork(AbstractTorchScriptNetwork):\n\n            def __init__(self):\n                super().__init__()\n                self.dense = torch.nn.Linear(3, 1)\n\n            @torch.jit.script_method\n            def forward(self, x):\n                return self.dense(x)\n\n        net = DummyNetwork()\n        save_checkpoint_torchscript(\"./model_jit.ptj\", model=net)\n        self.assertTrue(load_checkpoint_torchscript(\"./model_jit.ptj\"))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/logging/__init__.py",
    "content": ""
  },
  {
    "path": "tests/logging/test_logging_frequency.py",
    "content": "import unittest\nfrom delira.logging import BaseBackend, SingleThreadedLogger\nimport logging\n\n\nclass DummyBackend(BaseBackend):\n    def _text(self, logging_no: int, tag: str, global_step=None):\n        logging.info(\"INFO: Logging Item Number %d\" % logging_no)\n\n    # implement dummy funtions to be able to instantiate backend\n    def _image(self, *args, **kwargs):\n        pass\n\n    def _images(self, *args, **kwargs):\n        pass\n\n    def _image_with_boxes(self, *args, **kwargs):\n        pass\n\n    def _scalar(self, *args, **kwargs):\n        pass\n\n    def _scalars(self, *args, **kwargs):\n        pass\n\n    def _histogram(self, *args, **kwargs):\n        pass\n\n    def _figure(self, *args, **kwargs):\n        pass\n\n    def _audio(self, *args, **kwargs):\n        pass\n\n    def _video(self, *args, **kwargs):\n        pass\n\n    def _graph_pytorch(self, *args, **kwargs):\n        pass\n\n    def _graph_tf(self, *args, **kwargs):\n        pass\n\n    def _graph_onnx(self, *args, **kwargs):\n        pass\n\n    def _embedding(self, *args, **kwargs):\n        pass\n\n    def _pr_curve(self, *args, **kwargs):\n        pass\n\n\nclass LoggingFrequencyTestCase(unittest.TestCase):\n\n    def _logging_freq_test(self, frequencies, num_runs: int, check_freq=None):\n        logger = SingleThreadedLogger(DummyBackend(),\n                                      logging_frequencies=frequencies,\n                                      reduce_types=\"last\")\n\n        if check_freq is None and isinstance(frequencies, int):\n            check_freq = frequencies\n\n        assert check_freq is not None\n\n        target_messages = 0\n\n        with self.assertLogs() as cm:\n            for idx in range(num_runs):\n                logger.log({\"text\": {\"logging_no\": idx, \"tag\": \"dummy\"}})\n\n                target_messages += int((idx + 1) % check_freq == 0)\n\n        self.assertIsNotNone(cm.output)\n        self.assertEqual(target_messages, len(cm.output))\n\n    def test_logging_freq(self):\n        for frequencies, check_freq in zip([1, 5, 10, {\"text\": 15}],\n                                           [None, None, None, 15]):\n            with self.subTest(frequencies=frequencies):\n                self._logging_freq_test(frequencies, 50, check_freq)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/logging/test_logging_outside_trainer.py",
    "content": "import unittest\nfrom delira.logging import log\nfrom delira.training import BaseNetworkTrainer\nfrom delira.models import AbstractNetwork\nimport os\nfrom tests.utils import check_for_tf_graph_backend\n\ntry:\n    import tensorflow as tf\nexcept ImportError:\n    tf = None\n\n\nclass LoggingOutsideTrainerTestCase(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_tf_graph_backend(),\n                         \"TF Backend not installed\")\n    def test_logging_freq(self):\n        save_path = os.path.abspath(\"./logs\")\n        config = {\n            \"num_epochs\": 2,\n            \"losses\": {},\n            \"optimizer_cls\": None,\n            \"optimizer_params\": {\"learning_rate\": 1e-3},\n            \"metrics\": {},\n            \"lr_scheduler_cls\": None,\n            \"lr_scheduler_params\": {}\n        }\n        trainer = BaseNetworkTrainer(\n            AbstractNetwork(),\n            save_path,\n            **config,\n            gpu_ids=[],\n            save_freq=1,\n            optim_fn=None,\n            key_mapping={},\n            logging_type=\"tensorboardx\",\n            logging_kwargs={\n                'logdir': save_path\n            })\n\n        trainer._setup(\n            AbstractNetwork(),\n            lr_scheduler_cls=None,\n            lr_scheduler_params={},\n            gpu_ids=[],\n            key_mapping={},\n            convert_batch_to_npy_fn=None,\n            prepare_batch_fn=None,\n            callbacks=[])\n\n        tag = 'dummy'\n\n        log({\"scalar\": {\"scalar_value\": 1234, \"tag\": tag}})\n\n        file = [os.path.join(save_path, x)\n                for x in os.listdir(save_path)\n                if os.path.isfile(os.path.join(save_path, x))][0]\n\n        ret_val = False\n        if tf is not None:\n            for e in tf.train.summary_iterator(file):\n                for v in e.summary.value:\n                    if v.tag == tag:\n                        ret_val = True\n                        break\n                if ret_val:\n                    break\n\n        self.assertTrue(ret_val)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/logging/test_single_threaded_logging.py",
    "content": "from delira.logging import Logger, TensorboardBackend, make_logger\n\nfrom tests.utils import check_for_torch_backend, check_for_tf_graph_backend\n\nimport unittest\n\ntry:\n    import tensorflow as tf\nexcept ImportError:\n    tf = None\n\ntry:\n    import torch\nexcept ImportError:\n    torch = None\n\ntry:\n    import onnx\nexcept ImportError:\n    onnx = None\n\nimport numpy as np\nimport os\nimport gc\n\n\nclass TestTensorboardLogging(unittest.TestCase):\n\n    def setUp(self) -> None:\n\n        self._npy_imgs = np.random.rand(2, 3, 24, 24)\n        self._boxes_npy = np.array([[5, 5, 10, 10], [4, 8, 5, 16]])\n        self._scalars = [{\"1\": 4, \"2\": 14, \"3\": 24},\n                         {\"1\": 5, \"2\": 15, \"3\": 25},\n                         {\"1\": 6, \"2\": 16, \"3\": 26}]\n\n        self._hist_vals = np.random.randint(0, 10, size=(100,))\n        from scipy.signal import chirp\n        self._audio_sample_npy = chirp(np.linspace(0, 100), 500, 2, 100)\n\n        self._text_string = \"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrs\" \\\n                            \"tuvwxyz0123456789\"\n\n        if tf is not None:\n            tf.reset_default_graph()\n            input = np.zeros(shape=(1, 28, 28, 1))\n\n            layers = tf.keras.layers\n            self._model_tf = tf.keras.Sequential(\n                [layers.Conv2D(\n                    32,\n                    5,\n                    padding='same',\n                    data_format=\"channels_last\",\n                    activation=tf.nn.relu),\n                    layers.Conv2D(\n                        64,\n                        5,\n                        padding='same',\n                        data_format=\"channels_last\",\n                        activation=tf.nn.relu),\n                 ])\n            self._model_tf.build(input_shape=input.shape)\n\n        else:\n            self._model_tf = None\n\n        if torch is not None:\n            self._model_torch = torch.nn.Sequential(\n                torch.nn.Conv2d(3, 8, 3, padding=1),\n                torch.nn.ReLU(),\n                torch.nn.Conv2d(8, 1, 3, padding=1),\n                torch.nn.LeakyReLU(),\n                torch.nn.Conv2d(1, 23, 3),\n            )\n\n        else:\n            self._model_torch = None\n\n        self._embedding_npy = np.random.rand(500, 3)\n\n        self._labels_npy = np.random.randint(0, 10, 100)\n        self._predictions_npy = np.random.randint(0, 10, 100)\n\n        self._logger = self._setup_logger()\n\n    def _setup_logger(self):\n        return make_logger(TensorboardBackend(\n            {\"logdir\": os.path.join(\".\", \"runs\", self._testMethodName)}\n        ))\n\n    def _check_for_tag(self, tag, logdir=None):\n\n        if logdir is None:\n            try:\n                logdir = self._logger._backend._writer.logdir\n            except AttributeError:\n                logdir = self._logger._backend._writer.log_dir\n\n        file = [os.path.join(logdir, x)\n                for x in os.listdir(logdir)\n                if os.path.isfile(os.path.join(logdir, x))][0]\n\n        if tf is not None:\n            ret_val = False\n            for e in tf.train.summary_iterator(file):\n                for v in e.summary.value:\n                    if v.tag == tag:\n                        ret_val = True\n                        break\n                if ret_val:\n                    break\n\n            self.assertTrue(ret_val)\n\n    @staticmethod\n    def _destroy_logger(logger: Logger):\n        logger.close()\n        del logger\n        gc.collect()\n\n    def test_image_npy(self):\n        self._logger.log({\"image\": {\"tag\": \"image_npy\",\n                                    \"img_tensor\": self._npy_imgs[0]}})\n        self._check_for_tag(\"image_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_image_torch(self):\n        self._logger.log({\"image\": {\"tag\": \"image_torch\",\n                                    \"img_tensor\":\n                                        torch.from_numpy(self._npy_imgs[0])}})\n        self._check_for_tag(\"image_torch\")\n\n    def test_img_npy(self):\n        self._logger.log({\"img\": {\"tag\": \"img_npy\",\n                                  \"img_tensor\": self._npy_imgs[0]}})\n        self._check_for_tag(\"img_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_img_torch(self):\n        self._logger.log({\"img\": {\"tag\": \"img_torch\",\n                                  \"img_tensor\":\n                                      torch.from_numpy(self._npy_imgs[0])}})\n        self._check_for_tag(\"img_torch\")\n\n    def test_picture_npy(self):\n        self._logger.log({\"picture\": {\"tag\": \"picture_npy\",\n                                      \"img_tensor\": self._npy_imgs[0]}})\n        self._check_for_tag(\"picture_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_picture_torch(self):\n        self._logger.log({\n            \"picture\": {\n                \"tag\": \"picture_torch\",\n                \"img_tensor\": torch.from_numpy(self._npy_imgs[0])}})\n        self._check_for_tag(\"picture_torch\")\n\n    def test_images_npy(self):\n        self._logger.log({\"images\": {\"tag\": \"images_npy\",\n                                     \"img_tensor\": self._npy_imgs}})\n        self._check_for_tag(\"images_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_images_torch(self):\n        self._logger.log({\"images\": {\"tag\": \"images_torch\",\n                                     \"img_tensor\":\n                                         torch.from_numpy(self._npy_imgs)}})\n        self._check_for_tag(\"images_torch\")\n\n    def test_imgs_npy(self):\n        self._logger.log({\"imgs\": {\"tag\": \"imgs_npy\",\n                                   \"img_tensor\": self._npy_imgs}})\n        self._check_for_tag(\"imgs_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_imgs_torch(self):\n        self._logger.log({\"imgs\": {\"tag\": \"imgs_torch\",\n                                   \"img_tensor\":\n                                       torch.from_numpy(self._npy_imgs)}})\n        self._check_for_tag(\"imgs_torch\")\n\n    def test_pictures_npy(self):\n        self._logger.log({\"pictures\": {\"tag\": \"pictures_npy\",\n                                       \"img_tensor\": self._npy_imgs}})\n        self._check_for_tag(\"pictures_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_pictures_torch(self):\n        self._logger.log({\"pictures\": {\"tag\": \"pictures_torch\",\n                                       \"img_tensor\":\n                                           torch.from_numpy(self._npy_imgs)}})\n        self._check_for_tag(\"pictures_torch\")\n\n    def test_image_with_boxes_npy(self):\n        self._logger.log({\"image_with_boxes\": {\n            \"tag\": \"image_with_boxes_npy\",\n            \"img_tensor\": self._npy_imgs[0],\n            \"box_tensor\": self._boxes_npy\n        }})\n        self._check_for_tag(\"image_with_boxes_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_image_with_boxes_torch(self):\n        self._logger.log({\"image_with_boxes\": {\n            \"tag\": \"image_with_boxes_torch\",\n            \"img_tensor\": torch.from_numpy(self._npy_imgs[0]),\n            \"box_tensor\": torch.from_numpy(self._boxes_npy)\n        }})\n        self._check_for_tag(\"image_with_boxes_torch\")\n\n    def test_bounding_boxes_npy(self):\n        self._logger.log({\"bounding_boxes\": {\n            \"tag\": \"bounding_boxes_npy\",\n            \"img_tensor\": self._npy_imgs[0],\n            \"box_tensor\": self._boxes_npy\n        }})\n        self._check_for_tag(\"bounding_boxes_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_bounding_boxes_torch(self):\n\n        self._logger.log({\"bounding_boxes\": {\n            \"tag\": \"bounding_boxes_torch\",\n            \"img_tensor\": torch.from_numpy(self._npy_imgs[0]),\n            \"box_tensor\": torch.from_numpy(self._boxes_npy)\n        }})\n        self._check_for_tag(\"bounding_boxes_torch\")\n\n    def test_bboxes_npy(self):\n        self._logger.log({\"bboxes\": {\n            \"tag\": \"bboxes_npy\",\n            \"img_tensor\": self._npy_imgs[0],\n            \"box_tensor\": self._boxes_npy\n        }})\n        self._check_for_tag(\"bboxes_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_bboxes_torch(self):\n        self._logger.log({\"bboxes\": {\n            \"tag\": \"bboxes_torch\",\n            \"img_tensor\": torch.from_numpy(self._npy_imgs[0]),\n            \"box_tensor\": torch.from_numpy(self._boxes_npy)\n        }})\n        self._check_for_tag(\"bboxes_torch\")\n\n    def test_scalar(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalar\": {\n                    \"tag\": \"scalar\",\n                    \"scalar_value\": _scalar[\"1\"]\n                }\n            })\n        self._check_for_tag(\"scalar\")\n\n    def test_scalar_npy(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalar\": {\n                    \"tag\": \"scalar_npy\",\n                    \"scalar_value\": np.array(_scalar[\"1\"])\n                }\n            })\n\n        self._check_for_tag(\"scalar_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_scalar_torch(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalar\": {\n                    \"tag\": \"scalar_torch\",\n                    \"scalar_value\": torch.tensor(_scalar[\"1\"])\n                }\n            })\n\n    def test_value(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"value\": {\n                    \"tag\": \"value\",\n                    \"scalar_value\": _scalar[\"1\"]\n                }\n            })\n        self._check_for_tag(\"value\")\n\n    def test_value_npy(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"value\": {\n                    \"tag\": \"value_npy\",\n                    \"scalar_value\": np.array(_scalar[\"1\"])\n                }\n            })\n        self._check_for_tag(\"value_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_value_torch(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"value\": {\n                    \"tag\": \"value_torch\",\n                    \"scalar_value\": torch.tensor(_scalar[\"1\"])\n                }\n            })\n        self._check_for_tag(\"value_torch\")\n\n    def test_scalars(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalars\": {\n                    \"main_tag\": \"scalars\",\n                    \"tag_scalar_dict\": _scalar,\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"scalars/\" + k)\n\n    def test_scalars_npy(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalars\": {\n                    \"main_tag\": \"scalars_npy\",\n                    \"tag_scalar_dict\": {k: np.array(v)\n                                        for k, v in _scalar.items()},\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"scalars_npy/\" + k)\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_scalars_torch(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"scalars\": {\n                    \"main_tag\": \"scalars_torch\",\n                    \"tag_scalar_dict\": {k: torch.tensor(v)\n                                        for k, v in _scalar.items()},\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"scalars_torch/\" + k)\n\n    def test_values(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"values\": {\n                    \"main_tag\": \"values\",\n                    \"tag_scalar_dict\": _scalar,\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"values/\" + k)\n\n    def test_values_npy(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"values\": {\n                    \"main_tag\": \"values_npy\",\n                    \"tag_scalar_dict\": {k: np.array(v)\n                                        for k, v in _scalar.items()},\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"values_npy/\" + k)\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_values_torch(self):\n        for _scalar in self._scalars:\n            self._logger.log({\n                \"values\": {\n                    \"main_tag\": \"values_torch\",\n                    \"tag_scalar_dict\": {k: torch.tensor(v)\n                                        for k, v in _scalar.items()},\n                    \"sep\": \"/\"\n                }\n            })\n\n        for k in self._scalars[0].keys():\n            self._check_for_tag(\"values_torch/\" + k)\n\n    def test_histogram_npy(self):\n        self._logger.log({\n            \"histogram\": {\n                \"tag\": \"histogram_npy\",\n                \"values\": self._hist_vals\n            }\n        })\n\n        self._check_for_tag(\"histogram_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_histogram_torch(self):\n        self._logger.log({\n            \"histogram\": {\n                \"tag\": \"histogram_torch\",\n                \"values\": torch.from_numpy(self._hist_vals)\n            }\n        })\n\n        self._check_for_tag(\"histogram_torch\")\n\n    def test_hist_npy(self):\n        self._logger.log({\n            \"hist\": {\n                \"tag\": \"hist_npy\",\n                \"values\": self._hist_vals\n            }\n        })\n\n        self._check_for_tag(\"hist_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_hist_torch(self):\n        self._logger.log({\n            \"hist\": {\n                \"tag\": \"hist_torch\",\n                \"values\": torch.from_numpy(self._hist_vals)\n            }\n        })\n\n        self._check_for_tag(\"hist_torch\")\n\n    def test_figure(self):\n        from matplotlib.pyplot import figure, imshow, close\n        _fig = figure()\n        imshow(self._npy_imgs[0][0])\n        self._logger.log({\n            \"figure\": {\n                \"tag\": \"figure\",\n                \"figure\": _fig\n            }\n        })\n        close()\n\n        self._check_for_tag(\"figure\")\n\n    def test_fig(self):\n        from matplotlib.pyplot import figure, imshow, close\n        _fig = figure()\n        imshow(self._npy_imgs[0][0])\n        self._logger.log({\n            \"fig\": {\n                \"tag\": \"fig\",\n                \"figure\": _fig\n            }\n        })\n        close()\n\n        self._check_for_tag(\"fig\")\n\n    def test_audio_npy(self):\n        self._logger.log({\"audio\": {\n            \"tag\": \"audio_npy\",\n            \"snd_tensor\": self._audio_sample_npy\n        }})\n\n        self._check_for_tag(\"audio_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_audio_torch(self):\n        self._logger.log({\"audio\": {\n            \"tag\": \"audio_torch\",\n            \"snd_tensor\": torch.from_numpy(self._audio_sample_npy)\n        }})\n\n        self._check_for_tag(\"audio_torch\")\n\n    def test_sound_npy(self):\n        self._logger.log({\"sound\": {\n            \"tag\": \"sound_npy\",\n            \"snd_tensor\": self._audio_sample_npy\n        }})\n\n        self._check_for_tag(\"sound_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_sound_torch(self):\n        self._logger.log({\"sound\": {\n            \"tag\": \"sound_torch\",\n            \"snd_tensor\": torch.from_numpy(self._audio_sample_npy)\n        }})\n\n        self._check_for_tag(\"sound_torch\")\n\n    def test_video_npy(self):\n        # add channel and batch dimension for format BTCHW\n        vid = self._npy_imgs.reshape((1, *self._npy_imgs.shape))\n\n        self._logger.log({\"video\": {\n            \"tag\": \"video_npy\",\n            \"vid_tensor\": vid,\n            \"fps\": 1\n        }})\n        self._check_for_tag(\"video_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_video_torch(self):\n        # add channel and batch dimension for format BTCHW\n        vid = self._npy_imgs.reshape((1, *self._npy_imgs.shape))\n\n        self._logger.log({\"video\": {\n            \"tag\": \"video_torch\",\n            \"vid_tensor\": torch.from_numpy(vid),\n            \"fps\": 1\n        }})\n\n        self._check_for_tag(\"video_torch\")\n\n    def test_text(self):\n        self._logger.log({\"text\": {\n            \"tag\": \"text\",\n            \"text_string\": self._text_string\n        }})\n\n        self._check_for_tag(\"text/text_summary\")\n\n    @unittest.skipUnless(check_for_tf_graph_backend(),\n                         \"TF Backend not installed\")\n    def test_graph_tf(self):\n\n        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)\n        run_metadata = tf.RunMetadata()\n\n        with tf.Session() as sess:\n            outputs = self._model_tf(\n                np.zeros(\n                    shape=(\n                        1,\n                        28,\n                        28,\n                        1),\n                    dtype=np.float32))\n            sess.run(tf.initializers.global_variables())\n            sess.run(outputs, options=run_options, run_metadata=run_metadata)\n\n        self._logger.log({\"graph_tf\": {\n            \"graph\": self._model_tf._graph.as_graph_def(add_shapes=True),\n            \"run_metadata\": run_metadata\n        }})\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_graph_torch(self):\n\n        input_tensor = self._npy_imgs[0]\n        input_tensor = input_tensor.reshape(1, *input_tensor.shape)\n\n        self._logger.log({\n            \"graph_pytorch\": {\n                \"model\": self._model_torch,\n                \"input_to_model\": torch.from_numpy(input_tensor).float()\n            }\n        })\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    @unittest.skipIf(onnx is None, reason=\"ONNX not installed\")\n    def test_graph_onnx(self):\n        import os\n        input_tensor = self._npy_imgs[0]\n        input_tensor = input_tensor.reshape(1, *input_tensor.shape)\n        torch.onnx.export(self._model_torch,\n                          torch.from_numpy(input_tensor).float(),\n                          os.path.abspath(\"model.onnx\"))\n        self._logger.log({\n            \"graph_onnx\": {\"prototxt\": os.path.abspath(\"model.onnx\")}\n        })\n\n    def test_embedding_npy(self):\n        self._logger.log({\"embedding\": {\n            \"mat\": self._embedding_npy\n        }})\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_embedding_torch(self):\n        self._logger.log({\"embedding\": {\n            \"mat\": torch.from_numpy(self._embedding_npy)\n        }})\n\n    def test_pr_curve_npy(self):\n        self._logger.log({\"pr_curve\": {\n            \"tag\": \"pr_curve_npy\",\n            \"labels\": self._labels_npy,\n            \"predictions\": self._predictions_npy\n        }})\n        self._check_for_tag(\"pr_curve_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_pr_curve_torch(self):\n        self._logger.log({\"pr_curve\": {\n            \"tag\": \"pr_curve_torch\",\n            \"labels\": torch.from_numpy(self._labels_npy),\n            \"predictions\": torch.from_numpy(self._predictions_npy)\n        }})\n        self._check_for_tag(\"pr_curve_torch\")\n\n    def test_pr_npy(self):\n        self._logger.log({\"pr\": {\n            \"tag\": \"pr_npy\",\n            \"labels\": self._labels_npy,\n            \"predictions\": self._predictions_npy\n        }})\n        self._check_for_tag(\"pr_npy\")\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Torch Backend not installed\")\n    def test_pr_torch(self):\n        self._logger.log({\"pr\": {\n            \"tag\": \"pr_torch\",\n            \"labels\": torch.from_numpy(self._labels_npy),\n            \"predictions\": torch.from_numpy(self._predictions_npy)\n        }})\n        self._check_for_tag(\"pr_torch\")\n\n    def tearDown(self) -> None:\n        self._destroy_logger(self._logger)\n        self._logger = None\n\n\nif __name__ == '__main__':\n    from multiprocessing import freeze_support\n    freeze_support()\n    unittest.main()\n"
  },
  {
    "path": "tests/models/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/data_parallel/__init__.py",
    "content": ""
  },
  {
    "path": "tests/models/data_parallel/test_chainer.py",
    "content": "import unittest\nfrom tests.utils import check_for_chainer_backend\n\n\nclass TestDataParallelChainer(unittest.TestCase):\n\n    def setUp(self) -> None:\n        if check_for_chainer_backend():\n            import chainer\n            import chainer.link\n            import chainer.links\n            import chainer.functions\n            import chainer.optimizers\n            from delira.models.backends.chainer.data_parallel import \\\n                DataParallelChainerOptimizer, \\\n                DataParallelChainerNetwork\n            from delira.models.backends.chainer.abstract_network import \\\n                AbstractChainerNetwork\n\n            # creating a really simple model to test dataparallel behavior\n            class SimpleModel(AbstractChainerNetwork):\n                def __init__(self):\n                    super(SimpleModel, self).__init__()\n\n                    with self.init_scope():\n                        self.dense_1 = chainer.links.Linear(3, 32)\n                        self.dense_2 = chainer.links.Linear(32, 2)\n\n                def forward(self, x):\n                    return self.dense_2(\n                        chainer.functions.relu(\n                            self.dense_1(x)))\n\n            self.model = DataParallelChainerNetwork(SimpleModel(),\n                                                    devices=[\"@numpy\",\n                                                             \"@numpy\"])\n\n            self.optimizer = DataParallelChainerOptimizer.from_optimizer_class(\n                chainer.optimizers.Adam\n            )\n            self.optimizer.setup(self.model)\n\n    @unittest.skipUnless(check_for_chainer_backend(),\n                         \"Test should be only executed if chainer backend is \"\n                         \"installed and specified\")\n    def test_update(self):\n        import numpy as np\n        import chainer\n\n        input_tensor = np.random.rand(10, 3).astype(np.float32)\n        label_tensor = np.random.rand(10, 2).astype(np.float)\n\n        model_copy = self.model.copy()\n\n        preds = self.model(input_tensor)\n\n        loss = chainer.functions.sum(preds - label_tensor)\n\n        self.model.cleargrads()\n        loss.backward()\n        self.optimizer.update()\n\n        # check if param was updated\n        for orig_param, updated_param in zip(model_copy.params(),\n                                             self.model.params()):\n\n            self.assertFalse(np.array_equal(orig_param, updated_param))\n\n        # check if all grads were cleared\n        self.model.cleargrads()\n        for module in self.model.modules:\n            for updated_param in module.params():\n                self.assertIsNone(updated_param.grad_var)\n\n    # test with keyword arguments\n    @unittest.skipUnless(check_for_chainer_backend(),\n                         \"Test should be only executed if chainer backend is \"\n                         \"installed and specified\")\n    def test_keyword_arguments_different_batchsize(self):\n        import numpy as np\n        import chainer\n\n        # test batchsize smaller than, equal to and greater than number devices\n        for batchsize in [1, 2, 3]:\n            with self.subTest(batchsize=batchsize):\n                input_kwargs = {\n                    \"x\": np.random.rand(batchsize, 3).astype(np.float32)\n                }\n\n                pred = self.model(**input_kwargs)\n                self.assertTupleEqual(pred.shape,\n                                      (batchsize, 2))\n                self.assertEqual(chainer.get_device(pred.device),\n                                 chainer.get_device(\"@numpy\"))\n\n    # test with positional arguments\n    @unittest.skipUnless(check_for_chainer_backend(),\n                         \"Test should be only executed if chainer backend is \"\n                         \"installed and specified\")\n    def test_positional_arguments(self):\n        import numpy as np\n        import chainer\n\n        # test batchsize smaller than, equal to and greater than number devices\n        for batchsize in [1, 2, 3]:\n            with self.subTest(batchsize=batchsize):\n                input_args = [\n                    np.random.rand(batchsize, 3).astype(np.float32)\n                ]\n\n                pred = self.model(*input_args)\n                self.assertTupleEqual(pred.shape,\n                                      (batchsize, 2))\n\n                self.assertEqual(chainer.get_device(pred.device),\n                                 chainer.get_device(\"@numpy\"))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/models/data_parallel/test_torch.py",
    "content": "import unittest\nfrom copy import deepcopy\nimport numpy as np\n\nfrom tests.utils import check_for_torch_backend\n\n\nclass TestDataParallelTorch(unittest.TestCase):\n\n    def setUp(self) -> None:\n        if check_for_torch_backend():\n            from delira.models.backends.torch import AbstractPyTorchNetwork, \\\n                DataParallelPyTorchNetwork\n            import torch\n\n            class SimpleModel(AbstractPyTorchNetwork):\n                def __init__(self):\n                    super().__init__()\n\n                    self.dense_1 = torch.nn.Linear(3, 32)\n                    self.dense_2 = torch.nn.Linear(32, 2)\n                    self.relu = torch.nn.ReLU()\n\n                def forward(self, x):\n                    return {\"pred\": self.dense_2(self.relu(self.dense_1(x)))}\n\n            model = SimpleModel()\n\n            self.optimizer = torch.optim.Adam(model.parameters())\n\n            if torch.cuda.is_available() and torch.cuda.device_count() > 1:\n                self.model = DataParallelPyTorchNetwork(model, [0, 1])\n            else:\n                self.model = model\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Test should be only executed if torch backend is \"\n                         \"installed and specified\")\n    def test_update(self):\n        import torch\n\n        input_tensor = torch.rand(10, 3)\n        label_tensor = torch.rand(10, 2)\n\n        model_copy = deepcopy(self.model)\n\n        preds = self.model(input_tensor)\n\n        loss = (preds[\"pred\"] - label_tensor).sum()\n\n        self.optimizer.zero_grad()\n        loss.backward()\n        self.optimizer.step()\n\n        for orig_param, updated_param in zip(model_copy.parameters(),\n                                             self.model.parameters()):\n            self.assertFalse(\n                np.array_equal(\n                    orig_param.detach().cpu().numpy(),\n                    updated_param.detach().cpu().numpy()))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/models/test_abstract_models.py",
    "content": "import unittest\nimport numpy as np\nfrom ..utils import check_for_chainer_backend, check_for_torch_backend, \\\n    check_for_tf_graph_backend, check_for_tf_eager_backend, \\\n    check_for_torchscript_backend, check_for_sklearn_backend\n\n\nclass TestAbstractModels(unittest.TestCase):\n\n    @staticmethod\n    def _setup_torch(*args):\n        import torch\n        from delira.models.backends.torch import AbstractPyTorchNetwork\n\n        class Model(AbstractPyTorchNetwork):\n            def __init__(self):\n                super().__init__()\n                self.dense = torch.nn.Linear(1, 1)\n                self.relu = torch.nn.ReLU()\n\n            def forward(self, x):\n                return {\"pred\": self.relu(self.dense(x))}\n\n        return Model()\n\n    @staticmethod\n    def _setup_torchscript(*args):\n        import torch\n        from delira.models.backends.torchscript import \\\n            AbstractTorchScriptNetwork\n\n        class Model(AbstractTorchScriptNetwork):\n            def __init__(self):\n                super().__init__()\n                self.dense = torch.nn.Linear(1, 1)\n                self.relu = torch.nn.ReLU()\n\n            @torch.jit.script_method\n            def forward(self, x):\n                return {\"pred\": self.relu(self.dense(x))}\n\n        return Model()\n\n    @staticmethod\n    def _setup_tfeager(*args):\n        import tensorflow as tf\n        tf.enable_eager_execution()\n        tf.reset_default_graph()\n        from delira.models.backends.tf_eager import AbstractTfEagerNetwork\n\n        class Model(AbstractTfEagerNetwork):\n            def __init__(self):\n                super().__init__()\n\n                self.dense = tf.keras.layers.Dense(1, activation=\"relu\")\n\n            def call(self, x: tf.Tensor):\n                return {\"pred\": self.dense(x)}\n\n        return Model()\n\n    @staticmethod\n    def _setup_tfgraph(*args):\n        import tensorflow as tf\n        tf.disable_eager_execution()\n        tf.reset_default_graph()\n        from delira.models import AbstractTfGraphNetwork\n        from delira.training.backends.tf_graph.utils import \\\n            initialize_uninitialized\n\n        class Model(AbstractTfGraphNetwork):\n            def __init__(self):\n                super().__init__()\n                self.dense = tf.keras.layers.Dense(1, activation=\"relu\")\n\n                data = tf.placeholder(shape=[None, 1],\n                                      dtype=tf.float32)\n\n                labels = tf.placeholder_with_default(\n                    tf.zeros([tf.shape(data)[0], 1]), shape=[None, 1])\n\n                preds_train = self.dense(data)\n                preds_eval = self.dense(data)\n\n                self.inputs[\"data\"] = data\n                self.inputs[\"labels\"] = labels\n                self.outputs_train[\"pred\"] = preds_train\n                self.outputs_eval[\"pred\"] = preds_eval\n\n        model = Model()\n        initialize_uninitialized(model._sess)\n        return model\n\n    @staticmethod\n    def _setup_chainer(*args):\n        import chainer\n        from delira.models import AbstractChainerNetwork\n\n        class Model(AbstractChainerNetwork):\n            def __init__(self):\n                super().__init__()\n\n                with self.init_scope():\n                    self.dense = chainer.links.Linear(1, 1)\n\n            def forward(self, x):\n                return {\n                    \"pred\":\n                        chainer.functions.relu(\n                            self.dense(x))\n                }\n\n        return Model()\n\n    @staticmethod\n    def _setup_sklearn(*args):\n\n        from delira.models import SklearnEstimator\n        from sklearn.neural_network import MLPRegressor\n\n        class Model(SklearnEstimator):\n            def __init__(self):\n                # prefit to enable prediction mode afterwards\n                module = MLPRegressor()\n                module.fit(*args)\n                super().__init__(module)\n\n            @staticmethod\n            def prepare_batch(batch: dict, input_device, output_device):\n                return batch\n\n        return Model()\n\n    def run_model_arg(self, device=None):\n        prep_data = self._model.prepare_batch(self._data, input_device=device,\n                                              output_device=device)\n\n        pred = self._model(prep_data[\"data\"])\n        self.assertIsInstance(pred, dict)\n\n    def run_model_kwarg(self, device=None, keyword=\"data\"):\n        prep_data = self._model.prepare_batch(self._data, input_device=device,\n                                              output_device=device)\n\n        pred = self._model(**{keyword: prep_data[\"data\"]})\n        self.assertIsInstance(pred, dict)\n\n    def setUp(self) -> None:\n        self._data = {\"data\": np.random.rand(100, 1),\n                      \"label\": np.random.rand(100, 1)}\n\n        if \"sklearn\" in self._testMethodName.lower():\n            self._model = self._setup_sklearn(self._data[\"data\"],\n                                              self._data[\"label\"])\n\n        elif \"chainer\" in self._testMethodName.lower():\n            self._model = self._setup_chainer()\n\n        elif \"pytorch\" in self._testMethodName.lower():\n            self._model = self._setup_torch()\n\n        elif \"torchscript\" in self._testMethodName.lower():\n            self._model = self._setup_torchscript()\n\n        elif \"tf_graph\" in self._testMethodName.lower():\n            self._model = self._setup_tfgraph()\n\n        elif \"tf_eager\" in self._testMethodName.lower():\n            self._model = self._setup_tfeager()\n\n    @unittest.skipUnless(check_for_sklearn_backend(),\n                         \"Test should be only executed if sklearn backend is \"\n                         \"installed and specified\")\n    def test_sklearn(self):\n        self.run_model_arg()\n\n    @unittest.skipUnless(check_for_chainer_backend(),\n                         \"Test should be only executed if chainer backend is \"\n                         \"installed and specified\")\n    def test_chainer(self):\n        import chainer\n        self.run_model_arg(chainer.backend.CpuDevice())\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         \"Test should be only executed if torch backend is \"\n                         \"installed and specified\")\n    def test_pytorch(self):\n        self.run_model_arg(\"cpu\")\n\n    @unittest.skipUnless(check_for_torchscript_backend(),\n                         \"Test should be only executed if torch backend is \"\n                         \"installed and specified\")\n    def test_torchscript(self):\n        self.run_model_arg(\"cpu\")\n\n    @unittest.skipUnless(check_for_tf_eager_backend(),\n                         \"Test should be only executed if tf backend is \"\n                         \"installed and specified\")\n    def test_tf_eager(self):\n        self.run_model_arg(\"/cpu:0\")\n\n    @unittest.skipUnless(check_for_tf_graph_backend(),\n                         \"Test should be only executed if tf backend is \"\n                         \"installed and specified\")\n    def test_tf_graph(self):\n\n        self.run_model_kwarg()\n\n    def tearDown(self) -> None:\n        import sys\n        import gc\n        try:\n            del sys.modules[\"tf\"]\n        except KeyError:\n            pass\n        try:\n            del tf\n        except (UnboundLocalError, NameError):\n            pass\n        try:\n            del sys.modules[\"tensorflow\"]\n        except KeyError:\n            pass\n        try:\n            del tensorflow\n        except (UnboundLocalError, NameError):\n            pass\n        gc.collect()\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/training/__init__.py",
    "content": ""
  },
  {
    "path": "tests/training/backends/__init__.py",
    "content": ""
  },
  {
    "path": "tests/training/backends/test_chainer.py",
    "content": "import unittest\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend\n\nfrom tests.utils import check_for_chainer_backend\n\n\nif check_for_chainer_backend():\n    from delira.models import AbstractChainerNetwork\n    import chainer\n\n    # define this outside, because it has to be pickleable, which it won't be,\n    # wehn defined inside a function\n    class DummyNetworkChainer(AbstractChainerNetwork):\n        def __init__(self):\n            super().__init__()\n\n            with self.init_scope():\n                self.dense_1 = chainer.links.Linear(32, 64)\n                self.dense_2 = chainer.links.Linear(64, 1)\n\n        def forward(self, x):\n            return {\n                \"pred\":\n                    self.dense_2(chainer.functions.relu(\n                        self.dense_1(x)))\n            }\n\n\nclass TestChainerBackend(\n    create_experiment_test_template_for_backend(\"CHAINER\")\n):\n    def setUp(self) -> None:\n        if check_for_chainer_backend():\n            from delira.training import ChainerExperiment\n            import chainer\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"L1\":\n                            chainer.functions.mean_absolute_error},\n                    \"optimizer_cls\": chainer.optimizers.Adam,\n                    \"optimizer_params\": {},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n            model_cls = DummyNetworkChainer\n            experiment_cls = ChainerExperiment\n\n        else:\n            config = None\n            model_cls = None\n            experiment_cls = None\n\n        len_train = 50\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": model_cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"x\": \"data\"}\n            }\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/test_sklearn.py",
    "content": "import unittest\nimport numpy as np\nfrom tests.utils import check_for_sklearn_backend\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend, DummyDataset\n\n\nclass TestSklearnBackend(\n    create_experiment_test_template_for_backend(\"SKLEARN\")\n):\n    def setUp(self) -> None:\n        if check_for_sklearn_backend():\n            from delira.training import SklearnExperiment\n            from sklearn.tree import DecisionTreeClassifier\n            from sklearn.neural_network import MLPClassifier\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"L1\":\n                            mean_absolute_error},\n                    \"optimizer_cls\": None,\n                    \"optimizer_params\": {},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n\n            # run tests for estimator with and without partial_fit\n            model_cls = [\n                DecisionTreeClassifier,\n                MLPClassifier\n            ]\n\n            experiment_cls = SklearnExperiment\n\n        else:\n            config = None\n            model_cls = []\n            experiment_cls = None\n\n        len_train = 50\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": _cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"X\": \"X\"},\n                \"metric_keys\": {\"L1\": (\"pred\", \"y\"),\n                                \"mae\": (\"pred\", \"y\")}\n            } for _cls in model_cls\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n    @unittest.skipUnless(check_for_sklearn_backend(),\n                         \"Test should only be executed if SKLEARN backend is \"\n                         \"installed and specified\")\n    def test_experiment_test(self):\n        from delira.data_loading import DataManager\n\n        # iterate over test cases\n        for case in self._test_cases:\n            with self.subTest(case=case):\n\n                # pop arguments (to use remaining case as kwargs later)\n                _ = case.pop(\"len_train\")\n                config = case.pop(\"config\")\n                metric_keys = case.pop(\"metric_keys\")\n                network_cls = case.pop(\"network_cls\")\n                len_test = case.pop(\"len_test\")\n                exp = self._experiment_cls(config, network_cls, **case)\n\n                # create data\n                dset_test = DummyDataset(len_test)\n                dmgr_test = DataManager(dset_test, 16, 1, None)\n\n                model = network_cls()\n\n                # must fit on 2 samples to initialize coefficients\n                model.fit(np.random.rand(2, 32), np.array([[0], [1]]))\n\n                exp.test(model, dmgr_test,\n                         config.nested_get(\"metrics\", {}),\n                         metric_keys)\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/test_tf_eager.py",
    "content": "import unittest\nimport gc\nfrom tests.utils import check_for_tf_eager_backend\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend\n\n\nif check_for_tf_eager_backend():\n    from delira.models import AbstractTfEagerNetwork\n    import tensorflow as tf\n\n    class DummyNetworkTfEager(AbstractTfEagerNetwork):\n        def __init__(self):\n            super().__init__()\n\n            self.model = tf.keras.models.Sequential(\n                layers=[\n                    tf.keras.layers.Dense(64, input_shape=(\n                        32,), bias_initializer='glorot_uniform'),\n                    tf.keras.layers.ReLU(),\n                    tf.keras.layers.Dense(\n                        1,\n                        bias_initializer='glorot_uniform')]\n            )\n\n        def call(self, x: tf.Tensor):\n            return {\"pred\": self.model(x)}\n\n\nclass TestTfEagerBackend(\n    create_experiment_test_template_for_backend(\"TFEAGER\")\n):\n    def setUp(self) -> None:\n        if check_for_tf_eager_backend():\n            import tensorflow as tf\n            tf.enable_eager_execution()\n            from delira.training import TfEagerExperiment\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"L1\":\n                            tf.losses.absolute_difference},\n                    \"optimizer_cls\": tf.train.AdamOptimizer,\n                    \"optimizer_params\": {\"learning_rate\": 1e-3},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n            model_cls = DummyNetworkTfEager\n            experiment_cls = TfEagerExperiment\n\n        else:\n            config = None\n            model_cls = None\n            experiment_cls = None\n\n        len_train = 100\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": model_cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"x\": \"data\"},\n            }\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n    def tearDown(self):\n        import sys\n        try:\n            del sys.modules[\"tf\"]\n        except KeyError:\n            pass\n        try:\n            del tf\n        except (UnboundLocalError, NameError):\n            pass\n        try:\n            del sys.modules[\"tensorflow\"]\n        except KeyError:\n            pass\n        try:\n            del tensorflow\n        except (UnboundLocalError, NameError):\n            pass\n        gc.collect()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/test_tf_graph.py",
    "content": "import unittest\nimport gc\nfrom tests.utils import check_for_tf_graph_backend\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend\n\n\nif check_for_tf_graph_backend():\n    from delira.models import AbstractTfGraphNetwork\n    import tensorflow as tf\n\n    class DummyNetworkTfGraph(AbstractTfGraphNetwork):\n        def __init__(self):\n            super().__init__()\n\n            self.model = tf.keras.models.Sequential(\n                layers=[\n                    tf.keras.layers.Dense(64, input_shape=(\n                        32,), bias_initializer='glorot_uniform'),\n                    tf.keras.layers.ReLU(),\n                    tf.keras.layers.Dense(\n                        1,\n                        bias_initializer='glorot_uniform')]\n            )\n\n            data = tf.placeholder(shape=[None, 32], dtype=tf.float32)\n            labels = tf.placeholder_with_default(\n                tf.zeros([tf.shape(data)[0], 1]), shape=[None, 1])\n\n            preds_train = self.model(data)\n            preds_eval = self.model(data)\n\n            self.inputs[\"data\"] = data\n            self.inputs[\"label\"] = labels\n            self.outputs_train[\"pred\"] = preds_train\n            self.outputs_eval[\"pred\"] = preds_eval\n\n\nclass TestTfGraphBackend(\n    create_experiment_test_template_for_backend(\"TFGRAPH\")\n):\n    def setUp(self) -> None:\n        if check_for_tf_graph_backend():\n            import tensorflow as tf\n            tf.disable_eager_execution()\n            from delira.training import TfGraphExperiment\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"CE\":\n                            tf.losses.softmax_cross_entropy},\n                    \"optimizer_cls\": tf.train.AdamOptimizer,\n                    \"optimizer_params\": {\"learning_rate\": 1e-3},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n            model_cls = DummyNetworkTfGraph\n            experiment_cls = TfGraphExperiment\n\n        else:\n            config = None\n            model_cls = None\n            experiment_cls = None\n\n        len_train = 100\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": model_cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"data\": \"data\"},\n            }\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n    def tearDown(self):\n        import sys\n        try:\n            del sys.modules[\"tf\"]\n        except KeyError:\n            pass\n        try:\n            del tf\n        except (UnboundLocalError, NameError):\n            pass\n        try:\n            del sys.modules[\"tensorflow\"]\n        except KeyError:\n            pass\n        try:\n            del tensorflow\n        except (UnboundLocalError, NameError):\n            pass\n        gc.collect()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/test_torch.py",
    "content": "import unittest\nfrom tests.utils import check_for_torch_backend\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend\n\n\nif check_for_torch_backend():\n    from delira.models import AbstractPyTorchNetwork\n    import torch\n\n    class DummyNetworkTorch(AbstractPyTorchNetwork):\n        def __init__(self):\n            super().__init__()\n\n            self.module = torch.nn.Sequential(\n                torch.nn.Linear(32, 64),\n                torch.nn.ReLU(),\n                torch.nn.Linear(64, 1)\n            )\n\n        def forward(self, x):\n            return {\n                \"pred\":\n                    self.module(x)\n            }\n\n\nclass TestTorchBackend(\n    create_experiment_test_template_for_backend(\"TORCH\")\n):\n    def setUp(self) -> None:\n        if check_for_torch_backend():\n            import torch\n            from delira.training import PyTorchExperiment\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"L1\":\n                        torch.nn.BCEWithLogitsLoss()},\n                    \"optimizer_cls\": torch.optim.Adam,\n                    \"optimizer_params\": {},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n            model_cls = DummyNetworkTorch\n            experiment_cls = PyTorchExperiment\n\n        else:\n            config = None\n            model_cls = None\n            experiment_cls = None\n\n        len_train = 100\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": model_cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"x\": \"data\"},\n            }\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/test_torchscript.py",
    "content": "import unittest\nfrom tests.utils import check_for_torchscript_backend\nfrom delira.utils import DeliraConfig\nfrom sklearn.metrics import mean_absolute_error\nfrom .utils import create_experiment_test_template_for_backend\n\n\nif check_for_torchscript_backend():\n    from delira.models import AbstractTorchScriptNetwork\n    import torch\n\n    class DummyNetworkTorchScript(AbstractTorchScriptNetwork):\n        __constants__ = [\"module\"]\n\n        def __init__(self):\n            super().__init__()\n\n            self.module = torch.nn.Sequential(\n                torch.nn.Linear(32, 64),\n                torch.nn.ReLU(),\n                torch.nn.Linear(64, 1)\n            )\n\n        @torch.jit.script_method\n        def forward(self, x):\n            return {\n                \"pred\":\n                    self.module(x)\n            }\n\n\nclass TestTorchScriptBackend(\n    create_experiment_test_template_for_backend(\"TORCHSCRIPT\")\n):\n    def setUp(self) -> None:\n        if check_for_torchscript_backend():\n            import torch\n            from delira.training import TorchScriptExperiment\n\n            config = DeliraConfig()\n            config.fixed_params = {\n                \"model\": {},\n                \"training\": {\n                    \"losses\": {\n                        \"L1\":\n                            torch.nn.BCEWithLogitsLoss()},\n                    \"optimizer_cls\": torch.optim.Adam,\n                    \"optimizer_params\": {},\n                    \"num_epochs\": 2,\n                    \"metrics\": {\"mae\": mean_absolute_error},\n                    \"lr_sched_cls\": None,\n                    \"lr_sched_params\": {}}\n            }\n            model_cls = DummyNetworkTorchScript\n            experiment_cls = TorchScriptExperiment\n\n        else:\n            config = None\n            model_cls = None\n            experiment_cls = None\n\n        len_train = 100\n        len_test = 50\n\n        self._test_cases = [\n            {\n                \"config\": config,\n                \"network_cls\": model_cls,\n                \"len_train\": len_train,\n                \"len_test\": len_test,\n                \"key_mapping\": {\"x\": \"data\"},\n            }\n        ]\n        self._experiment_cls = experiment_cls\n\n        super().setUp()\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/backends/utils.py",
    "content": "import numpy as np\nfrom delira.data_loading import AbstractDataset, DataManager\nfrom delira.training import BaseExperiment\nfrom tests.utils import check_for_chainer_backend, \\\n    check_for_tf_eager_backend, check_for_tf_graph_backend, \\\n    check_for_sklearn_backend, check_for_torch_backend, \\\n    check_for_torchscript_backend\nimport unittest\nimport logging\n\nfrom delira.training.callbacks import AbstractCallback\n\ncallback_logger = logging.getLogger(\"CallbackLogger\")\n\n_SKIP_CONDITIONS = {\n    \"CHAINER\": check_for_chainer_backend,\n    \"TFEAGER\": check_for_tf_eager_backend,\n    \"TFGRAPH\": check_for_tf_graph_backend,\n    \"TORCH\": check_for_torch_backend,\n    \"TORCHSCRIPT\": check_for_torchscript_backend,\n    \"SKLEARN\": check_for_sklearn_backend\n}\n\n\nclass DummyDataset(AbstractDataset):\n    def __init__(self, length):\n        super().__init__(None, None)\n        self.length = length\n\n    def __getitem__(self, index):\n        return {\"data\": np.random.rand(32),\n                \"label\": np.random.randint(0, 1, 1)}\n\n    def __len__(self):\n        return self.length\n\n    def get_sample_from_index(self, index):\n        return self.__getitem__(index)\n\n\nclass LoggingCallback():\n    def at_epoch_begin(self, trainer, curr_epoch, **kwargs):\n        callback_logger.info(\"AtEpochBegin_epoch{}\".format(curr_epoch))\n        return {}\n\n    def at_epoch_end(self, trainer, curr_epoch, **kwargs):\n        callback_logger.info(\"AtEpochEnd_epoch{}\".format(curr_epoch))\n        return {}\n\n    def at_training_begin(self, trainer, **kwargs):\n        callback_logger.info(\"AtTrainingBegin_fold{}\".format(trainer.fold))\n        return {}\n\n    def at_training_end(self, trainer, **kwargs):\n        callback_logger.info(\"AtTrainingEnd_fold{}\".format(trainer.fold))\n        return {}\n\n    def at_iter_begin(self, trainer, iter_num, **kwargs):\n        callback_logger.info(\"AtIterBegin_iter{}\".format(iter_num))\n        return {}\n\n    def at_iter_end(self, trainer, iter_num, **kwargs):\n        callback_logger.info(\"AtIterEnd_iter{}\".format(iter_num))\n        return {}\n\n\ndef add_logging_callback(dict_like):\n    callbacks = list(dict_like.pop(\"callbacks\", []))\n    callbacks.append(LoggingCallback())\n    dict_like[\"callbacks\"] = callbacks\n    return dict_like\n\n\ndef run_experiment(experiment_cls, config, network_cls, len_train, len_test,\n                   **kwargs):\n    assert issubclass(experiment_cls, BaseExperiment)\n    exp = experiment_cls(config, network_cls, **kwargs)\n\n    dset_train = DummyDataset(len_train)\n    dset_test = DummyDataset(len_test)\n\n    dmgr_train = DataManager(dset_train, 16, 4, None)\n    dmgr_test = DataManager(dset_test, 16, 1, None)\n\n    return exp.run(dmgr_train, dmgr_test)\n\n\ndef test_experiment(experiment_cls, config, network_cls, len_test, **kwargs):\n    assert issubclass(experiment_cls, BaseExperiment)\n\n    exp = experiment_cls(config, network_cls, **kwargs)\n\n    dset_test = DummyDataset(len_test)\n    dmgr_test = DataManager(dset_test, 16, 1, None)\n\n    model = network_cls()\n\n    return exp.test(model, dmgr_test, config.nested_get(\"metrics\", {}),\n                    kwargs.get(\"metric_keys\", None))\n\n\ndef kfold_experiment(experiment_cls, config, network_cls, len_data,\n                     shuffle=True, split_type=\"random\",\n                     num_splits=2, val_split=None, **kwargs):\n    assert issubclass(experiment_cls, BaseExperiment)\n\n    metric_keys = kwargs.pop(\"metric_keys\", None)\n\n    exp = experiment_cls(config, network_cls, **kwargs)\n\n    dset = DummyDataset(len_data)\n    dmgr = DataManager(dset, 16, 1, None)\n\n    return exp.kfold(data=dmgr, metrics=config.nested_get(\"metrics\"),\n                     shuffle=shuffle, split_type=split_type,\n                     num_splits=num_splits, val_split=val_split,\n                     metric_keys=metric_keys)\n\n\ndef create_experiment_test_template_for_backend(backend: str):\n    backend_skip = unittest.skipUnless(_SKIP_CONDITIONS[backend](),\n                                       \"Test should be only executed if \"\n                                       \"backend %s is installed and specified\"\n                                       % backend)\n\n    class TestCase(unittest.TestCase):\n\n        def setUp(self) -> None:\n            # check if the proviced test case hast the following attributes set\n            assert hasattr(self, \"_experiment_cls\")\n            assert hasattr(self, \"_test_cases\")\n            self.logging_msg_run = [\n                'INFO:CallbackLogger:AtEpochBegin_epoch1',\n                'INFO:CallbackLogger:AtEpochEnd_epoch1',\n                'INFO:CallbackLogger:AtIterBegin_iter0',\n                'INFO:CallbackLogger:AtIterEnd_iter0',\n                'INFO:CallbackLogger:AtTrainingBegin_fold0',\n                'INFO:CallbackLogger:AtTrainingEnd_fold0',\n            ]\n            self.logging_msg_test = [\n                'INFO:CallbackLogger:AtIterBegin_iter0',\n                'INFO:CallbackLogger:AtIterEnd_iter0',\n            ]\n            self.logging_msg_kfold = [\n                'INFO:CallbackLogger:AtEpochBegin_epoch1',\n                'INFO:CallbackLogger:AtEpochEnd_epoch1',\n                'INFO:CallbackLogger:AtIterBegin_iter0',\n                'INFO:CallbackLogger:AtIterEnd_iter0',\n                'INFO:CallbackLogger:AtTrainingBegin_fold0',\n                'INFO:CallbackLogger:AtTrainingEnd_fold0',\n                'INFO:CallbackLogger:AtTrainingBegin_fold1',\n                'INFO:CallbackLogger:AtTrainingEnd_fold1',\n            ]\n\n        @backend_skip\n        def test_experiment_run(self):\n            # prototype to run an experiment once for each testcase\n            for case in self._test_cases:\n                with self.subTest(case=case):\n                    case = add_logging_callback(case)\n                    with self.assertLogs(callback_logger, \"INFO\") as cm:\n                        run_experiment(self._experiment_cls, **case)\n\n                    for msg in self.logging_msg_run:\n                        self.assertIn(msg, cm.output)\n\n        @backend_skip\n        def test_experiment_test(self):\n            # prototype to test an experiment once with each testcase\n            for case in self._test_cases:\n                with self.subTest(case=case):\n                    _ = case.pop(\"len_train\")\n                    case = add_logging_callback(case)\n                    with self.assertLogs(callback_logger, \"INFO\") as cm:\n                        test_experiment(self._experiment_cls,\n                                        **case)\n\n                    for msg in self.logging_msg_test:\n                        self.assertIn(msg, cm.output)\n\n        @backend_skip\n        def test_experiment_kfold(self):\n            # runs multiple kfolds with each testcase\n            # ( 1 for each combination of split_type and val_split)\n            for case in self._test_cases:\n                with self.subTest(case=case):\n\n                    # combine test and train data to len_data\n                    len_data = case.pop(\"len_test\") + case.pop(\"len_train\")\n                    case[\"len_data\"] = len_data\n                    case = add_logging_callback(case)\n\n                    for split_type in [\"random\", \"stratified\", \"error\"]:\n                        with self.subTest(split_type=split_type):\n\n                            if split_type == \"error\":\n\n                                # must raise ValueError\n                                with self.assertRaises(ValueError):\n                                    kfold_experiment(\n                                        self._experiment_cls, **case,\n                                        split_type=split_type, num_splits=2)\n\n                                continue\n\n                            else:\n                                for val_split in [0.2, None]:\n                                    with self.subTest(val_split=val_split):\n                                        with self.assertLogs(\n                                                callback_logger, \"INFO\") as cm:\n                                            kfold_experiment(\n                                                self._experiment_cls, **case,\n                                                val_split=val_split,\n                                                split_type=split_type,\n                                                num_splits=2,\n                                            )\n\n                                        for msg in self.logging_msg_kfold:\n                                            self.assertIn(msg, cm.output)\n\n    return TestCase\n"
  },
  {
    "path": "tests/training/test_losses_torch.py",
    "content": "\nimport unittest\n\nfrom ..utils import check_for_torch_backend\n\n\nclass FocalLossTestPyTorch(unittest.TestCase):\n\n    @unittest.skipUnless(check_for_torch_backend(),\n                         reason=\"No torch backend installed\")\n    def test_focalloss(self):\n        \"\"\"\n        Test some predefines focal loss values\n        \"\"\"\n\n        from delira.training.losses import BCEFocalLossLogitPyTorch, \\\n            BCEFocalLossPyTorch\n        import torch.nn as nn\n        import torch\n        import torch.nn.functional as F\n\n        # examples\n        #######################################################################\n        # binary values\n        p = torch.Tensor([[0, 0.2, 0.5, 1.0], [0, 0.2, 0.5, 1.0]])\n        t = torch.Tensor([[0, 0, 0, 0], [1, 1, 1, 1]])\n        p_l = torch.Tensor([[-2, -1, 0, 2], [-2, -1, 0, 1]])\n\n        #######################################################################\n        # params\n        gamma = 2\n        alpha = 0.25\n        eps = 1e-8\n\n        #######################################################################\n        # compute targets\n        # target for focal loss\n        p_t = p * t + (1 - p) * (1 - t)\n        alpha_t = torch.Tensor([alpha]).expand_as(t) * t + \\\n            (1 - t) * (1 - torch.Tensor([alpha]).expand_as(t))\n        w = alpha_t * (1 - p_t).pow(torch.Tensor([gamma]))\n        fc_value = F.binary_cross_entropy(p, t, w, reduction='none')\n\n        # target for focal loss with logit\n        p_tmp = torch.sigmoid(p_l)\n        p_t = p_tmp * t + (1 - p_tmp) * (1 - t)\n        alpha_t = torch.Tensor([alpha]).expand_as(t) * t + \\\n            (1 - t) * (1 - torch.Tensor([alpha]).expand_as(t))\n        w = alpha_t * (1 - p_t).pow(torch.Tensor([gamma]))\n\n        fc_value_logit = \\\n            F.binary_cross_entropy_with_logits(p_l, t, w, reduction='none')\n\n        #######################################################################\n        # test against BCE and CE =>focal loss with gamma=0, alpha=None\n        # test against binary_cross_entropy\n        bce = nn.BCELoss(reduction='none')\n        focal = BCEFocalLossPyTorch(alpha=None, gamma=0, reduction='none')\n        bce_loss = bce(p, t)\n        focal_loss = focal(p, t)\n\n        self.assertTrue((torch.abs(bce_loss - focal_loss) < eps).all())\n\n        # test against binary_cross_entropy with logit\n        bce = nn.BCEWithLogitsLoss()\n        focal = BCEFocalLossLogitPyTorch(alpha=None, gamma=0)\n        bce_loss = bce(p_l, t)\n        focal_loss = focal(p_l, t)\n        self.assertTrue((torch.abs(bce_loss - focal_loss) < eps).all())\n\n        #######################################################################\n        # test focal loss with pre computed values\n        # test focal loss binary (values manually pre computed)\n        focal = BCEFocalLossPyTorch(gamma=gamma, alpha=alpha, reduction='none')\n        focal_loss = focal(p, t)\n        self.assertTrue((torch.abs(fc_value - focal_loss) < eps).all())\n\n        # test focal loss binary with logit (values manually pre computed)\n        # Note that now p_l is used as prediction\n        focal = BCEFocalLossLogitPyTorch(\n            gamma=gamma, alpha=alpha, reduction='none')\n        focal_loss = focal(p_l, t)\n        self.assertTrue((torch.abs(fc_value_logit - focal_loss) < eps).all())\n\n        #######################################################################\n        # test if backward function works\n        p.requires_grad = True\n        focal = BCEFocalLossPyTorch(gamma=gamma, alpha=alpha)\n        focal_loss = focal(p, t)\n        try:\n            focal_loss.backward()\n        except BaseException:\n            self.assertTrue(False, \"Backward function failed for focal loss\")\n\n        p_l.requires_grad = True\n        focal = BCEFocalLossLogitPyTorch(gamma=gamma, alpha=alpha)\n        focal_loss = focal(p_l, t)\n        try:\n            focal_loss.backward()\n        except BaseException:\n            self.assertTrue(\n                False, \"Backward function failed for focal loss with logits\")\n\n\nif __name__ == \"__main__\":\n    unittest.main()\n"
  },
  {
    "path": "tests/training/test_metrics.py",
    "content": "import numpy as np\nfrom sklearn.metrics import accuracy_score\nimport unittest\n\nfrom delira.training.metrics import SklearnClassificationMetric, \\\n    SklearnAccuracyScore, AurocMetric\n\nfrom ..utils import check_for_no_backend\n\n\nclass TestMetrics(unittest.TestCase):\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is specified\")\n    def test_sklearn_classification_metric(self):\n        \"\"\"\n        Test metric wrapper for sklearn metrics\n        \"\"\"\n        target = np.array([1, 1, 1, 1, 1])\n        pred = np.array([0, 1, 0, 1, 0])\n        dummy_fn = accuracy_score\n\n        metric_wrapped = SklearnClassificationMetric(dummy_fn,\n                                                     pred_logits=False,\n                                                     gt_logits=False)\n        wrapped_score = metric_wrapped(target, pred)\n        self.assertLess(np.abs(wrapped_score - 0.4), 1e-8)\n\n        metric_ac = SklearnAccuracyScore(gt_logits=False, pred_logits=False)\n        score = metric_ac(target, pred)\n        self.assertLess(np.abs(score - 0.4), 1e-8)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is specified\")\n    def test_auroc_metric(self):\n        \"\"\"\n        Test auroc metric\n        \"\"\"\n        pred = np.array([1, 1, 1, 1])\n        target = np.array([1, 0, 1, 0])\n\n        metric_auc = AurocMetric()\n        score_auc = metric_auc(target, pred)\n        self.assertEqual(score_auc, 0.5)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/__init__.py",
    "content": "from delira import get_backends\nimport os\n\n\ndef check_for_environment_variable(variable: str, value: str):\n    if variable not in os.environ or os.environ[variable] == value:\n        return True\n    return False\n\n\ndef check_for_backend(backend_name, environment_variable):\n    backend_installed = backend_name in get_backends()\n    backend_specified = check_for_environment_variable(\"BACKEND\",\n                                                       environment_variable)\n\n    return backend_installed and backend_specified\n\n\ndef check_for_torch_backend():\n    return check_for_backend(\"TORCH\", \"Torch\")\n\n\ndef check_for_torchscript_backend():\n    return check_for_backend(\"TORCH\", \"TorchScript\")\n\n\ndef check_for_tf_eager_backend():\n    return check_for_backend(\"TF\", \"TFEager\")\n\n\ndef check_for_tf_graph_backend():\n    return check_for_backend(\"TF\", \"TFGraph\")\n\n\ndef check_for_chainer_backend():\n    return check_for_backend(\"CHAINER\", \"Chainer\")\n\n\ndef check_for_sklearn_backend():\n    return check_for_backend(\"SKLEARN\", \"Sklearn\")\n\n\ndef check_for_no_backend():\n    # sklearn backend is always installed, so this check is mainly a check if\n    # installation was successfull and checks for environment variable\n    return check_for_backend(\"SKLEARN\", \"None\")\n"
  },
  {
    "path": "tests/utils/dict_reductions.py",
    "content": "import unittest\nimport numpy as np\n\nfrom delira.utils.dict_reductions import possible_reductions, \\\n    flatten_dict, unflatten_dict, reduce_dict, get_reduction\n\n\nclass TestDictReductions(unittest.TestCase):\n    def setUp(self) -> None:\n        self._reduce_sequence = [2, 3, 4, 5, 6]\n\n        self._test_dict = {\n            \"a\": self._reduce_sequence,\n            \"b\": {\n                \"c\": self._reduce_sequence\n            },\n            \"d\": {\n                \"e\": {\n                    \"f\": self._reduce_sequence\n                }\n            }\n        }\n\n        self._flattened_test_dict = {\n            \"a\": self._reduce_sequence,\n            \"b.c\": self._reduce_sequence,\n            \"d.e.f\": self._reduce_sequence\n        }\n\n        self._reduction_results = {\"max\": max(self._reduce_sequence),\n                                   \"min\": min(self._reduce_sequence),\n                                   \"mean\": np.mean(self._reduce_sequence),\n                                   \"median\": np.median(self._reduce_sequence),\n                                   \"first\": self._reduce_sequence[0],\n                                   \"last\": self._reduce_sequence[-1]}\n\n        self._reduce_dicts = []\n        for i in self._reduce_sequence:\n            self._reduce_dicts.append(\n                {\n                    \"a\": i,\n                    \"b\": {\n                        \"c\": i\n                    },\n                    \"d\": {\n                        \"e\": {\n                            \"f\": i\n                        }\n                    }\n\n                }\n            )\n\n    def test_dict_flatten(self):\n        result_dict = flatten_dict(self._test_dict, parent_key='', sep=\".\")\n        self.assertDictEqual(result_dict, self._flattened_test_dict)\n\n    def test_dict_unflatten(self):\n        result_dict = unflatten_dict(self._flattened_test_dict, sep=\".\")\n        self.assertDictEqual(result_dict, self._test_dict)\n\n    def test_dict_flatten_unflatten(self):\n        result_dict = unflatten_dict(flatten_dict(self._test_dict,\n                                                  parent_key='', sep=\".\"),\n                                     sep=\".\")\n\n        self.assertDictEqual(result_dict, self._test_dict)\n\n    def test_reduction_fuctions(self):\n        for key in possible_reductions():\n            with self.subTest(reduce_type=key):\n                result = get_reduction(key)(self._reduce_sequence)\n\n                # convert array to scalar if necessary\n                if isinstance(result, np.ndarray):\n                    result = result.item()\n\n                self.assertEquals(result, self._reduction_results[key])\n\n    def test_reduce_dict(self):\n        for key in possible_reductions():\n            with self.subTest(reduce_type=key):\n                result_dict = reduce_dict(self._reduce_dicts,\n                                          get_reduction(key))\n\n                target_dict = {\n                    \"a\": self._reduction_results[key],\n                    \"b\": {\n                        \"c\": self._reduction_results[key]\n                    },\n                    \"d\": {\n                        \"e\": {\n                            \"f\": self._reduction_results[key]\n                        }\n                    }\n\n                }\n\n                self.assertDictEqual(result_dict, target_dict)\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_codecs.py",
    "content": "import unittest\nimport numpy as np\nfrom functools import partial\n\nfrom delira.utils.codecs import Encoder, Decoder\n\nfrom . import check_for_no_backend\n\n\nclass CodecsTest(unittest.TestCase):\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_encoder(self):\n        test_dict = {}\n        test_dict['number'] = 1\n        test_dict['string'] = \"test_string\"\n        test_dict['list'] = [0, 1, 2, \"skjd\"]\n        test_dict['dict'] = {\"key0\": 0, \"key1\": 1, \"key2\": 2}\n        test_dict['tuple'] = (1, 2, 3)\n        test_dict['none'] = None\n        test_dict['nparray'] = np.array([0, 1, 2])\n        test_dict['function'] = partial\n        test_dict['class'] = np.ndarray\n\n        encoded_test_dict = Encoder().encode(test_dict)\n\n        self.assertTrue(encoded_test_dict['number'] == 1)\n        self.assertTrue(encoded_test_dict['string'] == \"test_string\")\n        self.assertListEqual(encoded_test_dict['list'], [0, 1, 2, \"skjd\"])\n        self.assertDictEqual(encoded_test_dict['dict'], {\n                             \"key0\": 0, \"key1\": 1, \"key2\": 2})\n        self.assertDictEqual(encoded_test_dict['tuple'], {\n            \"__convert__\": {\n                \"repr\": [1, 2, 3],\n                \"type\": {\n                    \"__type__\": {\"module\": \"builtins\", \"name\": \"tuple\"}}\n            }})\n        self.assertIsNone(encoded_test_dict[\"none\"])\n        self.assertDictEqual(encoded_test_dict[\"nparray\"],\n                             {\"__array__\": [0, 1, 2]})\n        self.assertDictEqual(encoded_test_dict[\"function\"], {\n            \"__type__\": {\"module\": \"functools\",\n                             \"name\": \"partial\"}})\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_decoder(self):\n        test_dict = {}\n        test_dict['number'] = 1\n        test_dict['string'] = \"test_string\"\n        test_dict['list'] = [0, 1, 2, \"skjd\"]\n        test_dict['dict'] = {\"key0\": 0, \"key1\": 1, \"key2\": 2}\n        test_dict['tuple'] = {\"__convert__\": {\n            \"repr\": [1, 2, 3],\n            \"type\": {\"__type__\": {\"module\": \"builtins\", \"name\": \"tuple\"}}\n        }}\n        test_dict['none'] = None\n        test_dict['nparray'] = {\"__array__\": [0, 1, 2]}\n        test_dict['function'] = {\"__function__\": {\n            \"module\": \"numpy\", \"name\": \"amin\"}}\n        test_dict['class'] = {\"__type__\": {\n            \"module\": \"numpy\", \"name\": \"ndarray\"}}\n        test_dict[\"classargs\"] = {\"__classargs__\":\n                                  {\"module\": \"numpy\",\n                                   \"name\": \"ndarray\",\n                                   \"args\": [[1, 2, 3]]\n                                   }\n                                  }\n        test_dict[\"funcargs\"] = {\"__functionargs__\":\n                                 {\"module\": \"numpy\",\n                                  \"name\": \"min\",\n                                  \"kwargs\": {\"axis\": (1, 2)}}\n                                 }\n\n        decoded_dict = Decoder().decode(test_dict)\n\n        self.assertTrue(decoded_dict['number'] == 1)\n        self.assertTrue(decoded_dict['string'] == \"test_string\")\n        self.assertListEqual(decoded_dict['list'], [0, 1, 2, \"skjd\"])\n        self.assertDictEqual(decoded_dict['dict'], {\n                             \"key0\": 0, \"key1\": 1, \"key2\": 2})\n        self.assertTupleEqual(decoded_dict['tuple'], (1, 2, 3))\n        self.assertIsNone(decoded_dict[\"none\"])\n        self.assertTrue((decoded_dict[\"nparray\"] == np.array([0, 1, 2])).all())\n        self.assertTrue(\n            decoded_dict[\"function\"].__module__ == np.min.__module__)\n        self.assertTrue(\n            decoded_dict[\"function\"].__name__ == np.min.__name__)\n        self.assertTrue(\n            decoded_dict[\"class\"].__module__ == np.ndarray.__module__)\n        self.assertTrue(\n            decoded_dict[\"class\"].__name__ == np.ndarray.__name__)\n        self.assertTrue(test_dict[\"classargs\"].shape == (1, 2, 3))\n        self.assertTrue(test_dict[\"funcargs\"].args[0] == [])\n        self.assertTrue(test_dict[\"funcargs\"].args[1][\"axis\"] == (1, 2))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_config.py",
    "content": "import unittest\nimport os\nimport sys\nimport copy\nimport argparse\nfrom unittest.mock import patch\nfrom delira._version import get_versions\n\nfrom delira.utils.config import Config, LookupConfig, DeliraConfig\nfrom delira.logging import Logger, TensorboardBackend, make_logger, \\\n    register_logger\nimport warnings\n\nfrom . import check_for_no_backend\n\n\nclass ConfigTest(unittest.TestCase):\n    def setUp(self):\n        self.config_cls = Config\n        self.example_dict = {\n            \"shallowStr\": \"a\",\n            \"shallowNum\": 1,\n            \"deep\": {\"deepStr\": \"b\", \"deepNum\": 2},\n            \"nestedListOrig\": [{\"dictList\": [1, 2, 3]}],\n        }\n        self.update_dict = {\n            \"deep\": {\"deepStr\": \"c\"},\n            \"shallowNew\": 3,\n            \"deepNew\": {\"newNum\": 4},\n            \"nestedList\": [{\"dictList\": [1, 2, 3]}],\n            \"nestedList2\": [{\"dictList\": [1, 2, 3]}],\n        }\n\n        self._logger = self._setup_logger()\n        register_logger(self._logger, __file__)\n\n    def _setup_logger(self):\n        return make_logger(TensorboardBackend(\n            {\"logdir\": os.path.join(\".\", \"runs\", self._testMethodName)}\n        ))\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_config_access(self):\n        # initialization from dict\n        cf = self.config_cls(self.example_dict)\n        self.assertEqual(cf[\"shallowStr\"], self.example_dict[\"shallowStr\"])\n        self.assertEqual(cf[\"shallowNum\"], self.example_dict[\"shallowNum\"])\n\n        # check if parameters were written correctly\n        self.assertEqual(cf[\"deep\"][\"deepStr\"],\n                         self.example_dict[\"deep\"][\"deepStr\"])\n        self.assertEqual(cf[\"deep\"][\"deepNum\"],\n                         self.example_dict[\"deep\"][\"deepNum\"])\n\n        # check deep acces with operators\n        self.assertEqual(cf[\"deep.deepStr\"],\n                         self.example_dict[\"deep\"][\"deepStr\"])\n        self.assertEqual(cf.deep.deepNum,\n                         self.example_dict[\"deep\"][\"deepNum\"])\n\n        # empty initialization\n        cf = self.config_cls()\n\n        # set shallow attributes\n        cf.shallowString = \"string\"\n        cf.shallowNum = 1\n        cf.deep = {}\n        cf.deep.string = \"deepString\"\n        cf.deep.num = 2\n\n        cf[\"shallowString2\"] = \"string2\"\n        cf[\"shallowNum2\"] = 1\n        cf[\"deep.string2\"] = \"deepString2\"\n        cf[\"deep.num2\"] = 2\n\n        # check if parameters were written correctly\n        self.assertEqual(cf[\"shallowString\"], \"string\")\n        self.assertEqual(cf[\"shallowNum\"], 1)\n        self.assertEqual(cf[\"deep.string\"], \"deepString\")\n        self.assertEqual(cf[\"deep.num\"], 2)\n\n        self.assertEqual(cf[\"shallowString2\"], \"string2\")\n        self.assertEqual(cf[\"shallowNum2\"], 1)\n        self.assertEqual(cf[\"deep.string2\"], \"deepString2\")\n        self.assertEqual(cf[\"deep.num2\"], 2)\n\n        # check contains operator\n        self.assertTrue(\"shallowString\" in cf)\n        self.assertTrue(\"shallowString2\" in cf)\n        self.assertTrue(\"deep.string\" in cf)\n        self.assertTrue(\"deep.string2\" in cf)\n\n        warning_msg = (\"The key 5 is not a string, but a <class 'int'>. \"\n                       \"This may lead to unwanted behavior!\")\n        with self.assertWarns(RuntimeWarning, msg=warning_msg):\n            cf[5] = 10\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_config_access_with_non_existing_keys(self):\n        cf = self.config_cls(self.example_dict)\n\n        with self.assertRaises(KeyError):\n            cf[\"unknown_key\"]\n\n        with self.assertRaises(KeyError):\n            cf[\"shallowStr.unknown_key\"]\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_update(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        with self.assertRaises(ValueError):\n            cf.update(self.update_dict)\n\n        # update with overwrite\n        cf.update(self.update_dict, overwrite=True)\n        self.assertEqual(cf[\"deep.deepStr\"],\n                         self.update_dict[\"deep\"][\"deepStr\"])\n\n        # add new values\n        self.assertEqual(cf[\"shallowNew\"],\n                         self.update_dict[\"shallowNew\"])\n        self.assertEqual(cf[\"deepNew.newNum\"],\n                         self.update_dict[\"deepNew\"][\"newNum\"])\n\n        # check for shallow copy\n        cf[\"nestedList\"][0][\"dictList\"][0] = 10\n        self.assertEqual(self.update_dict[\"nestedList\"][0][\"dictList\"][0],\n                         cf[\"nestedList\"][0][\"dictList\"][0])\n\n        # check for deepcopy\n        cf.update(self.update_dict, overwrite=True, deepcopy=True)\n        cf[\"nestedList2\"][0][\"dictList\"][0] = 10\n        self.assertNotEqual(self.update_dict[\"nestedList2\"][0][\"dictList\"][0],\n                            cf[\"nestedList2\"][0][\"dictList\"][0])\n\n        # check for no error when only updating nested keys\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        update_dict = copy.deepcopy(self.update_dict)\n        update_dict[\"deep\"].pop(\"deepStr\")\n        update_dict[\"deep\"][\"deepStr2\"] = \"deepStr2\"\n        cf.update(update_dict)\n        self.assertEqual(cf[\"deep.deepStr2\"],\n                         update_dict[\"deep\"][\"deepStr2\"])\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_dump_and_load(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        path = os.path.join(\".\", \"test_config.yaml\")\n        # check dump\n        cf.dump(path)\n\n        # check load\n        cf_loaded = self.config_cls()\n        cf_loaded.load(path)\n        self.assertDictEqual(cf, cf_loaded)\n\n        cf_loaded_file = self.config_cls.create_from_file(path)\n        self.assertDictEqual(cf, cf_loaded_file)\n\n        # check dump\n        cf_string = cf.dumps()\n\n        # check load\n        cf_loaded = self.config_cls()\n        cf_loaded.loads(cf_string)\n        self.assertDictEqual(cf, cf_loaded)\n\n        cf_loaded_str = self.config_cls.create_from_str(cf_string)\n        self.assertDictEqual(cf, cf_loaded_str)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_copy(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n\n        # check for shallow copy\n        cf_shallow = copy.copy(cf)\n        cf_shallow[\"nestedListOrig\"][0][\"dictList\"][0] = 10\n        self.assertEqual(cf[\"nestedListOrig\"][0][\"dictList\"][0],\n                         cf_shallow[\"nestedListOrig\"][0][\"dictList\"][0])\n\n        # check for deepcopy\n        cf_deep = copy.deepcopy(cf)\n        cf_deep[\"nestedListOrig\"][0][\"dictList\"][0] = 20\n        self.assertNotEqual(cf[\"nestedListOrig\"][0][\"dictList\"][0],\n                            cf_deep[\"nestedListOrig\"][0][\"dictList\"][0])\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_create_from_argparse(self):\n        parser = argparse.ArgumentParser()\n        parser.add_argument('-p1')\n        parser.add_argument('--param2')\n        cf1 = self.config_cls.create_from_argparse(\n            parser, args=['-p1', 'parameter1', '--param2', 'parameter2'])\n        self.assertEqual(cf1['p1'], 'parameter1')\n        self.assertEqual(cf1['param2'], 'parameter2')\n\n        args = parser.parse_args(\n            ['-p1', 'parameter1', '--param2', 'parameter2'])\n        self.assertEqual(cf1['p1'], 'parameter1')\n        self.assertEqual(cf1['param2'], 'parameter2')\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_internal_type(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        self.assertTrue(isinstance(cf[\"deep\"], self.config_cls))\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_create_argparser(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        testargs = [\n            '--shallowNum',\n            '10',\n            '--deep.deepStr',\n            'check',\n            '--testlist',\n            'ele1',\n            'ele2',\n            '--setflag']\n        parser = cf.create_argparser()\n        known, unknown = parser.parse_known_args(testargs)\n        self.assertEqual(vars(known)['shallowNum'], 10)\n        self.assertEqual(vars(known)['deep.deepStr'], 'check')\n        self.assertEqual(unknown, ['--testlist', 'ele1', 'ele2', '--setflag'])\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_update_from_argparse(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        testargs = ['--shallowNum', '10',\n                    '--deep.deepStr', 'check',\n                    '--testlist', 'ele1', 'ele2',\n                    '--setflag']\n        # placeholder pyfile because argparser omits first argument from sys\n        # argv\n        with patch.object(sys, 'argv', ['pyfile.py'] + testargs):\n            cf.update_from_argparse(add_unknown_items=True)\n        self.assertEqual(cf['shallowNum'], int(testargs[1]))\n        self.assertEqual(cf['deep']['deepStr'], testargs[3])\n        self.assertEqual(cf['testlist'], testargs[5:7])\n        self.assertEqual(cf['setflag'], True)\n        with warnings.catch_warnings(record=True) as w:\n            with patch.object(sys, 'argv', ['pyfile.py', '--unknown', 'arg']):\n                cf.update_from_argparse(add_unknown_items=False)\n        self.assertEqual(len(w), 1)\n\n\nclass LookupConfigTest(ConfigTest):\n    def setUp(self):\n        super().setUp()\n        self.config_cls = LookupConfig\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_nested_lookpup(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        self.assertEqual(cf[\"deep.deepStr\"],\n                         cf.nested_get(\"deep.deepStr\"))\n        self.assertEqual(cf[\"deep.deepNum\"], cf.nested_get(\"deepNum\"))\n\n        with self.assertRaises(KeyError):\n            cf.nested_get(\"nonExistingKey\")\n\n        cf[\"deepStr\"] = \"duplicate\"\n        with self.assertRaises(KeyError):\n            cf.nested_get(\"deepStr\")\n\n        self.assertIsNone(cf.nested_get(\"nonExistingKey\", None))\n        self.assertIsNone(cf.nested_get(\"nonExistingKey\", default=None))\n\n        cf[\"nested_duplicate.deep\"] = \"duplicate\"\n        with self.assertRaises(KeyError):\n            cf.nested_get(\"deep\")\n\n        multiple_val = cf.nested_get(\"deep\", allow_multiple=True)\n\n        expected_result = [{\"deepStr\": \"b\", \"deepNum\": 2},\n                           \"duplicate\"]\n\n        for val in multiple_val:\n            self.assertIn(val, expected_result)\n            expected_result.pop(expected_result.index(val))\n\n        self.assertEquals(len(expected_result), 0)\n\n\nclass DeliraConfigTest(LookupConfigTest):\n    def setUp(self):\n        super().setUp()\n        self.config_cls = DeliraConfig\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_property_params(self):\n        for mode in [\"fixed\", \"variable\"]:\n            cf = self.config_cls.create_from_dict({})\n            setattr(cf, \"{}_params\".format(mode),\n                    {\"model\": {\"num_classes\": 3}, \"training\": {\"epochs\": 2}})\n\n            # manual checking of values\n            self.assertEqual(cf[\"{}_model.num_classes\".format(mode)], 3)\n            self.assertEqual(cf[\"{}_training.epochs\".format(mode)], 2)\n\n            # check getter\n            params = getattr(cf, \"{}_params\".format(mode))\n            self.assertEqual(params[\"model.num_classes\"], 3)\n            self.assertEqual(params[\"training.epochs\"], 2)\n\n        for mode in [\"training\", \"model\"]:\n            cf = self.config_cls.create_from_dict(self.example_dict)\n            setattr(cf, \"{}_params\".format(mode),\n                    {\"fixed\": {\"num_classes\": 3}, \"variable\": {\"epochs\": 2}})\n\n            # manual checking of values\n            self.assertEqual(cf[\"fixed_{}.num_classes\".format(mode)], 3)\n            self.assertEqual(cf[\"variable_{}.epochs\".format(mode)], 2)\n\n            # check getter\n            params = getattr(cf, \"{}_params\".format(mode))\n            self.assertEqual(params[\"fixed.num_classes\"], 3)\n            self.assertEqual(params[\"variable.epochs\"], 2)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_logging_as_string(self):\n        cf = self.config_cls()\n        cf.update({\"augment\": True})\n        cf.update({\"fixed_model\": \"fm\", \"fixed_training\": \"ft\",\n                   \"variable_model\": \"vm\", \"variable_training\": \"vt\"},\n                  overwrite=True)\n\n        cf_str = cf.log_as_string()\n        cf_str_full = cf.log_as_string(full_config=True)\n\n        self.assertEqual(cf_str,\n                         (\"__convert__:\\n\"\n                          \"  repr:\\n\"\n                          \"    _timestamp: {}\\n\"\n                          \"    fixed_model: fm\\n\"\n                          \"    fixed_training: ft\\n\"\n                          \"    variable_model: vm\\n\"\n                          \"    variable_training: vt\\n\"\n                          \"  type:\\n\"\n                          \"    __type__:\\n\"\n                          \"      module: delira.utils.config\\n\"\n                          \"      name: LookupConfig\\n\".format(\n                              cf[\"_timestamp\"])))\n\n        self.assertEqual(cf_str_full,\n                         (\"__convert__:\\n\"\n                          \"  repr:\\n\"\n                          \"    _timestamp: {}\\n\"\n                          \"    _version: {}\\n\"\n                          \"    augment: true\\n\"\n                          \"    fixed_model: fm\\n\"\n                          \"    fixed_training: ft\\n\"\n                          \"    variable_model: vm\\n\"\n                          \"    variable_training: vt\\n\"\n                          \"  type:\\n\"\n                          \"    __type__:\\n\"\n                          \"      module: delira.utils.config\\n\"\n                          \"      name: DeliraConfig\\n\".format(\n                              cf[\"_timestamp\"], cf[\"_version\"])))\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed if no backend is specified\")\n    def test_internal_type(self):\n        cf = self.config_cls.create_from_dict(self.example_dict)\n        self.assertTrue(isinstance(cf[\"deep\"], LookupConfig))\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "tests/utils/test_messenger.py",
    "content": "from delira.training import BaseExperiment, BaseNetworkTrainer, Predictor\nfrom delira.utils import DeliraConfig\nfrom delira.models import AbstractNetwork\n\nfrom delira.data_loading import DataManager\n\nfrom delira.training.utils import convert_to_numpy_identity\n\n\nfrom delira.utils.messenger import BaseMessenger, SlackMessenger\n\nfrom ..training.backends.utils import DummyDataset\n\nfrom . import check_for_no_backend\n\nimport unittest\nimport logging\nimport copy\n\nlogger = logging.getLogger(\"UnitTestMessenger\")\n\n\nclass DummyNetwork(AbstractNetwork):\n    \"\"\"\n    Emulate Network\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__()\n\n    def __call__(self, *args, **kwargs):\n        return {}\n\n    @staticmethod\n    def closure(model, data_dict: dict, optimizers: dict, losses=None,\n                metrics=None, fold=0, **kwargs):\n        return {}, {}, {}\n\n    @staticmethod\n    def prepare_batch(batch: dict, input_device, output_device):\n        return {}\n\n\nclass DummyTrainer(BaseNetworkTrainer):\n    \"\"\"\n    Emulate Trainer states\n    \"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.module = DummyNetwork()\n        callbacks = kwargs.pop(\"callbacks\", [])\n        self._setup(network=self.module, lr_scheduler_cls=None,\n                    lr_scheduler_params={}, gpu_ids=[], key_mapping={},\n                    convert_batch_to_npy_fn=convert_to_numpy_identity,\n                    prepare_batch_fn=self.module.prepare_batch,\n                    callbacks=callbacks)\n\n    def train(self, *args, num_epochs=2, **kwargs):\n        self._at_training_begin()\n        for epoch in range(self.start_epoch, num_epochs + 1):\n            self._at_epoch_begin(None, epoch, num_epochs)\n            is_best = True if epoch % 2 == 1 else False\n            self._at_epoch_end({}, None, epoch, is_best)\n        self._at_training_end()\n        return DummyNetwork()\n\n    def test(self, *args, **kwargs):\n        return [{}], [{}]\n\n    def save_state(self, file_name, *args, **kwargs):\n        pass\n\n\nclass DummyPredictor(Predictor):\n    \"\"\"\n    Emulate predictor\n    \"\"\"\n\n    def predict(self, *args, **kwargs):\n        return {}\n\n    def predict_data_mgr(self, *args, **kwargs):\n        yield {}, {}\n        return\n\n\nclass DummyExperiment(BaseExperiment):\n    def __init__(self):\n        dummy_config = DeliraConfig()\n        dummy_config.fixed_params = {\n            \"model\": {},\n            \"training\": {\n                \"losses\": {},\n                \"optimizer_cls\": None,\n                \"optimizer_params\": {},\n                \"num_epochs\": 2,\n                \"lr_sched_cls\": None,\n                \"lr_sched_params\": {}}\n        }\n        super().__init__(dummy_config,\n                         DummyNetwork,\n                         key_mapping={},\n                         name=\"TestExperiment\",\n                         trainer_cls=DummyTrainer,\n                         predictor_cls=DummyPredictor)\n\n    def run(self, *args, raise_error=False, **kwargs):\n        if raise_error:\n            raise RuntimeError()\n        else:\n            return super().run(*args, **kwargs)\n\n    def resume(self, *args, raise_error=False, **kwargs):\n        if raise_error:\n            raise RuntimeError()\n        else:\n            return super().resume(*args, **kwargs)\n\n    def test(self, *args, raise_error=False, **kwargs):\n        if raise_error:\n            raise RuntimeError()\n        else:\n            return super().test(*args, **kwargs)\n\n    def kfold(self, *args, raise_error=False, **kwargs):\n        if raise_error:\n            raise RuntimeError()\n        else:\n            return super().kfold(*args, **kwargs)\n\n\nclass LoggingBaseMessenger(BaseMessenger):\n    def __init__(\n            self,\n            experiment,\n            notify_epochs=None,\n            **kwargs):\n        \"\"\"\n        Test messenger for BaseMessenger\n        \"\"\"\n        super().__init__(experiment, notify_epochs=notify_epochs,\n                         **kwargs)\n\n    def emit_message(self, msg):\n        logger.info(msg)\n\n\nclass TestBaseMessenger(unittest.TestCase):\n    def setUp(self) -> None:\n        self.msg_run_successful = [\n            \"INFO:UnitTestMessenger:TestExperiment : Training started.\",\n            \"INFO:UnitTestMessenger:Epoch 1 trained.\",\n            \"INFO:UnitTestMessenger:Epoch 2 trained.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Training completed.\",\n        ]\n        self.msg_run_failed = [\n            \"INFO:UnitTestMessenger:TestExperiment : Training started.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Training failed. \\n\",\n        ]\n        # self.msg_resume_successful = []\n        # self.msg_resume_failed = []\n        self.msg_test_successful = [\n            \"INFO:UnitTestMessenger:TestExperiment : Test started.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Test completed.\",\n        ]\n        self.msg_test_failed = [\n            \"INFO:UnitTestMessenger:TestExperiment : Test started.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Test failed. \\n\",\n        ]\n        self.msg_kfold_successful = [\n            \"INFO:UnitTestMessenger:TestExperiment : Kfold started.\",\n            \"INFO:UnitTestMessenger:Fold 0 started.\",\n            \"INFO:UnitTestMessenger:Epoch 1 trained.\",\n            \"INFO:UnitTestMessenger:Epoch 2 trained.\",\n            \"INFO:UnitTestMessenger:Fold 0 completed.\",\n            \"INFO:UnitTestMessenger:Fold 1 started.\",\n            \"INFO:UnitTestMessenger:Epoch 1 trained.\",\n            \"INFO:UnitTestMessenger:Epoch 2 trained.\",\n            \"INFO:UnitTestMessenger:Fold 1 completed.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Kfold completed.\",\n        ]\n        self.msg_kfold_failed = [\n            \"INFO:UnitTestMessenger:TestExperiment : Kfold started.\",\n            \"INFO:UnitTestMessenger:TestExperiment : Kfold failed. \\n\",\n        ]\n\n        self.msg_create_experiment = []\n\n        self.messenger_cls = LoggingBaseMessenger\n        self.messenger_kwargs = {\"notify_epochs\": 1}\n        self.run_kwargs = {\"gpu_ids\": [], \"logging_type\": \"tensorboardX\",\n                           \"logging_kwargs\": {}, \"fold\": 3}\n\n    def create_experiment(self, expected_msg=None):\n        with self.assertLogs(logger, level='INFO') as cm:\n            dummy_exp = DummyExperiment()\n            dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs)\n\n            if expected_msg is None or not expected_msg:\n                logger.info(\"NoExpectedMessage\")\n\n        if expected_msg is None or not expected_msg:\n            self.assertEqual(cm.output,\n                             [\"INFO:UnitTestMessenger:NoExpectedMessage\"])\n        else:\n            self.assertEqual(cm.output, expected_msg)\n\n    def run_experiment(self, raise_error=False, expected_msg=None):\n        dummy_exp = DummyExperiment()\n        dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs)\n\n        dset_train = DummyDataset(10)\n        dset_test = DummyDataset(10)\n\n        dmgr_train = DataManager(dset_train, 2, 1, None)\n        dmgr_test = DataManager(dset_test, 2, 1, None)\n\n        with self.assertLogs(logger, level='INFO') as cm:\n            if raise_error:\n                with self.assertRaises(RuntimeError):\n                    dummy_exp.run(dmgr_train, dmgr_test,\n                                  raise_error=True, **self.run_kwargs)\n            else:\n                dummy_exp.run(dmgr_train, dmgr_test, raise_error=False,\n                              **self.run_kwargs,)\n\n            if expected_msg is None or not expected_msg:\n                logger.info(\"NoExpectedMessage\")\n\n        if expected_msg is None or not expected_msg:\n            self.assertEqual(cm.output,\n                             [\"INFO:UnitTestMessenger:NoExpectedMessage\"])\n        else:\n            self.assertEqual(cm.output, expected_msg)\n\n    def t_experiment(self, raise_error=False, expected_msg=None):\n        dummy_exp = DummyExperiment()\n        dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs)\n\n        dset_test = DummyDataset(10)\n        dmgr_test = DataManager(dset_test, 2, 1, None)\n\n        model = DummyNetwork()\n\n        with self.assertLogs(logger, level='INFO') as cm:\n            if raise_error:\n                with self.assertRaises(RuntimeError):\n                    dummy_exp.test(model, dmgr_test, {},\n                                   raise_error=True)\n            else:\n                dummy_exp.test(model, dmgr_test, {}, raise_error=False)\n\n            if expected_msg is None or not expected_msg:\n                logger.info(\"NoExpectedMessage\")\n\n        if expected_msg is None or not expected_msg:\n            self.assertEqual(cm.output,\n                             [\"INFO:UnitTestMessenger:NoExpectedMessage\"])\n        else:\n            self.assertEqual(cm.output, expected_msg)\n\n    def kfold_experiment(self, raise_error=False, expected_msg=None):\n        kfold_kwargs = copy.deepcopy(self.run_kwargs)\n        kfold_kwargs.pop(\"fold\")\n\n        dummy_exp = DummyExperiment()\n        dummy_exp = self.messenger_cls(dummy_exp, **self.messenger_kwargs)\n\n        dset = DummyDataset(10)\n        dmgr = DataManager(dset, 2, 1, None)\n\n        with self.assertLogs(logger, level='INFO') as cm:\n            if raise_error:\n                with self.assertRaises(RuntimeError):\n                    dummy_exp.kfold(data=dmgr, metrics={}, num_splits=2,\n                                    raise_error=True, **kfold_kwargs)\n            else:\n                dummy_exp.kfold(data=dmgr, metrics={}, num_splits=2,\n                                raise_error=False, **kfold_kwargs)\n\n            if expected_msg is None:\n                logger.info(\"NoExpectedMessage\")\n\n        if expected_msg is None:\n            self.assertEqual(cm.output,\n                             [\"INFO:UnitTestMessenger:NoExpectedMessage\"])\n        else:\n            self.assertEqual(cm.output, expected_msg)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_create_experiment(self):\n        self.create_experiment(self.msg_create_experiment)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_run_successful(self):\n        self.run_experiment(raise_error=False,\n                            expected_msg=self.msg_run_successful)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_run_failed(self):\n        self.run_experiment(raise_error=True,\n                            expected_msg=self.msg_run_failed)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_test_successful(self):\n        self.t_experiment(raise_error=False,\n                          expected_msg=self.msg_test_successful)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_test_failed(self):\n        self.t_experiment(raise_error=True,\n                          expected_msg=self.msg_test_failed)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_kfold_successful(self):\n        self.kfold_experiment(raise_error=False,\n                              expected_msg=self.msg_kfold_successful)\n\n    @unittest.skipUnless(\n        check_for_no_backend(),\n        \"Test should only be executed \"\n        \"if no backend is installed\")\n    def test_kfold_failed(self):\n        self.kfold_experiment(raise_error=True,\n                              expected_msg=self.msg_kfold_failed)\n\n\nclass LoggingSlackMessenger(SlackMessenger):\n    def emit_message(self, msg):\n        logger.info(msg)\n        return {}\n\n\nclass TestSlackMessenger(TestBaseMessenger):\n    def setUp(self) -> None:\n        super().setUp()\n\n        self.msg_create_experiment = [\n            \"INFO:UnitTestMessenger:Created new experiment: TestExperiment\",\n        ]\n\n        self.messenger_cls = LoggingSlackMessenger\n        self.messenger_kwargs = {\"notify_epochs\": 1, \"token\": \"dummyToken\",\n                                 \"channel\": \"dummyChannel\"}\n\n\nif __name__ == '__main__':\n    unittest.main()\n"
  },
  {
    "path": "versioneer.py",
    "content": "\n# Version: 0.18\n\n\"\"\"The Versioneer - like a rocketeer, but for versions.\n\nThe Versioneer\n==============\n\n* like a rocketeer, but for versions!\n* https://github.com/warner/python-versioneer\n* Brian Warner\n* License: Public Domain\n* Compatible With: python2.6, 2.7, 3.2, 3.3, 3.4, 3.5, 3.6, and pypy\n* [![Latest Version]\n(https://pypip.in/version/versioneer/badge.svg?style=flat)\n](https://pypi.python.org/pypi/versioneer/)\n* [![Build Status]\n(https://travis-ci.org/warner/python-versioneer.png?branch=master)\n](https://travis-ci.org/warner/python-versioneer)\n\nThis is a tool for managing a recorded version number in distutils-based\npython projects. The goal is to remove the tedious and error-prone \"update\nthe embedded version string\" step from your release process. Making a new\nrelease should be as easy as recording a new tag in your version-control\nsystem, and maybe making new tarballs.\n\n\n## Quick Install\n\n* `pip install versioneer` to somewhere to your $PATH\n* add a `[versioneer]` section to your setup.cfg (see below)\n* run `versioneer install` in your source tree, commit the results\n\n## Version Identifiers\n\nSource trees come from a variety of places:\n\n* a version-control system checkout (mostly used by developers)\n* a nightly tarball, produced by build automation\n* a snapshot tarball, produced by a web-based VCS browser, like github's\n  \"tarball from tag\" feature\n* a release tarball, produced by \"setup.py sdist\", distributed through PyPI\n\nWithin each source tree, the version identifier (either a string or a number,\nthis tool is format-agnostic) can come from a variety of places:\n\n* ask the VCS tool itself, e.g. \"git describe\" (for checkouts), which knows\n  about recent \"tags\" and an absolute revision-id\n* the name of the directory into which the tarball was unpacked\n* an expanded VCS keyword ($Id$, etc)\n* a `_version.py` created by some earlier build step\n\nFor released software, the version identifier is closely related to a VCS\ntag. Some projects use tag names that include more than just the version\nstring (e.g. \"myproject-1.2\" instead of just \"1.2\"), in which case the tool\nneeds to strip the tag prefix to extract the version identifier. For\nunreleased software (between tags), the version identifier should provide\nenough information to help developers recreate the same tree, while also\ngiving them an idea of roughly how old the tree is (after version 1.2, before\nversion 1.3). Many VCS systems can report a description that captures this,\nfor example `git describe --tags --dirty --always` reports things like\n\"0.7-1-g574ab98-dirty\" to indicate that the checkout is one revision past the\n0.7 tag, has a unique revision id of \"574ab98\", and is \"dirty\" (it has\nuncommitted changes.\n\nThe version identifier is used for multiple purposes:\n\n* to allow the module to self-identify its version: `myproject.__version__`\n* to choose a name and prefix for a 'setup.py sdist' tarball\n\n## Theory of Operation\n\nVersioneer works by adding a special `_version.py` file into your source\ntree, where your `__init__.py` can import it. This `_version.py` knows how to\ndynamically ask the VCS tool for version information at import time.\n\n`_version.py` also contains `$Revision$` markers, and the installation\nprocess marks `_version.py` to have this marker rewritten with a tag name\nduring the `git archive` command. As a result, generated tarballs will\ncontain enough information to get the proper version.\n\nTo allow `setup.py` to compute a version too, a `versioneer.py` is added to\nthe top level of your source tree, next to `setup.py` and the `setup.cfg`\nthat configures it. This overrides several distutils/setuptools commands to\ncompute the version when invoked, and changes `setup.py build` and `setup.py\nsdist` to replace `_version.py` with a small static file that contains just\nthe generated version data.\n\n## Installation\n\nSee [INSTALL.md](./INSTALL.md) for detailed installation instructions.\n\n## Version-String Flavors\n\nCode which uses Versioneer can learn about its version string at runtime by\nimporting `_version` from your main `__init__.py` file and running the\n`get_versions()` function. From the \"outside\" (e.g. in `setup.py`), you can\nimport the top-level `versioneer.py` and run `get_versions()`.\n\nBoth functions return a dictionary with different flavors of version\ninformation:\n\n* `['version']`: A condensed version string, rendered using the selected\n  style. This is the most commonly used value for the project's version\n  string. The default \"pep440\" style yields strings like `0.11`,\n  `0.11+2.g1076c97`, or `0.11+2.g1076c97.dirty`. See the \"Styles\" section\n  below for alternative styles.\n\n* `['full-revisionid']`: detailed revision identifier. For Git, this is the\n  full SHA1 commit id, e.g. \"1076c978a8d3cfc70f408fe5974aa6c092c949ac\".\n\n* `['date']`: Date and time of the latest `HEAD` commit. For Git, it is the\n  commit date in ISO 8601 format. This will be None if the date is not\n  available.\n\n* `['dirty']`: a boolean, True if the tree has uncommitted changes. Note that\n  this is only accurate if run in a VCS checkout, otherwise it is likely to\n  be False or None\n\n* `['error']`: if the version string could not be computed, this will be set\n  to a string describing the problem, otherwise it will be None. It may be\n  useful to throw an exception in setup.py if this is set, to avoid e.g.\n  creating tarballs with a version string of \"unknown\".\n\nSome variants are more useful than others. Including `full-revisionid` in a\nbug report should allow developers to reconstruct the exact code being tested\n(or indicate the presence of local changes that should be shared with the\ndevelopers). `version` is suitable for display in an \"about\" box or a CLI\n`--version` output: it can be easily compared against release notes and lists\nof bugs fixed in various releases.\n\nThe installer adds the following text to your `__init__.py` to place a basic\nversion in `YOURPROJECT.__version__`:\n\n    from ._version import get_versions\n    __version__ = get_versions()['version']\n    del get_versions\n\n## Styles\n\nThe setup.cfg `style=` configuration controls how the VCS information is\nrendered into a version string.\n\nThe default style, \"pep440\", produces a PEP440-compliant string, equal to the\nun-prefixed tag name for actual releases, and containing an additional \"local\nversion\" section with more detail for in-between builds. For Git, this is\nTAG[+DISTANCE.gHEX[.dirty]] , using information from `git describe --tags\n--dirty --always`. For example \"0.11+2.g1076c97.dirty\" indicates that the\ntree is like the \"1076c97\" commit but has uncommitted changes (\".dirty\"), and\nthat this commit is two revisions (\"+2\") beyond the \"0.11\" tag. For released\nsoftware (exactly equal to a known tag), the identifier will only contain the\nstripped tag, e.g. \"0.11\".\n\nOther styles are available. See [details.md](details.md) in the Versioneer\nsource tree for descriptions.\n\n## Debugging\n\nVersioneer tries to avoid fatal errors: if something goes wrong, it will tend\nto return a version of \"0+unknown\". To investigate the problem, run `setup.py\nversion`, which will run the version-lookup code in a verbose mode, and will\ndisplay the full contents of `get_versions()` (including the `error` string,\nwhich may help identify what went wrong).\n\n## Known Limitations\n\nSome situations are known to cause problems for Versioneer. This details the\nmost significant ones. More can be found on Github\n[issues page](https://github.com/warner/python-versioneer/issues).\n\n### Subprojects\n\nVersioneer has limited support for source trees in which `setup.py` is not in\nthe root directory (e.g. `setup.py` and `.git/` are *not* siblings). The are\ntwo common reasons why `setup.py` might not be in the root:\n\n* Source trees which contain multiple subprojects, such as\n  [Buildbot](https://github.com/buildbot/buildbot), which contains both\n  \"master\" and \"slave\" subprojects, each with their own `setup.py`,\n  `setup.cfg`, and `tox.ini`. Projects like these produce multiple PyPI\n  distributions (and upload multiple independently-installable tarballs).\n* Source trees whose main purpose is to contain a C library, but which also\n  provide bindings to Python (and perhaps other langauges) in subdirectories.\n\nVersioneer will look for `.git` in parent directories, and most operations\nshould get the right version string. However `pip` and `setuptools` have bugs\nand implementation details which frequently cause `pip install .` from a\nsubproject directory to fail to find a correct version string (so it usually\ndefaults to `0+unknown`).\n\n`pip install --editable .` should work correctly. `setup.py install` might\nwork too.\n\nPip-8.1.1 is known to have this problem, but hopefully it will get fixed in\nsome later version.\n\n[Bug #38](https://github.com/warner/python-versioneer/issues/38) is tracking\nthis issue. The discussion in\n[PR #61](https://github.com/warner/python-versioneer/pull/61) describes the\nissue from the Versioneer side in more detail.\n[pip PR#3176](https://github.com/pypa/pip/pull/3176) and\n[pip PR#3615](https://github.com/pypa/pip/pull/3615) contain work to improve\npip to let Versioneer work correctly.\n\nVersioneer-0.16 and earlier only looked for a `.git` directory next to the\n`setup.cfg`, so subprojects were completely unsupported with those releases.\n\n### Editable installs with setuptools <= 18.5\n\n`setup.py develop` and `pip install --editable .` allow you to install a\nproject into a virtualenv once, then continue editing the source code (and\ntest) without re-installing after every change.\n\n\"Entry-point scripts\" (`setup(entry_points={\"console_scripts\": ..})`) are a\nconvenient way to specify executable scripts that should be installed along\nwith the python package.\n\nThese both work as expected when using modern setuptools. When using\nsetuptools-18.5 or earlier, however, certain operations will cause\n`pkg_resources.DistributionNotFound` errors when running the entrypoint\nscript, which must be resolved by re-installing the package. This happens\nwhen the install happens with one version, then the egg_info data is\nregenerated while a different version is checked out. Many setup.py commands\ncause egg_info to be rebuilt (including `sdist`, `wheel`, and installing into\na different virtualenv), so this can be surprising.\n\n[Bug #83](https://github.com/warner/python-versioneer/issues/83) describes\nthis one, but upgrading to a newer version of setuptools should probably\nresolve it.\n\n### Unicode version strings\n\nWhile Versioneer works (and is continually tested) with both Python 2 and\nPython 3, it is not entirely consistent with bytes-vs-unicode distinctions.\nNewer releases probably generate unicode version strings on py2. It's not\nclear that this is wrong, but it may be surprising for applications when then\nwrite these strings to a network connection or include them in bytes-oriented\nAPIs like cryptographic checksums.\n\n[Bug #71](https://github.com/warner/python-versioneer/issues/71) investigates\nthis question.\n\n\n## Updating Versioneer\n\nTo upgrade your project to a new release of Versioneer, do the following:\n\n* install the new Versioneer (`pip install -U versioneer` or equivalent)\n* edit `setup.cfg`, if necessary, to include any new configuration settings\n  indicated by the release notes. See [UPGRADING](./UPGRADING.md) for details.\n* re-run `versioneer install` in your source tree, to replace\n  `SRC/_version.py`\n* commit any changed files\n\n## Future Directions\n\nThis tool is designed to make it easily extended to other version-control\nsystems: all VCS-specific components are in separate directories like\nsrc/git/ . The top-level `versioneer.py` script is assembled from these\ncomponents by running make-versioneer.py . In the future, make-versioneer.py\nwill take a VCS name as an argument, and will construct a version of\n`versioneer.py` that is specific to the given VCS. It might also take the\nconfiguration arguments that are currently provided manually during\ninstallation by editing setup.py . Alternatively, it might go the other\ndirection and include code from all supported VCS systems, reducing the\nnumber of intermediate scripts.\n\n\n## License\n\nTo make Versioneer easier to embed, all its code is dedicated to the public\ndomain. The `_version.py` that it creates is also in the public domain.\nSpecifically, both are released under the Creative Commons \"Public Domain\nDedication\" license (CC0-1.0), as described in\nhttps://creativecommons.org/publicdomain/zero/1.0/ .\n\n\"\"\"\n\nfrom __future__ import print_function\ntry:\n    import configparser\nexcept ImportError:\n    import ConfigParser as configparser\nimport errno\nimport json\nimport os\nimport re\nimport subprocess\nimport sys\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_root():\n    \"\"\"Get the project root directory.\n\n    We require that all commands are run from the project root, i.e. the\n    directory that contains setup.py, setup.cfg, and versioneer.py .\n    \"\"\"\n    root = os.path.realpath(os.path.abspath(os.getcwd()))\n    setup_py = os.path.join(root, \"setup.py\")\n    versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):\n        # allow 'python path/to/setup.py COMMAND'\n        root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0])))\n        setup_py = os.path.join(root, \"setup.py\")\n        versioneer_py = os.path.join(root, \"versioneer.py\")\n    if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)):\n        err = (\"Versioneer was unable to run the project root directory. \"\n               \"Versioneer requires setup.py to be executed from \"\n               \"its immediate directory (like 'python setup.py COMMAND'), \"\n               \"or in a way that lets it use sys.argv[0] to find the root \"\n               \"(like 'python path/to/setup.py COMMAND').\")\n        raise VersioneerBadRootError(err)\n    try:\n        # Certain runtime workflows (setup.py install/develop in a setuptools\n        # tree) execute all dependencies in a single python process, so\n        # \"versioneer\" may be imported multiple times, and python's shared\n        # module-import table will cache the first one. So we can't use\n        # os.path.dirname(__file__), as that will find whichever\n        # versioneer.py was first imported, even in later projects.\n        me = os.path.realpath(os.path.abspath(__file__))\n        me_dir = os.path.normcase(os.path.splitext(me)[0])\n        vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0])\n        if me_dir != vsr_dir:\n            print(\"Warning: build in %s is using versioneer.py from %s\"\n                  % (os.path.dirname(me), versioneer_py))\n    except NameError:\n        pass\n    return root\n\n\ndef get_config_from_root(root):\n    \"\"\"Read the project setup.cfg file to determine Versioneer config.\"\"\"\n    # This might raise EnvironmentError (if setup.cfg is missing), or\n    # configparser.NoSectionError (if it lacks a [versioneer] section), or\n    # configparser.NoOptionError (if it lacks \"VCS=\"). See the docstring at\n    # the top of versioneer.py for instructions on writing your setup.cfg .\n    setup_cfg = os.path.join(root, \"setup.cfg\")\n    parser = configparser.SafeConfigParser()\n    with open(setup_cfg, \"r\") as f:\n        parser.readfp(f)\n    VCS = parser.get(\"versioneer\", \"VCS\")  # mandatory\n\n    def get(parser, name):\n        if parser.has_option(\"versioneer\", name):\n            return parser.get(\"versioneer\", name)\n        return None\n    cfg = VersioneerConfig()\n    cfg.VCS = VCS\n    cfg.style = get(parser, \"style\") or \"\"\n    cfg.versionfile_source = get(parser, \"versionfile_source\")\n    cfg.versionfile_build = get(parser, \"versionfile_build\")\n    cfg.tag_prefix = get(parser, \"tag_prefix\")\n    if cfg.tag_prefix in (\"''\", '\"\"'):\n        cfg.tag_prefix = \"\"\n    cfg.parentdir_prefix = get(parser, \"parentdir_prefix\")\n    cfg.verbose = get(parser, \"verbose\")\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\n# these dictionaries contain VCS-specific tools\nLONG_VERSION_PY = {}\nHANDLERS = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Decorator to mark a method as the handler for a particular VCS.\"\"\"\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,\n                env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    p = None\n    for c in commands:\n        try:\n            dispcmd = str([c] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            p = subprocess.Popen([c] + args, cwd=cwd, env=env,\n                                 stdout=subprocess.PIPE,\n                                 stderr=(subprocess.PIPE if hide_stderr\n                                         else None))\n            break\n        except EnvironmentError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %s\" % dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %s\" % (commands,))\n        return None, None\n    stdout = p.communicate()[0].strip()\n    if sys.version_info[0] >= 3:\n        stdout = stdout.decode()\n    if p.returncode != 0:\n        if verbose:\n            print(\"unable to run %s (error)\" % dispcmd)\n            print(\"stdout was %s\" % stdout)\n        return None, p.returncode\n    return stdout, p.returncode\n\n\nLONG_VERSION_PY['git'] = '''\n# This file helps to compute a version number in source trees obtained from\n# git-archive tarball (such as those provided by githubs download-from-tag\n# feature). Distribution tarballs (built by setup.py sdist) and build\n# directories (produced by setup.py build) will contain a much shorter file\n# that just contains the computed version number.\n\n# This file is released into the public domain. Generated by\n# versioneer-0.18 (https://github.com/warner/python-versioneer)\n\n\"\"\"Git implementation of _version.py.\"\"\"\n\nimport errno\nimport os\nimport re\nimport subprocess\nimport sys\n\n\ndef get_keywords():\n    \"\"\"Get the keywords needed to look up the version information.\"\"\"\n    # these strings will be replaced by git during git-archive.\n    # setup.py/versioneer.py will grep for the variable names, so they must\n    # each be defined on a line of their own. _version.py will just call\n    # get_keywords().\n    git_refnames = \"%(DOLLAR)sFormat:%%d%(DOLLAR)s\"\n    git_full = \"%(DOLLAR)sFormat:%%H%(DOLLAR)s\"\n    git_date = \"%(DOLLAR)sFormat:%%ci%(DOLLAR)s\"\n    keywords = {\"refnames\": git_refnames, \"full\": git_full, \"date\": git_date}\n    return keywords\n\n\nclass VersioneerConfig:\n    \"\"\"Container for Versioneer configuration parameters.\"\"\"\n\n\ndef get_config():\n    \"\"\"Create, populate and return the VersioneerConfig() object.\"\"\"\n    # these strings are filled in when 'setup.py versioneer' creates\n    # _version.py\n    cfg = VersioneerConfig()\n    cfg.VCS = \"git\"\n    cfg.style = \"%(STYLE)s\"\n    cfg.tag_prefix = \"%(TAG_PREFIX)s\"\n    cfg.parentdir_prefix = \"%(PARENTDIR_PREFIX)s\"\n    cfg.versionfile_source = \"%(VERSIONFILE_SOURCE)s\"\n    cfg.verbose = False\n    return cfg\n\n\nclass NotThisMethod(Exception):\n    \"\"\"Exception raised if a method is not valid for the current scenario.\"\"\"\n\n\nLONG_VERSION_PY = {}\nHANDLERS = {}\n\n\ndef register_vcs_handler(vcs, method):  # decorator\n    \"\"\"Decorator to mark a method as the handler for a particular VCS.\"\"\"\n    def decorate(f):\n        \"\"\"Store f in HANDLERS[vcs][method].\"\"\"\n        if vcs not in HANDLERS:\n            HANDLERS[vcs] = {}\n        HANDLERS[vcs][method] = f\n        return f\n    return decorate\n\n\ndef run_command(commands, args, cwd=None, verbose=False, hide_stderr=False,\n                env=None):\n    \"\"\"Call the given command(s).\"\"\"\n    assert isinstance(commands, list)\n    p = None\n    for c in commands:\n        try:\n            dispcmd = str([c] + args)\n            # remember shell=False, so use git.cmd on windows, not just git\n            p = subprocess.Popen([c] + args, cwd=cwd, env=env,\n                                 stdout=subprocess.PIPE,\n                                 stderr=(subprocess.PIPE if hide_stderr\n                                         else None))\n            break\n        except EnvironmentError:\n            e = sys.exc_info()[1]\n            if e.errno == errno.ENOENT:\n                continue\n            if verbose:\n                print(\"unable to run %%s\" %% dispcmd)\n                print(e)\n            return None, None\n    else:\n        if verbose:\n            print(\"unable to find command, tried %%s\" %% (commands,))\n        return None, None\n    stdout = p.communicate()[0].strip()\n    if sys.version_info[0] >= 3:\n        stdout = stdout.decode()\n    if p.returncode != 0:\n        if verbose:\n            print(\"unable to run %%s (error)\" %% dispcmd)\n            print(\"stdout was %%s\" %% stdout)\n        return None, p.returncode\n    return stdout, p.returncode\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for i in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        else:\n            rootdirs.append(root)\n            root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %%s but none started with prefix %%s\" %%\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        f = open(versionfile_abs, \"r\")\n        for line in f.readlines():\n            if line.strip().startswith(\"git_refnames =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"refnames\"] = mo.group(1)\n            if line.strip().startswith(\"git_full =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"full\"] = mo.group(1)\n            if line.strip().startswith(\"git_date =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"date\"] = mo.group(1)\n        f.close()\n    except EnvironmentError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if not keywords:\n        raise NotThisMethod(\"no keywords at all, weird\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # git-2.2.0 added \"%%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = set([r.strip() for r in refnames.strip(\"()\").split(\",\")])\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %%d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = set([r for r in refs if re.search(r'\\d', r)])\n        if verbose:\n            print(\"discarding '%%s', no digits\" %% \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %%s\" %% \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            if verbose:\n                print(\"picking %%s\" %% r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    out, rc = run_command(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                          hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %%s not under git control\" %% root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = run_command(GITS, [\"describe\", \"--tags\", \"--dirty\",\n                                          \"--always\", \"--long\",\n                                          \"--match\", \"%%s*\" %% tag_prefix],\n                                   cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = run_command(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparseable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%%s'\"\n                               %% describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%%s' doesn't start with prefix '%%s'\"\n                print(fmt %% (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%%s' doesn't start with prefix '%%s'\"\n                               %% (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        count_out, rc = run_command(GITS, [\"rev-list\", \"HEAD\", \"--count\"],\n                                    cwd=root)\n        pieces[\"distance\"] = int(count_out)  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = run_command(GITS, [\"show\", \"-s\", \"--format=%%ci\", \"HEAD\"],\n                       cwd=root)[0].strip()\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%%d.g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%%d.g%%s\" %% (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.post.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \".post.dev%%d\" %% pieces[\"distance\"]\n    else:\n        # exception #1\n        rendered = \"0.post.dev%%d\" %% pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%%s\" %% pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%%s\" %% pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Eexceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%%d\" %% pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%%d\" %% pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%%d-g%%s\" %% (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%%s'\" %% style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\ndef get_versions():\n    \"\"\"Get version information or return default if unable to do so.\"\"\"\n    # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have\n    # __file__, we can work backwards from there to the root. Some\n    # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which\n    # case we can only use expanded keywords.\n\n    cfg = get_config()\n    verbose = cfg.verbose\n\n    try:\n        return git_versions_from_keywords(get_keywords(), cfg.tag_prefix,\n                                          verbose)\n    except NotThisMethod:\n        pass\n\n    try:\n        root = os.path.realpath(__file__)\n        # versionfile_source is the relative path from the top of the source\n        # tree (where the .git directory might live) to this file. Invert\n        # this to find the root from __file__.\n        for i in cfg.versionfile_source.split('/'):\n            root = os.path.dirname(root)\n    except NameError:\n        return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n                \"dirty\": None,\n                \"error\": \"unable to find root of source tree\",\n                \"date\": None}\n\n    try:\n        pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)\n        return render(pieces, cfg.style)\n    except NotThisMethod:\n        pass\n\n    try:\n        if cfg.parentdir_prefix:\n            return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n    except NotThisMethod:\n        pass\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None,\n            \"error\": \"unable to compute version\", \"date\": None}\n'''\n\n\n@register_vcs_handler(\"git\", \"get_keywords\")\ndef git_get_keywords(versionfile_abs):\n    \"\"\"Extract version information from the given file.\"\"\"\n    # the code embedded in _version.py can just fetch the value of these\n    # keywords. When used from setup.py, we don't want to import _version.py,\n    # so we do it with a regexp instead. This function is not used from\n    # _version.py.\n    keywords = {}\n    try:\n        f = open(versionfile_abs, \"r\")\n        for line in f.readlines():\n            if line.strip().startswith(\"git_refnames =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"refnames\"] = mo.group(1)\n            if line.strip().startswith(\"git_full =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"full\"] = mo.group(1)\n            if line.strip().startswith(\"git_date =\"):\n                mo = re.search(r'=\\s*\"(.*)\"', line)\n                if mo:\n                    keywords[\"date\"] = mo.group(1)\n        f.close()\n    except EnvironmentError:\n        pass\n    return keywords\n\n\n@register_vcs_handler(\"git\", \"keywords\")\ndef git_versions_from_keywords(keywords, tag_prefix, verbose):\n    \"\"\"Get version information from git keywords.\"\"\"\n    if not keywords:\n        raise NotThisMethod(\"no keywords at all, weird\")\n    date = keywords.get(\"date\")\n    if date is not None:\n        # git-2.2.0 added \"%cI\", which expands to an ISO-8601 -compliant\n        # datestamp. However we prefer \"%ci\" (which expands to an \"ISO-8601\n        # -like\" string, which we must then edit to make compliant), because\n        # it's been around since git-1.5.3, and it's too difficult to\n        # discover which version we're using, or to work around using an\n        # older one.\n        date = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n    refnames = keywords[\"refnames\"].strip()\n    if refnames.startswith(\"$Format\"):\n        if verbose:\n            print(\"keywords are unexpanded, not using\")\n        raise NotThisMethod(\"unexpanded keywords, not a git-archive tarball\")\n    refs = set([r.strip() for r in refnames.strip(\"()\").split(\",\")])\n    # starting in git-1.8.3, tags are listed as \"tag: foo-1.0\" instead of\n    # just \"foo-1.0\". If we see a \"tag: \" prefix, prefer those.\n    TAG = \"tag: \"\n    tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)])\n    if not tags:\n        # Either we're using git < 1.8.3, or there really are no tags. We use\n        # a heuristic: assume all version tags have a digit. The old git %d\n        # expansion behaves like git log --decorate=short and strips out the\n        # refs/heads/ and refs/tags/ prefixes that would let us distinguish\n        # between branches and tags. By ignoring refnames without digits, we\n        # filter out many common branch names like \"release\" and\n        # \"stabilization\", as well as \"HEAD\" and \"master\".\n        tags = set([r for r in refs if re.search(r'\\d', r)])\n        if verbose:\n            print(\"discarding '%s', no digits\" % \",\".join(refs - tags))\n    if verbose:\n        print(\"likely tags: %s\" % \",\".join(sorted(tags)))\n    for ref in sorted(tags):\n        # sorting will prefer e.g. \"2.0\" over \"2.0rc1\"\n        if ref.startswith(tag_prefix):\n            r = ref[len(tag_prefix):]\n            if verbose:\n                print(\"picking %s\" % r)\n            return {\"version\": r,\n                    \"full-revisionid\": keywords[\"full\"].strip(),\n                    \"dirty\": False, \"error\": None,\n                    \"date\": date}\n    # no suitable tags, so version is \"0+unknown\", but full hex is still there\n    if verbose:\n        print(\"no suitable tags, using unknown + full revision id\")\n    return {\"version\": \"0+unknown\",\n            \"full-revisionid\": keywords[\"full\"].strip(),\n            \"dirty\": False, \"error\": \"no suitable tags\", \"date\": None}\n\n\n@register_vcs_handler(\"git\", \"pieces_from_vcs\")\ndef git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command):\n    \"\"\"Get version from 'git describe' in the root of the source tree.\n\n    This only gets called if the git-archive 'subst' keywords were *not*\n    expanded, and _version.py hasn't already been rewritten with a short\n    version string, meaning we're inside a checked out source tree.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n\n    out, rc = run_command(GITS, [\"rev-parse\", \"--git-dir\"], cwd=root,\n                          hide_stderr=True)\n    if rc != 0:\n        if verbose:\n            print(\"Directory %s not under git control\" % root)\n        raise NotThisMethod(\"'git rev-parse --git-dir' returned error\")\n\n    # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]\n    # if there isn't one, this yields HEX[-dirty] (no NUM)\n    describe_out, rc = run_command(GITS, [\"describe\", \"--tags\", \"--dirty\",\n                                          \"--always\", \"--long\",\n                                          \"--match\", \"%s*\" % tag_prefix],\n                                   cwd=root)\n    # --long was added in git-1.5.5\n    if describe_out is None:\n        raise NotThisMethod(\"'git describe' failed\")\n    describe_out = describe_out.strip()\n    full_out, rc = run_command(GITS, [\"rev-parse\", \"HEAD\"], cwd=root)\n    if full_out is None:\n        raise NotThisMethod(\"'git rev-parse' failed\")\n    full_out = full_out.strip()\n\n    pieces = {}\n    pieces[\"long\"] = full_out\n    pieces[\"short\"] = full_out[:7]  # maybe improved later\n    pieces[\"error\"] = None\n\n    # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]\n    # TAG might have hyphens.\n    git_describe = describe_out\n\n    # look for -dirty suffix\n    dirty = git_describe.endswith(\"-dirty\")\n    pieces[\"dirty\"] = dirty\n    if dirty:\n        git_describe = git_describe[:git_describe.rindex(\"-dirty\")]\n\n    # now we have TAG-NUM-gHEX or HEX\n\n    if \"-\" in git_describe:\n        # TAG-NUM-gHEX\n        mo = re.search(r'^(.+)-(\\d+)-g([0-9a-f]+)$', git_describe)\n        if not mo:\n            # unparseable. Maybe git-describe is misbehaving?\n            pieces[\"error\"] = (\"unable to parse git-describe output: '%s'\"\n                               % describe_out)\n            return pieces\n\n        # tag\n        full_tag = mo.group(1)\n        if not full_tag.startswith(tag_prefix):\n            if verbose:\n                fmt = \"tag '%s' doesn't start with prefix '%s'\"\n                print(fmt % (full_tag, tag_prefix))\n            pieces[\"error\"] = (\"tag '%s' doesn't start with prefix '%s'\"\n                               % (full_tag, tag_prefix))\n            return pieces\n        pieces[\"closest-tag\"] = full_tag[len(tag_prefix):]\n\n        # distance: number of commits since tag\n        pieces[\"distance\"] = int(mo.group(2))\n\n        # commit: short hex revision ID\n        pieces[\"short\"] = mo.group(3)\n\n    else:\n        # HEX: no tags\n        pieces[\"closest-tag\"] = None\n        count_out, rc = run_command(GITS, [\"rev-list\", \"HEAD\", \"--count\"],\n                                    cwd=root)\n        pieces[\"distance\"] = int(count_out)  # total number of commits\n\n    # commit date: see ISO-8601 comment in git_versions_from_keywords()\n    date = run_command(GITS, [\"show\", \"-s\", \"--format=%ci\", \"HEAD\"],\n                       cwd=root)[0].strip()\n    pieces[\"date\"] = date.strip().replace(\" \", \"T\", 1).replace(\" \", \"\", 1)\n\n    return pieces\n\n\ndef do_vcs_install(manifest_in, versionfile_source, ipy):\n    \"\"\"Git-specific installation logic for Versioneer.\n\n    For Git, this means creating/changing .gitattributes to mark _version.py\n    for export-subst keyword substitution.\n    \"\"\"\n    GITS = [\"git\"]\n    if sys.platform == \"win32\":\n        GITS = [\"git.cmd\", \"git.exe\"]\n    files = [manifest_in, versionfile_source]\n    if ipy:\n        files.append(ipy)\n    try:\n        me = __file__\n        if me.endswith(\".pyc\") or me.endswith(\".pyo\"):\n            me = os.path.splitext(me)[0] + \".py\"\n        versioneer_file = os.path.relpath(me)\n    except NameError:\n        versioneer_file = \"versioneer.py\"\n    files.append(versioneer_file)\n    present = False\n    try:\n        f = open(\".gitattributes\", \"r\")\n        for line in f.readlines():\n            if line.strip().startswith(versionfile_source):\n                if \"export-subst\" in line.strip().split()[1:]:\n                    present = True\n        f.close()\n    except EnvironmentError:\n        pass\n    if not present:\n        f = open(\".gitattributes\", \"a+\")\n        f.write(\"%s export-subst\\n\" % versionfile_source)\n        f.close()\n        files.append(\".gitattributes\")\n    run_command(GITS, [\"add\", \"--\"] + files)\n\n\ndef versions_from_parentdir(parentdir_prefix, root, verbose):\n    \"\"\"Try to determine the version from the parent directory name.\n\n    Source tarballs conventionally unpack into a directory that includes both\n    the project name and a version string. We will also support searching up\n    two directory levels for an appropriately named parent directory\n    \"\"\"\n    rootdirs = []\n\n    for i in range(3):\n        dirname = os.path.basename(root)\n        if dirname.startswith(parentdir_prefix):\n            return {\"version\": dirname[len(parentdir_prefix):],\n                    \"full-revisionid\": None,\n                    \"dirty\": False, \"error\": None, \"date\": None}\n        else:\n            rootdirs.append(root)\n            root = os.path.dirname(root)  # up a level\n\n    if verbose:\n        print(\"Tried directories %s but none started with prefix %s\" %\n              (str(rootdirs), parentdir_prefix))\n    raise NotThisMethod(\"rootdir doesn't start with parentdir_prefix\")\n\n\nSHORT_VERSION_PY = \"\"\"\n# This file was generated by 'versioneer.py' (0.18) from\n# revision-control system data, or from the parent directory name of an\n# unpacked source archive. Distribution tarballs contain a pre-generated copy\n# of this file.\n\nimport json\n\nversion_json = '''\n%s\n'''  # END VERSION_JSON\n\n\ndef get_versions():\n    return json.loads(version_json)\n\"\"\"\n\n\ndef versions_from_file(filename):\n    \"\"\"Try to determine the version from _version.py if present.\"\"\"\n    try:\n        with open(filename) as f:\n            contents = f.read()\n    except EnvironmentError:\n        raise NotThisMethod(\"unable to read _version.py\")\n    mo = re.search(r\"version_json = '''\\n(.*)'''  # END VERSION_JSON\",\n                   contents, re.M | re.S)\n    if not mo:\n        mo = re.search(r\"version_json = '''\\r\\n(.*)'''  # END VERSION_JSON\",\n                       contents, re.M | re.S)\n    if not mo:\n        raise NotThisMethod(\"no version_json in _version.py\")\n    return json.loads(mo.group(1))\n\n\ndef write_to_version_file(filename, versions):\n    \"\"\"Write the given version number to the given _version.py file.\"\"\"\n    os.unlink(filename)\n    contents = json.dumps(versions, sort_keys=True,\n                          indent=1, separators=(\",\", \": \"))\n    with open(filename, \"w\") as f:\n        f.write(SHORT_VERSION_PY % contents)\n\n    print(\"set %s to '%s'\" % (filename, versions[\"version\"]))\n\n\ndef plus_or_dot(pieces):\n    \"\"\"Return a + if we don't already have one, else return a .\"\"\"\n    if \"+\" in pieces.get(\"closest-tag\", \"\"):\n        return \".\"\n    return \"+\"\n\n\ndef render_pep440(pieces):\n    \"\"\"Build up version string, with post-release \"local version identifier\".\n\n    Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you\n    get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty\n\n    Exceptions:\n    1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += plus_or_dot(pieces)\n            rendered += \"%d.g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n            if pieces[\"dirty\"]:\n                rendered += \".dirty\"\n    else:\n        # exception #1\n        rendered = \"0+untagged.%d.g%s\" % (pieces[\"distance\"],\n                                          pieces[\"short\"])\n        if pieces[\"dirty\"]:\n            rendered += \".dirty\"\n    return rendered\n\n\ndef render_pep440_pre(pieces):\n    \"\"\"TAG[.post.devDISTANCE] -- No -dirty.\n\n    Exceptions:\n    1: no tags. 0.post.devDISTANCE\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \".post.dev%d\" % pieces[\"distance\"]\n    else:\n        # exception #1\n        rendered = \"0.post.dev%d\" % pieces[\"distance\"]\n    return rendered\n\n\ndef render_pep440_post(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]+gHEX] .\n\n    The \".dev0\" means dirty. Note that .dev0 sorts backwards\n    (a dirty tree will appear \"older\" than the corresponding clean one),\n    but you shouldn't be releasing software with -dirty anyways.\n\n    Exceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n            rendered += plus_or_dot(pieces)\n            rendered += \"g%s\" % pieces[\"short\"]\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n        rendered += \"+g%s\" % pieces[\"short\"]\n    return rendered\n\n\ndef render_pep440_old(pieces):\n    \"\"\"TAG[.postDISTANCE[.dev0]] .\n\n    The \".dev0\" means dirty.\n\n    Eexceptions:\n    1: no tags. 0.postDISTANCE[.dev0]\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"] or pieces[\"dirty\"]:\n            rendered += \".post%d\" % pieces[\"distance\"]\n            if pieces[\"dirty\"]:\n                rendered += \".dev0\"\n    else:\n        # exception #1\n        rendered = \"0.post%d\" % pieces[\"distance\"]\n        if pieces[\"dirty\"]:\n            rendered += \".dev0\"\n    return rendered\n\n\ndef render_git_describe(pieces):\n    \"\"\"TAG[-DISTANCE-gHEX][-dirty].\n\n    Like 'git describe --tags --dirty --always'.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        if pieces[\"distance\"]:\n            rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render_git_describe_long(pieces):\n    \"\"\"TAG-DISTANCE-gHEX[-dirty].\n\n    Like 'git describe --tags --dirty --always -long'.\n    The distance/hash is unconditional.\n\n    Exceptions:\n    1: no tags. HEX[-dirty]  (note: no 'g' prefix)\n    \"\"\"\n    if pieces[\"closest-tag\"]:\n        rendered = pieces[\"closest-tag\"]\n        rendered += \"-%d-g%s\" % (pieces[\"distance\"], pieces[\"short\"])\n    else:\n        # exception #1\n        rendered = pieces[\"short\"]\n    if pieces[\"dirty\"]:\n        rendered += \"-dirty\"\n    return rendered\n\n\ndef render(pieces, style):\n    \"\"\"Render the given version pieces into the requested style.\"\"\"\n    if pieces[\"error\"]:\n        return {\"version\": \"unknown\",\n                \"full-revisionid\": pieces.get(\"long\"),\n                \"dirty\": None,\n                \"error\": pieces[\"error\"],\n                \"date\": None}\n\n    if not style or style == \"default\":\n        style = \"pep440\"  # the default\n\n    if style == \"pep440\":\n        rendered = render_pep440(pieces)\n    elif style == \"pep440-pre\":\n        rendered = render_pep440_pre(pieces)\n    elif style == \"pep440-post\":\n        rendered = render_pep440_post(pieces)\n    elif style == \"pep440-old\":\n        rendered = render_pep440_old(pieces)\n    elif style == \"git-describe\":\n        rendered = render_git_describe(pieces)\n    elif style == \"git-describe-long\":\n        rendered = render_git_describe_long(pieces)\n    else:\n        raise ValueError(\"unknown style '%s'\" % style)\n\n    return {\"version\": rendered, \"full-revisionid\": pieces[\"long\"],\n            \"dirty\": pieces[\"dirty\"], \"error\": None,\n            \"date\": pieces.get(\"date\")}\n\n\nclass VersioneerBadRootError(Exception):\n    \"\"\"The project root directory is unknown or missing key files.\"\"\"\n\n\ndef get_versions(verbose=False):\n    \"\"\"Get the project version from whatever source is available.\n\n    Returns dict with two keys: 'version' and 'full'.\n    \"\"\"\n    if \"versioneer\" in sys.modules:\n        # see the discussion in cmdclass.py:get_cmdclass()\n        del sys.modules[\"versioneer\"]\n\n    root = get_root()\n    cfg = get_config_from_root(root)\n\n    assert cfg.VCS is not None, \"please set [versioneer]VCS= in setup.cfg\"\n    handlers = HANDLERS.get(cfg.VCS)\n    assert handlers, \"unrecognized VCS '%s'\" % cfg.VCS\n    verbose = verbose or cfg.verbose\n    assert cfg.versionfile_source is not None, \\\n        \"please set versioneer.versionfile_source\"\n    assert cfg.tag_prefix is not None, \"please set versioneer.tag_prefix\"\n\n    versionfile_abs = os.path.join(root, cfg.versionfile_source)\n\n    # extract version from first of: _version.py, VCS command (e.g. 'git\n    # describe'), parentdir. This is meant to work for developers using a\n    # source checkout, for users of a tarball created by 'setup.py sdist',\n    # and for users of a tarball/zipball created by 'git archive' or github's\n    # download-from-tag feature or the equivalent in other VCSes.\n\n    get_keywords_f = handlers.get(\"get_keywords\")\n    from_keywords_f = handlers.get(\"keywords\")\n    if get_keywords_f and from_keywords_f:\n        try:\n            keywords = get_keywords_f(versionfile_abs)\n            ver = from_keywords_f(keywords, cfg.tag_prefix, verbose)\n            if verbose:\n                print(\"got version from expanded keyword %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        ver = versions_from_file(versionfile_abs)\n        if verbose:\n            print(\"got version from file %s %s\" % (versionfile_abs, ver))\n        return ver\n    except NotThisMethod:\n        pass\n\n    from_vcs_f = handlers.get(\"pieces_from_vcs\")\n    if from_vcs_f:\n        try:\n            pieces = from_vcs_f(cfg.tag_prefix, root, verbose)\n            ver = render(pieces, cfg.style)\n            if verbose:\n                print(\"got version from VCS %s\" % ver)\n            return ver\n        except NotThisMethod:\n            pass\n\n    try:\n        if cfg.parentdir_prefix:\n            ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose)\n            if verbose:\n                print(\"got version from parentdir %s\" % ver)\n            return ver\n    except NotThisMethod:\n        pass\n\n    if verbose:\n        print(\"unable to compute version\")\n\n    return {\"version\": \"0+unknown\", \"full-revisionid\": None,\n            \"dirty\": None, \"error\": \"unable to compute version\",\n            \"date\": None}\n\n\ndef get_version():\n    \"\"\"Get the short version string for this project.\"\"\"\n    return get_versions()[\"version\"]\n\n\ndef get_cmdclass():\n    \"\"\"Get the custom setuptools/distutils subclasses used by Versioneer.\"\"\"\n    if \"versioneer\" in sys.modules:\n        del sys.modules[\"versioneer\"]\n        # this fixes the \"python setup.py develop\" case (also 'install' and\n        # 'easy_install .'), in which subdependencies of the main project are\n        # built (using setup.py bdist_egg) in the same python process. Assume\n        # a main project A and a dependency B, which use different versions\n        # of Versioneer. A's setup.py imports A's Versioneer, leaving it in\n        # sys.modules by the time B's setup.py is executed, causing B to run\n        # with the wrong versioneer. Setuptools wraps the sub-dep builds in a\n        # sandbox that restores sys.modules to it's pre-build state, so the\n        # parent is protected against the child's \"import versioneer\". By\n        # removing ourselves from sys.modules here, before the child build\n        # happens, we protect the child from the parent's versioneer too.\n        # Also see https://github.com/warner/python-versioneer/issues/52\n\n    cmds = {}\n\n    # we add \"version\" to both distutils and setuptools\n    from distutils.core import Command\n\n    class cmd_version(Command):\n        description = \"report generated version string\"\n        user_options = []\n        boolean_options = []\n\n        def initialize_options(self):\n            pass\n\n        def finalize_options(self):\n            pass\n\n        def run(self):\n            vers = get_versions(verbose=True)\n            print(\"Version: %s\" % vers[\"version\"])\n            print(\" full-revisionid: %s\" % vers.get(\"full-revisionid\"))\n            print(\" dirty: %s\" % vers.get(\"dirty\"))\n            print(\" date: %s\" % vers.get(\"date\"))\n            if vers[\"error\"]:\n                print(\" error: %s\" % vers[\"error\"])\n    cmds[\"version\"] = cmd_version\n\n    # we override \"build_py\" in both distutils and setuptools\n    #\n    # most invocation pathways end up running build_py:\n    #  distutils/build -> build_py\n    #  distutils/install -> distutils/build ->..\n    #  setuptools/bdist_wheel -> distutils/install ->..\n    #  setuptools/bdist_egg -> distutils/install_lib -> build_py\n    #  setuptools/install -> bdist_egg ->..\n    #  setuptools/develop -> ?\n    #  pip install:\n    #   copies source tree to a tempdir before running egg_info/etc\n    #   if .git isn't copied too, 'git describe' will fail\n    #   then does setup.py bdist_wheel, or sometimes setup.py install\n    #  setup.py egg_info -> ?\n\n    # we override different \"build_py\" commands for both environments\n    if \"setuptools\" in sys.modules:\n        from setuptools.command.build_py import build_py as _build_py\n    else:\n        from distutils.command.build_py import build_py as _build_py\n\n    class cmd_build_py(_build_py):\n        def run(self):\n            root = get_root()\n            cfg = get_config_from_root(root)\n            versions = get_versions()\n            _build_py.run(self)\n            # now locate _version.py in the new build/ directory and replace\n            # it with an updated value\n            if cfg.versionfile_build:\n                target_versionfile = os.path.join(self.build_lib,\n                                                  cfg.versionfile_build)\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n    cmds[\"build_py\"] = cmd_build_py\n\n    if \"cx_Freeze\" in sys.modules:  # cx_freeze enabled?\n        from cx_Freeze.dist import build_exe as _build_exe\n        # nczeczulin reports that py2exe won't like the pep440-style string\n        # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g.\n        # setup(console=[{\n        #   \"version\": versioneer.get_version().split(\"+\", 1)[0], # FILEVERSION\n        #   \"product_version\": versioneer.get_version(),\n        #   ...\n\n        class cmd_build_exe(_build_exe):\n            def run(self):\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _build_exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(LONG %\n                            {\"DOLLAR\": \"$\",\n                             \"STYLE\": cfg.style,\n                             \"TAG_PREFIX\": cfg.tag_prefix,\n                             \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                             \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                             })\n        cmds[\"build_exe\"] = cmd_build_exe\n        del cmds[\"build_py\"]\n\n    if 'py2exe' in sys.modules:  # py2exe enabled?\n        try:\n            from py2exe.distutils_buildexe import py2exe as _py2exe  # py3\n        except ImportError:\n            from py2exe.build_exe import py2exe as _py2exe  # py2\n\n        class cmd_py2exe(_py2exe):\n            def run(self):\n                root = get_root()\n                cfg = get_config_from_root(root)\n                versions = get_versions()\n                target_versionfile = cfg.versionfile_source\n                print(\"UPDATING %s\" % target_versionfile)\n                write_to_version_file(target_versionfile, versions)\n\n                _py2exe.run(self)\n                os.unlink(target_versionfile)\n                with open(cfg.versionfile_source, \"w\") as f:\n                    LONG = LONG_VERSION_PY[cfg.VCS]\n                    f.write(LONG %\n                            {\"DOLLAR\": \"$\",\n                             \"STYLE\": cfg.style,\n                             \"TAG_PREFIX\": cfg.tag_prefix,\n                             \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                             \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                             })\n        cmds[\"py2exe\"] = cmd_py2exe\n\n    # we override different \"sdist\" commands for both environments\n    if \"setuptools\" in sys.modules:\n        from setuptools.command.sdist import sdist as _sdist\n    else:\n        from distutils.command.sdist import sdist as _sdist\n\n    class cmd_sdist(_sdist):\n        def run(self):\n            versions = get_versions()\n            self._versioneer_generated_versions = versions\n            # unless we update this, the command will keep using the old\n            # version\n            self.distribution.metadata.version = versions[\"version\"]\n            return _sdist.run(self)\n\n        def make_release_tree(self, base_dir, files):\n            root = get_root()\n            cfg = get_config_from_root(root)\n            _sdist.make_release_tree(self, base_dir, files)\n            # now locate _version.py in the new base_dir directory\n            # (remembering that it may be a hardlink) and replace it with an\n            # updated value\n            target_versionfile = os.path.join(base_dir, cfg.versionfile_source)\n            print(\"UPDATING %s\" % target_versionfile)\n            write_to_version_file(target_versionfile,\n                                  self._versioneer_generated_versions)\n    cmds[\"sdist\"] = cmd_sdist\n\n    return cmds\n\n\nCONFIG_ERROR = \"\"\"\nsetup.cfg is missing the necessary Versioneer configuration. You need\na section like:\n\n [versioneer]\n VCS = git\n style = pep440\n versionfile_source = src/myproject/_version.py\n versionfile_build = myproject/_version.py\n tag_prefix =\n parentdir_prefix = myproject-\n\nYou will also need to edit your setup.py to use the results:\n\n import versioneer\n setup(version=versioneer.get_version(),\n       cmdclass=versioneer.get_cmdclass(), ...)\n\nPlease read the docstring in ./versioneer.py for configuration instructions,\nedit setup.cfg, and re-run the installer or 'python versioneer.py setup'.\n\"\"\"\n\nSAMPLE_CONFIG = \"\"\"\n# See the docstring in versioneer.py for instructions. Note that you must\n# re-run 'versioneer.py setup' after changing this section, and commit the\n# resulting files.\n\n[versioneer]\n#VCS = git\n#style = pep440\n#versionfile_source =\n#versionfile_build =\n#tag_prefix =\n#parentdir_prefix =\n\n\"\"\"\n\nINIT_PY_SNIPPET = \"\"\"\nfrom ._version import get_versions\n__version__ = get_versions()['version']\ndel get_versions\n\"\"\"\n\n\ndef do_setup():\n    \"\"\"Main VCS-independent setup function for installing Versioneer.\"\"\"\n    root = get_root()\n    try:\n        cfg = get_config_from_root(root)\n    except (EnvironmentError, configparser.NoSectionError,\n            configparser.NoOptionError) as e:\n        if isinstance(e, (EnvironmentError, configparser.NoSectionError)):\n            print(\"Adding sample versioneer config to setup.cfg\",\n                  file=sys.stderr)\n            with open(os.path.join(root, \"setup.cfg\"), \"a\") as f:\n                f.write(SAMPLE_CONFIG)\n        print(CONFIG_ERROR, file=sys.stderr)\n        return 1\n\n    print(\" creating %s\" % cfg.versionfile_source)\n    with open(cfg.versionfile_source, \"w\") as f:\n        LONG = LONG_VERSION_PY[cfg.VCS]\n        f.write(LONG % {\"DOLLAR\": \"$\",\n                        \"STYLE\": cfg.style,\n                        \"TAG_PREFIX\": cfg.tag_prefix,\n                        \"PARENTDIR_PREFIX\": cfg.parentdir_prefix,\n                        \"VERSIONFILE_SOURCE\": cfg.versionfile_source,\n                        })\n\n    ipy = os.path.join(os.path.dirname(cfg.versionfile_source),\n                       \"__init__.py\")\n    if os.path.exists(ipy):\n        try:\n            with open(ipy, \"r\") as f:\n                old = f.read()\n        except EnvironmentError:\n            old = \"\"\n        if INIT_PY_SNIPPET not in old:\n            print(\" appending to %s\" % ipy)\n            with open(ipy, \"a\") as f:\n                f.write(INIT_PY_SNIPPET)\n        else:\n            print(\" %s unmodified\" % ipy)\n    else:\n        print(\" %s doesn't exist, ok\" % ipy)\n        ipy = None\n\n    # Make sure both the top-level \"versioneer.py\" and versionfile_source\n    # (PKG/_version.py, used by runtime code) are in MANIFEST.in, so\n    # they'll be copied into source distributions. Pip won't be able to\n    # install the package without this.\n    manifest_in = os.path.join(root, \"MANIFEST.in\")\n    simple_includes = set()\n    try:\n        with open(manifest_in, \"r\") as f:\n            for line in f:\n                if line.startswith(\"include \"):\n                    for include in line.split()[1:]:\n                        simple_includes.add(include)\n    except EnvironmentError:\n        pass\n    # That doesn't cover everything MANIFEST.in can do\n    # (http://docs.python.org/2/distutils/sourcedist.html#commands), so\n    # it might give some false negatives. Appending redundant 'include'\n    # lines is safe, though.\n    if \"versioneer.py\" not in simple_includes:\n        print(\" appending 'versioneer.py' to MANIFEST.in\")\n        with open(manifest_in, \"a\") as f:\n            f.write(\"include versioneer.py\\n\")\n    else:\n        print(\" 'versioneer.py' already in MANIFEST.in\")\n    if cfg.versionfile_source not in simple_includes:\n        print(\" appending versionfile_source ('%s') to MANIFEST.in\" %\n              cfg.versionfile_source)\n        with open(manifest_in, \"a\") as f:\n            f.write(\"include %s\\n\" % cfg.versionfile_source)\n    else:\n        print(\" versionfile_source already in MANIFEST.in\")\n\n    # Make VCS-specific changes. For git, this means creating/changing\n    # .gitattributes to mark _version.py for export-subst keyword\n    # substitution.\n    do_vcs_install(manifest_in, cfg.versionfile_source, ipy)\n    return 0\n\n\ndef scan_setup_py():\n    \"\"\"Validate the contents of setup.py against Versioneer's expectations.\"\"\"\n    found = set()\n    setters = False\n    errors = 0\n    with open(\"setup.py\", \"r\") as f:\n        for line in f.readlines():\n            if \"import versioneer\" in line:\n                found.add(\"import\")\n            if \"versioneer.get_cmdclass()\" in line:\n                found.add(\"cmdclass\")\n            if \"versioneer.get_version()\" in line:\n                found.add(\"get_version\")\n            if \"versioneer.VCS\" in line:\n                setters = True\n            if \"versioneer.versionfile_source\" in line:\n                setters = True\n    if len(found) != 3:\n        print(\"\")\n        print(\"Your setup.py appears to be missing some important items\")\n        print(\"(but I might be wrong). Please make sure it has something\")\n        print(\"roughly like the following:\")\n        print(\"\")\n        print(\" import versioneer\")\n        print(\" setup( version=versioneer.get_version(),\")\n        print(\"        cmdclass=versioneer.get_cmdclass(),  ...)\")\n        print(\"\")\n        errors += 1\n    if setters:\n        print(\"You should remove lines like 'versioneer.VCS = ' and\")\n        print(\"'versioneer.versionfile_source = ' . This configuration\")\n        print(\"now lives in setup.cfg, and should be removed from setup.py\")\n        print(\"\")\n        errors += 1\n    return errors\n\n\nif __name__ == \"__main__\":\n    cmd = sys.argv[1]\n    if cmd == \"setup\":\n        errors = do_setup()\n        errors += scan_setup_py()\n        if errors:\n            sys.exit(1)\n"
  }
]