Full Code of mgharbi/demosaicnet for AI

master 959e9d163097 cached
24 files
82.4 KB
24.8k tokens
108 symbols
1 requests
Download .txt
Repository: mgharbi/demosaicnet
Branch: master
Commit: 959e9d163097
Files: 24
Total size: 82.4 KB

Directory structure:
gitextract_p3udv4rd/

├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── demosaicnet/
│   ├── .gitignore
│   ├── __init__.py
│   ├── dataset.py
│   ├── modules.py
│   ├── mosaic.py
│   ├── utils.py
│   └── version.py
├── docs/
│   ├── .gitignore
│   ├── Makefile
│   └── source/
│       ├── conf.py
│       ├── dataset.rst
│       ├── helpers.rst
│       ├── index.rst
│       └── models.rst
├── requirements.txt
├── scripts/
│   ├── demosaicnet_demo.py
│   ├── eval.py
│   └── train.py
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
output/
dist/
demosaicnet.egg-info
build/
data
.DS_Store


================================================
FILE: LICENSE
================================================
MIT License

Deep Joint Demosaicking and Denoising
Siggraph Asia 2016
Michael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand

Copyright (c) 2016 Michael Gharbi

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: MANIFEST.in
================================================
include demosaicnet/data/bayer.pth
include demosaicnet/data/xtrans.pth
include demosaicnet/data/test_input.png


================================================
FILE: Makefile
================================================
test:
	py.test tests

.PHONY: docs
docs:
	$(MAKE) -C docs html

clean:
	python setup.py clean
	rm -rf build demosaicnet.egg-info dist .pytest_cache

