[
  {
    "path": ".bumpversion.toml",
    "content": "[tool.bumpversion]\ncurrent_version = \"0.8.1\"\ntag = true\ncommit = true\nmessage = \"Bump version: {current_version} → {new_version}\"\n\n[[tool.bumpversion.files]]\nfilename = \"pyproject.toml\"\nsearch = 'version = \"{current_version}\"'\nreplace = 'version = \"{new_version}\"'\n"
  },
  {
    "path": ".github/ISSUE_TEMPLATE.md",
    "content": "* unet version:\n* Python version:\n* Operating System:\n\n### Description\n\nDescribe what you were trying to get done.\nTell us what happened, what went wrong, and what you expected to happen.\n\n### What I Did\n\n```\nPaste the command(s) you ran and the output.\nIf there was a crash, please include the traceback here.\n```\n"
  },
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# PyCharm\n.idea\n\n# VS Code\n.vscode\n\n# uv\nuv.lock\n\n# ruff\n.ruff_cache\n"
  },
  {
    "path": ".zenodo.json",
    "content": "{\n  \"creators\": [\n    {\n      \"affiliation\": \"University College London, United Kingdom\",\n      \"name\": \"P\\u00e9rez-Garc\\u00eda, Fernando\",\n      \"orcid\": \"0000-0001-9090-3024\"\n    }\n  ],\n  \"description\": \"PyTorch implementation of 2D and 3D U-Net\",\n  \"keywords\": [\n    \"pytorch\",\n    \"medical-image-computing\",\n    \"deep-learning\",\n    \"machine-learning\",\n    \"convolutional-neural-networks\"\n  ],\n  \"license\": \"mit-license\",\n  \"upload_type\": \"software\"\n}\n"
  },
  {
    "path": "CONTRIBUTING.rst",
    "content": ".. highlight:: shell\n\n============\nContributing\n============\n\nContributions are welcome, and they are greatly appreciated! Every little bit\nhelps, and credit will always be given.\n\nYou can contribute in many ways:\n\nTypes of Contributions\n----------------------\n\nReport Bugs\n~~~~~~~~~~~\n\nReport bugs at https://github.com/fepegar/unet/issues.\n\nIf you are reporting a bug, please include:\n\n* Your operating system name and version.\n* Any details about your local setup that might be helpful in troubleshooting.\n* Detailed steps to reproduce the bug.\n\nFix Bugs\n~~~~~~~~\n\nLook through the GitHub issues for bugs. Anything tagged with \"bug\" and \"help\nwanted\" is open to whoever wants to implement it.\n\nImplement Features\n~~~~~~~~~~~~~~~~~~\n\nLook through the GitHub issues for features. Anything tagged with \"enhancement\"\nand \"help wanted\" is open to whoever wants to implement it.\n\nWrite Documentation\n~~~~~~~~~~~~~~~~~~~\n\nunet could always use more documentation, whether as part of the\nofficial unet docs, in docstrings, or even on the web in blog posts,\narticles, and such.\n\nSubmit Feedback\n~~~~~~~~~~~~~~~\n\nThe best way to send feedback is to file an issue at https://github.com/fepegar/unet/issues.\n\nIf you are proposing a feature:\n\n* Explain in detail how it would work.\n* Keep the scope as narrow as possible, to make it easier to implement.\n* Remember that this is a volunteer-driven project, and that contributions\n  are welcome :)\n\nGet Started!\n------------\n\nReady to contribute? Here's how to set up `unet` for local development.\n\n1. Fork the `unet` repo on GitHub.\n2. Clone your fork locally::\n\n    $ git clone git@github.com:your_name_here/unet.git\n\n3. Install your local copy into a virtualenv. Assuming you have virtualenvwrapper installed, this is how you set up your fork for local development::\n\n    $ mkvirtualenv unet\n    $ cd unet/\n    $ python setup.py develop\n\n4. Create a branch for local development::\n\n    $ git checkout -b name-of-your-bugfix-or-feature\n\n   Now you can make your changes locally.\n\n5. When you're done making changes, check that your changes pass flake8 and the\n   tests, including testing other Python versions with tox::\n\n    $ flake8 unet tests\n    $ python setup.py test or pytest\n    $ tox\n\n   To get flake8 and tox, just pip install them into your virtualenv.\n\n6. Commit your changes and push your branch to GitHub::\n\n    $ git add .\n    $ git commit -m \"Your detailed description of your changes.\"\n    $ git push origin name-of-your-bugfix-or-feature\n\n7. Submit a pull request through the GitHub website.\n\nPull Request Guidelines\n-----------------------\n\nBefore you submit a pull request, check that it meets these guidelines:\n\n1. The pull request should include tests.\n2. If the pull request adds functionality, the docs should be updated. Put\n   your new functionality into a function with a docstring, and add the\n   feature to the list in README.rst.\n3. The pull request should work for Python 3.6 and 3.7, and for PyPy. Check\n   https://travis-ci.org/fepegar/unet/pull_requests\n   and make sure that the tests pass for all supported Python versions.\n\nTips\n----\n\nTo run a subset of tests::\n\n\n    $ python -m unittest tests.test_unet\n\nDeploying\n---------\n\nA reminder for the maintainers on how to deploy.\nMake sure all your changes are committed (including an entry in HISTORY.rst).\nThen run::\n\n$ bump2version patch # possible: major / minor / patch\n$ git push\n$ git push --tags\n\nTravis will then deploy to PyPI if tests pass.\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2019 Fernando Perez-Garcia\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.rst",
    "content": "U-Net\n=====\n\n\n.. image:: https://zenodo.org/badge/DOI/10.5281/zenodo.3522306.svg\n        :target: https://doi.org/10.5281/zenodo.3522306\n        :alt: DOI\n\n.. image:: https://img.shields.io/badge/License-MIT-yellow.svg\n        :target: https://opensource.org/licenses/MIT\n        :alt: License\n\n.. image:: https://img.shields.io/pypi/v/unet.svg\n        :target: https://pypi.python.org/pypi/unet\n\n\nPyTorch implementation of 1D, 2D and 3D U-Net.\n\nThe U-Net architecture was first described in\n`Ronneberger et al. 2015, U-Net: Convolutional Networks for Biomedical Image\nSegmentation <https://arxiv.org/abs/1505.04597>`_.\nThe 3D version was described in\n`Çiçek et al. 2016, 3D U-Net: Learning Dense Volumetric Segmentation from\nSparse Annotation <https://arxiv.org/abs/1606.06650>`_.\n\n\nInstallation\n------------\n\n::\n\n   pip install unet\n\n\nCredits\n-------\n\nIf you used this code for your research, please cite this repository using the\ninformation available on its\n`Zenodo entry <https://doi.org/10.5281/zenodo.3697931>`_:\n\n    Pérez-García, Fernando. (2020). fepegar/unet: PyTorch implementation of 2D and 3D U-Net (v0.7.5). Zenodo. https://doi.org/10.5281/zenodo.3697931\n"
  },
  {
    "path": "justfile",
    "content": "@install_uv:\n\tif ! command -v uv >/dev/null 2>&1; then \\\n\t\techo \"uv is not installed. Installing...\"; \\\n\t\tcurl -LsSf https://astral.sh/uv/install.sh | sh; \\\n\tfi\n\nsetup: install_uv\n    uv sync --all-extras --all-groups\n\nbump part='patch': install_uv\n    uv run bump-my-version bump {{part}} --verbose\n\nrelease: install_uv\n    rm -rf dist\n    uv build --no-sources\n    uv publish\n\nchangelog: install_uv\n    uvx git-changelog --output CHANGELOG.md\n\nruff: install_uv\n    uvx ruff check --fix\n    uvx ruff format\n\ntest: install_uv\n    uv run tox -p\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"hatchling\"]\nbuild-backend = \"hatchling.build\"\n\n[project]\nname = \"unet\"\nversion = \"0.8.1\"\ndescription = \"PyTorch implementation of 1D, 2D and 3D U-Net.\"\nauthors = [{ name = \"Fernando Perez-Garcia\", email = \"fepegar@gmail.com\" }]\nreadme = { file = \"README.rst\", content-type = \"text/x-rst\" }\nrequires-python = \">=3.9\"\nclassifiers = [\n    \"Intended Audience :: Science/Research\",\n    \"License :: OSI Approved :: MIT License\",\n    \"Natural Language :: English\",\n    \"Operating System :: OS Independent\",\n]\ndependencies = [\"torch\"]\n\n[project.urls]\nHomepage = \"https://github.com/fepegar/unet\"\nSource = \"https://github.com/fepegar/unet\"\n\"Issue tracker\" = \"https://github.com/fepegar/unet/issues\"\n\n[dependency-groups]\ndev = [\"bump-my-version\", \"pytest\", \"pytest-sugar\", \"ruff\", \"tox-uv\"]\nselect = [\n    # pycodestyle\n    \"E\",\n    # Pyflakes\n    \"F\",\n    # pyupgrade\n    \"UP\",\n    # flake8-bugbear\n    \"B\",\n    # isort\n    \"I\",\n]\n"
  },
  {
    "path": "tests/__init__.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Unit test package for unet.\"\"\"\n"
  },
  {
    "path": "tests/test_unet.py",
    "content": "import torch\n\nfrom unet import UNet1D, UNet2D, UNet3D\n\nresidual = False\n\ntorch.manual_seed(0)\ntorch.set_grad_enabled(False)\n\n\ndef run(model, shape):\n    x_sample = torch.rand(*shape)\n    with torch.no_grad():\n        y = model(x_sample)\n    return y\n\n\ndef test_unet_1d():\n    model = UNet1D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=False,\n    ).eval()\n    shape = 1, 1, 572\n    result = 1, 2, 388\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n\n\ndef test_unet_1d_residual():\n    model = UNet1D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=True,\n    ).eval()\n    shape = 1, 1, 512\n    result = 1, 2, 512\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n\n\ndef test_unet_2d():\n    model = UNet2D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=False,\n    ).eval()\n    shape = 1, 1, 572, 572\n    result = 1, 2, 388, 388\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n\n\ndef test_unet_2d_residual():\n    model = UNet2D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=True,\n    ).eval()\n    shape = 1, 1, 512, 512\n    result = 1, 2, 512, 512\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n\n\ndef test_unet_3d():\n    model = UNet3D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=False,\n    ).eval()\n    shape = 1, 1, 132, 132, 116\n    result = 1, 2, 44, 44, 28\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n\n\ndef test_unet_3d_residual():\n    model = UNet3D(\n        normalization=\"batch\",\n        preactivation=True,\n        residual=True,\n        num_encoding_blocks=2,\n        upsampling_type=\"trilinear\",\n    ).eval()\n    shape = 1, 1, 64, 64, 56\n    result = 1, 2, 64, 64, 56\n    y = run(model, shape)\n    assert tuple(y.shape) == result\n"
  },
  {
    "path": "tox.ini",
    "content": "[tox]\nenvlist = py3{9,10,11,12}\n\n[testenv]\ncommands = pytest\n"
  },
  {
    "path": "unet/__init__.py",
    "content": "__version__ = \"0.7.9\"\n\nfrom .unet import UNet, UNet1D, UNet2D, UNet3D\n\n__all__ = [\n    \"UNet\",\n    \"UNet1D\",\n    \"UNet2D\",\n    \"UNet3D\",\n]\n"
  },
  {
    "path": "unet/conv.py",
    "content": "from typing import Optional\n\nimport torch.nn as nn\n\n\nclass ConvolutionalBlock(nn.Module):\n    def __init__(\n        self,\n        dimensions: int,\n        in_channels: int,\n        out_channels: int,\n        normalization: Optional[str] = None,\n        kernel_size: int = 3,\n        activation: Optional[str] = \"ReLU\",\n        preactivation: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        dilation: Optional[int] = None,\n        dropout: float = 0,\n    ):\n        super().__init__()\n\n        block = nn.ModuleList()\n\n        dilation = 1 if dilation is None else dilation\n        if padding:\n            total_padding = kernel_size + 2 * (dilation - 1) - 1\n            padding = total_padding // 2\n\n        class_name = \"Conv{}d\".format(dimensions)\n        conv_class = getattr(nn, class_name)\n        no_bias = not preactivation and (normalization is not None)\n        conv_layer = conv_class(\n            in_channels,\n            out_channels,\n            kernel_size,\n            padding=padding,\n            padding_mode=padding_mode,\n            dilation=dilation,\n            bias=not no_bias,\n        )\n\n        norm_layer = None\n        if normalization is not None:\n            class_name = \"{}Norm{}d\".format(normalization.capitalize(), dimensions)\n            norm_class = getattr(nn, class_name)\n            num_features = in_channels if preactivation else out_channels\n            norm_layer = norm_class(num_features)\n\n        activation_layer = None\n        if activation is not None:\n            activation_layer = getattr(nn, activation)()\n\n        if preactivation:\n            self.add_if_not_none(block, norm_layer)\n            self.add_if_not_none(block, activation_layer)\n            self.add_if_not_none(block, conv_layer)\n        else:\n            self.add_if_not_none(block, conv_layer)\n            self.add_if_not_none(block, norm_layer)\n            self.add_if_not_none(block, activation_layer)\n\n        dropout_layer = None\n        if dropout:\n            class_name = \"Dropout{}d\".format(dimensions)\n            dropout_class = getattr(nn, class_name)\n            dropout_layer = dropout_class(p=dropout)\n            self.add_if_not_none(block, dropout_layer)\n\n        self.conv_layer = conv_layer\n        self.norm_layer = norm_layer\n        self.activation_layer = activation_layer\n        self.dropout_layer = dropout_layer\n\n        self.block = nn.Sequential(*block)\n\n    def forward(self, x):\n        return self.block(x)\n\n    @staticmethod\n    def add_if_not_none(module_list, module):\n        if module is not None:\n            module_list.append(module)\n"
  },
  {
    "path": "unet/decoding.py",
    "content": "from typing import Optional\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .conv import ConvolutionalBlock\n\nCHANNELS_DIMENSION = 1\nUPSAMPLING_MODES = (\n    \"nearest\",\n    \"linear\",\n    \"bilinear\",\n    \"bicubic\",\n    \"trilinear\",\n)\n\n\nclass Decoder(nn.Module):\n    def __init__(\n        self,\n        in_channels_skip_connection: int,\n        dimensions: int,\n        upsampling_type: str,\n        num_decoding_blocks: int,\n        normalization: Optional[str],\n        preactivation: bool = False,\n        residual: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        activation: Optional[str] = \"ReLU\",\n        initial_dilation: Optional[int] = None,\n        dropout: float = 0,\n    ):\n        super().__init__()\n        upsampling_type = fix_upsampling_type(upsampling_type, dimensions)\n        self.decoding_blocks = nn.ModuleList()\n        self.dilation = initial_dilation\n        for _ in range(num_decoding_blocks):\n            decoding_block = DecodingBlock(\n                in_channels_skip_connection,\n                dimensions,\n                upsampling_type,\n                normalization=normalization,\n                preactivation=preactivation,\n                residual=residual,\n                padding=padding,\n                padding_mode=padding_mode,\n                activation=activation,\n                dilation=self.dilation,\n                dropout=dropout,\n            )\n            self.decoding_blocks.append(decoding_block)\n            in_channels_skip_connection //= 2\n            if self.dilation is not None:\n                self.dilation //= 2\n\n    def forward(self, skip_connections, x):\n        zipped = zip(reversed(skip_connections), self.decoding_blocks)\n        for skip_connection, decoding_block in zipped:\n            x = decoding_block(skip_connection, x)\n        return x\n\n\nclass DecodingBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels_skip_connection: int,\n        dimensions: int,\n        upsampling_type: str,\n        normalization: Optional[str],\n        preactivation: bool = True,\n        residual: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        activation: Optional[str] = \"ReLU\",\n        dilation: Optional[int] = None,\n        dropout: float = 0,\n    ):\n        super().__init__()\n\n        self.residual = residual\n\n        if upsampling_type == \"conv\":\n            in_channels = out_channels = 2 * in_channels_skip_connection\n            self.upsample = get_conv_transpose_layer(\n                dimensions, in_channels, out_channels\n            )\n        else:\n            self.upsample = get_upsampling_layer(upsampling_type)\n        in_channels_first = in_channels_skip_connection * (1 + 2)\n        out_channels = in_channels_skip_connection\n        self.conv1 = ConvolutionalBlock(\n            dimensions,\n            in_channels_first,\n            out_channels,\n            normalization=normalization,\n            preactivation=preactivation,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            dilation=dilation,\n            dropout=dropout,\n        )\n        in_channels_second = out_channels\n        self.conv2 = ConvolutionalBlock(\n            dimensions,\n            in_channels_second,\n            out_channels,\n            normalization=normalization,\n            preactivation=preactivation,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            dilation=dilation,\n            dropout=dropout,\n        )\n\n        if residual:\n            self.conv_residual = ConvolutionalBlock(\n                dimensions,\n                in_channels_first,\n                out_channels,\n                kernel_size=1,\n                normalization=None,\n                activation=None,\n            )\n\n    def forward(self, skip_connection, x):\n        x = self.upsample(x)\n        skip_connection = self.center_crop(skip_connection, x)\n        x = torch.cat((skip_connection, x), dim=CHANNELS_DIMENSION)\n        if self.residual:\n            connection = self.conv_residual(x)\n            x = self.conv1(x)\n            x = self.conv2(x)\n            x = x + connection\n        else:\n            x = self.conv1(x)\n            x = self.conv2(x)\n        return x\n\n    def center_crop(self, skip_connection, x):\n        skip_shape = torch.tensor(skip_connection.shape)\n        x_shape = torch.tensor(x.shape)\n        crop = skip_shape[2:] - x_shape[2:]\n        half_crop = (crop / 2).int()\n        # If skip_connection is 10, 20, 30 and x is (6, 14, 12)\n        # Then pad will be (-2, -2, -3, -3, -9, -9)\n        pad = -torch.stack((half_crop, half_crop)).t().flatten()\n        skip_connection = F.pad(skip_connection, pad.tolist())\n        return skip_connection\n\n\ndef get_upsampling_layer(upsampling_type: str) -> nn.Upsample:\n    if upsampling_type not in UPSAMPLING_MODES:\n        message = 'Upsampling type is \"{}\"' \" but should be one of the following: {}\"\n        message = message.format(upsampling_type, UPSAMPLING_MODES)\n        raise ValueError(message)\n    upsample = nn.Upsample(\n        scale_factor=2,\n        mode=upsampling_type,\n        align_corners=False,\n    )\n    return upsample\n\n\ndef get_conv_transpose_layer(dimensions, in_channels, out_channels):\n    class_name = \"ConvTranspose{}d\".format(dimensions)\n    conv_class = getattr(nn, class_name)\n    conv_layer = conv_class(in_channels, out_channels, kernel_size=2, stride=2)\n    return conv_layer\n\n\ndef fix_upsampling_type(upsampling_type: str, dimensions: int):\n    if upsampling_type == \"linear\":\n        if dimensions == 2:\n            upsampling_type = \"bilinear\"\n        elif dimensions == 3:\n            upsampling_type = \"trilinear\"\n    return upsampling_type\n"
  },
  {
    "path": "unet/encoding.py",
    "content": "from typing import Optional\nimport torch.nn as nn\nfrom .conv import ConvolutionalBlock\n\n\nclass Encoder(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels_first: int,\n        dimensions: int,\n        pooling_type: str,\n        num_encoding_blocks: int,\n        normalization: Optional[str],\n        preactivation: bool = False,\n        residual: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        activation: Optional[str] = \"ReLU\",\n        initial_dilation: Optional[int] = None,\n        dropout: float = 0,\n    ):\n        super().__init__()\n\n        self.encoding_blocks = nn.ModuleList()\n        self.dilation = initial_dilation\n        is_first_block = True\n        for _ in range(num_encoding_blocks):\n            encoding_block = EncodingBlock(\n                in_channels,\n                out_channels_first,\n                dimensions,\n                normalization,\n                pooling_type,\n                preactivation,\n                is_first_block=is_first_block,\n                residual=residual,\n                padding=padding,\n                padding_mode=padding_mode,\n                activation=activation,\n                dilation=self.dilation,\n                dropout=dropout,\n            )\n            is_first_block = False\n            self.encoding_blocks.append(encoding_block)\n            if dimensions in (1, 2):\n                in_channels = out_channels_first\n                out_channels_first = in_channels * 2\n            elif dimensions == 3:\n                in_channels = 2 * out_channels_first\n                out_channels_first = in_channels\n            if self.dilation is not None:\n                self.dilation *= 2\n\n    def forward(self, x):\n        skip_connections = []\n        for encoding_block in self.encoding_blocks:\n            x, skip_connnection = encoding_block(x)\n            skip_connections.append(skip_connnection)\n        return skip_connections, x\n\n    @property\n    def out_channels(self):\n        return self.encoding_blocks[-1].out_channels\n\n\nclass EncodingBlock(nn.Module):\n    def __init__(\n        self,\n        in_channels: int,\n        out_channels_first: int,\n        dimensions: int,\n        normalization: Optional[str],\n        pooling_type: Optional[str],\n        preactivation: bool = False,\n        is_first_block: bool = False,\n        residual: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        activation: Optional[str] = \"ReLU\",\n        dilation: Optional[int] = None,\n        dropout: float = 0,\n    ):\n        super().__init__()\n\n        self.preactivation = preactivation\n        self.normalization = normalization\n\n        self.residual = residual\n\n        if is_first_block:\n            normalization = None\n            preactivation = None\n        else:\n            normalization = self.normalization\n            preactivation = self.preactivation\n\n        self.conv1 = ConvolutionalBlock(\n            dimensions,\n            in_channels,\n            out_channels_first,\n            normalization=normalization,\n            preactivation=preactivation,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            dilation=dilation,\n            dropout=dropout,\n        )\n\n        if dimensions in (1, 2):\n            out_channels_second = out_channels_first\n        elif dimensions == 3:\n            out_channels_second = 2 * out_channels_first\n        self.conv2 = ConvolutionalBlock(\n            dimensions,\n            out_channels_first,\n            out_channels_second,\n            normalization=self.normalization,\n            preactivation=self.preactivation,\n            padding=padding,\n            activation=activation,\n            dilation=dilation,\n            dropout=dropout,\n        )\n\n        if residual:\n            self.conv_residual = ConvolutionalBlock(\n                dimensions,\n                in_channels,\n                out_channels_second,\n                kernel_size=1,\n                normalization=None,\n                activation=None,\n            )\n\n        self.downsample = None\n        if pooling_type is not None:\n            self.downsample = get_downsampling_layer(dimensions, pooling_type)\n\n    def forward(self, x):\n        if self.residual:\n            connection = self.conv_residual(x)\n            x = self.conv1(x)\n            x = self.conv2(x)\n            x = x + connection\n        else:\n            x = self.conv1(x)\n            x = self.conv2(x)\n        if self.downsample is None:\n            return x\n        else:\n            skip_connection = x\n            x = self.downsample(x)\n            return x, skip_connection\n\n    @property\n    def out_channels(self):\n        return self.conv2.conv_layer.out_channels\n\n\ndef get_downsampling_layer(\n    dimensions: int,\n    pooling_type: str,\n    kernel_size: int = 2,\n) -> nn.Module:\n    class_name = \"{}Pool{}d\".format(pooling_type.capitalize(), dimensions)\n    class_ = getattr(nn, class_name)\n    return class_(kernel_size)\n"
  },
  {
    "path": "unet/unet.py",
    "content": "# -*- coding: utf-8 -*-\n\n\"\"\"Main module.\"\"\"\n\nfrom typing import Optional\nimport torch.nn as nn\nfrom .encoding import Encoder, EncodingBlock\nfrom .decoding import Decoder\nfrom .conv import ConvolutionalBlock\n\n__all__ = [\"UNet\", \"UNet1D\", \"UNet2D\", \"UNet3D\"]\n\n\nclass UNet(nn.Module):\n    def __init__(\n        self,\n        in_channels: int = 1,\n        out_classes: int = 2,\n        dimensions: int = 2,\n        num_encoding_blocks: int = 5,\n        out_channels_first_layer: int = 64,\n        normalization: Optional[str] = None,\n        pooling_type: str = \"max\",\n        upsampling_type: str = \"conv\",\n        preactivation: bool = False,\n        residual: bool = False,\n        padding: int = 0,\n        padding_mode: str = \"zeros\",\n        activation: Optional[str] = \"ReLU\",\n        initial_dilation: Optional[int] = None,\n        dropout: float = 0,\n        monte_carlo_dropout: float = 0,\n    ):\n        super().__init__()\n        depth = num_encoding_blocks - 1\n\n        # Force padding if residual blocks\n        if residual:\n            padding = 1\n\n        # Encoder\n        self.encoder = Encoder(\n            in_channels,\n            out_channels_first_layer,\n            dimensions,\n            pooling_type,\n            depth,\n            normalization,\n            preactivation=preactivation,\n            residual=residual,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            initial_dilation=initial_dilation,\n            dropout=dropout,\n        )\n\n        # Bottom (last encoding block)\n        in_channels = self.encoder.out_channels\n        if dimensions in (1, 2):\n            out_channels_first = 2 * in_channels\n        else:\n            out_channels_first = in_channels\n\n        self.bottom_block = EncodingBlock(\n            in_channels,\n            out_channels_first,\n            dimensions,\n            normalization,\n            pooling_type=None,\n            preactivation=preactivation,\n            residual=residual,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            dilation=self.encoder.dilation,\n            dropout=dropout,\n        )\n\n        # Decoder\n        if dimensions in (1, 2):\n            power = depth - 1\n        elif dimensions == 3:\n            power = depth\n        in_channels = self.bottom_block.out_channels\n        in_channels_skip_connection = out_channels_first_layer * 2**power\n        num_decoding_blocks = depth\n        self.decoder = Decoder(\n            in_channels_skip_connection,\n            dimensions,\n            upsampling_type,\n            num_decoding_blocks,\n            normalization=normalization,\n            preactivation=preactivation,\n            residual=residual,\n            padding=padding,\n            padding_mode=padding_mode,\n            activation=activation,\n            initial_dilation=self.encoder.dilation,\n            dropout=dropout,\n        )\n\n        # Monte Carlo dropout\n        self.monte_carlo_layer = None\n        if monte_carlo_dropout:\n            dropout_class = getattr(nn, \"Dropout{}d\".format(dimensions))\n            self.monte_carlo_layer = dropout_class(p=monte_carlo_dropout)\n\n        # Classifier\n        if dimensions in (1, 2):\n            in_channels = out_channels_first_layer\n        elif dimensions == 3:\n            in_channels = 2 * out_channels_first_layer\n        self.classifier = ConvolutionalBlock(\n            dimensions,\n            in_channels,\n            out_classes,\n            kernel_size=1,\n            activation=None,\n        )\n\n    def forward(self, x):\n        skip_connections, encoding = self.encoder(x)\n        encoding = self.bottom_block(encoding)\n        x = self.decoder(skip_connections, encoding)\n        if self.monte_carlo_layer is not None:\n            x = self.monte_carlo_layer(x)\n        return self.classifier(x)\n\n\nclass UNet1D(UNet):\n    def __init__(self, *args, **user_kwargs):\n        kwargs = {}\n        kwargs[\"dimensions\"] = 1\n        kwargs[\"num_encoding_blocks\"] = 5\n        kwargs[\"out_channels_first_layer\"] = 64\n        kwargs.update(user_kwargs)\n        super().__init__(*args, **kwargs)\n\n\nclass UNet2D(UNet):\n    def __init__(self, *args, **user_kwargs):\n        kwargs = {}\n        kwargs[\"dimensions\"] = 2\n        kwargs[\"num_encoding_blocks\"] = 5\n        kwargs[\"out_channels_first_layer\"] = 64\n        kwargs.update(user_kwargs)\n        super().__init__(*args, **kwargs)\n\n\nclass UNet3D(UNet):\n    def __init__(self, *args, **user_kwargs):\n        kwargs = {}\n        kwargs[\"dimensions\"] = 3\n        kwargs[\"num_encoding_blocks\"] = 4\n        kwargs[\"out_channels_first_layer\"] = 32\n        kwargs[\"normalization\"] = \"batch\"\n        kwargs.update(user_kwargs)\n        super().__init__(*args, **kwargs)\n"
  }
]