distribution:
	python setup.py sdist bdist_wheel
	twine check dist/*

test_upload:
	twine upload --repository-url https://test.pypi.org/legacy/ dist/*

upload_distribution:
	twine upload dist/*


================================================
FILE: README.md
================================================
# Deep Joint Demosaicking and Denoising
SiGGRAPH Asia 2016

Michaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris Frédo Durand

A minimal pytorch implementation of "Deep Joint Demosaicking and Denoising" [Gharbi2016]

# Installation

From this repo:

```shell
python setup.py install
```

Using pip:

```shell
pip install demosaicnet
```

Then run the demo script with:

```shell
python scripts/demosaicnet_demo.py output
```

To train a dummy model on the demo dataset provided, run:

```shell
python scripts/train.py --data demosaicnet/data/dummy_dataset --checkpoint_dir ckpt
```

To build and update the whee:

```shell
pip install wheel twine
make distribution
make upload_distribution
```

# FAQ

- **How is noise handled? Where is the pretrained model?** The noise-aware model is not implementation, see the earlier Caffe implementation for that <https://github.com/mgharbi/demosaicnet_caffe>
- **How do I train this?** The script `scripts/train.py` is a good start to setup your training job, but I haven't tested it yet, I recommend rolling your own.


================================================
FILE: demosaicnet/.gitignore
================================================
__pycache__


================================================
FILE: demosaicnet/__init__.py
================================================
from .modules import BayerDemosaick
from .modules import XTransDemosaick
from .mosaic import xtrans
from .mosaic import bayer
from .mosaic import xtrans_cell
from .dataset import *
from . import utils


================================================
FILE: demosaicnet/dataset.py
================================================
"""Dataset loader for demosaicnet."""
import os
import subprocess
import shutil
import hashlib
import logging


import numpy as np
from imageio import imread
from torch.utils.data import Dataset as TorchDataset
import wget

from .mosaic import bayer, xtrans

__all__ = ["BAYER_MODE", "XTRANS_MODE", "Dataset",
           "TRAIN_SUBSET", "VAL_SUBSET", "TEST_SUBSET"]


log = logging.getLogger(__name__)

BAYER_MODE = "bayer"
"""Applies a Bayer mosaic pattern."""

XTRANS_MODE = "xtrans"
"""Applies an X-Trans mosaic pattern."""

TRAIN_SUBSET = "train"
"""Loads the 'train' subset of the data."""

VAL_SUBSET = "val"
"""Loads the 'val' subset of the data."""

TEST_SUBSET = "test"
"""Loads the 'test' subset of the data."""


class Dataset(TorchDataset):
    """Dataset of challenging image patches for demosaicking.

    Args:
        download(bool): if True, automatically download the dataset.
        mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data.
        subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load.
    """

    def __init__(self, root, download=False,
                 mode=BAYER_MODE, subset="train"):

        super(Dataset, self).__init__()

        self.root = os.path.abspath(root)

        if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]:
            raise ValueError("Dataset subet should be '%s', '%s' or '%s', got"
                             " %s" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET,
                                      subset))

        if mode not in [BAYER_MODE, XTRANS_MODE]:
            raise ValueError("Dataset mode should be '%s' or '%s', got"
                             " %s" % (BAYER_MODE, XTRANS_MODE, mode))
        self.mode = mode

        listfile = os.path.join(self.root, subset, "filelist.txt")
        log.debug("Reading image list from %s", listfile)

        if not os.path.exists(listfile):
            if download:
                _download(self.root)
            else:
                log.error("Filelist %s not found", listfile)
                raise ValueError("Filelist %s not found" % listfile)
        else:
            log.debug("No need no download the data, filelist exists.")

        self.files = []
        with open(listfile, "r") as fid:
            for fname in fid.readlines():
                self.files.append(os.path.join(self.root, subset, fname.strip()))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        """Fetches a mosaic / demosaicked pair of images.

        Returns
            mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels.
            img(np.array): with size [3, h, w] the groundtruth image.
        """
        fname = self.files[idx]
        img = np.array(imread(fname)).astype(np.float32) / (2**8-1)
        img = np.transpose(img, [2, 0, 1])

        if self.mode == BAYER_MODE:
            mosaic = bayer(img)
        else:
            mosaic = xtrans(img)

        return mosaic, img


CHECKSUMS = {
    'datasets.z01': 'da46277afe85d3a91c065e4751fb8175',
    'datasets.z02': 'e274a9646323d954b00094ea424e4e4c',
    'datasets.z03': 'e071cc595a99a5aa4545d06350e5165f',
    'datasets.z04': 'c3d2f229834569cd5ae6e2d1467c4a95',
    'datasets.z05': 'daf90136c7b1ee4bb4653e9b6bf4b67d',
    'datasets.z06': '87e85d2854d40116e066b28e5a8750cc',
    'datasets.z07': 'a0b2854bf025c87c0bfdf83ce9aa9055',
    'datasets.z08': '62125ccf29cd4b182dd81a4bb82f94c4',
    'datasets.z09': 'f990f8a5090d586f2f31e61b5e6434bd',
    'datasets.z10': '41ecf8d8b7d981604d661d258bf988db',
    'datasets.z100': '923a536ece64cd036eec4a13156531c8',
    'datasets.z101': '44a936558af2e830fdf65d9acb3960ab',
    'datasets.z102': 'b24870482b41200ab7b91f0bcd3ed718',
    'datasets.z103': 'a85521c1fe0b8d2d1a074b0b52bf9db1',
    'datasets.z104': 'aacc7a81ec9e9a7849e3a45b1cb12f7c',
    'datasets.z105': '19b62c0f0ae008b77df6465182f43dc4',
    'datasets.z106': '4b0c414ce5825a9e2249e5810f0e55f0',
    'datasets.z107': '7f6df7fea899a656fcde898225890daf',
    'datasets.z108': '16a877c357f112367200a2534b5e54f3',
    'datasets.z109': '9180129bf9c204184f729bdf1c284c9c',
    'datasets.z11': 'cff5b0e9950933fa9dd6ced8ffb9528f',
    'datasets.z110': 'a95a6fbfd32d90058b9e4b9f0645c646',
    'datasets.z111': 'f3c7894a7d04178ca417dd5ed3a9e649',
    'datasets.z112': 'd46a73703a72c07137424cad90c9c0bd',
    'datasets.z113': '04b3421e465c5ef8bd64fc23730b58f7',
    'datasets.z114': '9c31d2d94bd6f1ba321039b18c462175',
    'datasets.z115': '427a3a6f3f936b0ba435da35be3e4bc3',
    'datasets.z116': 'c633a66e7644d7d8e8148d651b76d93f',
    'datasets.z117': 'cf316d3acaf301fc7b2b7e250ef734dd',
    'datasets.z118': 'c3b53a604492499930dfd000d7d33fa4',
    'datasets.z119': 'e64ede2179c589abd9e587347d9ba3b0',
    'datasets.z12': '17b70298245ae7965f4e4b4fb01f19cc',
    'datasets.z120': 'ece09ee0bc30eab71f06716cca393029',
    'datasets.z121': '79db008803bf58593df6c32db8c0b3d0',
    'datasets.z122': '647b151eb30a44d123ce9ddfbb380094',
    'datasets.z123': '5065f755fe61c666d6ed28096e4047a3',
    'datasets.z124': '6306215855e30112c495291d5928e0b7',
    'datasets.z125': 'a55c6e31e7ad42170016a15791e25134',
    'datasets.z126': '582eb81f0251840507ca0b53e624b1e0',
    'datasets.z127': '3beefb769a01481a8bc7ae39bc2f539b',
    'datasets.z128': 'fac96b38f96ea364ea51020386597b5d',
    'datasets.z129': '3aaca8ce2b67d2c1fe764ae2b306d17f',
    'datasets.z13': '7ec2aa595441d9f698a46d707f299e8c',
    'datasets.z130': '9f94105dcc39cf5421b2df2532c06ec9',
    'datasets.z131': '5de2901388a2e531d6874cdaf23bf15e',
    'datasets.z132': '6851ed45004ae6864892532d1ad44b20',
    'datasets.z133': '0f6b417d57f9bdc9fa91d85ff5e3378d',
    'datasets.z134': '0602b14c8828f7a9fed92713047c695f',
    'datasets.z135': '3ab79fd4da5c4c5b5a7896a189ec43a7',
    'datasets.z136': 'dd05152db786d8189cdb419ac5d0018a',
    'datasets.z137': 'b4e97abef22ee8b81ac232760d9f539d',
    'datasets.z138': '4ce939f1fa4f3e110e989db07d53d33f',
    'datasets.z139': 'a096add471cc5e074852c063eec3863a',
    'datasets.z14': '08571b6629b8856813fb35b45bbc082c',
    'datasets.z140': '3d84e8cc84ab26969c5239be60222ad3',
    'datasets.z141': 'd77062f59ca9957d33c9a671657fd795',
    'datasets.z142': '82901d01917006348deb89aa37fe3629',
    'datasets.z143': '736f77856f0854b26fa951479691df8f',
    'datasets.z144': '55c44320975f4278a8837085c5e02eda',
    'datasets.z145': '087d3b7634bf4720a916767d5c6b7d70',
    'datasets.z146': '5659d6f0495dcdc5f5d98bf2efaaa09b',
    'datasets.z147': '66dd69b2f9348e3c0d0c93c3e61416dd',
    'datasets.z148': 'f3fc8f15aceb0f9bf04d786b894caa44',
    'datasets.z149': '3863be1d2b130f79399432cfc1281c2e',
    'datasets.z15': 'cc57e0c4466575436f670ac3e07ad2f3',
    'datasets.z150': '09750f2019da9ff7132b904b8bcbd895',
    'datasets.z151': 'b1573f086c0f7d1fdf249a8e3a9bb178',
    'datasets.z152': '1a2d4374aea1e22c0b676a6a7eac49ec',
    'datasets.z153': 'b24320708d2019ed71ab16055e971b1d',
    'datasets.z154': '7ba27e1946afa610e131f3afefe78326',
    'datasets.z155': '2f02c8b5470be4cb6b53e4c9e512394e',
    'datasets.z156': 'fa4f0977409f181820bf78174257d657',
    'datasets.z157': '6736b97a29d1393ec65ddc9376a06369',
    'datasets.z158': 'b4b72842b13ec3877bc530ca2470a0db',
    'datasets.z159': 'fa0dfa57c9d299175719bbbdf319c935',
    'datasets.z16': '02c213b708e2ee7ebd68464dfb2279fc',
    'datasets.z160': '6edbb9dc7fa6d12d2e21631ae14eaa8c',
    'datasets.z161': '4ab093ce5af2726e9ee71fdf1943e8e2',
    'datasets.z162': '4e21401db9f9884d953df20381c5fd97',
    'datasets.z163': '8d39f0ed1a1d9b5de22b583d00081522',
    'datasets.z164': '40fe425e2c5e89b87b44a6e9735590d5',
    'datasets.z165': '9552ff9b03e2dbb45befc9d1cc99ad81',
    'datasets.z166': '6c0098a36d6827aea846d8522c578751',
    'datasets.z167': 'ce6b8b981d92f5a61f2ec40089a400a2',
    'datasets.z168': '60f6e16a3e5e409a3fc89edd3e0034d5',
    'datasets.z169': '75eb975a10d5cbf136796651a1789b42',
    'datasets.z17': 'c92ae62205eaef02db27996f0dc6c282',
    'datasets.z170': '287966840fff015ed36da3a08a18ebfb',
    'datasets.z171': 'e08193b722af492a78ac36a3125ac8d9',
    'datasets.z172': '32c795461f194c38b25047faeb46fdb5',
    'datasets.z173': '58fbb396dbbb902ac2f2c43722573200',
    'datasets.z174': '506fee2b982ee81689f3fe4d89133cac',
    'datasets.z175': '9d553e31b07b23e30c427800168eec6b',
    'datasets.z176': '0f6e3048824ead093d3127434ac83a72',
    'datasets.z177': 'ece15c004fa708849295987b8b1aba9c',
    'datasets.z178': 'd9db14d92d56ae2970798417030b5bb4',
    'datasets.z179': '25ddbe866d0a6b9cebe8d90f7b801fa6',
    'datasets.z18': 'f55ddc31cf203f495e352182a5bbadc3',
    'datasets.z180': 'd2fc49c68c77d1da592aff4ab90c0915',
    'datasets.z181': '5a4a635b1f3535311c6caecf4ab3ba80',
    'datasets.z182': 'df51725daff3edfc377a5f6bc158ec3a',
    'datasets.z183': '626add199ec4f263ff278d5392f41c9c',
    'datasets.z184': '5069483fd064ee5e8c24a240e6ee7736',
    'datasets.z185': '589249e98db0a4ded1d3e4acefd07509',
    'datasets.z186': 'e4415c64463ef16bceb9d2e2fa934d71',
    'datasets.z187': 'e070cbaaf88a1085964244f6505c713c',
    'datasets.z188': '71b38eb51edff8b049a302bacbe344d6',
    'datasets.z189': '8fa7a8b58c9e7cb9e86bfd0ca5f6d2ea',
    'datasets.z19': '6c34cd0e39a33737983ebf89f6cabf5f',
    'datasets.z190': 'daad0ea7c87d0935e014a370c38cc926',
    'datasets.z191': 'c355e9ee9d0afa67faa34739b7f7cf79',
    'datasets.z192': 'c97d5a784625795cdf3c36c337986afe',
    'datasets.z193': '5f3b8425e215798c9e454cdbe586db90',
    'datasets.z194': '31b5d74c1cbbabbf58ee470467b40d12',
    'datasets.z195': '4c65958343bc2ea1a28e779ee7e5e498',
    'datasets.z196': '26ab3664e62c7fd5d0be673c45dd0d93',
    'datasets.z197': '32d690086b6e9f05e3ced3a126af870c',
    'datasets.z198': '323071827db89626c9f186455fbb38c9',
    'datasets.z199': '72475a8500be1ff21407a66f0e2e91a6',
    'datasets.z20': '4161d6eda0ca5ed9500f953f789a25b2',
    'datasets.z200': '44a863b9c9760cb87a23f1422f242c0d',
    'datasets.z201': 'dc38b455fa45e3ef0d5f06397507982e',
    'datasets.z202': 'b9ba231b317b008602f9472325b40e65',
    'datasets.z203': '36f4afc46258d80be626040956550028',
    'datasets.z204': '5c522fbdfe1f9d449c9189173f2ed2b2',
    'datasets.z205': '58d39995017eed2c4abcf9fcfd07e695',
    'datasets.z206': '2efdfc2abc834f0f0f1cabe10423f865',
    'datasets.z207': '04ead7536e5c13936c724f644ab1cb3a',
    'datasets.z208': '9e3e0a02a07bebcc7cdb62a0ad047946',
    'datasets.z209': 'c1fc44cd8b6f50955c8b3b317155ecb6',
    'datasets.z21': '8669b8bb9fa90628d4423c45648868b2',
    'datasets.z210': '210df79f8434bdb4e2a7d12c4078d972',
    'datasets.z211': 'f078d2d8a14b6c59f58c67865bbc3334',
    'datasets.z212': 'ef7d08a6cc39f6cb96b631ca61b440a0',
    'datasets.z213': '723057f7619d8820f142944f55f9542b',
    'datasets.z214': '2ba38ca8561b51710f660c03f84c0eb1',
    'datasets.z215': '92a6e97dfaff295110ddead242ebe932',
    'datasets.z216': '05f40901ae70f73b3c099fcdd4ca945e',
    'datasets.z217': 'b690ed3e8c6ba9f8bba9154d7e8f7ece',
    'datasets.z218': 'e290ca6f5573579df9f3aa7c5158891e',
    'datasets.z219': 'a8e4626968f089163179f30066ce732c',
    'datasets.z22': '0f90463abdc8f0f81d81249302cf2d09',
    'datasets.z220': '3ecd2c0c855505d2957046d784944fce',
    'datasets.z221': '6c9c28287fbadcab2ce777ef3134e5d6',
    'datasets.z222': '29361a77f05e5e68113fc23e11b54b4d',
    'datasets.z223': '9214257b9a87c0037e88709addba8948',
    'datasets.z224': '7596a516fb7e308f33a81c5b3c36810a',
    'datasets.z225': 'c1dc079f5261a976b1bd7f5c05cd4a02',
    'datasets.z226': '5b87815b0ccc5cacec83a399a52874aa',
    'datasets.z227': '6bed353dc50263b2c720af663c833bbc',
    'datasets.z228': '37ed3574bf978bccd6e2db9be00bae94',
    'datasets.z229': '8fd57367808fd77581f998850e5f935a',
    'datasets.z23': '90ae2cdcdc1663b80c20e080e5c0e038',
    'datasets.z230': '3f5ef3234da0236d2fbfaf7366407d70',
    'datasets.z231': 'f67d8320028620c8bdd9a800a78afa27',
    'datasets.z232': '5f831f25f8e8557168b38a7a28f8e7f9',
    'datasets.z233': '56ee8c4b01825ba7f12340ae8b990db3',
    'datasets.z234': '5428e98487b0e077cf9c24dc60599286',
    'datasets.z235': '883c0ea97facca4d57d5c9c54922e8be',
    'datasets.z236': 'e23fb6f610a3b528d5b310df4e452256',
    'datasets.z237': 'bd858e84a47668edc851dab131239ae4',
    'datasets.z238': 'db51bed3e3e5c6a40881f22532618533',
    'datasets.z239': 'c1be852117739fc63227a503b08a8436',
    'datasets.z24': '00dcb2e2a72b15a9aa9a646ecaea0019',
    'datasets.z240': 'f4e6da02349b03b4f433b3399dcf8b3c',
    'datasets.z241': '4f29898105aaf9f1a753a1c639947c2f',
    'datasets.z242': '168a15bd8367f5d5f3e5e8cf4d0da6af',
    'datasets.z243': '2930bced33bd1ecabce070fc831567e7',
    'datasets.z244': 'aed9c5ec05f57e3fa9e7b224d47fa7b0',
    'datasets.z245': '5aa83729fec805c166e48e5ec21530a5',
    'datasets.z246': '646ceabfae028d631568930b4056227a',
    'datasets.z247': '7381491175c1a63cc04ecac81148925b',
    'datasets.z248': '4064df81449c1980d0abaa8c7262b315',
    'datasets.z249': '7ae84d1dde2e935d86138d1e7b077df8',
    'datasets.z25': '5ae383bcd01d4ce22387680e28833f06',
    'datasets.z250': 'c018b41fbc4982a561b07cf0d52137f2',
    'datasets.z251': '9c9dc7a889d537fd1e02f4549529a5f1',
    'datasets.z252': 'b74df0680e7a62794902186fe1e3fec2',
    'datasets.z253': '894cbb618ddffe65ce2ada0f250ad79c',
    'datasets.z254': 'd5d8bc590d109c7d592d4df183c495a8',
    'datasets.z255': 'faf313a3edd70129c212d3dbd1de5042',
    'datasets.z256': '0fefcb84b66518df03a93ae53079409f',
    'datasets.z257': '887e8e8abaf09682903b9b1060fb8153',
    'datasets.z258': 'a700ae13abb7707032123468fab1bf55',
    'datasets.z259': 'cad7c974b832d27d1cfb4ae0e4dc6c3c',
    'datasets.z26': '5e0cdf281eeca969a4e0adfc44e11dcb',
    'datasets.z260': '5e1c11df40440e4e84354d40efe7940e',
    'datasets.z261': '2a75579c238569356855244d9fedf50b',
    'datasets.z262': '51d508271fef5558df387542b3561b67',
    'datasets.z263': 'e11b54dcd5069e838a34dcc2daebf4b5',
    'datasets.z264': '8a5a3b288d21a3c2ef641370d436703e',
    'datasets.z265': '56954047cba7c8732f0490323540af43',
    'datasets.z266': 'cad0325e494cc720c385ac7420acd2d7',
    'datasets.z267': '25de522b499ec7af12f583dd89a31769',
    'datasets.z268': '0e9882f392ca679e8c47276371384efd',
    'datasets.z269': '28487bc8eb731d4913254a9d63bb13ae',
    'datasets.z27': 'c78804dfac1e395156abd235ca416b33',
    'datasets.z270': '276fde25d412d8a1197e3dda307580d7',
    'datasets.z271': 'ba6f46558aa9ebc64efca2485b4f18ca',
    'datasets.z272': '5acceb08940d937d3023d98e745b8197',
    'datasets.z273': '57201a53390fe9c6a8c069397dcb81b8',
    'datasets.z274': '92491b6786ccea7b6ccccfb4e09c6d75',
    'datasets.z275': '349136c52b8554f03967d7083e5cd95c',
    'datasets.z276': '1525075dbf4d5d101d63cdade8bed9e7',
    'datasets.z277': 'eeb6365ef482cd2c6bac20aab8181081',
    'datasets.z278': 'ab1b04860b27fc11b7f57074a0815877',
    'datasets.z279': 'dc1cee0d4b69da9fd7aeb47f91768589',
    'datasets.z28': '0257b938256a2b7b55637970ebb3edcc',
    'datasets.z280': 'd168adf5e7c223a1d8dddfc663ebaeb0',
    'datasets.z281': '540fc1de91a90e9bd91b3f2b590ddbf6',
    'datasets.z282': 'f3fd6fbe05cfd53eb4a4c2e41bf75cc7',
    'datasets.z283': '20e89e60dd6a582bcb98b74df82699c3',
    'datasets.z284': '6e6ba6077437285881609999ead45463',
    'datasets.z285': 'fdc4bad8adb36b3d6653a438f1fc000f',
    'datasets.z286': '39c0f6fc7aace7e30a33e7e73afdb6ae',
    'datasets.z287': 'e60818eefd7426f3de0cc0746550be7f',
    'datasets.z288': 'd9db9107b8e92c0bf6a311a219363554',
    'datasets.z289': 'a250aa672d1f9165981cfaf1c6c8fff6',
    'datasets.z29': 'e1240710e1c4dd506aec03c02caf5606',
    'datasets.z290': '6dcb5a7674c927ec4e965a42d04a0ccd',
    'datasets.z291': 'c058fca515b7f714338816b72672ce20',
    'datasets.z292': '3fa254ed46ad6a6f7686836fe6fb7991',
    'datasets.z293': '782344c6620582d6b1681e142853a61e',
    'datasets.z294': '5263c5eb50cfec20adf82a89973a7547',
    'datasets.z295': '0c217819aa7308ce8f744511e572a632',
    'datasets.z296': 'cf8e7f1d3503ad6371ebcc9f827a29f8',
    'datasets.z297': 'f16861471be342b291e55991654c882b',
    'datasets.z298': 'ceb8b1170f1f2ec8113ef1c004df236c',
    'datasets.z299': '4f66a50c5dfb03143a6456c7fd925ec1',
    'datasets.z30': '75b0444187fc6c3df7bd3182108bc647',
    'datasets.z300': 'becbeea47cef192a1a13b35911c4795a',
    'datasets.z301': '92a67108c6bde11111f3b94f690f4b42',
    'datasets.z302': '06292210e05575cf1632a099648c16af',
    'datasets.z303': 'ecb7a48777719322c957ecc99340f04a',
    'datasets.z304': 'b318fc0b7d645d16467b78bbce95befa',
    'datasets.z305': '5f83e6a17e5977c2b99fc29893a1f479',
    'datasets.z306': 'dde2ec22363740081f54032a3add00e0',
    'datasets.z307': 'e0b7f9a7eddf2117a6127542883eb767',
    'datasets.z308': 'fff180911deaf4b476f42e6e47d78e6e',
    'datasets.z309': 'c55a8ac017f7ea69e77519fbcc617301',
    'datasets.z31': '04bee9374c7436d66f560bbbfc22299d',
    'datasets.z310': '3b4282fd1ab2b4f885df196a83726d34',
    'datasets.z311': '08a7c41c5d290750d9ebf6266a86bec8',
    'datasets.z312': 'b0f844ffd0ae6785e077ef8871dfc5da',
    'datasets.z313': 'bd03ba03a63877b17274e21ddb828218',
    'datasets.z314': '392f4528f9c345355434e7448a80b28b',
    'datasets.z315': 'e499cdf8acf1561720d4eb8f9ad9daad',
    'datasets.z316': '2575496c4d9c5082c6c4ef2f0cedeb69',
    'datasets.z317': '6f387284dbef478e02f7a5954da66015',
    'datasets.z318': '2ea47b6b7c9790bbfc2d3c238ffba391',
    'datasets.z319': '8e1e48b2aa1d6c7ae840a23360a3a8f8',
    'datasets.z32': 'ab930ee97741da82bfc778474f528a28',
    'datasets.z320': 'a785b37410ccf0366f97c0d221faa629',
    'datasets.z321': '16c73037fa0ca7704a3c4ceddbd7d599',
    'datasets.z322': 'af97027070cf5d057934dfd6bd819e61',
    'datasets.z323': '79078baf8d5fee935620f48aed2980a6',
    'datasets.z324': '987b626c8731b0092df593f0eaddb32a',
    'datasets.z325': 'fcc9a289015b40044c55ca96bc3dbe1e',
    'datasets.z326': 'f2921b782a23f2729e1a6641e8c99954',
    'datasets.z327': '6d1b488d88cf0fcaf5105b6544e5ea66',
    'datasets.z328': '268885eaeb6290be4ddfaefb34943985',
    'datasets.z329': '3951d3dd0527b54c5302720ea037cb12',
    'datasets.z33': 'fff4585bf34e5a7ed8f369b337aac901',
    'datasets.z330': '4b546d783366442d95eed64f882ae9f6',
    'datasets.z331': '4605615511ea901c18256306263b2226',
    'datasets.z332': '4127b2a8dd549b02ce3e465cfcfcc0a3',
    'datasets.z333': '711d282247b4fd56758c96c13bfe1b8e',
    'datasets.z34': '39b5d53f995fa8c223ed0d7a2de34652',
    'datasets.z35': 'ae351c1e2961f99a6d3ac37fbae27548',
    'datasets.z36': 'c2102c4a984d03c32c7c99378df953dc',
    'datasets.z37': '76ba463895984049a1814654b7290890',
    'datasets.z38': '5481901e75ee9d55066ba2731b2f36b7',
    'datasets.z39': '59815df91b2532d1e300bc71c976da12',
    'datasets.z40': '321ebe9b7812c14ee185fe2d6f16300c',
    'datasets.z41': '862bc3fcb9fc1df2bef61561bcca8090',
    'datasets.z42': '10303d540fea2150a7574cefcec92977',
    'datasets.z43': '2269db3212f2d2db86982408a2b24948',
    'datasets.z44': '80673bdbd722d02d97febeca00e57cf1',
    'datasets.z45': '42329dbb5c5165788902f33265db66a3',
    'datasets.z46': 'f81999194d418c515bb5df32c278d7f7',
    'datasets.z47': '2e3d3520636b5f3eb0cd2d649b6b4dab',
    'datasets.z48': 'e57e7aee104e80deaf068fbdd3292410',
    'datasets.z49': '36146291e4af0eff44e763e7e1facb4f',
    'datasets.z50': '60ef115c5c621c757b9b7075d5590c20',
    'datasets.z51': '02326e4392d176c2aa6d479cf43f29f8',
    'datasets.z52': 'b454f81ef7cda24cb13b8176c641d7ef',
    'datasets.z53': '95a4857fe7bc6230e6a3cc379085e989',
    'datasets.z54': '8ff35d1e9d738eb8b2afde04699ba73f',
    'datasets.z55': '75d19ea4a9a283b6d236e3353063d82d',
    'datasets.z56': '2ec9121fb364cc76eaf6fe7ff947dd04',
    'datasets.z57': 'f5d0b9b2b0a91f82dea784023f71cf3c',
    'datasets.z58': '79e318093d97bd573cc7aab4ed68dd6d',
    'datasets.z59': '68350ada72cfc22d2d6fc8ca636ffef2',
    'datasets.z60': '1ff9d36ec2b3723253503d524c90c9df',
    'datasets.z61': '9fb963355e8c895c1a0a02fa45eb6fb2',
    'datasets.z62': '5900cdc6c3d5f4b0c05825b47946096a',
    'datasets.z63': '9d039ddf86aef563bf7834faa2766e84',
    'datasets.z64': '15decde2fb2b4f869684f88f067378c1',
    'datasets.z65': '7cf283ebffd79c38bc09835caba483e1',
    'datasets.z66': 'e3b65f4facef36a355b444bd6b4ec73b',
    'datasets.z67': 'b007466dd2ac3388a5902fdf92979b2d',
    'datasets.z68': 'c07cd3fbc9a99e8f4d3cc9bbc55682e0',
    'datasets.z69': 'cff012a6a3af8a9f286cf68baf37ac73',
    'datasets.z70': '032ad7661e468723242b995357870a05',
    'datasets.z71': 'fd7aa78706b8ae7ff50eb2f312a14ea5',
    'datasets.z72': 'c2b90b8e8f75e36c927d523138d8ae93',
    'datasets.z73': '8c001e26b5f4ca0e061ff0b685d509c6',
    'datasets.z74': '484dcaf9311cc4315175e114acf01f32',
    'datasets.z75': '19f6896b426b99f34b99c88f3b68209f',
    'datasets.z76': '4a64c040ba9aa6e455fee774a08873e5',
    'datasets.z77': 'e96b1d94e51741ca916580c5f8287ef5',
    'datasets.z78': 'e0f5c9d93a3a5cfd95eeb61ee322650d',
    'datasets.z79': '75fbea4391cd38d85c583c0f504325f1',
    'datasets.z80': '7a99d15beff97f2dc620300f5ea82506',
    'datasets.z81': '1e303b839a0349a8497eb6c49e3ab4c3',
    'datasets.z82': '48c23365901b8d9cb51a20f7c71fa0a7',
    'datasets.z83': 'fed6445cb47f0b6cd6ed6f5482e38be7',
    'datasets.z84': '3d052276fc186ec151b5c237c5774e66',
    'datasets.z85': 'd84dccdf4fd4e6f6126b746d0519d19d',
    'datasets.z86': '772ab63f1ae266761db99124a152fc29',
    'datasets.z87': '44208b663adf97563018a416560cce6b',
    'datasets.z88': '6be7aaff3349239507b103c928afeaf5',
    'datasets.z89': '515daa16bcaf83b0be91dcc32e0e4985',
    'datasets.z90': 'f0fe0aa02bf0f19c0a7e301b532afa4c',
    'datasets.z91': '87fa807089240e01886389a6c4c77641',
    'datasets.z92': '1fd3154b3d8ec321f58f7685aedc5441',
    'datasets.z93': 'b24ddb4b3e2c0c1c96d7650a8320e446',
    'datasets.z94': 'b5e98cd90a629c4abb70d66a6976e0c7',
    'datasets.z95': '87c8d5d7f26e0fbf7422fda2194bef97',
    'datasets.z96': 'bb18c027a71f326a33b0a8fbe2d3a11f',
    'datasets.z97': '0489ed1c8b4fb7d5965d1565076d5c6a',
    'datasets.z98': 'a85303e08ce59fdf28f36ae9d0f20dcb',
    'datasets.z99': '992e63e126f05eca9fa9a84bbf66165c',
    'datasets.zip': '3434f60f5e9b263ef78e207b54e9debe',
}


def _download(dst):
    dst = os.path.abspath(dst)
    files = CHECKSUMS.keys()
    fullzip = os.path.join(dst, "datasets.zip")
    joinedzip = os.path.join(dst, "joined.zip")

    URL_ROOT = "https://data.csail.mit.edu/graphics/demosaicnet"

    if not os.path.exists(joinedzip):
        log.info("Dowloading %d files to %s (This will take a while, and ~80GB)", len(
            files), dst)

        os.makedirs(dst, exist_ok=True)
        for f in files:
            fname = os.path.join(dst, f)
            url = os.path.join(URL_ROOT, f)

            do_download = True
            if os.path.exists(fname):
                checksum = md5sum(fname)
                if checksum == CHECKSUMS[f]:  # File is is and correct
                    log.info('%s already downloaded, with correct checksum', f)
                    do_download = False
                else:
                    log.warning('%s checksums do not match, got %s, should be %s',
                                f, checksum, CHECKSUMS[f])
                    try:
                        os.remove(fname)
                    except OSError as e:
                        log.error("Could not delete broken part %s: %s", f, e)
                        raise ValueError

            if do_download:
                log.info('Downloading %s', f)
                wget.download(url, fname)

            checksum = md5sum(fname)

            if checksum == CHECKSUMS[f]:
                log.info("%s MD5 correct", f)
            else:
                log.error('%s checksums do not match, got %s, should be %s. Downloading failed',
                          f, checksum, CHECKSUMS[f])

        log.info("Joining zip files")
        cmd = " ".join(["zip", "-FF", fullzip, "--out", joinedzip])
        subprocess.check_call(cmd, shell=True)

        # Cleanup the parts
        for f in files:
            fname = os.path.join(dst, f)
            try:
                os.remove(fname)
            except OSError as e:
                log.warning("Could not delete file %s", f)

    # Extract
    wd = os.path.abspath(os.curdir)
    os.chdir(dst)
    log.info("Extracting files from %s", joinedzip)
    cmd = " ".join(["unzip", joinedzip])
    subprocess.check_call(cmd, shell=True)

    try:
        os.remove(joinedzip)
    except OSError as e:
        log.warning("Could not delete file %s", f)

    log.info("Moving subfolders")
    for k in ["train", "test", "val"]:
        shutil.move(os.path.join(dst, "images", k), os.path.join(dst, k))
    images = os.path.join(dst, "images")
    log.info("removing '%s' folder", images)
    shutil.rmtree(images)


def md5sum(filename, blocksize=65536):
    hash = hashlib.md5()
    with open(filename, "rb") as f:
        for block in iter(lambda: f.read(blocksize), b""):
            hash.update(block)
    return hash.hexdigest()


================================================
FILE: demosaicnet/modules.py
================================================
"""Models for [Gharbi2016] Deep Joint demosaicking and denoising."""
import os
from collections import OrderedDict
from pkg_resources import resource_filename

import numpy as np
import torch as th
import torch.nn as nn


__all__ = ["BayerDemosaick", "XTransDemosaick"]


_BAYER_WEIGHTS = resource_filename(__name__, 'data/bayer.pth')
_XTRANS_WEIGHTS = resource_filename(__name__, 'data/xtrans.pth')


class BayerDemosaick(nn.Module):
  """Released version of the network, best quality.

  This model differs from the published description. It has a mask/filter split
  towards the end of the processing. Masks and filters are multiplied with each
  other. This is not key to performance and can be ignored when training new
  models from scratch.
  """
  def __init__(self, depth=15, width=64, pretrained=True, pad=False):
    super(BayerDemosaick, self).__init__()

    self.depth = depth
    self.width = width

    if pad:
      pad = 1
    else:
      pad = 0

    layers = OrderedDict([
        ("pack_mosaic", nn.Conv2d(3, 4, 2, stride=2)),  # Downsample 2x2 to re-establish translation invariance
      ])
    for i in range(depth):
      n_out = width
      n_in = width
      if i == 0:
        n_in = 4
      if i == depth-1:
        n_out = 2*width
      layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)
      layers["relu{}".format(i+1)] = nn.ReLU(inplace=True)

    self.main_processor = nn.Sequential(layers)
    self.residual_predictor = nn.Conv2d(width, 12, 1)
    self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3)

    self.fullres_processor = nn.Sequential(OrderedDict([
      ("post_conv", nn.Conv2d(6, width, 3, padding=pad)),
      ("post_relu", nn.ReLU(inplace=True)),
      ("output", nn.Conv2d(width, 3, 1)),
      ]))

    # Load weights
    if pretrained:
      assert depth == 15, "pretrained bayer model has depth=15."
      assert width == 64, "pretrained bayer model has width=64."
      state_dict = th.load(_BAYER_WEIGHTS)
      self.load_state_dict(state_dict)

  def forward(self, mosaic):
    """Demosaicks a Bayer image.

    Args:
      mosaic (th.Tensor):  input Bayer mosaic

    Returns:
      th.Tensor: the demosaicked image
    """

    # 1/4 resolution features
    features = self.main_processor(mosaic)
    filters, masks = features[:, 0:self.width], features[:, self.width:2*self.width]
    filtered = filters * masks
    residual = self.residual_predictor(filtered)

    # Match mosaic and residual
    upsampled = self.upsampler(residual)
    cropped = _crop_like(mosaic, upsampled)

    packed = th.cat([cropped, upsampled], 1)  # skip connection
    output = self.fullres_processor(packed)
    return output


class XTransDemosaick(nn.Module):
  """Released version of the network.

  There is no downsampling here.

  """
  def __init__(self, depth=11, width=64, pretrained=True, pad=False):
    super(XTransDemosaick, self).__init__()

    self.depth = depth
    self.width = width

    if pad:
      pad = 1
    else:
      pad = 0

    layers = OrderedDict([])
    for i in range(depth):
      n_in = width
      n_out = width
      if i == 0:
        n_in = 3
      layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)
      layers["relu{}".format(i+1)] = nn.ReLU(inplace=True)

    self.main_processor = nn.Sequential(layers)

    self.fullres_processor = nn.Sequential(OrderedDict([
      ("post_conv", nn.Conv2d(3+width, width, 3, padding=pad)),
      ("post_relu", nn.ReLU(inplace=True)),
      ("output", nn.Conv2d(width, 3, 1)),
      ]))

    # Load weights
    if pretrained:
      assert depth == 11, "pretrained xtrans model has depth=11."
      assert width == 64, "pretrained xtrans model has width=64."
      state_dict = th.load(_XTRANS_WEIGHTS)
      self.load_state_dict(state_dict)


  def forward(self, mosaic):
    """Demosaicks an XTrans image.

    Args:
      mosaic (th.Tensor):  input XTrans mosaic

    Returns:
      th.Tensor: the demosaicked image
    """

    features = self.main_processor(mosaic)
    cropped = _crop_like(mosaic, features)  # Match mosaic and residual
    packed = th.cat([cropped, features], 1)  # skip connection
    output = self.fullres_processor(packed)
    return output


def _crop_like(src, tgt):
    """Crop a source image to match the spatial dimensions of a target.

    Args:
        src (th.Tensor or np.ndarray): image to be cropped
        tgt (th.Tensor or np.ndarray): reference image
    """
    src_sz = np.array(src.shape)
    tgt_sz = np.array(tgt.shape)

    # Assumes the spatial dimensions are the last two
    crop = (src_sz[-2:]-tgt_sz[-2:])
    crop_t = crop[0] // 2
    crop_b = crop[0] - crop_t
    crop_l = crop[1] // 2
    crop_r = crop[1] - crop_l
    crop //= 2
    if (np.array([crop_t, crop_b, crop_r, crop_l])> 0).any():
        return src[..., crop_t:src_sz[-2]-crop_b, crop_l:src_sz[-1]-crop_r]
    else:
        return src



================================================
FILE: demosaicnet/mosaic.py
================================================
"""Utilities to make a mosaic mask and apply it to an image."""
import numpy as np
import torch as th


__all__ = ["bayer", "xtrans"]


def bayer(im, return_mask=False):
  """Bayer mosaic.

  The patterned assumed is::

    G r
    b G

  Args:
    im (np.array): image to mosaic. Dimensions are [c, h, w]
    return_mask (bool): if true return the binary mosaic mask, instead of the mosaic image.

  Returns:
    np.array: mosaicked image (if return_mask==False), or binary mask if (return_mask==True)
  """

  numpy = False
  if type(im) == np.ndarray:
    numpy = True

  if type(im) == np.ndarray:
    mask = np.ones_like(im)
  else:
    mask = th.ones_like(im)

  # red
  mask[..., 0, ::2, 0::2] = 0
  mask[..., 0, 1::2, :] = 0

  # green
  mask[..., 1, ::2, 1::2] = 0
  mask[..., 1, 1::2, ::2] = 0

  # blue
  mask[..., 2, 0::2, :] = 0
  mask[..., 2, 1::2, 1::2] = 0

  if not numpy:  # make it a constant for ONNX conversion
    mask = th.from_numpy(mask.cpu().detach().numpy()).to(im.device)

  if mask.shape[0] == 1:
    mask = mask.squeeze(0) # coreml hack

  if return_mask:
    return mask

  return im*mask


def xtrans_cell(torch=False):
  g_pos = [(0,0),        (0,2), (0,3),        (0,5),
                  (1,1),               (1,4),
           (2,0),        (2,2), (2,3),        (2,5),
           (3,0),        (3,2), (3,3),        (3,5),
                  (4,1),               (4,4),
           (5,0),        (5,2), (5,3),        (5,5)]
  r_pos = [(0,4),
           (1,0), (1,2),
           (2,4),
           (3,1),
           (4,3), (4,5),
           (5,1)]
  b_pos = [(0,1),
           (1,3), (1,5),
           (2,1),
           (3,4),
           (4,0), (4,2),
           (5,4)]

  if torch:
    mask = th.zeros(3, 6, 6)
  else:
    mask = np.zeros((3, 6, 6), dtype=np.float32)

  for idx, coord in enumerate([r_pos, g_pos, b_pos]):
    for y, x in coord:
      mask[..., idx, y, x] = 1

  return mask

def xtrans(im, return_mask=False):
  """XTrans Mosaick.

   The patterned assumed is::

     G b G G r G
     r G r b G b
     G b G G r G
     G r G G b G
     b G b r G r
     G r G G b G

  Args:
    im(np.array, th.Tensor): image to mosaic. Dimensions are [c, h, w]
    mask(bool): if true return the binary mosaic mask, instead of the mosaic image.

  Returns:
    np.array: mosaicked image (if mask==False), or binary mask if (mask==True)
  """

  numpy = False
  if type(im) == np.ndarray:
    numpy = True
    mask = xtrans_cell(torch=False)
    # mask = np.zeros((3, 6, 6), dtype=np.float32)
  else:
    # mask = th.zeros(3, 6, 6).to(im.device)
    mask = xtrans_cell(torch=True).to(im.device)
    if len(im.shape) == 4:
      mask = mask.unsqueeze(0).repeat(im.shape[0], 1, 1, 1)

  h, w = im.shape[-2:]
  h = int(h)
  w = int(w)

  new_sz = [np.ceil(h / 6).astype(np.int32), np.ceil(w / 6).astype(np.int32)]

  sz = np.array(mask.shape)
  sz[:-2] = 1
  sz[-2:] = new_sz
  sz = list(sz)

  if numpy:
    mask = np.tile(mask, sz)
  else:
    mask = mask.repeat(*sz)

  if return_mask:
    return mask

  return mask*im


================================================
FILE: demosaicnet/utils.py
================================================
"""Helper functions."""

from abc import ABCMeta, abstractmethod
import argparse
import logging
import os
import re
import signal
import time

import torch as th
import numpy as np
import torch as th
from tqdm import tqdm


log = logging.getLogger(__name__)


def crop_like(src, tgt):
    """Crop a source image to match the spatial dimensions of a target.

    Assumes sizes are even.

    Args:
        src (th.Tensor or np.ndarray): image to be cropped
        tgt (th.Tensor or np.ndarray): reference image
    """
    src_sz = np.array(src.shape)
    tgt_sz = np.array(tgt.shape)

    # Assumes the spatial dimensions are the last two
    delta = (src_sz[2:4]-tgt_sz[2:4])
    crop = np.maximum(delta // 2, 0)  # no negative crop
    crop2 = delta - crop

    if (crop > 0).any() or (crop2 > 0).any():
        # NOTE: convert to ints to enable static slicing in ONNX conversion
        src_sz = [int(x) for x in src_sz]
        crop = [int(x) for x in crop]
        crop2 = [int(x) for x in crop2]
        return src[..., crop[0]:src_sz[-2]-crop2[0],
                   crop[1]:src_sz[-1]-crop2[1]]
    else:
        return src


class ExponentialMovingAverage(object):
    """Keyed tracker that maintains an exponential moving average for each key.

    Args:
      keys(list of str): keys to track.
      alpha(float): exponential smoothing factor (higher = smoother).
    """

    def __init__(self, keys, alpha=0.999):
        self._is_first_update = {k: True for k in keys}
        self._alpha = alpha
        self._values = {k: 0 for k in keys}

    def __getitem__(self, key):
        return self._values[key]

    def update(self, key, value):
        if value is None:
            return
        if self._is_first_update[key]:
            self._values[key] = value
            self._is_first_update[key] = False
        else:
            self._values[key] = self._values[key] * \
                self._alpha + value*(1.0-self._alpha)


class BasicArgumentParser(argparse.ArgumentParser):
    """A basic argument parser with commonly used training options."""

    def __init__(self, *args, **kwargs):
        super(BasicArgumentParser, self).__init__(*args, **kwargs)

        self.add_argument("--data", required=True, help="path to the training data.")
        self.add_argument("--val_data", help="path to the validation data.")
        self.add_argument("--config", help="path to a config file.")
        self.add_argument("--checkpoint_dir", required=True,
                          help="Output directory where checkpoints are saved")
        self.add_argument("--init_from", help="path to a checkpoint from which to try and initialize the weights.")

        self.add_argument("--lr", type=float, default=1e-4,
                          help="Learning rate for the optimizer")
        self.add_argument("--bs", type=int, default=4, help="Batch size")
        self.add_argument("--num_epochs", type=int,
                          help="Number of epochs to train for")
        self.add_argument("--num_worker_threads", default=4, type=int,
                          help="Number of threads that load data")

        # self.add_argument("--experiment_log",
        #                   help="csv file in which we log our experiments")

        self.add_argument("--cuda", action="store_true",
                          dest="cuda", help="Force GPU")
        self.add_argument("--no-cuda", action="store_false",
                          dest="cuda", help="Force CPU")

        self.add_argument("--server", help="Visdom server url")
        self.add_argument("--base_url", default="/", help="Visdom base url")
        self.add_argument("--env", default="main", help="Visdom environment")
        self.add_argument("--port", default=8097, type=int,
                          help="Visdom server port")

        self.add_argument('--debug', dest="debug", action="store_true")

        self.set_defaults(cuda=th.cuda.is_available(), debug=False)


class ModelInterface(metaclass=ABCMeta):
    """An adapter to run or train a model."""

    def __init__(self):
        pass

    @abstractmethod
    def training_step(self, batch):
        """Training step given a batch of data.

        This should implement a forward pass of the model, compute gradients,
        take an optimizer step and return useful metrics and tensors for
        visualization and training callbacks. 

        Args:
          batch (dict): batch of data provided by a data pipeline.

        Returns:
          train_step_data (dict): a dictionary of outputs.
        """
        return {}

    def init_validation(self):
        """Initializes the quantities to be reported during validation.

        The default implementation is a no-op

        Returns:
          data (dict): initialized values
        """
        log.warning("Running a ModelInterface validation initialization that was not overriden: this is a no-op.")
        data = {}
        return data

    def validation_step(self, batch, running_val_data):
        """Updates the running validataion with the current batch's results.

        The default implementation is a no-op

        Args:
          batch (dict): batch of data provided by a data pipeline.
          running_val_data (dict): current aggregates of the validation loop.

        Returns:
          updated_data (dict): new updated value for the running_val_data.
        """
        log.warning("Running a ModelInterface validation step that was not overriden: this is a no-op.")
        return {}

    def __repr__(self):
        return self.__class__.__name__


class Checkpointer(object):
    """Save and restore model and optimizer variables.

    Args:
      root (string): path to the root directory where the files are stored.
      model (torch.nn.Module):
      meta (dict): a dictionary of training or configuration parameters useful
          to initialize the model upon loading the checkpoint again.
      optimizers (single or list of torch.optimizer): optimizers whose parameters will
        be checkpointed together with the model.
      schedulers (single or list of
      torch.optim.lr_scheduler): schedulers whose
          parameters will be checkpointed together with
          the model.
      prefix (str): unique prefix name in case several models are stored in the
        same folder.
    """

    EXTENSION = ".pth"

    def __init__(self, root, model=None, meta=None, optimizers=None,
                 schedulers=None, prefix=None):
        self.root = root
        self.model = model
        self.meta = meta

        # TODO(mgharbi): verify the prefixes are unique.

        if optimizers is None:
            log.info("No optimizer state will be stored in the "
                        "checkpointer")
        else:
            # if we have only one optimizer, make it a list
            if not isinstance(optimizers, list):
                optimizers = [optimizers]
        self.optimizers = optimizers
        if schedulers is not None:
            if not isinstance(schedulers, list):
                schedulers = [schedulers]
        self.schedulers = schedulers

        log.debug(self)

        self.prefix = ""
        if prefix is not None:
            self.prefix = prefix

    def __repr__(self):
        return "Checkpointer with root at \"{}\"".format(self.root)

    def __path(self, path, prefix=None):
        if prefix is None:
            prefix = ""
        return os.path.join(self.root, prefix+os.path.splitext(path)[0] + ".pth")

    def save(self, path, extras=None):
        """Save model, metaparams and extras to relative path.

        Args:
          path (string): relative path to the file being saved (without extension).
          extras (dict): extra user-provided information to be saved with the model.
        """

        if self.model is None:
            model_state = None
        else:
            log.debug("Saving model state dict")
            model_state = self.model.state_dict()

        opt_dicts = []
        if self.optimizers is not None:
            for opt in self.optimizers:
                opt_dicts.append(opt.state_dict())

        sched_dicts = []
        if self.schedulers is not None:
            for s in self.schedulers:
                sched_dicts.append(s.state_dict())

        filename = self.__path(path, prefix=self.prefix)
        os.makedirs(self.root, exist_ok=True)
        th.save({'model': model_state,
                 'meta': self.meta,
                 'extras': extras,
                 'optimizers': opt_dicts,
                 'schedulers': sched_dicts,
                 }, filename)
        log.debug("Checkpoint saved to \"{}\"".format(filename))

    def try_and_init_from(self, path):
        """Try to initialize the models's weights from an external checkpoint.

        Args:
            path(str): full path to the checkpoints to load model parameters
                from.
        """
        log.info("Loading weights from foreign checkpoint {}".format(path))
        if not os.path.exists(path):
            raise ValueError("Checkpoint {} does not exist".format(path))

        chkpt = th.load(path, map_location=th.device("cpu"))
        if "model" not in chkpt.keys() or chkpt["model"] is None:
            raise ValueError("{} has no model saved".format(path))

        mdl = chkpt["model"]
        for n, p in self.model.named_parameters():
            if n in mdl:
                p2 = mdl[n]
                if p2.shape != p.shape:
                    log.warning("Parameter {} ignored, checkpoint size does not match: {}, should be {}".format(n, p2.shape, p.shape))
                    continue
                log.debug("Parameter {} copied".format(n))
                p.data.copy_(p2)
            else:
                log.warning("Parameter {} ignored, not found in source checkpoint.".format(n))

        log.info("Weights loaded from foreign checkpoint {}".format(path))

    def load(self, path):
        """Loads a checkpoint, updates the model and returns extra data.

        Args:
          path (string): path to the checkpoint file, relative to the root dir.

        Returns:
          extras (dict): extra information passed by the user at save time.
          meta (dict): metaparameters of the model passed at save time.
        """

        filename = self.__path(path, prefix=None)
        chkpt = th.load(filename, map_location="cpu")  # TODO: check behavior

        if self.model is not None and chkpt["model"] is not None:
            log.debug("Loading model state dict")
            self.model.load_state_dict(chkpt["model"])

        if "optimizers" in chkpt.keys():
            if self.optimizers is not None and chkpt["optimizers"] is not None:
                try:
                    for opt, state in zip(self.optimizers,
                                          chkpt["optimizers"]):
                        log.debug("Loading optimizers state dict for %s", opt)
                        opt.load_state_dict(state)
                except:
                    # We do not raise an error here, e.g. in case the user simply
                    # changes optimizer
                    log.warning("Could not load optimizer state dicts, "
                                "starting from scratch")

        if "schedulers" in chkpt.keys():
            if self.schedulers is not None and chkpt["schedulers"] is not None:
                try:
                    for s, state in zip(self.schedulers,
                                          chkpt["schedulers"]):
                        log.debug("Loading scheduler state dict for %s", s)
                        s.load_state_dict(state)
                except:
                    log.warning("Could not load scheduler state dicts, "
                                "starting from scratch")

        log.debug("Loaded checkpoint \"{}\"".format(filename))
        return tuple(chkpt[k] for k in ["extras", "meta"])

    def load_latest(self):
        """Try to load the most recent checkpoint, skip failing files.

        Returns:
          extras (dict): extra user-defined information that was saved in the
              checkpoint.
          meta (dict): metaparameters of the model passed at save time.
        """
        all_checkpoints = self.sorted_checkpoints()

        extras = None
        meta = None
        for f in all_checkpoints:
            try:
                extras, meta = self.load(f)
                return extras, meta
            except Exception as e:
                log.debug(
                    "Could not load checkpoint \"{}\", moving on ({}).".format(f, e))
        log.debug("No checkpoint found to load.")
        return extras, meta

    def sorted_checkpoints(self):
        """Get list of all checkpoints in root directory, sorted by creation date.

        Returns:
            chkpts (list of str): sorted checkpoints in the root folder.
        """
        reg = re.compile(r"{}.*\{}".format(self.prefix, Checkpointer.EXTENSION))
        if not os.path.exists(self.root):
            all_checkpoints = []
        else:
            all_checkpoints = [f for f in os.listdir(
                self.root) if reg.match(f)]
        mtimes = []
        for f in all_checkpoints:
            mtimes.append(os.path.getmtime(os.path.join(self.root, f)))

        mf = sorted(zip(mtimes, all_checkpoints))
        chkpts = [m[1] for m in reversed(mf)]
        log.debug("Sorted checkpoints {}".format(chkpts))
        return chkpts

    def delete(self, path):
        """Delete checkpoint at path.

        Args:
            path(str): full path to the checkpoint to delete.
        """
        if path in self.sorted_checkpoints():
            os.remove(os.path.join(self.root, path))
        else:
            log.warning("Trying to delete a checkpoint that does not exists.")

    @staticmethod
    def load_meta(root, prefix=None):
        """Fetch model metadata without touching the saved parameters.

        This loads the metadata from the most recent checkpoint in the root
        directory.

        Args:
            root(str): path to the root directory containing the checkpoints
            prefix(str): unique prefix for the checkpoint to be loaded (e.g. if
                multiple models are saved in the same location)
        """
        chkptr = Checkpointer(root, model=None, meta=None, prefix=prefix, 
                              optimizers=[])
        log.debug("checkpoints: %s", chkptr.sorted_checkpoints())
        _, meta = chkptr.load_latest()
        return meta


class Trainer(object):
    """Implements a simple training loop with hooks for callbacks.

    Args:
      interface (ModelInterface): adapter to run forward and backward
        pass on the model being trained.

    Attributes:
      callbacks (list of Callbacks): hooks that will be called while training
        progresses.
    """

    def __init__(self, interface):
        super(Trainer, self).__init__()
        self.callbacks = []
        self.interface = interface
        log.debug("Creating {}".format(self))

        signal.signal(signal.SIGINT, self.interrupt_handler)

        self._keep_running = True

    def interrupt_handler(self, signo, frame):
        """Stop the training process upon receiving a SIGINT (Ctrl+C)."""
        log.debug("interrupting run")
        self._keep_running = False

    def _stop(self):
        # Reset the run flag
        self._keep_running = True
        self.__training_end()

    def add_callback(self, callback):
        """Adds a callback to the list of training hooks.

        Args:
            callback(ttools.Callback): callback to add.
        """
        log.debug("Adding callback {}".format(callback))
        # pass an interface reference to the callback
        callback.model_interface = self.interface
        self.callbacks.append(callback)

    def train(self, dataloader, starting_epoch=None, num_epochs=None,
              val_dataloader=None):
        """Main training loop. This starts the training procedure.

        Args:
          dataloader (DataLoader): loader that yields training batches.
          starting_epoch (int, optional): index of the epoch we are starting from.
          num_epochs (int, optional): max number of epochs to run.
          val_dataloader (DataLoader, optional): loader that yields validation
            batches
        """
        self.__training_start(dataloader)
        if starting_epoch is None:
            starting_epoch = 0

        log.info("Starting taining from epoch %d", starting_epoch)

        epoch = starting_epoch

        while num_epochs is None or epoch < starting_epoch + num_epochs:
            self.__epoch_start(epoch)
            for batch_idx, batch in enumerate(dataloader):
                if not self._keep_running:
                    self._stop()
                    return
                self.__batch_start(batch_idx, batch)
                train_step_data = self.__training_step(batch)
                self.__batch_end(batch, train_step_data)
            self.__epoch_end()

            # TODO: allow validation at intermediate steps during one epoch

            # Validate
            if val_dataloader:
                with th.no_grad():
                    running_val_data = self.__validation_start(val_dataloader)
                    for batch_idx, batch in enumerate(val_dataloader):
                        if not self._keep_running:
                            self._stop()
                            return
                        self.__val_batch_start(batch_idx, batch)
                        running_val_data = self.__validation_step(batch, running_val_data)
                        self.__val_batch_end(batch, running_val_data)
                    self.__validation_end(running_val_data)

            epoch += 1

            if not self._keep_running:
                self._stop()
                return

        self._stop()

    def __repr__(self):
        return "Trainer({}, {} callbacks)".format(
            self.interface, len(self.callbacks))

    def __training_start(self, dataloader):
        for cb in self.callbacks:
            cb.training_start(dataloader)

    def __training_end(self):
        for cb in self.callbacks:
            cb.training_end()

    def __epoch_start(self, epoch_idx):
        for cb in self.callbacks:
            cb.epoch_start(epoch_idx)

    def __epoch_end(self):
        for cb in self.callbacks:
            cb.epoch_end()

    def __batch_start(self, batch_idx, batch):
        for cb in self.callbacks:
            cb.batch_start(batch_idx, batch)

    def __batch_end(self, batch, train_step_data):
        for cb in self.callbacks:
            cb.batch_end(batch, train_step_data)

    def __val_batch_start(self, batch_idx, batch):
        for cb in self.callbacks:
            cb.val_batch_start(batch_idx, batch)

    def __val_batch_end(self, batch, running_val_data):
        for cb in self.callbacks:
            cb.val_batch_end(batch, running_val_data)

    def __validation_start(self, dataloader):
        for cb in self.callbacks:
            cb.validation_start(dataloader)
        return self.interface.init_validation()

    def __validation_end(self, running_val_data):
        for cb in self.callbacks:
            cb.validation_end(running_val_data)

    def __training_step(self, batch):
        return self.interface.training_step(batch)

    def __validation_step(self, batch, running_val_data):
        return self.interface.validation_step(batch, running_val_data)


class Callback(object):
    """Base class for all training callbacks.

    Attributes:
        epoch(int): current epoch index.
        batch(int): current batch index.
        datasize(int): number of batches in the training dataset.
        val_datasize(int): number of batches in the validation dataset.
        model_interface(ttools.ModelInterface): parent interface driving the training.
    """

    def __repr__(self):
        return self.__class__.__name__

    def __init__(self):
        super(Callback, self).__init__()
        self.epoch = 0
        self.batch = 0
        self.val_batch = 0
        self.datasize = 0
        self.val_datasize = 0
        self.model_interface = None

    def training_start(self, dataloader):
        """Hook to execute code when the training begins.

        Args:
            dataloader(th.utils.data.Dataloader): a data loading class that
            provides batches of data for training.
        """
        self.datasize = len(dataloader)

    def training_end(self):
        """Hook to execute code when the training ends."""
        pass

    def epoch_start(self, epoch):
        """Hook to execute code when a new epoch starts.

        Args:
            epoch(int): index of the current epoch.

        Note: self.epoch is never incremented. Instead, it should be set by the
        caller.
        """
        self.epoch = epoch

    def epoch_end(self):
        """Hook to execute code when an epoch ends.

        NOTE: self.epoch is not incremented. Instead it is set externally in
        the `epoch_start` method.
        """
        pass

    def validation_start(self, dataloader):
        """Hook to execute code when a validation run starts.

        Args:
            dataloader(th.utils.data.Dataloader): a data loading class that
            provides batches of data for evaluation.
        """
        self.val_datasize = len(dataloader)

    def validation_end(self, val_data):
        """Hook to execute code when a validation run ends."""
        pass

    def batch_start(self, batch_idx, batch_data):
        """Hook to execute code when a training step starts.

        Args:
            batch_idx(int): index of the current batch.
            batch_data: a Tensor, tuple of dict with the current batch of data.
        """
        self.batch = batch_idx

    def batch_end(self, batch_data, train_step_data):
        """Hook to execute code when a training step ends.

        Args:
            batch_data: a Tensor, tuple of dict with the current batch of data.
            train_setp_data(dict): outputs from the `training_step` of a
                ModelInterface.
        """
        pass

    def val_batch_start(self, batch_idx, batch_data):
        """Hook to execute code when a validation step starts.

        Args:
            batch_idx(int): index of the current batch.
            batch_data: a Tensor, tuple of dict with the current batch of data.
        """
        self.val_batch = batch_idx

    def val_batch_end(self, batch_data, running_val_data):
        """Hook to execute code when a validation step ends.

        Args:
            batch_data: a Tensor, tuple of dict with the current batch of data.
            train_setp_data(dict): running outputs produced by the `validation_step` of a
                ModelInterface.
        """
        pass

class CheckpointingCallback(Callback):
    """A callback that periodically saves model checkpoints to disk.

    Args:
      checkpointer (Checkpointer): actual checkpointer responsible for the I/O.
      interval (int, optional): minimum time in seconds between periodic
          checkpoints (within an epoch). There is not periodic checkpoint if
          this value is None.
      max_files (int, optional): maximum number of periodic checkpoints to keep
          on disk.
      max_epochs (int, optional): maximum number of epoch checkpoints to keep
          on disk.
    """

    PERIODIC_PREFIX = "periodic_"
    EPOCH_PREFIX = "epoch_"

    def __init__(self, checkpointer, interval=600,
                 max_files=5, max_epochs=10):
        super(CheckpointingCallback, self).__init__()
        self.checkpointer = checkpointer
        self.interval = interval
        self.max_files = max_files
        self.max_epochs = max_epochs

        self.last_checkpoint_time = time.time()

    def epoch_end(self):
        """Save a checkpoint at the end of each epoch."""
        super(CheckpointingCallback, self).epoch_end()
        path = "{}{}".format(CheckpointingCallback.EPOCH_PREFIX, self.epoch)
        self.checkpointer.save(path, extras={"epoch": self.epoch + 1})
        self.__purge_old_files()

    def training_end(self):
        super(CheckpointingCallback, self).training_end()
        self.checkpointer.save("training_end", extras={"epoch": self.epoch + 1})

    def batch_end(self, batch_data, train_step_data):
        """Save a periodic checkpoint if requested."""

        super(CheckpointingCallback, self).batch_end(
            batch_data, train_step_data)

        if self.interval is None:  # We skip periodic checkpoints
            return

        now = time.time()

        delta = now - self.last_checkpoint_time

        if delta < self.interval:  # last checkpoint is too recent
            return

        log.debug("Periodic checkpoint")
        self.last_checkpoint_time = now

        filename = "{}{}".format(CheckpointingCallback.PERIODIC_PREFIX,
                                   time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
        self.checkpointer.save(filename, extras={"epoch": self.epoch})
        self.__purge_old_files()

    def __purge_old_files(self):
        """Delete checkpoints that are beyond the max to keep."""

        chkpts = self.checkpointer.sorted_checkpoints()
        p_chkpts = []
        e_chkpts = []
        for c in chkpts:
            if c.startswith(self.checkpointer.prefix + CheckpointingCallback.PERIODIC_PREFIX):
                p_chkpts.append(c)

            if c.startswith(self.checkpointer.prefix + CheckpointingCallback.EPOCH_PREFIX):
                e_chkpts.append(c)

        # Delete periodic checkpoints
        if self.max_files is not None and len(p_chkpts) > self.max_files:
            for c in p_chkpts[self.max_files:]:
                log.debug("CheckpointingCallback deleting {}".format(c))
                self.checkpointer.delete(c)

        # Delete older epochs
        if self.max_epochs is not None and len(e_chkpts) > self.max_epochs:
            for c in e_chkpts[self.max_epochs:]:
                log.debug("CheckpointingCallback deleting (epoch) {}".format(c))
                self.checkpointer.delete(c)


class KeyedCallback(Callback):
    """An abstract Callback that performs the same action for all keys in a list.

    The keys (resp. val_keys) are used to access the backward_data (resp.
    validation_data) produced by a ModelInterface.

    Args:
      keys (list of str or None): list of keys whose values will be logged during
          training.
      val_keys (list of str or None): list of keys whose values will be logged during
          validation
    """
    def __init__(self, keys=None, val_keys=None, smoothing=0.999):
        super(KeyedCallback, self).__init__()
        if keys is None and val_keys is None:
            log.warning("Logger has no keys, nor val_keys")

        if keys is None:
            self.keys = []
        else:
            self.keys = keys

        if val_keys is None:
            self.val_keys = []
        else:
            self.val_keys = val_keys

        # Only smooth the training keys
        self.ema = ExponentialMovingAverage(self.keys, alpha=smoothing)

    def batch_end(self, batch_data, train_step_data):
        for k in self.keys:
            self.ema.update(k, train_step_data[k])

class ProgressBarCallback(KeyedCallback):
    """A progress bar optimization logger.

    Args:
        label(str): a prefix label to identify the experiment currently
            running.
    """
    def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None):
        super(ProgressBarCallback, self).__init__(
            keys=keys, val_keys=val_keys, smoothing=smoothing)
        self.pbar = None
        if label is None:
            self.label = ""
        else:
            self.label = label

    def training_start(self, dataloader):
        super(ProgressBarCallback, self).training_start(dataloader)
        print("Training start")

    def training_end(self):
        super(ProgressBarCallback, self).training_end()
        print("Training ends")

    def epoch_start(self, epoch):
        super(ProgressBarCallback, self).epoch_start(epoch)
        desc = "Epoch {}".format(self.epoch)
        if self.label is not None:
            desc = "%s | " % self.label + desc
        self.pbar = tqdm(total=self.datasize, unit=" batches",
                         desc=desc)

    def epoch_end(self):
        super(ProgressBarCallback, self).epoch_end()
        self.pbar.close()
        self.pbar = None

    def validation_start(self, dataloader):
        super(ProgressBarCallback, self).validation_start(dataloader)
        print("Running validation...")
        self.pbar = tqdm(total=len(dataloader), unit=" batches",
                         desc="Validation {}".format(self.epoch))

    def val_batch_end(self, batch, running_val_data):
        self.pbar.update(1)

    def validation_end(self, val_data):
        super(ProgressBarCallback, self).validation_end(val_data)
        self.pbar.close()
        self.pbar = None
        s = " "*ProgressBarCallback.TABSTOPS + "Validation {} | ".format(
            self.epoch)
        for k in self.val_keys:
            s += "{} = {:.2f} ".format(k, val_data[k])
        print(s)

    def batch_end(self, batch_data, train_step_data):
        super(ProgressBarCallback, self).batch_end(batch_data, train_step_data)
        d = {}
        for k in self.keys:
            d[k] = self.ema[k]
        self.pbar.update(1)
        self.pbar.set_postfix(d)

    TABSTOPS = 2

================================================
FILE: demosaicnet/version.py
================================================
__version__ = "0.0.14"


================================================
FILE: docs/.gitignore
================================================
build


================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line.
SPHINXOPTS    =
SPHINXBUILD   = sphinx-build
SOURCEDIR     = source
BUILDDIR      = build

# Put it first so that "make" without argument is like "make help".
help:
	@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
	@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

================================================
FILE: docs/source/conf.py
================================================
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
dirname = os.path.dirname
rootdir = dirname(dirname(dirname(os.path.abspath(__file__))))
sys.path.insert(0, rootdir)


# -- Project information -----------------------------------------------------

project = 'demosaicnet'
copyright = '2019, Michael Gharbi'
author = 'Michael Gharbi'

import re
with open(os.path.join(rootdir, "demosaicnet", "version.py")) as fid:
    try:
        __version__, = re.findall( '__version__ = "(.*)"', fid.read() )
    except:
        raise ValueError("could not find version number")

# The full version, including alpha/beta/rc tags
release = __version__


# -- General configuration ---------------------------------------------------

# If your documentation needs a minimal Sphinx version, state it here.
autodoc_mock_imports = ["torch", "numpy", "imageio", "torchvision", "wget"]
#
# needs_sphinx = '1.0'

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
    'sphinx.ext.autodoc',
    'sphinx.ext.napoleon',
    'sphinx.ext.doctest',
    'sphinx.ext.todo',
    'sphinx.ext.coverage',
    'sphinx.ext.mathjax',
    'sphinx.ext.viewcode',
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['ytemplates']

# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'

# The master toctree document.
master_doc = 'index'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []

# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None


# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages.  See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'

# Theme options are theme-specific and customize the look and feel of a theme
# further.  For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['ystatic']

# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself.  Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}


# -- Options for HTMLHelp output ---------------------------------------------

# Output file base name for HTML help builder.
htmlhelp_basename = 'demosaicnetdoc'


# -- Options for LaTeX output ------------------------------------------------

latex_elements = {
    # The paper size ('letterpaper' or 'a4paper').
    #
    # 'papersize': 'letterpaper',

    # The font size ('10pt', '11pt' or '12pt').
    #
    # 'pointsize': '10pt',

    # Additional stuff for the LaTeX preamble.
    #
    # 'preamble': '',

    # Latex figure (float) alignment
    #
    # 'figure_align': 'htbp',
}

# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
#  author, documentclass [howto, manual, or own class]).
latex_documents = [
    (master_doc, 'demosaicnet.tex', 'demosaicnet Documentation',
     'Michael Gharbi', 'manual'),
]


# -- Options for manual page output ------------------------------------------

# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
    (master_doc, 'demosaicnet', 'demosaicnet Documentation',
     [author], 1)
]


# -- Options for Texinfo output ----------------------------------------------

# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
#  dir menu entry, description, category)
texinfo_documents = [
    (master_doc, 'demosaicnet', 'demosaicnet Documentation',
     author, 'demosaicnet', 'One line description of project.',
     'Miscellaneous'),
]


# -- Options for Epub output -------------------------------------------------

# Bibliographic Dublin Core info.
epub_title = project

# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''

# A unique identification for the text.
#
# epub_uid = ''

# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']


# -- Extension configuration -------------------------------------------------

# -- Options for todo extension ----------------------------------------------

# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True


================================================
FILE: docs/source/dataset.rst
================================================
Dataset
=======

.. automodule:: demosaicnet.dataset
   :members:


================================================
FILE: docs/source/helpers.rst
================================================
Helpers
=======

.. automodule:: demosaicnet.mosaic
   :members:


================================================
FILE: docs/source/index.rst
================================================
.. demosaicnet documentation master file, created by
   sphinx-quickstart on Thu Mar 14 13:14:16 2019.
   You can adapt this file completely to your liking, but it should at least
   contain the root `toctree` directive.

Welcome to demosaicnet's documentation!
=======================================

.. toctree::
   :maxdepth: 2
   :caption: Contents:

   models
   dataset
   helpers

.. automodule:: demosaicnet
   :members:


Indices and tables
==================

* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`


================================================
FILE: docs/source/models.rst
================================================
Models
======

.. automodule:: demosaicnet.modules
   :members:


================================================
FILE: requirements.txt
================================================
imageio==2.31.1
numpy==1.25.0
setuptools==49.2.1
torch==2.0.1
tqdm==4.65.0
wget


================================================
FILE: scripts/demosaicnet_demo.py
================================================
#!/usr/bin/env python
"""Demo script on using demosaicnet for inference."""

import os
from pkg_resources import resource_filename

import argparse
import numpy as np
import torch as th
import imageio

import demosaicnet

_TEST_INPUT = resource_filename("demosaicnet", 'data/test_input.png')

def main(args):
  print("Running demosaicnet demo on {}, outputing to {}".format(_TEST_INPUT, args.output))
  bayer = demosaicnet.BayerDemosaick()
  xtrans = demosaicnet.XTransDemosaick()

  # Load some ground-truth image
  gt = imageio.imread(args.input).astype(np.float32) / 255.0
  gt = np.array(gt)

  h, w, _ = gt.shape

  # Make the image size a multiple of 6 (for xtrans pattern)
  gt = gt[:6*(h//6), :6*(w//6)]


  # Network expects channel first
  gt = np.transpose(gt, [2, 0, 1])
  mosaicked = demosaicnet.bayer(gt)
  xmosaicked = demosaicnet.xtrans(gt)

  # Run the model (expects batch as first dimension)
  mosaicked = th.from_numpy(mosaicked).unsqueeze(0)
  xmosaicked = th.from_numpy(xmosaicked).unsqueeze(0)
  with th.no_grad():  # inference only
    out = bayer(mosaicked).squeeze(0).cpu().numpy()
    out = np.clip(out, 0, 1)
    xout = xtrans(xmosaicked).squeeze(0).cpu().numpy()
    xout = np.clip(xout, 0, 1)
  print("done")

  os.makedirs(args.output, exist_ok=True)
  output = args.output

  imageio.imsave(os.path.join(output, "bayer_mosaick.tif"), mosaicked.squeeze(0).permute([1, 2, 0]))
  imageio.imsave(os.path.join(output, "bayer_result.tif"), np.transpose(out, [1, 2, 0]))
  imageio.imsave(os.path.join(output, "xtrans_mosaick.tif"), xmosaicked.squeeze(0).permute([1, 2, 0]))
  imageio.imsave(os.path.join(output, "xtrans_result.tif"), np.transpose(xout, [1, 2, 0]))

  
if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("output", help="output directory")
  parser.add_argument("--input", default=_TEST_INPUT, help="test input, uses the default test input provided if no argument.")
  args = parser.parse_args()
  main(args)
  


================================================
FILE: scripts/eval.py
================================================
#!/bin/env python
"""Evaluate a demosaicking model."""
import argparse
import logging

import torch as th
from torch.utils.data import DataLoader

import demosaicnet


log = logging.getLogger(__name__)

class PSNR(th.nn.Module):
    def __init__(self):
        super(PSNR, self).__init__()
        self.mse = th.nn.MSELoss()
    def forward(self, out, ref):
        mse = self.mse(out, ref)
        return -10*th.log10(mse)

def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        log.warning("No checkpoint found at %s, aborting.", args.checkpoint_dir)
        return

    data = demosaicnet.Dataset(args.data, download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TEST_SUBSET)
    dataloader = DataLoader(
        data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)

    checkpointer = demosaicnet.utils.Checkpointer(args.checkpoint_dir, model, meta=meta)
    checkpointer.load_latest()  # Resume from checkpoint, if any.

    # No need for gradients
    for p in model.parameters():
        p.requires_grad = False

    mse_fn = th.nn.MSELoss()
    psnr_fn = PSNR()

    device = "cpu"
    if th.cuda.is_available():
        device = "cuda"
        log.info("Using CUDA")

    count = 0
    mse = 0.0
    psnr = 0.0
    for idx, batch in enumerate(dataloader):
        mosaic = batch[0].to(device)
        target = batch[1].to(device)
        output = model(mosaic)

        target = demosaicnet.utils.crop_like(target, output)

        output = th.clamp(output, 0, 1)

        psnr_ = psnr_fn(output, target).item()
        mse_ = mse_fn(output, target).item()

        psnr += psnr_
        mse += mse_
        count += 1

        log.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_)

    mse /= count
    psnr /= count

    log.info("-----------------------------------")
    log.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("data", help="root directory for the demosaicnet dataset.")
    parser.add_argument("checkpoint_dir", help="directory with the model checkpoints.")
    args = parser.parse_args()
    main(args)


================================================
FILE: scripts/train.py
================================================
#!/bin/env python
"""Train a demosaicking model."""
import logging

import torch as th
from torch.utils.data import DataLoader

import demosaicnet


log = logging.getLogger(__name__)


class PSNR(th.nn.Module):
    def __init__(self):
        super(PSNR, self).__init__()
        self.mse = th.nn.MSELoss()
    def forward(self, out, ref):
        mse = self.mse(out, ref)
        return -10*th.log10(mse+1e-12)


class DemosaicnetInterface(demosaicnet.utils.ModelInterface):
    """Training and validation interface.

    Args:
        model(th.nn.Module): model to train.
        lr(float): learning rate for the optimizer.
        cuda(bool): whether to use CPU or GPU for training.
    """
    def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()):
        self.model = model
        self.device = "cpu"
        if cuda:
            self.device = "cuda"
        self.model.to(self.device)
        self.opt = th.optim.Adam(self.model.parameters(), lr=lr)
        self.loss = th.nn.MSELoss()
        self.psnr = PSNR()

    def training_step(self, batch):
        fwd_data = self.forward(batch)
        bwd_data = self.backward(batch, fwd_data)
        return bwd_data

    def forward(self, batch):
        mosaic = batch[0]
        mosaic = mosaic.to(self.device)
        output = self.model(mosaic)
        return output

    def backward(self, batch, fwd_output):
        target = batch[1].to(self.device)

        # remove boundaries to match output size
        target = demosaicnet.utils.crop_like(target, fwd_output)

        loss = self.loss(fwd_output, target)

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        with th.no_grad():
            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)

        return {"loss": loss.item(), "psnr": psnr.item()}

    def init_validation(self):
        return {"count": 0, "psnr": 0}

    def update_validation(self, batch, fwd_output, running_data):
        target = batch[1].to(self.device)

        # remove boundaries to match output size
        target = demosaicnet.utils.crop_like(target, fwd_output)

        with th.no_grad():
            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)
            n = target.shape[0]

        return {
            "psnr": running_data["psnr"] + psnr.item()*n,
            "count": running_data["count"] + n
        }

    def finalize_validation(self, running_data):
        return {
            "psnr": running_data["psnr"] / running_data["count"]
        }


def main(args):
    """Entrypoint to the training."""

    # Load model parameters from checkpoint, if any
    meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)
    if meta is None:
        log.info("No metadata or checkpoint, "
                 "parsing model parameters from command line.")
        meta = {
            "depth": args.depth,
            "width": args.width,
            "mode": args.mode,
        }

    data = demosaicnet.Dataset(args.data, download=False,
                               mode=meta["mode"],
                               subset=demosaicnet.TRAIN_SUBSET)
    dataloader = DataLoader(
        data, batch_size=args.bs, num_workers=args.num_worker_threads,
        pin_memory=True, shuffle=True)

    val_dataloader = None
    if args.val_data:
        val_data = demosaicnet.Dataset(args.data, download=False,
                                       mode=meta["mode"],
                                       subset=demosaicnet.VAL_SUBSET)
        val_dataloader = DataLoader(
            val_data, batch_size=args.bs, num_workers=1,
            pin_memory=True, shuffle=False)

    if meta["mode"] == demosaicnet.BAYER_MODE:
        model = demosaicnet.BayerDemosaick(depth=meta["depth"],
                                           width=meta["width"],
                                           pretrained=True,
                                           pad=False)
    elif meta["mode"] == demosaicnet.XTRANS_MODE:
        model = demosaicnet.XTransDemosaick(depth=meta["depth"],
                                            width=meta["width"],
                                            pretrained=True,
                                            pad=False)
    checkpointer = demosaicnet.utils.Checkpointer(
        args.checkpoint_dir, model, meta=meta)

    interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda)

    checkpointer.load_latest()  # Resume from checkpoint, if any.

    trainer = demosaicnet.utils.Trainer(interface)

    keys = ["loss", "psnr"]
    val_keys = ["psnr"]

    trainer.add_callback(demosaicnet.utils.ProgressBarCallback(
        keys=keys, val_keys=val_keys))
    trainer.add_callback(demosaicnet.utils.CheckpointingCallback(
        checkpointer, max_files=8, interval=3600, max_epochs=10))

    if args.cuda:
        log.info("Training with CUDA enabled")
    else:
        log.info("Training on CPU")

    trainer.train(
        dataloader, num_epochs=args.num_epochs,
        val_dataloader=val_dataloader)


if __name__ == "__main__":
    parser = demosaicnet.utils.BasicArgumentParser()
    parser.add_argument("--depth", default=15,
                        help="number of net layers.")
    parser.add_argument("--width", default=64,
                        help="number of features per layer.")
    parser.add_argument("--mode", default=demosaicnet.BAYER_MODE,
                        choices=[demosaicnet.BAYER_MODE,
                                 demosaicnet.XTRANS_MODE],
                        help="number of features per layer.")
    args = parser.parse_args()
    main(args)


================================================
FILE: setup.py
================================================
import re
import setuptools


with open('demosaicnet/version.py') as fid:
    try:
        __version__, = re.findall( '__version__ = "(.*)"', fid.read() )
    except:
        raise ValueError("could not find version number")


with open("README.md", "r") as fh:
    long_description = fh.read()


setuptools.setup(
    name='demosaicnet',
    version=__version__,
    scripts=["scripts/demosaicnet_demo.py"],
    author="Michaël Gharbi",
    author_email="gharbi@csail.mit.edu",
    description="Minimal implementation of Deep Joint Demosaicking and Denoising [Gharbi2016]",
    long_description=long_description,
    url="https://github.com/mgharbi/",
    packages = setuptools.find_packages(exclude=["tests"]),
    include_package_data=True,
    install_requires=["wget", "tqdm", "torch", "imageio", "numpy"],
    classifiers=[
      "Programming Language :: Python :: 3",
      "License :: OSI Approved :: MIT License",
      "Operating System :: MacOS :: MacOS X",
      "Operating System :: POSIX",
    ],
)
Download .txt
gitextract_p3udv4rd/

├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── demosaicnet/
│   ├── .gitignore
│   ├── __init__.py
│   ├── dataset.py
│   ├── modules.py
│   ├── mosaic.py
│   ├── utils.py
│   └── version.py
├── docs/
│   ├── .gitignore
│   ├── Makefile
│   └── source/
│       ├── conf.py
│       ├── dataset.rst
│       ├── helpers.rst
│       ├── index.rst
│       └── models.rst
├── requirements.txt
├── scripts/
│   ├── demosaicnet_demo.py
│   ├── eval.py
│   └── train.py
└── setup.py
Download .txt
SYMBOL INDEX (108 symbols across 7 files)

FILE: demosaicnet/dataset.py
  class Dataset (line 38) | class Dataset(TorchDataset):
    method __init__ (line 47) | def __init__(self, root, download=False,
    method __len__ (line 81) | def __len__(self):
    method __getitem__ (line 84) | def __getitem__(self, idx):
  function _download (line 441) | def _download(dst):
  function md5sum (line 517) | def md5sum(filename, blocksize=65536):

FILE: demosaicnet/modules.py
  class BayerDemosaick (line 18) | class BayerDemosaick(nn.Module):
    method __init__ (line 26) | def __init__(self, depth=15, width=64, pretrained=True, pad=False):
    method forward (line 67) | def forward(self, mosaic):
  class XTransDemosaick (line 92) | class XTransDemosaick(nn.Module):
    method __init__ (line 98) | def __init__(self, depth=11, width=64, pretrained=True, pad=False):
    method forward (line 134) | def forward(self, mosaic):
  function _crop_like (line 151) | def _crop_like(src, tgt):

FILE: demosaicnet/mosaic.py
  function bayer (line 9) | def bayer(im, return_mask=False):
  function xtrans_cell (line 58) | def xtrans_cell(torch=False):
  function xtrans (line 89) | def xtrans(im, return_mask=False):

FILE: demosaicnet/utils.py
  function crop_like (line 20) | def crop_like(src, tgt):
  class ExponentialMovingAverage (line 48) | class ExponentialMovingAverage(object):
    method __init__ (line 56) | def __init__(self, keys, alpha=0.999):
    method __getitem__ (line 61) | def __getitem__(self, key):
    method update (line 64) | def update(self, key, value):
  class BasicArgumentParser (line 75) | class BasicArgumentParser(argparse.ArgumentParser):
    method __init__ (line 78) | def __init__(self, *args, **kwargs):
  class ModelInterface (line 115) | class ModelInterface(metaclass=ABCMeta):
    method __init__ (line 118) | def __init__(self):
    method training_step (line 122) | def training_step(self, batch):
    method init_validation (line 137) | def init_validation(self):
    method validation_step (line 149) | def validation_step(self, batch, running_val_data):
    method __repr__ (line 164) | def __repr__(self):
  class Checkpointer (line 168) | class Checkpointer(object):
    method __init__ (line 188) | def __init__(self, root, model=None, meta=None, optimizers=None,
    method __repr__ (line 215) | def __repr__(self):
    method __path (line 218) | def __path(self, path, prefix=None):
    method save (line 223) | def save(self, path, extras=None):
    method try_and_init_from (line 257) | def try_and_init_from(self, path):
    method load (line 286) | def load(self, path):
    method load_latest (line 331) | def load_latest(self):
    method sorted_checkpoints (line 353) | def sorted_checkpoints(self):
    method delete (line 374) | def delete(self, path):
    method load_meta (line 386) | def load_meta(root, prefix=None):
  class Trainer (line 404) | class Trainer(object):
    method __init__ (line 416) | def __init__(self, interface):
    method interrupt_handler (line 426) | def interrupt_handler(self, signo, frame):
    method _stop (line 431) | def _stop(self):
    method add_callback (line 436) | def add_callback(self, callback):
    method train (line 447) | def train(self, dataloader, starting_epoch=None, num_epochs=None,
    method __repr__ (line 500) | def __repr__(self):
    method __training_start (line 504) | def __training_start(self, dataloader):
    method __training_end (line 508) | def __training_end(self):
    method __epoch_start (line 512) | def __epoch_start(self, epoch_idx):
    method __epoch_end (line 516) | def __epoch_end(self):
    method __batch_start (line 520) | def __batch_start(self, batch_idx, batch):
    method __batch_end (line 524) | def __batch_end(self, batch, train_step_data):
    method __val_batch_start (line 528) | def __val_batch_start(self, batch_idx, batch):
    method __val_batch_end (line 532) | def __val_batch_end(self, batch, running_val_data):
    method __validation_start (line 536) | def __validation_start(self, dataloader):
    method __validation_end (line 541) | def __validation_end(self, running_val_data):
    method __training_step (line 545) | def __training_step(self, batch):
    method __validation_step (line 548) | def __validation_step(self, batch, running_val_data):
  class Callback (line 552) | class Callback(object):
    method __repr__ (line 563) | def __repr__(self):
    method __init__ (line 566) | def __init__(self):
    method training_start (line 575) | def training_start(self, dataloader):
    method training_end (line 584) | def training_end(self):
    method epoch_start (line 588) | def epoch_start(self, epoch):
    method epoch_end (line 599) | def epoch_end(self):
    method validation_start (line 607) | def validation_start(self, dataloader):
    method validation_end (line 616) | def validation_end(self, val_data):
    method batch_start (line 620) | def batch_start(self, batch_idx, batch_data):
    method batch_end (line 629) | def batch_end(self, batch_data, train_step_data):
    method val_batch_start (line 639) | def val_batch_start(self, batch_idx, batch_data):
    method val_batch_end (line 648) | def val_batch_end(self, batch_data, running_val_data):
  class CheckpointingCallback (line 658) | class CheckpointingCallback(Callback):
    method __init__ (line 675) | def __init__(self, checkpointer, interval=600,
    method epoch_end (line 685) | def epoch_end(self):
    method training_end (line 692) | def training_end(self):
    method batch_end (line 696) | def batch_end(self, batch_data, train_step_data):
    method __purge_old_files (line 720) | def __purge_old_files(self):
  class KeyedCallback (line 746) | class KeyedCallback(Callback):
    method __init__ (line 758) | def __init__(self, keys=None, val_keys=None, smoothing=0.999):
    method batch_end (line 776) | def batch_end(self, batch_data, train_step_data):
  class ProgressBarCallback (line 780) | class ProgressBarCallback(KeyedCallback):
    method __init__ (line 787) | def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None):
    method training_start (line 796) | def training_start(self, dataloader):
    method training_end (line 800) | def training_end(self):
    method epoch_start (line 804) | def epoch_start(self, epoch):
    method epoch_end (line 812) | def epoch_end(self):
    method validation_start (line 817) | def validation_start(self, dataloader):
    method val_batch_end (line 823) | def val_batch_end(self, batch, running_val_data):
    method validation_end (line 826) | def validation_end(self, val_data):
    method batch_end (line 836) | def batch_end(self, batch_data, train_step_data):

FILE: scripts/demosaicnet_demo.py
  function main (line 16) | def main(args):

FILE: scripts/eval.py
  class PSNR (line 14) | class PSNR(th.nn.Module):
    method __init__ (line 15) | def __init__(self):
    method forward (line 18) | def forward(self, out, ref):
  function main (line 22) | def main(args):

FILE: scripts/train.py
  class PSNR (line 14) | class PSNR(th.nn.Module):
    method __init__ (line 15) | def __init__(self):
    method forward (line 18) | def forward(self, out, ref):
  class DemosaicnetInterface (line 23) | class DemosaicnetInterface(demosaicnet.utils.ModelInterface):
    method __init__ (line 31) | def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()):
    method training_step (line 41) | def training_step(self, batch):
    method forward (line 46) | def forward(self, batch):
    method backward (line 52) | def backward(self, batch, fwd_output):
    method init_validation (line 69) | def init_validation(self):
    method update_validation (line 72) | def update_validation(self, batch, fwd_output, running_data):
    method finalize_validation (line 87) | def finalize_validation(self, running_data):
  function main (line 93) | def main(args):
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (89K chars).
[
  {
    "path": ".gitignore",
    "chars": 57,
    "preview": "output/\ndist/\ndemosaicnet.egg-info\nbuild/\ndata\n.DS_Store\n"
  },
  {
    "path": "LICENSE",
    "chars": 1216,
    "preview": "MIT License\r\n\r\nDeep Joint Demosaicking and Denoising\r\nSiggraph Asia 2016\r\nMichael Gharbi, Gaurav Chaurasia, Sylvain Pari"
  },
  {
    "path": "MANIFEST.in",
    "chars": 111,
    "preview": "include demosaicnet/data/bayer.pth\ninclude demosaicnet/data/xtrans.pth\ninclude demosaicnet/data/test_input.png\n"
  },
  {
    "path": "Makefile",
    "chars": 343,
    "preview": "test:\n\tpy.test tests\n\n.PHONY: docs\ndocs:\n\t$(MAKE) -C docs html\n\nclean:\n\tpython setup.py clean\n\trm -rf build demosaicnet."
  },
  {
    "path": "README.md",
    "chars": 1067,
    "preview": "# Deep Joint Demosaicking and Denoising\nSiGGRAPH Asia 2016\n\nMichaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris"
  },
  {
    "path": "demosaicnet/.gitignore",
    "chars": 12,
    "preview": "__pycache__\n"
  },
  {
    "path": "demosaicnet/__init__.py",
    "chars": 201,
    "preview": "from .modules import BayerDemosaick\nfrom .modules import XTransDemosaick\nfrom .mosaic import xtrans\nfrom .mosaic import "
  },
  {
    "path": "demosaicnet/dataset.py",
    "chars": 24857,
    "preview": "\"\"\"Dataset loader for demosaicnet.\"\"\"\nimport os\nimport subprocess\nimport shutil\nimport hashlib\nimport logging\n\n\nimport n"
  },
  {
    "path": "demosaicnet/modules.py",
    "chars": 4922,
    "preview": "\"\"\"Models for [Gharbi2016] Deep Joint demosaicking and denoising.\"\"\"\nimport os\nfrom collections import OrderedDict\nfrom "
  },
  {
    "path": "demosaicnet/mosaic.py",
    "chars": 3052,
    "preview": "\"\"\"Utilities to make a mosaic mask and apply it to an image.\"\"\"\nimport numpy as np\nimport torch as th\n\n\n__all__ = [\"baye"
  },
  {
    "path": "demosaicnet/utils.py",
    "chars": 29621,
    "preview": "\"\"\"Helper functions.\"\"\"\n\nfrom abc import ABCMeta, abstractmethod\nimport argparse\nimport logging\nimport os\nimport re\nimpo"
  },
  {
    "path": "demosaicnet/version.py",
    "chars": 23,
    "preview": "__version__ = \"0.0.14\"\n"
  },
  {
    "path": "docs/.gitignore",
    "chars": 6,
    "preview": "build\n"
  },
  {
    "path": "docs/Makefile",
    "chars": 584,
    "preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHI"
  },
  {
    "path": "docs/source/conf.py",
    "chars": 5933,
    "preview": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a s"
  },
  {
    "path": "docs/source/dataset.rst",
    "chars": 66,
    "preview": "Dataset\n=======\n\n.. automodule:: demosaicnet.dataset\n   :members:\n"
  },
  {
    "path": "docs/source/helpers.rst",
    "chars": 65,
    "preview": "Helpers\n=======\n\n.. automodule:: demosaicnet.mosaic\n   :members:\n"
  },
  {
    "path": "docs/source/index.rst",
    "chars": 523,
    "preview": ".. demosaicnet documentation master file, created by\n   sphinx-quickstart on Thu Mar 14 13:14:16 2019.\n   You can adapt "
  },
  {
    "path": "docs/source/models.rst",
    "chars": 64,
    "preview": "Models\n======\n\n.. automodule:: demosaicnet.modules\n   :members:\n"
  },
  {
    "path": "requirements.txt",
    "chars": 80,
    "preview": "imageio==2.31.1\nnumpy==1.25.0\nsetuptools==49.2.1\ntorch==2.0.1\ntqdm==4.65.0\nwget\n"
  },
  {
    "path": "scripts/demosaicnet_demo.py",
    "chars": 1987,
    "preview": "#!/usr/bin/env python\n\"\"\"Demo script on using demosaicnet for inference.\"\"\"\n\nimport os\nfrom pkg_resources import resourc"
  },
  {
    "path": "scripts/eval.py",
    "chars": 2959,
    "preview": "#!/bin/env python\n\"\"\"Evaluate a demosaicking model.\"\"\"\nimport argparse\nimport logging\n\nimport torch as th\nfrom torch.uti"
  },
  {
    "path": "scripts/train.py",
    "chars": 5605,
    "preview": "#!/bin/env python\n\"\"\"Train a demosaicking model.\"\"\"\nimport logging\n\nimport torch as th\nfrom torch.utils.data import Data"
  },
  {
    "path": "setup.py",
    "chars": 1013,
    "preview": "import re\nimport setuptools\n\n\nwith open('demosaicnet/version.py') as fid:\n    try:\n        __version__, = re.findall( '_"
  }
]

About this extraction

This page contains the full source code of the mgharbi/demosaicnet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (82.4 KB), approximately 24.8k tokens, and a symbol index with 108 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!