[
  {
    "path": ".gitignore",
    "content": "output/\ndist/\ndemosaicnet.egg-info\nbuild/\ndata\n.DS_Store\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\r\n\r\nDeep Joint Demosaicking and Denoising\r\nSiggraph Asia 2016\r\nMichael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand\r\n\r\nCopyright (c) 2016 Michael Gharbi\r\n\r\nPermission is hereby granted, free of charge, to any person obtaining a copy\r\nof this software and associated documentation files (the \"Software\"), to deal\r\nin the Software without restriction, including without limitation the rights\r\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\r\ncopies of the Software, and to permit persons to whom the Software is\r\nfurnished to do so, subject to the following conditions:\r\n\r\nThe above copyright notice and this permission notice shall be included in all\r\ncopies or substantial portions of the Software.\r\n\r\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\r\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\r\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\r\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\r\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\r\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\r\nSOFTWARE.\r\n"
  },
  {
    "path": "MANIFEST.in",
    "content": "include demosaicnet/data/bayer.pth\ninclude demosaicnet/data/xtrans.pth\ninclude demosaicnet/data/test_input.png\n"
  },
  {
    "path": "Makefile",
    "content": "test:\n\tpy.test tests\n\n.PHONY: docs\ndocs:\n\t$(MAKE) -C docs html\n\nclean:\n\tpython setup.py clean\n\trm -rf build demosaicnet.egg-info dist .pytest_cache\n\ndistribution:\n\tpython setup.py sdist bdist_wheel\n\ttwine check dist/*\n\ntest_upload:\n\ttwine upload --repository-url https://test.pypi.org/legacy/ dist/*\n\nupload_distribution:\n\ttwine upload dist/*\n"
  },
  {
    "path": "README.md",
    "content": "# Deep Joint Demosaicking and Denoising\nSiGGRAPH Asia 2016\n\nMichaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris Frédo Durand\n\nA minimal pytorch implementation of \"Deep Joint Demosaicking and Denoising\" [Gharbi2016]\n\n# Installation\n\nFrom this repo:\n\n```shell\npython setup.py install\n```\n\nUsing pip:\n\n```shell\npip install demosaicnet\n```\n\nThen run the demo script with:\n\n```shell\npython scripts/demosaicnet_demo.py output\n```\n\nTo train a dummy model on the demo dataset provided, run:\n\n```shell\npython scripts/train.py --data demosaicnet/data/dummy_dataset --checkpoint_dir ckpt\n```\n\nTo build and update the whee:\n\n```shell\npip install wheel twine\nmake distribution\nmake upload_distribution\n```\n\n# FAQ\n\n- **How is noise handled? Where is the pretrained model?** The noise-aware model is not implementation, see the earlier Caffe implementation for that <https://github.com/mgharbi/demosaicnet_caffe>\n- **How do I train this?** The script `scripts/train.py` is a good start to setup your training job, but I haven't tested it yet, I recommend rolling your own.\n"
  },
  {
    "path": "demosaicnet/.gitignore",
    "content": "__pycache__\n"
  },
  {
    "path": "demosaicnet/__init__.py",
    "content": "from .modules import BayerDemosaick\nfrom .modules import XTransDemosaick\nfrom .mosaic import xtrans\nfrom .mosaic import bayer\nfrom .mosaic import xtrans_cell\nfrom .dataset import *\nfrom . import utils\n"
  },
  {
    "path": "demosaicnet/dataset.py",
    "content": "\"\"\"Dataset loader for demosaicnet.\"\"\"\nimport os\nimport subprocess\nimport shutil\nimport hashlib\nimport logging\n\n\nimport numpy as np\nfrom imageio import imread\nfrom torch.utils.data import Dataset as TorchDataset\nimport wget\n\nfrom .mosaic import bayer, xtrans\n\n__all__ = [\"BAYER_MODE\", \"XTRANS_MODE\", \"Dataset\",\n           \"TRAIN_SUBSET\", \"VAL_SUBSET\", \"TEST_SUBSET\"]\n\n\nlog = logging.getLogger(__name__)\n\nBAYER_MODE = \"bayer\"\n\"\"\"Applies a Bayer mosaic pattern.\"\"\"\n\nXTRANS_MODE = \"xtrans\"\n\"\"\"Applies an X-Trans mosaic pattern.\"\"\"\n\nTRAIN_SUBSET = \"train\"\n\"\"\"Loads the 'train' subset of the data.\"\"\"\n\nVAL_SUBSET = \"val\"\n\"\"\"Loads the 'val' subset of the data.\"\"\"\n\nTEST_SUBSET = \"test\"\n\"\"\"Loads the 'test' subset of the data.\"\"\"\n\n\nclass Dataset(TorchDataset):\n    \"\"\"Dataset of challenging image patches for demosaicking.\n\n    Args:\n        download(bool): if True, automatically download the dataset.\n        mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data.\n        subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load.\n    \"\"\"\n\n    def __init__(self, root, download=False,\n                 mode=BAYER_MODE, subset=\"train\"):\n\n        super(Dataset, self).__init__()\n\n        self.root = os.path.abspath(root)\n\n        if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]:\n            raise ValueError(\"Dataset subet should be '%s', '%s' or '%s', got\"\n                             \" %s\" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET,\n                                      subset))\n\n        if mode not in [BAYER_MODE, XTRANS_MODE]:\n            raise ValueError(\"Dataset mode should be '%s' or '%s', got\"\n                             \" %s\" % (BAYER_MODE, XTRANS_MODE, mode))\n        self.mode = mode\n\n        listfile = os.path.join(self.root, subset, \"filelist.txt\")\n        log.debug(\"Reading image list from %s\", listfile)\n\n        if not os.path.exists(listfile):\n            if download:\n                _download(self.root)\n            else:\n                log.error(\"Filelist %s not found\", listfile)\n                raise ValueError(\"Filelist %s not found\" % listfile)\n        else:\n            log.debug(\"No need no download the data, filelist exists.\")\n\n        self.files = []\n        with open(listfile, \"r\") as fid:\n            for fname in fid.readlines():\n                self.files.append(os.path.join(self.root, subset, fname.strip()))\n\n    def __len__(self):\n        return len(self.files)\n\n    def __getitem__(self, idx):\n        \"\"\"Fetches a mosaic / demosaicked pair of images.\n\n        Returns\n            mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels.\n            img(np.array): with size [3, h, w] the groundtruth image.\n        \"\"\"\n        fname = self.files[idx]\n        img = np.array(imread(fname)).astype(np.float32) / (2**8-1)\n        img = np.transpose(img, [2, 0, 1])\n\n        if self.mode == BAYER_MODE:\n            mosaic = bayer(img)\n        else:\n            mosaic = xtrans(img)\n\n        return mosaic, img\n\n\nCHECKSUMS = {\n    'datasets.z01': 'da46277afe85d3a91c065e4751fb8175',\n    'datasets.z02': 'e274a9646323d954b00094ea424e4e4c',\n    'datasets.z03': 'e071cc595a99a5aa4545d06350e5165f',\n    'datasets.z04': 'c3d2f229834569cd5ae6e2d1467c4a95',\n    'datasets.z05': 'daf90136c7b1ee4bb4653e9b6bf4b67d',\n    'datasets.z06': '87e85d2854d40116e066b28e5a8750cc',\n    'datasets.z07': 'a0b2854bf025c87c0bfdf83ce9aa9055',\n    'datasets.z08': '62125ccf29cd4b182dd81a4bb82f94c4',\n    'datasets.z09': 'f990f8a5090d586f2f31e61b5e6434bd',\n    'datasets.z10': '41ecf8d8b7d981604d661d258bf988db',\n    'datasets.z100': '923a536ece64cd036eec4a13156531c8',\n    'datasets.z101': '44a936558af2e830fdf65d9acb3960ab',\n    'datasets.z102': 'b24870482b41200ab7b91f0bcd3ed718',\n    'datasets.z103': 'a85521c1fe0b8d2d1a074b0b52bf9db1',\n    'datasets.z104': 'aacc7a81ec9e9a7849e3a45b1cb12f7c',\n    'datasets.z105': '19b62c0f0ae008b77df6465182f43dc4',\n    'datasets.z106': '4b0c414ce5825a9e2249e5810f0e55f0',\n    'datasets.z107': '7f6df7fea899a656fcde898225890daf',\n    'datasets.z108': '16a877c357f112367200a2534b5e54f3',\n    'datasets.z109': '9180129bf9c204184f729bdf1c284c9c',\n    'datasets.z11': 'cff5b0e9950933fa9dd6ced8ffb9528f',\n    'datasets.z110': 'a95a6fbfd32d90058b9e4b9f0645c646',\n    'datasets.z111': 'f3c7894a7d04178ca417dd5ed3a9e649',\n    'datasets.z112': 'd46a73703a72c07137424cad90c9c0bd',\n    'datasets.z113': '04b3421e465c5ef8bd64fc23730b58f7',\n    'datasets.z114': '9c31d2d94bd6f1ba321039b18c462175',\n    'datasets.z115': '427a3a6f3f936b0ba435da35be3e4bc3',\n    'datasets.z116': 'c633a66e7644d7d8e8148d651b76d93f',\n    'datasets.z117': 'cf316d3acaf301fc7b2b7e250ef734dd',\n    'datasets.z118': 'c3b53a604492499930dfd000d7d33fa4',\n    'datasets.z119': 'e64ede2179c589abd9e587347d9ba3b0',\n    'datasets.z12': '17b70298245ae7965f4e4b4fb01f19cc',\n    'datasets.z120': 'ece09ee0bc30eab71f06716cca393029',\n    'datasets.z121': '79db008803bf58593df6c32db8c0b3d0',\n    'datasets.z122': '647b151eb30a44d123ce9ddfbb380094',\n    'datasets.z123': '5065f755fe61c666d6ed28096e4047a3',\n    'datasets.z124': '6306215855e30112c495291d5928e0b7',\n    'datasets.z125': 'a55c6e31e7ad42170016a15791e25134',\n    'datasets.z126': '582eb81f0251840507ca0b53e624b1e0',\n    'datasets.z127': '3beefb769a01481a8bc7ae39bc2f539b',\n    'datasets.z128': 'fac96b38f96ea364ea51020386597b5d',\n    'datasets.z129': '3aaca8ce2b67d2c1fe764ae2b306d17f',\n    'datasets.z13': '7ec2aa595441d9f698a46d707f299e8c',\n    'datasets.z130': '9f94105dcc39cf5421b2df2532c06ec9',\n    'datasets.z131': '5de2901388a2e531d6874cdaf23bf15e',\n    'datasets.z132': '6851ed45004ae6864892532d1ad44b20',\n    'datasets.z133': '0f6b417d57f9bdc9fa91d85ff5e3378d',\n    'datasets.z134': '0602b14c8828f7a9fed92713047c695f',\n    'datasets.z135': '3ab79fd4da5c4c5b5a7896a189ec43a7',\n    'datasets.z136': 'dd05152db786d8189cdb419ac5d0018a',\n    'datasets.z137': 'b4e97abef22ee8b81ac232760d9f539d',\n    'datasets.z138': '4ce939f1fa4f3e110e989db07d53d33f',\n    'datasets.z139': 'a096add471cc5e074852c063eec3863a',\n    'datasets.z14': '08571b6629b8856813fb35b45bbc082c',\n    'datasets.z140': '3d84e8cc84ab26969c5239be60222ad3',\n    'datasets.z141': 'd77062f59ca9957d33c9a671657fd795',\n    'datasets.z142': '82901d01917006348deb89aa37fe3629',\n    'datasets.z143': '736f77856f0854b26fa951479691df8f',\n    'datasets.z144': '55c44320975f4278a8837085c5e02eda',\n    'datasets.z145': '087d3b7634bf4720a916767d5c6b7d70',\n    'datasets.z146': '5659d6f0495dcdc5f5d98bf2efaaa09b',\n    'datasets.z147': '66dd69b2f9348e3c0d0c93c3e61416dd',\n    'datasets.z148': 'f3fc8f15aceb0f9bf04d786b894caa44',\n    'datasets.z149': '3863be1d2b130f79399432cfc1281c2e',\n    'datasets.z15': 'cc57e0c4466575436f670ac3e07ad2f3',\n    'datasets.z150': '09750f2019da9ff7132b904b8bcbd895',\n    'datasets.z151': 'b1573f086c0f7d1fdf249a8e3a9bb178',\n    'datasets.z152': '1a2d4374aea1e22c0b676a6a7eac49ec',\n    'datasets.z153': 'b24320708d2019ed71ab16055e971b1d',\n    'datasets.z154': '7ba27e1946afa610e131f3afefe78326',\n    'datasets.z155': '2f02c8b5470be4cb6b53e4c9e512394e',\n    'datasets.z156': 'fa4f0977409f181820bf78174257d657',\n    'datasets.z157': '6736b97a29d1393ec65ddc9376a06369',\n    'datasets.z158': 'b4b72842b13ec3877bc530ca2470a0db',\n    'datasets.z159': 'fa0dfa57c9d299175719bbbdf319c935',\n    'datasets.z16': '02c213b708e2ee7ebd68464dfb2279fc',\n    'datasets.z160': '6edbb9dc7fa6d12d2e21631ae14eaa8c',\n    'datasets.z161': '4ab093ce5af2726e9ee71fdf1943e8e2',\n    'datasets.z162': '4e21401db9f9884d953df20381c5fd97',\n    'datasets.z163': '8d39f0ed1a1d9b5de22b583d00081522',\n    'datasets.z164': '40fe425e2c5e89b87b44a6e9735590d5',\n    'datasets.z165': '9552ff9b03e2dbb45befc9d1cc99ad81',\n    'datasets.z166': '6c0098a36d6827aea846d8522c578751',\n    'datasets.z167': 'ce6b8b981d92f5a61f2ec40089a400a2',\n    'datasets.z168': '60f6e16a3e5e409a3fc89edd3e0034d5',\n    'datasets.z169': '75eb975a10d5cbf136796651a1789b42',\n    'datasets.z17': 'c92ae62205eaef02db27996f0dc6c282',\n    'datasets.z170': '287966840fff015ed36da3a08a18ebfb',\n    'datasets.z171': 'e08193b722af492a78ac36a3125ac8d9',\n    'datasets.z172': '32c795461f194c38b25047faeb46fdb5',\n    'datasets.z173': '58fbb396dbbb902ac2f2c43722573200',\n    'datasets.z174': '506fee2b982ee81689f3fe4d89133cac',\n    'datasets.z175': '9d553e31b07b23e30c427800168eec6b',\n    'datasets.z176': '0f6e3048824ead093d3127434ac83a72',\n    'datasets.z177': 'ece15c004fa708849295987b8b1aba9c',\n    'datasets.z178': 'd9db14d92d56ae2970798417030b5bb4',\n    'datasets.z179': '25ddbe866d0a6b9cebe8d90f7b801fa6',\n    'datasets.z18': 'f55ddc31cf203f495e352182a5bbadc3',\n    'datasets.z180': 'd2fc49c68c77d1da592aff4ab90c0915',\n    'datasets.z181': '5a4a635b1f3535311c6caecf4ab3ba80',\n    'datasets.z182': 'df51725daff3edfc377a5f6bc158ec3a',\n    'datasets.z183': '626add199ec4f263ff278d5392f41c9c',\n    'datasets.z184': '5069483fd064ee5e8c24a240e6ee7736',\n    'datasets.z185': '589249e98db0a4ded1d3e4acefd07509',\n    'datasets.z186': 'e4415c64463ef16bceb9d2e2fa934d71',\n    'datasets.z187': 'e070cbaaf88a1085964244f6505c713c',\n    'datasets.z188': '71b38eb51edff8b049a302bacbe344d6',\n    'datasets.z189': '8fa7a8b58c9e7cb9e86bfd0ca5f6d2ea',\n    'datasets.z19': '6c34cd0e39a33737983ebf89f6cabf5f',\n    'datasets.z190': 'daad0ea7c87d0935e014a370c38cc926',\n    'datasets.z191': 'c355e9ee9d0afa67faa34739b7f7cf79',\n    'datasets.z192': 'c97d5a784625795cdf3c36c337986afe',\n    'datasets.z193': '5f3b8425e215798c9e454cdbe586db90',\n    'datasets.z194': '31b5d74c1cbbabbf58ee470467b40d12',\n    'datasets.z195': '4c65958343bc2ea1a28e779ee7e5e498',\n    'datasets.z196': '26ab3664e62c7fd5d0be673c45dd0d93',\n    'datasets.z197': '32d690086b6e9f05e3ced3a126af870c',\n    'datasets.z198': '323071827db89626c9f186455fbb38c9',\n    'datasets.z199': '72475a8500be1ff21407a66f0e2e91a6',\n    'datasets.z20': '4161d6eda0ca5ed9500f953f789a25b2',\n    'datasets.z200': '44a863b9c9760cb87a23f1422f242c0d',\n    'datasets.z201': 'dc38b455fa45e3ef0d5f06397507982e',\n    'datasets.z202': 'b9ba231b317b008602f9472325b40e65',\n    'datasets.z203': '36f4afc46258d80be626040956550028',\n    'datasets.z204': '5c522fbdfe1f9d449c9189173f2ed2b2',\n    'datasets.z205': '58d39995017eed2c4abcf9fcfd07e695',\n    'datasets.z206': '2efdfc2abc834f0f0f1cabe10423f865',\n    'datasets.z207': '04ead7536e5c13936c724f644ab1cb3a',\n    'datasets.z208': '9e3e0a02a07bebcc7cdb62a0ad047946',\n    'datasets.z209': 'c1fc44cd8b6f50955c8b3b317155ecb6',\n    'datasets.z21': '8669b8bb9fa90628d4423c45648868b2',\n    'datasets.z210': '210df79f8434bdb4e2a7d12c4078d972',\n    'datasets.z211': 'f078d2d8a14b6c59f58c67865bbc3334',\n    'datasets.z212': 'ef7d08a6cc39f6cb96b631ca61b440a0',\n    'datasets.z213': '723057f7619d8820f142944f55f9542b',\n    'datasets.z214': '2ba38ca8561b51710f660c03f84c0eb1',\n    'datasets.z215': '92a6e97dfaff295110ddead242ebe932',\n    'datasets.z216': '05f40901ae70f73b3c099fcdd4ca945e',\n    'datasets.z217': 'b690ed3e8c6ba9f8bba9154d7e8f7ece',\n    'datasets.z218': 'e290ca6f5573579df9f3aa7c5158891e',\n    'datasets.z219': 'a8e4626968f089163179f30066ce732c',\n    'datasets.z22': '0f90463abdc8f0f81d81249302cf2d09',\n    'datasets.z220': '3ecd2c0c855505d2957046d784944fce',\n    'datasets.z221': '6c9c28287fbadcab2ce777ef3134e5d6',\n    'datasets.z222': '29361a77f05e5e68113fc23e11b54b4d',\n    'datasets.z223': '9214257b9a87c0037e88709addba8948',\n    'datasets.z224': '7596a516fb7e308f33a81c5b3c36810a',\n    'datasets.z225': 'c1dc079f5261a976b1bd7f5c05cd4a02',\n    'datasets.z226': '5b87815b0ccc5cacec83a399a52874aa',\n    'datasets.z227': '6bed353dc50263b2c720af663c833bbc',\n    'datasets.z228': '37ed3574bf978bccd6e2db9be00bae94',\n    'datasets.z229': '8fd57367808fd77581f998850e5f935a',\n    'datasets.z23': '90ae2cdcdc1663b80c20e080e5c0e038',\n    'datasets.z230': '3f5ef3234da0236d2fbfaf7366407d70',\n    'datasets.z231': 'f67d8320028620c8bdd9a800a78afa27',\n    'datasets.z232': '5f831f25f8e8557168b38a7a28f8e7f9',\n    'datasets.z233': '56ee8c4b01825ba7f12340ae8b990db3',\n    'datasets.z234': '5428e98487b0e077cf9c24dc60599286',\n    'datasets.z235': '883c0ea97facca4d57d5c9c54922e8be',\n    'datasets.z236': 'e23fb6f610a3b528d5b310df4e452256',\n    'datasets.z237': 'bd858e84a47668edc851dab131239ae4',\n    'datasets.z238': 'db51bed3e3e5c6a40881f22532618533',\n    'datasets.z239': 'c1be852117739fc63227a503b08a8436',\n    'datasets.z24': '00dcb2e2a72b15a9aa9a646ecaea0019',\n    'datasets.z240': 'f4e6da02349b03b4f433b3399dcf8b3c',\n    'datasets.z241': '4f29898105aaf9f1a753a1c639947c2f',\n    'datasets.z242': '168a15bd8367f5d5f3e5e8cf4d0da6af',\n    'datasets.z243': '2930bced33bd1ecabce070fc831567e7',\n    'datasets.z244': 'aed9c5ec05f57e3fa9e7b224d47fa7b0',\n    'datasets.z245': '5aa83729fec805c166e48e5ec21530a5',\n    'datasets.z246': '646ceabfae028d631568930b4056227a',\n    'datasets.z247': '7381491175c1a63cc04ecac81148925b',\n    'datasets.z248': '4064df81449c1980d0abaa8c7262b315',\n    'datasets.z249': '7ae84d1dde2e935d86138d1e7b077df8',\n    'datasets.z25': '5ae383bcd01d4ce22387680e28833f06',\n    'datasets.z250': 'c018b41fbc4982a561b07cf0d52137f2',\n    'datasets.z251': '9c9dc7a889d537fd1e02f4549529a5f1',\n    'datasets.z252': 'b74df0680e7a62794902186fe1e3fec2',\n    'datasets.z253': '894cbb618ddffe65ce2ada0f250ad79c',\n    'datasets.z254': 'd5d8bc590d109c7d592d4df183c495a8',\n    'datasets.z255': 'faf313a3edd70129c212d3dbd1de5042',\n    'datasets.z256': '0fefcb84b66518df03a93ae53079409f',\n    'datasets.z257': '887e8e8abaf09682903b9b1060fb8153',\n    'datasets.z258': 'a700ae13abb7707032123468fab1bf55',\n    'datasets.z259': 'cad7c974b832d27d1cfb4ae0e4dc6c3c',\n    'datasets.z26': '5e0cdf281eeca969a4e0adfc44e11dcb',\n    'datasets.z260': '5e1c11df40440e4e84354d40efe7940e',\n    'datasets.z261': '2a75579c238569356855244d9fedf50b',\n    'datasets.z262': '51d508271fef5558df387542b3561b67',\n    'datasets.z263': 'e11b54dcd5069e838a34dcc2daebf4b5',\n    'datasets.z264': '8a5a3b288d21a3c2ef641370d436703e',\n    'datasets.z265': '56954047cba7c8732f0490323540af43',\n    'datasets.z266': 'cad0325e494cc720c385ac7420acd2d7',\n    'datasets.z267': '25de522b499ec7af12f583dd89a31769',\n    'datasets.z268': '0e9882f392ca679e8c47276371384efd',\n    'datasets.z269': '28487bc8eb731d4913254a9d63bb13ae',\n    'datasets.z27': 'c78804dfac1e395156abd235ca416b33',\n    'datasets.z270': '276fde25d412d8a1197e3dda307580d7',\n    'datasets.z271': 'ba6f46558aa9ebc64efca2485b4f18ca',\n    'datasets.z272': '5acceb08940d937d3023d98e745b8197',\n    'datasets.z273': '57201a53390fe9c6a8c069397dcb81b8',\n    'datasets.z274': '92491b6786ccea7b6ccccfb4e09c6d75',\n    'datasets.z275': '349136c52b8554f03967d7083e5cd95c',\n    'datasets.z276': '1525075dbf4d5d101d63cdade8bed9e7',\n    'datasets.z277': 'eeb6365ef482cd2c6bac20aab8181081',\n    'datasets.z278': 'ab1b04860b27fc11b7f57074a0815877',\n    'datasets.z279': 'dc1cee0d4b69da9fd7aeb47f91768589',\n    'datasets.z28': '0257b938256a2b7b55637970ebb3edcc',\n    'datasets.z280': 'd168adf5e7c223a1d8dddfc663ebaeb0',\n    'datasets.z281': '540fc1de91a90e9bd91b3f2b590ddbf6',\n    'datasets.z282': 'f3fd6fbe05cfd53eb4a4c2e41bf75cc7',\n    'datasets.z283': '20e89e60dd6a582bcb98b74df82699c3',\n    'datasets.z284': '6e6ba6077437285881609999ead45463',\n    'datasets.z285': 'fdc4bad8adb36b3d6653a438f1fc000f',\n    'datasets.z286': '39c0f6fc7aace7e30a33e7e73afdb6ae',\n    'datasets.z287': 'e60818eefd7426f3de0cc0746550be7f',\n    'datasets.z288': 'd9db9107b8e92c0bf6a311a219363554',\n    'datasets.z289': 'a250aa672d1f9165981cfaf1c6c8fff6',\n    'datasets.z29': 'e1240710e1c4dd506aec03c02caf5606',\n    'datasets.z290': '6dcb5a7674c927ec4e965a42d04a0ccd',\n    'datasets.z291': 'c058fca515b7f714338816b72672ce20',\n    'datasets.z292': '3fa254ed46ad6a6f7686836fe6fb7991',\n    'datasets.z293': '782344c6620582d6b1681e142853a61e',\n    'datasets.z294': '5263c5eb50cfec20adf82a89973a7547',\n    'datasets.z295': '0c217819aa7308ce8f744511e572a632',\n    'datasets.z296': 'cf8e7f1d3503ad6371ebcc9f827a29f8',\n    'datasets.z297': 'f16861471be342b291e55991654c882b',\n    'datasets.z298': 'ceb8b1170f1f2ec8113ef1c004df236c',\n    'datasets.z299': '4f66a50c5dfb03143a6456c7fd925ec1',\n    'datasets.z30': '75b0444187fc6c3df7bd3182108bc647',\n    'datasets.z300': 'becbeea47cef192a1a13b35911c4795a',\n    'datasets.z301': '92a67108c6bde11111f3b94f690f4b42',\n    'datasets.z302': '06292210e05575cf1632a099648c16af',\n    'datasets.z303': 'ecb7a48777719322c957ecc99340f04a',\n    'datasets.z304': 'b318fc0b7d645d16467b78bbce95befa',\n    'datasets.z305': '5f83e6a17e5977c2b99fc29893a1f479',\n    'datasets.z306': 'dde2ec22363740081f54032a3add00e0',\n    'datasets.z307': 'e0b7f9a7eddf2117a6127542883eb767',\n    'datasets.z308': 'fff180911deaf4b476f42e6e47d78e6e',\n    'datasets.z309': 'c55a8ac017f7ea69e77519fbcc617301',\n    'datasets.z31': '04bee9374c7436d66f560bbbfc22299d',\n    'datasets.z310': '3b4282fd1ab2b4f885df196a83726d34',\n    'datasets.z311': '08a7c41c5d290750d9ebf6266a86bec8',\n    'datasets.z312': 'b0f844ffd0ae6785e077ef8871dfc5da',\n    'datasets.z313': 'bd03ba03a63877b17274e21ddb828218',\n    'datasets.z314': '392f4528f9c345355434e7448a80b28b',\n    'datasets.z315': 'e499cdf8acf1561720d4eb8f9ad9daad',\n    'datasets.z316': '2575496c4d9c5082c6c4ef2f0cedeb69',\n    'datasets.z317': '6f387284dbef478e02f7a5954da66015',\n    'datasets.z318': '2ea47b6b7c9790bbfc2d3c238ffba391',\n    'datasets.z319': '8e1e48b2aa1d6c7ae840a23360a3a8f8',\n    'datasets.z32': 'ab930ee97741da82bfc778474f528a28',\n    'datasets.z320': 'a785b37410ccf0366f97c0d221faa629',\n    'datasets.z321': '16c73037fa0ca7704a3c4ceddbd7d599',\n    'datasets.z322': 'af97027070cf5d057934dfd6bd819e61',\n    'datasets.z323': '79078baf8d5fee935620f48aed2980a6',\n    'datasets.z324': '987b626c8731b0092df593f0eaddb32a',\n    'datasets.z325': 'fcc9a289015b40044c55ca96bc3dbe1e',\n    'datasets.z326': 'f2921b782a23f2729e1a6641e8c99954',\n    'datasets.z327': '6d1b488d88cf0fcaf5105b6544e5ea66',\n    'datasets.z328': '268885eaeb6290be4ddfaefb34943985',\n    'datasets.z329': '3951d3dd0527b54c5302720ea037cb12',\n    'datasets.z33': 'fff4585bf34e5a7ed8f369b337aac901',\n    'datasets.z330': '4b546d783366442d95eed64f882ae9f6',\n    'datasets.z331': '4605615511ea901c18256306263b2226',\n    'datasets.z332': '4127b2a8dd549b02ce3e465cfcfcc0a3',\n    'datasets.z333': '711d282247b4fd56758c96c13bfe1b8e',\n    'datasets.z34': '39b5d53f995fa8c223ed0d7a2de34652',\n    'datasets.z35': 'ae351c1e2961f99a6d3ac37fbae27548',\n    'datasets.z36': 'c2102c4a984d03c32c7c99378df953dc',\n    'datasets.z37': '76ba463895984049a1814654b7290890',\n    'datasets.z38': '5481901e75ee9d55066ba2731b2f36b7',\n    'datasets.z39': '59815df91b2532d1e300bc71c976da12',\n    'datasets.z40': '321ebe9b7812c14ee185fe2d6f16300c',\n    'datasets.z41': '862bc3fcb9fc1df2bef61561bcca8090',\n    'datasets.z42': '10303d540fea2150a7574cefcec92977',\n    'datasets.z43': '2269db3212f2d2db86982408a2b24948',\n    'datasets.z44': '80673bdbd722d02d97febeca00e57cf1',\n    'datasets.z45': '42329dbb5c5165788902f33265db66a3',\n    'datasets.z46': 'f81999194d418c515bb5df32c278d7f7',\n    'datasets.z47': '2e3d3520636b5f3eb0cd2d649b6b4dab',\n    'datasets.z48': 'e57e7aee104e80deaf068fbdd3292410',\n    'datasets.z49': '36146291e4af0eff44e763e7e1facb4f',\n    'datasets.z50': '60ef115c5c621c757b9b7075d5590c20',\n    'datasets.z51': '02326e4392d176c2aa6d479cf43f29f8',\n    'datasets.z52': 'b454f81ef7cda24cb13b8176c641d7ef',\n    'datasets.z53': '95a4857fe7bc6230e6a3cc379085e989',\n    'datasets.z54': '8ff35d1e9d738eb8b2afde04699ba73f',\n    'datasets.z55': '75d19ea4a9a283b6d236e3353063d82d',\n    'datasets.z56': '2ec9121fb364cc76eaf6fe7ff947dd04',\n    'datasets.z57': 'f5d0b9b2b0a91f82dea784023f71cf3c',\n    'datasets.z58': '79e318093d97bd573cc7aab4ed68dd6d',\n    'datasets.z59': '68350ada72cfc22d2d6fc8ca636ffef2',\n    'datasets.z60': '1ff9d36ec2b3723253503d524c90c9df',\n    'datasets.z61': '9fb963355e8c895c1a0a02fa45eb6fb2',\n    'datasets.z62': '5900cdc6c3d5f4b0c05825b47946096a',\n    'datasets.z63': '9d039ddf86aef563bf7834faa2766e84',\n    'datasets.z64': '15decde2fb2b4f869684f88f067378c1',\n    'datasets.z65': '7cf283ebffd79c38bc09835caba483e1',\n    'datasets.z66': 'e3b65f4facef36a355b444bd6b4ec73b',\n    'datasets.z67': 'b007466dd2ac3388a5902fdf92979b2d',\n    'datasets.z68': 'c07cd3fbc9a99e8f4d3cc9bbc55682e0',\n    'datasets.z69': 'cff012a6a3af8a9f286cf68baf37ac73',\n    'datasets.z70': '032ad7661e468723242b995357870a05',\n    'datasets.z71': 'fd7aa78706b8ae7ff50eb2f312a14ea5',\n    'datasets.z72': 'c2b90b8e8f75e36c927d523138d8ae93',\n    'datasets.z73': '8c001e26b5f4ca0e061ff0b685d509c6',\n    'datasets.z74': '484dcaf9311cc4315175e114acf01f32',\n    'datasets.z75': '19f6896b426b99f34b99c88f3b68209f',\n    'datasets.z76': '4a64c040ba9aa6e455fee774a08873e5',\n    'datasets.z77': 'e96b1d94e51741ca916580c5f8287ef5',\n    'datasets.z78': 'e0f5c9d93a3a5cfd95eeb61ee322650d',\n    'datasets.z79': '75fbea4391cd38d85c583c0f504325f1',\n    'datasets.z80': '7a99d15beff97f2dc620300f5ea82506',\n    'datasets.z81': '1e303b839a0349a8497eb6c49e3ab4c3',\n    'datasets.z82': '48c23365901b8d9cb51a20f7c71fa0a7',\n    'datasets.z83': 'fed6445cb47f0b6cd6ed6f5482e38be7',\n    'datasets.z84': '3d052276fc186ec151b5c237c5774e66',\n    'datasets.z85': 'd84dccdf4fd4e6f6126b746d0519d19d',\n    'datasets.z86': '772ab63f1ae266761db99124a152fc29',\n    'datasets.z87': '44208b663adf97563018a416560cce6b',\n    'datasets.z88': '6be7aaff3349239507b103c928afeaf5',\n    'datasets.z89': '515daa16bcaf83b0be91dcc32e0e4985',\n    'datasets.z90': 'f0fe0aa02bf0f19c0a7e301b532afa4c',\n    'datasets.z91': '87fa807089240e01886389a6c4c77641',\n    'datasets.z92': '1fd3154b3d8ec321f58f7685aedc5441',\n    'datasets.z93': 'b24ddb4b3e2c0c1c96d7650a8320e446',\n    'datasets.z94': 'b5e98cd90a629c4abb70d66a6976e0c7',\n    'datasets.z95': '87c8d5d7f26e0fbf7422fda2194bef97',\n    'datasets.z96': 'bb18c027a71f326a33b0a8fbe2d3a11f',\n    'datasets.z97': '0489ed1c8b4fb7d5965d1565076d5c6a',\n    'datasets.z98': 'a85303e08ce59fdf28f36ae9d0f20dcb',\n    'datasets.z99': '992e63e126f05eca9fa9a84bbf66165c',\n    'datasets.zip': '3434f60f5e9b263ef78e207b54e9debe',\n}\n\n\ndef _download(dst):\n    dst = os.path.abspath(dst)\n    files = CHECKSUMS.keys()\n    fullzip = os.path.join(dst, \"datasets.zip\")\n    joinedzip = os.path.join(dst, \"joined.zip\")\n\n    URL_ROOT = \"https://data.csail.mit.edu/graphics/demosaicnet\"\n\n    if not os.path.exists(joinedzip):\n        log.info(\"Dowloading %d files to %s (This will take a while, and ~80GB)\", len(\n            files), dst)\n\n        os.makedirs(dst, exist_ok=True)\n        for f in files:\n            fname = os.path.join(dst, f)\n            url = os.path.join(URL_ROOT, f)\n\n            do_download = True\n            if os.path.exists(fname):\n                checksum = md5sum(fname)\n                if checksum == CHECKSUMS[f]:  # File is is and correct\n                    log.info('%s already downloaded, with correct checksum', f)\n                    do_download = False\n                else:\n                    log.warning('%s checksums do not match, got %s, should be %s',\n                                f, checksum, CHECKSUMS[f])\n                    try:\n                        os.remove(fname)\n                    except OSError as e:\n                        log.error(\"Could not delete broken part %s: %s\", f, e)\n                        raise ValueError\n\n            if do_download:\n                log.info('Downloading %s', f)\n                wget.download(url, fname)\n\n            checksum = md5sum(fname)\n\n            if checksum == CHECKSUMS[f]:\n                log.info(\"%s MD5 correct\", f)\n            else:\n                log.error('%s checksums do not match, got %s, should be %s. Downloading failed',\n                          f, checksum, CHECKSUMS[f])\n\n        log.info(\"Joining zip files\")\n        cmd = \" \".join([\"zip\", \"-FF\", fullzip, \"--out\", joinedzip])\n        subprocess.check_call(cmd, shell=True)\n\n        # Cleanup the parts\n        for f in files:\n            fname = os.path.join(dst, f)\n            try:\n                os.remove(fname)\n            except OSError as e:\n                log.warning(\"Could not delete file %s\", f)\n\n    # Extract\n    wd = os.path.abspath(os.curdir)\n    os.chdir(dst)\n    log.info(\"Extracting files from %s\", joinedzip)\n    cmd = \" \".join([\"unzip\", joinedzip])\n    subprocess.check_call(cmd, shell=True)\n\n    try:\n        os.remove(joinedzip)\n    except OSError as e:\n        log.warning(\"Could not delete file %s\", f)\n\n    log.info(\"Moving subfolders\")\n    for k in [\"train\", \"test\", \"val\"]:\n        shutil.move(os.path.join(dst, \"images\", k), os.path.join(dst, k))\n    images = os.path.join(dst, \"images\")\n    log.info(\"removing '%s' folder\", images)\n    shutil.rmtree(images)\n\n\ndef md5sum(filename, blocksize=65536):\n    hash = hashlib.md5()\n    with open(filename, \"rb\") as f:\n        for block in iter(lambda: f.read(blocksize), b\"\"):\n            hash.update(block)\n    return hash.hexdigest()\n"
  },
  {
    "path": "demosaicnet/modules.py",
    "content": "\"\"\"Models for [Gharbi2016] Deep Joint demosaicking and denoising.\"\"\"\nimport os\nfrom collections import OrderedDict\nfrom pkg_resources import resource_filename\n\nimport numpy as np\nimport torch as th\nimport torch.nn as nn\n\n\n__all__ = [\"BayerDemosaick\", \"XTransDemosaick\"]\n\n\n_BAYER_WEIGHTS = resource_filename(__name__, 'data/bayer.pth')\n_XTRANS_WEIGHTS = resource_filename(__name__, 'data/xtrans.pth')\n\n\nclass BayerDemosaick(nn.Module):\n  \"\"\"Released version of the network, best quality.\n\n  This model differs from the published description. It has a mask/filter split\n  towards the end of the processing. Masks and filters are multiplied with each\n  other. This is not key to performance and can be ignored when training new\n  models from scratch.\n  \"\"\"\n  def __init__(self, depth=15, width=64, pretrained=True, pad=False):\n    super(BayerDemosaick, self).__init__()\n\n    self.depth = depth\n    self.width = width\n\n    if pad:\n      pad = 1\n    else:\n      pad = 0\n\n    layers = OrderedDict([\n        (\"pack_mosaic\", nn.Conv2d(3, 4, 2, stride=2)),  # Downsample 2x2 to re-establish translation invariance\n      ])\n    for i in range(depth):\n      n_out = width\n      n_in = width\n      if i == 0:\n        n_in = 4\n      if i == depth-1:\n        n_out = 2*width\n      layers[\"conv{}\".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)\n      layers[\"relu{}\".format(i+1)] = nn.ReLU(inplace=True)\n\n    self.main_processor = nn.Sequential(layers)\n    self.residual_predictor = nn.Conv2d(width, 12, 1)\n    self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3)\n\n    self.fullres_processor = nn.Sequential(OrderedDict([\n      (\"post_conv\", nn.Conv2d(6, width, 3, padding=pad)),\n      (\"post_relu\", nn.ReLU(inplace=True)),\n      (\"output\", nn.Conv2d(width, 3, 1)),\n      ]))\n\n    # Load weights\n    if pretrained:\n      assert depth == 15, \"pretrained bayer model has depth=15.\"\n      assert width == 64, \"pretrained bayer model has width=64.\"\n      state_dict = th.load(_BAYER_WEIGHTS)\n      self.load_state_dict(state_dict)\n\n  def forward(self, mosaic):\n    \"\"\"Demosaicks a Bayer image.\n\n    Args:\n      mosaic (th.Tensor):  input Bayer mosaic\n\n    Returns:\n      th.Tensor: the demosaicked image\n    \"\"\"\n\n    # 1/4 resolution features\n    features = self.main_processor(mosaic)\n    filters, masks = features[:, 0:self.width], features[:, self.width:2*self.width]\n    filtered = filters * masks\n    residual = self.residual_predictor(filtered)\n\n    # Match mosaic and residual\n    upsampled = self.upsampler(residual)\n    cropped = _crop_like(mosaic, upsampled)\n\n    packed = th.cat([cropped, upsampled], 1)  # skip connection\n    output = self.fullres_processor(packed)\n    return output\n\n\nclass XTransDemosaick(nn.Module):\n  \"\"\"Released version of the network.\n\n  There is no downsampling here.\n\n  \"\"\"\n  def __init__(self, depth=11, width=64, pretrained=True, pad=False):\n    super(XTransDemosaick, self).__init__()\n\n    self.depth = depth\n    self.width = width\n\n    if pad:\n      pad = 1\n    else:\n      pad = 0\n\n    layers = OrderedDict([])\n    for i in range(depth):\n      n_in = width\n      n_out = width\n      if i == 0:\n        n_in = 3\n      layers[\"conv{}\".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)\n      layers[\"relu{}\".format(i+1)] = nn.ReLU(inplace=True)\n\n    self.main_processor = nn.Sequential(layers)\n\n    self.fullres_processor = nn.Sequential(OrderedDict([\n      (\"post_conv\", nn.Conv2d(3+width, width, 3, padding=pad)),\n      (\"post_relu\", nn.ReLU(inplace=True)),\n      (\"output\", nn.Conv2d(width, 3, 1)),\n      ]))\n\n    # Load weights\n    if pretrained:\n      assert depth == 11, \"pretrained xtrans model has depth=11.\"\n      assert width == 64, \"pretrained xtrans model has width=64.\"\n      state_dict = th.load(_XTRANS_WEIGHTS)\n      self.load_state_dict(state_dict)\n\n\n  def forward(self, mosaic):\n    \"\"\"Demosaicks an XTrans image.\n\n    Args:\n      mosaic (th.Tensor):  input XTrans mosaic\n\n    Returns:\n      th.Tensor: the demosaicked image\n    \"\"\"\n\n    features = self.main_processor(mosaic)\n    cropped = _crop_like(mosaic, features)  # Match mosaic and residual\n    packed = th.cat([cropped, features], 1)  # skip connection\n    output = self.fullres_processor(packed)\n    return output\n\n\ndef _crop_like(src, tgt):\n    \"\"\"Crop a source image to match the spatial dimensions of a target.\n\n    Args:\n        src (th.Tensor or np.ndarray): image to be cropped\n        tgt (th.Tensor or np.ndarray): reference image\n    \"\"\"\n    src_sz = np.array(src.shape)\n    tgt_sz = np.array(tgt.shape)\n\n    # Assumes the spatial dimensions are the last two\n    crop = (src_sz[-2:]-tgt_sz[-2:])\n    crop_t = crop[0] // 2\n    crop_b = crop[0] - crop_t\n    crop_l = crop[1] // 2\n    crop_r = crop[1] - crop_l\n    crop //= 2\n    if (np.array([crop_t, crop_b, crop_r, crop_l])> 0).any():\n        return src[..., crop_t:src_sz[-2]-crop_b, crop_l:src_sz[-1]-crop_r]\n    else:\n        return src\n\n"
  },
  {
    "path": "demosaicnet/mosaic.py",
    "content": "\"\"\"Utilities to make a mosaic mask and apply it to an image.\"\"\"\nimport numpy as np\nimport torch as th\n\n\n__all__ = [\"bayer\", \"xtrans\"]\n\n\ndef bayer(im, return_mask=False):\n  \"\"\"Bayer mosaic.\n\n  The patterned assumed is::\n\n    G r\n    b G\n\n  Args:\n    im (np.array): image to mosaic. Dimensions are [c, h, w]\n    return_mask (bool): if true return the binary mosaic mask, instead of the mosaic image.\n\n  Returns:\n    np.array: mosaicked image (if return_mask==False), or binary mask if (return_mask==True)\n  \"\"\"\n\n  numpy = False\n  if type(im) == np.ndarray:\n    numpy = True\n\n  if type(im) == np.ndarray:\n    mask = np.ones_like(im)\n  else:\n    mask = th.ones_like(im)\n\n  # red\n  mask[..., 0, ::2, 0::2] = 0\n  mask[..., 0, 1::2, :] = 0\n\n  # green\n  mask[..., 1, ::2, 1::2] = 0\n  mask[..., 1, 1::2, ::2] = 0\n\n  # blue\n  mask[..., 2, 0::2, :] = 0\n  mask[..., 2, 1::2, 1::2] = 0\n\n  if not numpy:  # make it a constant for ONNX conversion\n    mask = th.from_numpy(mask.cpu().detach().numpy()).to(im.device)\n\n  if mask.shape[0] == 1:\n    mask = mask.squeeze(0) # coreml hack\n\n  if return_mask:\n    return mask\n\n  return im*mask\n\n\ndef xtrans_cell(torch=False):\n  g_pos = [(0,0),        (0,2), (0,3),        (0,5),\n                  (1,1),               (1,4),\n           (2,0),        (2,2), (2,3),        (2,5),\n           (3,0),        (3,2), (3,3),        (3,5),\n                  (4,1),               (4,4),\n           (5,0),        (5,2), (5,3),        (5,5)]\n  r_pos = [(0,4),\n           (1,0), (1,2),\n           (2,4),\n           (3,1),\n           (4,3), (4,5),\n           (5,1)]\n  b_pos = [(0,1),\n           (1,3), (1,5),\n           (2,1),\n           (3,4),\n           (4,0), (4,2),\n           (5,4)]\n\n  if torch:\n    mask = th.zeros(3, 6, 6)\n  else:\n    mask = np.zeros((3, 6, 6), dtype=np.float32)\n\n  for idx, coord in enumerate([r_pos, g_pos, b_pos]):\n    for y, x in coord:\n      mask[..., idx, y, x] = 1\n\n  return mask\n\ndef xtrans(im, return_mask=False):\n  \"\"\"XTrans Mosaick.\n\n   The patterned assumed is::\n\n     G b G G r G\n     r G r b G b\n     G b G G r G\n     G r G G b G\n     b G b r G r\n     G r G G b G\n\n  Args:\n    im(np.array, th.Tensor): image to mosaic. Dimensions are [c, h, w]\n    mask(bool): if true return the binary mosaic mask, instead of the mosaic image.\n\n  Returns:\n    np.array: mosaicked image (if mask==False), or binary mask if (mask==True)\n  \"\"\"\n\n  numpy = False\n  if type(im) == np.ndarray:\n    numpy = True\n    mask = xtrans_cell(torch=False)\n    # mask = np.zeros((3, 6, 6), dtype=np.float32)\n  else:\n    # mask = th.zeros(3, 6, 6).to(im.device)\n    mask = xtrans_cell(torch=True).to(im.device)\n    if len(im.shape) == 4:\n      mask = mask.unsqueeze(0).repeat(im.shape[0], 1, 1, 1)\n\n  h, w = im.shape[-2:]\n  h = int(h)\n  w = int(w)\n\n  new_sz = [np.ceil(h / 6).astype(np.int32), np.ceil(w / 6).astype(np.int32)]\n\n  sz = np.array(mask.shape)\n  sz[:-2] = 1\n  sz[-2:] = new_sz\n  sz = list(sz)\n\n  if numpy:\n    mask = np.tile(mask, sz)\n  else:\n    mask = mask.repeat(*sz)\n\n  if return_mask:\n    return mask\n\n  return mask*im\n"
  },
  {
    "path": "demosaicnet/utils.py",
    "content": "\"\"\"Helper functions.\"\"\"\n\nfrom abc import ABCMeta, abstractmethod\nimport argparse\nimport logging\nimport os\nimport re\nimport signal\nimport time\n\nimport torch as th\nimport numpy as np\nimport torch as th\nfrom tqdm import tqdm\n\n\nlog = logging.getLogger(__name__)\n\n\ndef crop_like(src, tgt):\n    \"\"\"Crop a source image to match the spatial dimensions of a target.\n\n    Assumes sizes are even.\n\n    Args:\n        src (th.Tensor or np.ndarray): image to be cropped\n        tgt (th.Tensor or np.ndarray): reference image\n    \"\"\"\n    src_sz = np.array(src.shape)\n    tgt_sz = np.array(tgt.shape)\n\n    # Assumes the spatial dimensions are the last two\n    delta = (src_sz[2:4]-tgt_sz[2:4])\n    crop = np.maximum(delta // 2, 0)  # no negative crop\n    crop2 = delta - crop\n\n    if (crop > 0).any() or (crop2 > 0).any():\n        # NOTE: convert to ints to enable static slicing in ONNX conversion\n        src_sz = [int(x) for x in src_sz]\n        crop = [int(x) for x in crop]\n        crop2 = [int(x) for x in crop2]\n        return src[..., crop[0]:src_sz[-2]-crop2[0],\n                   crop[1]:src_sz[-1]-crop2[1]]\n    else:\n        return src\n\n\nclass ExponentialMovingAverage(object):\n    \"\"\"Keyed tracker that maintains an exponential moving average for each key.\n\n    Args:\n      keys(list of str): keys to track.\n      alpha(float): exponential smoothing factor (higher = smoother).\n    \"\"\"\n\n    def __init__(self, keys, alpha=0.999):\n        self._is_first_update = {k: True for k in keys}\n        self._alpha = alpha\n        self._values = {k: 0 for k in keys}\n\n    def __getitem__(self, key):\n        return self._values[key]\n\n    def update(self, key, value):\n        if value is None:\n            return\n        if self._is_first_update[key]:\n            self._values[key] = value\n            self._is_first_update[key] = False\n        else:\n            self._values[key] = self._values[key] * \\\n                self._alpha + value*(1.0-self._alpha)\n\n\nclass BasicArgumentParser(argparse.ArgumentParser):\n    \"\"\"A basic argument parser with commonly used training options.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(BasicArgumentParser, self).__init__(*args, **kwargs)\n\n        self.add_argument(\"--data\", required=True, help=\"path to the training data.\")\n        self.add_argument(\"--val_data\", help=\"path to the validation data.\")\n        self.add_argument(\"--config\", help=\"path to a config file.\")\n        self.add_argument(\"--checkpoint_dir\", required=True,\n                          help=\"Output directory where checkpoints are saved\")\n        self.add_argument(\"--init_from\", help=\"path to a checkpoint from which to try and initialize the weights.\")\n\n        self.add_argument(\"--lr\", type=float, default=1e-4,\n                          help=\"Learning rate for the optimizer\")\n        self.add_argument(\"--bs\", type=int, default=4, help=\"Batch size\")\n        self.add_argument(\"--num_epochs\", type=int,\n                          help=\"Number of epochs to train for\")\n        self.add_argument(\"--num_worker_threads\", default=4, type=int,\n                          help=\"Number of threads that load data\")\n\n        # self.add_argument(\"--experiment_log\",\n        #                   help=\"csv file in which we log our experiments\")\n\n        self.add_argument(\"--cuda\", action=\"store_true\",\n                          dest=\"cuda\", help=\"Force GPU\")\n        self.add_argument(\"--no-cuda\", action=\"store_false\",\n                          dest=\"cuda\", help=\"Force CPU\")\n\n        self.add_argument(\"--server\", help=\"Visdom server url\")\n        self.add_argument(\"--base_url\", default=\"/\", help=\"Visdom base url\")\n        self.add_argument(\"--env\", default=\"main\", help=\"Visdom environment\")\n        self.add_argument(\"--port\", default=8097, type=int,\n                          help=\"Visdom server port\")\n\n        self.add_argument('--debug', dest=\"debug\", action=\"store_true\")\n\n        self.set_defaults(cuda=th.cuda.is_available(), debug=False)\n\n\nclass ModelInterface(metaclass=ABCMeta):\n    \"\"\"An adapter to run or train a model.\"\"\"\n\n    def __init__(self):\n        pass\n\n    @abstractmethod\n    def training_step(self, batch):\n        \"\"\"Training step given a batch of data.\n\n        This should implement a forward pass of the model, compute gradients,\n        take an optimizer step and return useful metrics and tensors for\n        visualization and training callbacks. \n\n        Args:\n          batch (dict): batch of data provided by a data pipeline.\n\n        Returns:\n          train_step_data (dict): a dictionary of outputs.\n        \"\"\"\n        return {}\n\n    def init_validation(self):\n        \"\"\"Initializes the quantities to be reported during validation.\n\n        The default implementation is a no-op\n\n        Returns:\n          data (dict): initialized values\n        \"\"\"\n        log.warning(\"Running a ModelInterface validation initialization that was not overriden: this is a no-op.\")\n        data = {}\n        return data\n\n    def validation_step(self, batch, running_val_data):\n        \"\"\"Updates the running validataion with the current batch's results.\n\n        The default implementation is a no-op\n\n        Args:\n          batch (dict): batch of data provided by a data pipeline.\n          running_val_data (dict): current aggregates of the validation loop.\n\n        Returns:\n          updated_data (dict): new updated value for the running_val_data.\n        \"\"\"\n        log.warning(\"Running a ModelInterface validation step that was not overriden: this is a no-op.\")\n        return {}\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n\nclass Checkpointer(object):\n    \"\"\"Save and restore model and optimizer variables.\n\n    Args:\n      root (string): path to the root directory where the files are stored.\n      model (torch.nn.Module):\n      meta (dict): a dictionary of training or configuration parameters useful\n          to initialize the model upon loading the checkpoint again.\n      optimizers (single or list of torch.optimizer): optimizers whose parameters will\n        be checkpointed together with the model.\n      schedulers (single or list of\n      torch.optim.lr_scheduler): schedulers whose\n          parameters will be checkpointed together with\n          the model.\n      prefix (str): unique prefix name in case several models are stored in the\n        same folder.\n    \"\"\"\n\n    EXTENSION = \".pth\"\n\n    def __init__(self, root, model=None, meta=None, optimizers=None,\n                 schedulers=None, prefix=None):\n        self.root = root\n        self.model = model\n        self.meta = meta\n\n        # TODO(mgharbi): verify the prefixes are unique.\n\n        if optimizers is None:\n            log.info(\"No optimizer state will be stored in the \"\n                        \"checkpointer\")\n        else:\n            # if we have only one optimizer, make it a list\n            if not isinstance(optimizers, list):\n                optimizers = [optimizers]\n        self.optimizers = optimizers\n        if schedulers is not None:\n            if not isinstance(schedulers, list):\n                schedulers = [schedulers]\n        self.schedulers = schedulers\n\n        log.debug(self)\n\n        self.prefix = \"\"\n        if prefix is not None:\n            self.prefix = prefix\n\n    def __repr__(self):\n        return \"Checkpointer with root at \\\"{}\\\"\".format(self.root)\n\n    def __path(self, path, prefix=None):\n        if prefix is None:\n            prefix = \"\"\n        return os.path.join(self.root, prefix+os.path.splitext(path)[0] + \".pth\")\n\n    def save(self, path, extras=None):\n        \"\"\"Save model, metaparams and extras to relative path.\n\n        Args:\n          path (string): relative path to the file being saved (without extension).\n          extras (dict): extra user-provided information to be saved with the model.\n        \"\"\"\n\n        if self.model is None:\n            model_state = None\n        else:\n            log.debug(\"Saving model state dict\")\n            model_state = self.model.state_dict()\n\n        opt_dicts = []\n        if self.optimizers is not None:\n            for opt in self.optimizers:\n                opt_dicts.append(opt.state_dict())\n\n        sched_dicts = []\n        if self.schedulers is not None:\n            for s in self.schedulers:\n                sched_dicts.append(s.state_dict())\n\n        filename = self.__path(path, prefix=self.prefix)\n        os.makedirs(self.root, exist_ok=True)\n        th.save({'model': model_state,\n                 'meta': self.meta,\n                 'extras': extras,\n                 'optimizers': opt_dicts,\n                 'schedulers': sched_dicts,\n                 }, filename)\n        log.debug(\"Checkpoint saved to \\\"{}\\\"\".format(filename))\n\n    def try_and_init_from(self, path):\n        \"\"\"Try to initialize the models's weights from an external checkpoint.\n\n        Args:\n            path(str): full path to the checkpoints to load model parameters\n                from.\n        \"\"\"\n        log.info(\"Loading weights from foreign checkpoint {}\".format(path))\n        if not os.path.exists(path):\n            raise ValueError(\"Checkpoint {} does not exist\".format(path))\n\n        chkpt = th.load(path, map_location=th.device(\"cpu\"))\n        if \"model\" not in chkpt.keys() or chkpt[\"model\"] is None:\n            raise ValueError(\"{} has no model saved\".format(path))\n\n        mdl = chkpt[\"model\"]\n        for n, p in self.model.named_parameters():\n            if n in mdl:\n                p2 = mdl[n]\n                if p2.shape != p.shape:\n                    log.warning(\"Parameter {} ignored, checkpoint size does not match: {}, should be {}\".format(n, p2.shape, p.shape))\n                    continue\n                log.debug(\"Parameter {} copied\".format(n))\n                p.data.copy_(p2)\n            else:\n                log.warning(\"Parameter {} ignored, not found in source checkpoint.\".format(n))\n\n        log.info(\"Weights loaded from foreign checkpoint {}\".format(path))\n\n    def load(self, path):\n        \"\"\"Loads a checkpoint, updates the model and returns extra data.\n\n        Args:\n          path (string): path to the checkpoint file, relative to the root dir.\n\n        Returns:\n          extras (dict): extra information passed by the user at save time.\n          meta (dict): metaparameters of the model passed at save time.\n        \"\"\"\n\n        filename = self.__path(path, prefix=None)\n        chkpt = th.load(filename, map_location=\"cpu\")  # TODO: check behavior\n\n        if self.model is not None and chkpt[\"model\"] is not None:\n            log.debug(\"Loading model state dict\")\n            self.model.load_state_dict(chkpt[\"model\"])\n\n        if \"optimizers\" in chkpt.keys():\n            if self.optimizers is not None and chkpt[\"optimizers\"] is not None:\n                try:\n                    for opt, state in zip(self.optimizers,\n                                          chkpt[\"optimizers\"]):\n                        log.debug(\"Loading optimizers state dict for %s\", opt)\n                        opt.load_state_dict(state)\n                except:\n                    # We do not raise an error here, e.g. in case the user simply\n                    # changes optimizer\n                    log.warning(\"Could not load optimizer state dicts, \"\n                                \"starting from scratch\")\n\n        if \"schedulers\" in chkpt.keys():\n            if self.schedulers is not None and chkpt[\"schedulers\"] is not None:\n                try:\n                    for s, state in zip(self.schedulers,\n                                          chkpt[\"schedulers\"]):\n                        log.debug(\"Loading scheduler state dict for %s\", s)\n                        s.load_state_dict(state)\n                except:\n                    log.warning(\"Could not load scheduler state dicts, \"\n                                \"starting from scratch\")\n\n        log.debug(\"Loaded checkpoint \\\"{}\\\"\".format(filename))\n        return tuple(chkpt[k] for k in [\"extras\", \"meta\"])\n\n    def load_latest(self):\n        \"\"\"Try to load the most recent checkpoint, skip failing files.\n\n        Returns:\n          extras (dict): extra user-defined information that was saved in the\n              checkpoint.\n          meta (dict): metaparameters of the model passed at save time.\n        \"\"\"\n        all_checkpoints = self.sorted_checkpoints()\n\n        extras = None\n        meta = None\n        for f in all_checkpoints:\n            try:\n                extras, meta = self.load(f)\n                return extras, meta\n            except Exception as e:\n                log.debug(\n                    \"Could not load checkpoint \\\"{}\\\", moving on ({}).\".format(f, e))\n        log.debug(\"No checkpoint found to load.\")\n        return extras, meta\n\n    def sorted_checkpoints(self):\n        \"\"\"Get list of all checkpoints in root directory, sorted by creation date.\n\n        Returns:\n            chkpts (list of str): sorted checkpoints in the root folder.\n        \"\"\"\n        reg = re.compile(r\"{}.*\\{}\".format(self.prefix, Checkpointer.EXTENSION))\n        if not os.path.exists(self.root):\n            all_checkpoints = []\n        else:\n            all_checkpoints = [f for f in os.listdir(\n                self.root) if reg.match(f)]\n        mtimes = []\n        for f in all_checkpoints:\n            mtimes.append(os.path.getmtime(os.path.join(self.root, f)))\n\n        mf = sorted(zip(mtimes, all_checkpoints))\n        chkpts = [m[1] for m in reversed(mf)]\n        log.debug(\"Sorted checkpoints {}\".format(chkpts))\n        return chkpts\n\n    def delete(self, path):\n        \"\"\"Delete checkpoint at path.\n\n        Args:\n            path(str): full path to the checkpoint to delete.\n        \"\"\"\n        if path in self.sorted_checkpoints():\n            os.remove(os.path.join(self.root, path))\n        else:\n            log.warning(\"Trying to delete a checkpoint that does not exists.\")\n\n    @staticmethod\n    def load_meta(root, prefix=None):\n        \"\"\"Fetch model metadata without touching the saved parameters.\n\n        This loads the metadata from the most recent checkpoint in the root\n        directory.\n\n        Args:\n            root(str): path to the root directory containing the checkpoints\n            prefix(str): unique prefix for the checkpoint to be loaded (e.g. if\n                multiple models are saved in the same location)\n        \"\"\"\n        chkptr = Checkpointer(root, model=None, meta=None, prefix=prefix, \n                              optimizers=[])\n        log.debug(\"checkpoints: %s\", chkptr.sorted_checkpoints())\n        _, meta = chkptr.load_latest()\n        return meta\n\n\nclass Trainer(object):\n    \"\"\"Implements a simple training loop with hooks for callbacks.\n\n    Args:\n      interface (ModelInterface): adapter to run forward and backward\n        pass on the model being trained.\n\n    Attributes:\n      callbacks (list of Callbacks): hooks that will be called while training\n        progresses.\n    \"\"\"\n\n    def __init__(self, interface):\n        super(Trainer, self).__init__()\n        self.callbacks = []\n        self.interface = interface\n        log.debug(\"Creating {}\".format(self))\n\n        signal.signal(signal.SIGINT, self.interrupt_handler)\n\n        self._keep_running = True\n\n    def interrupt_handler(self, signo, frame):\n        \"\"\"Stop the training process upon receiving a SIGINT (Ctrl+C).\"\"\"\n        log.debug(\"interrupting run\")\n        self._keep_running = False\n\n    def _stop(self):\n        # Reset the run flag\n        self._keep_running = True\n        self.__training_end()\n\n    def add_callback(self, callback):\n        \"\"\"Adds a callback to the list of training hooks.\n\n        Args:\n            callback(ttools.Callback): callback to add.\n        \"\"\"\n        log.debug(\"Adding callback {}\".format(callback))\n        # pass an interface reference to the callback\n        callback.model_interface = self.interface\n        self.callbacks.append(callback)\n\n    def train(self, dataloader, starting_epoch=None, num_epochs=None,\n              val_dataloader=None):\n        \"\"\"Main training loop. This starts the training procedure.\n\n        Args:\n          dataloader (DataLoader): loader that yields training batches.\n          starting_epoch (int, optional): index of the epoch we are starting from.\n          num_epochs (int, optional): max number of epochs to run.\n          val_dataloader (DataLoader, optional): loader that yields validation\n            batches\n        \"\"\"\n        self.__training_start(dataloader)\n        if starting_epoch is None:\n            starting_epoch = 0\n\n        log.info(\"Starting taining from epoch %d\", starting_epoch)\n\n        epoch = starting_epoch\n\n        while num_epochs is None or epoch < starting_epoch + num_epochs:\n            self.__epoch_start(epoch)\n            for batch_idx, batch in enumerate(dataloader):\n                if not self._keep_running:\n                    self._stop()\n                    return\n                self.__batch_start(batch_idx, batch)\n                train_step_data = self.__training_step(batch)\n                self.__batch_end(batch, train_step_data)\n            self.__epoch_end()\n\n            # TODO: allow validation at intermediate steps during one epoch\n\n            # Validate\n            if val_dataloader:\n                with th.no_grad():\n                    running_val_data = self.__validation_start(val_dataloader)\n                    for batch_idx, batch in enumerate(val_dataloader):\n                        if not self._keep_running:\n                            self._stop()\n                            return\n                        self.__val_batch_start(batch_idx, batch)\n                        running_val_data = self.__validation_step(batch, running_val_data)\n                        self.__val_batch_end(batch, running_val_data)\n                    self.__validation_end(running_val_data)\n\n            epoch += 1\n\n            if not self._keep_running:\n                self._stop()\n                return\n\n        self._stop()\n\n    def __repr__(self):\n        return \"Trainer({}, {} callbacks)\".format(\n            self.interface, len(self.callbacks))\n\n    def __training_start(self, dataloader):\n        for cb in self.callbacks:\n            cb.training_start(dataloader)\n\n    def __training_end(self):\n        for cb in self.callbacks:\n            cb.training_end()\n\n    def __epoch_start(self, epoch_idx):\n        for cb in self.callbacks:\n            cb.epoch_start(epoch_idx)\n\n    def __epoch_end(self):\n        for cb in self.callbacks:\n            cb.epoch_end()\n\n    def __batch_start(self, batch_idx, batch):\n        for cb in self.callbacks:\n            cb.batch_start(batch_idx, batch)\n\n    def __batch_end(self, batch, train_step_data):\n        for cb in self.callbacks:\n            cb.batch_end(batch, train_step_data)\n\n    def __val_batch_start(self, batch_idx, batch):\n        for cb in self.callbacks:\n            cb.val_batch_start(batch_idx, batch)\n\n    def __val_batch_end(self, batch, running_val_data):\n        for cb in self.callbacks:\n            cb.val_batch_end(batch, running_val_data)\n\n    def __validation_start(self, dataloader):\n        for cb in self.callbacks:\n            cb.validation_start(dataloader)\n        return self.interface.init_validation()\n\n    def __validation_end(self, running_val_data):\n        for cb in self.callbacks:\n            cb.validation_end(running_val_data)\n\n    def __training_step(self, batch):\n        return self.interface.training_step(batch)\n\n    def __validation_step(self, batch, running_val_data):\n        return self.interface.validation_step(batch, running_val_data)\n\n\nclass Callback(object):\n    \"\"\"Base class for all training callbacks.\n\n    Attributes:\n        epoch(int): current epoch index.\n        batch(int): current batch index.\n        datasize(int): number of batches in the training dataset.\n        val_datasize(int): number of batches in the validation dataset.\n        model_interface(ttools.ModelInterface): parent interface driving the training.\n    \"\"\"\n\n    def __repr__(self):\n        return self.__class__.__name__\n\n    def __init__(self):\n        super(Callback, self).__init__()\n        self.epoch = 0\n        self.batch = 0\n        self.val_batch = 0\n        self.datasize = 0\n        self.val_datasize = 0\n        self.model_interface = None\n\n    def training_start(self, dataloader):\n        \"\"\"Hook to execute code when the training begins.\n\n        Args:\n            dataloader(th.utils.data.Dataloader): a data loading class that\n            provides batches of data for training.\n        \"\"\"\n        self.datasize = len(dataloader)\n\n    def training_end(self):\n        \"\"\"Hook to execute code when the training ends.\"\"\"\n        pass\n\n    def epoch_start(self, epoch):\n        \"\"\"Hook to execute code when a new epoch starts.\n\n        Args:\n            epoch(int): index of the current epoch.\n\n        Note: self.epoch is never incremented. Instead, it should be set by the\n        caller.\n        \"\"\"\n        self.epoch = epoch\n\n    def epoch_end(self):\n        \"\"\"Hook to execute code when an epoch ends.\n\n        NOTE: self.epoch is not incremented. Instead it is set externally in\n        the `epoch_start` method.\n        \"\"\"\n        pass\n\n    def validation_start(self, dataloader):\n        \"\"\"Hook to execute code when a validation run starts.\n\n        Args:\n            dataloader(th.utils.data.Dataloader): a data loading class that\n            provides batches of data for evaluation.\n        \"\"\"\n        self.val_datasize = len(dataloader)\n\n    def validation_end(self, val_data):\n        \"\"\"Hook to execute code when a validation run ends.\"\"\"\n        pass\n\n    def batch_start(self, batch_idx, batch_data):\n        \"\"\"Hook to execute code when a training step starts.\n\n        Args:\n            batch_idx(int): index of the current batch.\n            batch_data: a Tensor, tuple of dict with the current batch of data.\n        \"\"\"\n        self.batch = batch_idx\n\n    def batch_end(self, batch_data, train_step_data):\n        \"\"\"Hook to execute code when a training step ends.\n\n        Args:\n            batch_data: a Tensor, tuple of dict with the current batch of data.\n            train_setp_data(dict): outputs from the `training_step` of a\n                ModelInterface.\n        \"\"\"\n        pass\n\n    def val_batch_start(self, batch_idx, batch_data):\n        \"\"\"Hook to execute code when a validation step starts.\n\n        Args:\n            batch_idx(int): index of the current batch.\n            batch_data: a Tensor, tuple of dict with the current batch of data.\n        \"\"\"\n        self.val_batch = batch_idx\n\n    def val_batch_end(self, batch_data, running_val_data):\n        \"\"\"Hook to execute code when a validation step ends.\n\n        Args:\n            batch_data: a Tensor, tuple of dict with the current batch of data.\n            train_setp_data(dict): running outputs produced by the `validation_step` of a\n                ModelInterface.\n        \"\"\"\n        pass\n\nclass CheckpointingCallback(Callback):\n    \"\"\"A callback that periodically saves model checkpoints to disk.\n\n    Args:\n      checkpointer (Checkpointer): actual checkpointer responsible for the I/O.\n      interval (int, optional): minimum time in seconds between periodic\n          checkpoints (within an epoch). There is not periodic checkpoint if\n          this value is None.\n      max_files (int, optional): maximum number of periodic checkpoints to keep\n          on disk.\n      max_epochs (int, optional): maximum number of epoch checkpoints to keep\n          on disk.\n    \"\"\"\n\n    PERIODIC_PREFIX = \"periodic_\"\n    EPOCH_PREFIX = \"epoch_\"\n\n    def __init__(self, checkpointer, interval=600,\n                 max_files=5, max_epochs=10):\n        super(CheckpointingCallback, self).__init__()\n        self.checkpointer = checkpointer\n        self.interval = interval\n        self.max_files = max_files\n        self.max_epochs = max_epochs\n\n        self.last_checkpoint_time = time.time()\n\n    def epoch_end(self):\n        \"\"\"Save a checkpoint at the end of each epoch.\"\"\"\n        super(CheckpointingCallback, self).epoch_end()\n        path = \"{}{}\".format(CheckpointingCallback.EPOCH_PREFIX, self.epoch)\n        self.checkpointer.save(path, extras={\"epoch\": self.epoch + 1})\n        self.__purge_old_files()\n\n    def training_end(self):\n        super(CheckpointingCallback, self).training_end()\n        self.checkpointer.save(\"training_end\", extras={\"epoch\": self.epoch + 1})\n\n    def batch_end(self, batch_data, train_step_data):\n        \"\"\"Save a periodic checkpoint if requested.\"\"\"\n\n        super(CheckpointingCallback, self).batch_end(\n            batch_data, train_step_data)\n\n        if self.interval is None:  # We skip periodic checkpoints\n            return\n\n        now = time.time()\n\n        delta = now - self.last_checkpoint_time\n\n        if delta < self.interval:  # last checkpoint is too recent\n            return\n\n        log.debug(\"Periodic checkpoint\")\n        self.last_checkpoint_time = now\n\n        filename = \"{}{}\".format(CheckpointingCallback.PERIODIC_PREFIX,\n                                   time.strftime(\"%Y-%m-%d_%H-%M-%S\", time.localtime()))\n        self.checkpointer.save(filename, extras={\"epoch\": self.epoch})\n        self.__purge_old_files()\n\n    def __purge_old_files(self):\n        \"\"\"Delete checkpoints that are beyond the max to keep.\"\"\"\n\n        chkpts = self.checkpointer.sorted_checkpoints()\n        p_chkpts = []\n        e_chkpts = []\n        for c in chkpts:\n            if c.startswith(self.checkpointer.prefix + CheckpointingCallback.PERIODIC_PREFIX):\n                p_chkpts.append(c)\n\n            if c.startswith(self.checkpointer.prefix + CheckpointingCallback.EPOCH_PREFIX):\n                e_chkpts.append(c)\n\n        # Delete periodic checkpoints\n        if self.max_files is not None and len(p_chkpts) > self.max_files:\n            for c in p_chkpts[self.max_files:]:\n                log.debug(\"CheckpointingCallback deleting {}\".format(c))\n                self.checkpointer.delete(c)\n\n        # Delete older epochs\n        if self.max_epochs is not None and len(e_chkpts) > self.max_epochs:\n            for c in e_chkpts[self.max_epochs:]:\n                log.debug(\"CheckpointingCallback deleting (epoch) {}\".format(c))\n                self.checkpointer.delete(c)\n\n\nclass KeyedCallback(Callback):\n    \"\"\"An abstract Callback that performs the same action for all keys in a list.\n\n    The keys (resp. val_keys) are used to access the backward_data (resp.\n    validation_data) produced by a ModelInterface.\n\n    Args:\n      keys (list of str or None): list of keys whose values will be logged during\n          training.\n      val_keys (list of str or None): list of keys whose values will be logged during\n          validation\n    \"\"\"\n    def __init__(self, keys=None, val_keys=None, smoothing=0.999):\n        super(KeyedCallback, self).__init__()\n        if keys is None and val_keys is None:\n            log.warning(\"Logger has no keys, nor val_keys\")\n\n        if keys is None:\n            self.keys = []\n        else:\n            self.keys = keys\n\n        if val_keys is None:\n            self.val_keys = []\n        else:\n            self.val_keys = val_keys\n\n        # Only smooth the training keys\n        self.ema = ExponentialMovingAverage(self.keys, alpha=smoothing)\n\n    def batch_end(self, batch_data, train_step_data):\n        for k in self.keys:\n            self.ema.update(k, train_step_data[k])\n\nclass ProgressBarCallback(KeyedCallback):\n    \"\"\"A progress bar optimization logger.\n\n    Args:\n        label(str): a prefix label to identify the experiment currently\n            running.\n    \"\"\"\n    def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None):\n        super(ProgressBarCallback, self).__init__(\n            keys=keys, val_keys=val_keys, smoothing=smoothing)\n        self.pbar = None\n        if label is None:\n            self.label = \"\"\n        else:\n            self.label = label\n\n    def training_start(self, dataloader):\n        super(ProgressBarCallback, self).training_start(dataloader)\n        print(\"Training start\")\n\n    def training_end(self):\n        super(ProgressBarCallback, self).training_end()\n        print(\"Training ends\")\n\n    def epoch_start(self, epoch):\n        super(ProgressBarCallback, self).epoch_start(epoch)\n        desc = \"Epoch {}\".format(self.epoch)\n        if self.label is not None:\n            desc = \"%s | \" % self.label + desc\n        self.pbar = tqdm(total=self.datasize, unit=\" batches\",\n                         desc=desc)\n\n    def epoch_end(self):\n        super(ProgressBarCallback, self).epoch_end()\n        self.pbar.close()\n        self.pbar = None\n\n    def validation_start(self, dataloader):\n        super(ProgressBarCallback, self).validation_start(dataloader)\n        print(\"Running validation...\")\n        self.pbar = tqdm(total=len(dataloader), unit=\" batches\",\n                         desc=\"Validation {}\".format(self.epoch))\n\n    def val_batch_end(self, batch, running_val_data):\n        self.pbar.update(1)\n\n    def validation_end(self, val_data):\n        super(ProgressBarCallback, self).validation_end(val_data)\n        self.pbar.close()\n        self.pbar = None\n        s = \" \"*ProgressBarCallback.TABSTOPS + \"Validation {} | \".format(\n            self.epoch)\n        for k in self.val_keys:\n            s += \"{} = {:.2f} \".format(k, val_data[k])\n        print(s)\n\n    def batch_end(self, batch_data, train_step_data):\n        super(ProgressBarCallback, self).batch_end(batch_data, train_step_data)\n        d = {}\n        for k in self.keys:\n            d[k] = self.ema[k]\n        self.pbar.update(1)\n        self.pbar.set_postfix(d)\n\n    TABSTOPS = 2"
  },
  {
    "path": "demosaicnet/version.py",
    "content": "__version__ = \"0.0.14\"\n"
  },
  {
    "path": "docs/.gitignore",
    "content": "build\n"
  },
  {
    "path": "docs/Makefile",
    "content": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS    =\nSPHINXBUILD   = sphinx-build\nSOURCEDIR     = source\nBUILDDIR      = build\n\n# Put it first so that \"make\" without argument is like \"make help\".\nhelp:\n\t@$(SPHINXBUILD) -M help \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)\n\n.PHONY: help Makefile\n\n# Catch-all target: route all unknown targets to Sphinx using the new\n# \"make mode\" option.  $(O) is meant as a shortcut for $(SPHINXOPTS).\n%: Makefile\n\t@$(SPHINXBUILD) -M $@ \"$(SOURCEDIR)\" \"$(BUILDDIR)\" $(SPHINXOPTS) $(O)"
  },
  {
    "path": "docs/source/conf.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a selection of the most common options. For a\n# full list see the documentation:\n# http://www.sphinx-doc.org/en/master/config\n\n# -- Path setup --------------------------------------------------------------\n\n# If extensions (or modules to document with autodoc) are in another directory,\n# add these directories to sys.path here. If the directory is relative to the\n# documentation root, use os.path.abspath to make it absolute, like shown here.\n#\nimport os\nimport sys\ndirname = os.path.dirname\nrootdir = dirname(dirname(dirname(os.path.abspath(__file__))))\nsys.path.insert(0, rootdir)\n\n\n# -- Project information -----------------------------------------------------\n\nproject = 'demosaicnet'\ncopyright = '2019, Michael Gharbi'\nauthor = 'Michael Gharbi'\n\nimport re\nwith open(os.path.join(rootdir, \"demosaicnet\", \"version.py\")) as fid:\n    try:\n        __version__, = re.findall( '__version__ = \"(.*)\"', fid.read() )\n    except:\n        raise ValueError(\"could not find version number\")\n\n# The full version, including alpha/beta/rc tags\nrelease = __version__\n\n\n# -- General configuration ---------------------------------------------------\n\n# If your documentation needs a minimal Sphinx version, state it here.\nautodoc_mock_imports = [\"torch\", \"numpy\", \"imageio\", \"torchvision\", \"wget\"]\n#\n# needs_sphinx = '1.0'\n\n# Add any Sphinx extension module names here, as strings. They can be\n# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom\n# ones.\nextensions = [\n    'sphinx.ext.autodoc',\n    'sphinx.ext.napoleon',\n    'sphinx.ext.doctest',\n    'sphinx.ext.todo',\n    'sphinx.ext.coverage',\n    'sphinx.ext.mathjax',\n    'sphinx.ext.viewcode',\n]\n\n# Add any paths that contain templates here, relative to this directory.\ntemplates_path = ['ytemplates']\n\n# The suffix(es) of source filenames.\n# You can specify multiple suffix as a list of string:\n#\n# source_suffix = ['.rst', '.md']\nsource_suffix = '.rst'\n\n# The master toctree document.\nmaster_doc = 'index'\n\n# The language for content autogenerated by Sphinx. Refer to documentation\n# for a list of supported languages.\n#\n# This is also used if you do content translation via gettext catalogs.\n# Usually you set \"language\" from the command line for these cases.\nlanguage = None\n\n# List of patterns, relative to source directory, that match files and\n# directories to ignore when looking for source files.\n# This pattern also affects html_static_path and html_extra_path.\nexclude_patterns = []\n\n# The name of the Pygments (syntax highlighting) style to use.\npygments_style = None\n\n\n# -- Options for HTML output -------------------------------------------------\n\n# The theme to use for HTML and HTML Help pages.  See the documentation for\n# a list of builtin themes.\n#\nhtml_theme = 'alabaster'\n\n# Theme options are theme-specific and customize the look and feel of a theme\n# further.  For a list of options available for each theme, see the\n# documentation.\n#\n# html_theme_options = {}\n\n# Add any paths that contain custom static files (such as style sheets) here,\n# relative to this directory. They are copied after the builtin static files,\n# so a file named \"default.css\" will overwrite the builtin \"default.css\".\nhtml_static_path = ['ystatic']\n\n# Custom sidebar templates, must be a dictionary that maps document names\n# to template names.\n#\n# The default sidebars (for documents that don't match any pattern) are\n# defined by theme itself.  Builtin themes are using these templates by\n# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',\n# 'searchbox.html']``.\n#\n# html_sidebars = {}\n\n\n# -- Options for HTMLHelp output ---------------------------------------------\n\n# Output file base name for HTML help builder.\nhtmlhelp_basename = 'demosaicnetdoc'\n\n\n# -- Options for LaTeX output ------------------------------------------------\n\nlatex_elements = {\n    # The paper size ('letterpaper' or 'a4paper').\n    #\n    # 'papersize': 'letterpaper',\n\n    # The font size ('10pt', '11pt' or '12pt').\n    #\n    # 'pointsize': '10pt',\n\n    # Additional stuff for the LaTeX preamble.\n    #\n    # 'preamble': '',\n\n    # Latex figure (float) alignment\n    #\n    # 'figure_align': 'htbp',\n}\n\n# Grouping the document tree into LaTeX files. List of tuples\n# (source start file, target name, title,\n#  author, documentclass [howto, manual, or own class]).\nlatex_documents = [\n    (master_doc, 'demosaicnet.tex', 'demosaicnet Documentation',\n     'Michael Gharbi', 'manual'),\n]\n\n\n# -- Options for manual page output ------------------------------------------\n\n# One entry per manual page. List of tuples\n# (source start file, name, description, authors, manual section).\nman_pages = [\n    (master_doc, 'demosaicnet', 'demosaicnet Documentation',\n     [author], 1)\n]\n\n\n# -- Options for Texinfo output ----------------------------------------------\n\n# Grouping the document tree into Texinfo files. List of tuples\n# (source start file, target name, title, author,\n#  dir menu entry, description, category)\ntexinfo_documents = [\n    (master_doc, 'demosaicnet', 'demosaicnet Documentation',\n     author, 'demosaicnet', 'One line description of project.',\n     'Miscellaneous'),\n]\n\n\n# -- Options for Epub output -------------------------------------------------\n\n# Bibliographic Dublin Core info.\nepub_title = project\n\n# The unique identifier of the text. This can be a ISBN number\n# or the project homepage.\n#\n# epub_identifier = ''\n\n# A unique identification for the text.\n#\n# epub_uid = ''\n\n# A list of files that should not be packed into the epub file.\nepub_exclude_files = ['search.html']\n\n\n# -- Extension configuration -------------------------------------------------\n\n# -- Options for todo extension ----------------------------------------------\n\n# If true, `todo` and `todoList` produce output, else they produce nothing.\ntodo_include_todos = True\n"
  },
  {
    "path": "docs/source/dataset.rst",
    "content": "Dataset\n=======\n\n.. automodule:: demosaicnet.dataset\n   :members:\n"
  },
  {
    "path": "docs/source/helpers.rst",
    "content": "Helpers\n=======\n\n.. automodule:: demosaicnet.mosaic\n   :members:\n"
  },
  {
    "path": "docs/source/index.rst",
    "content": ".. demosaicnet documentation master file, created by\n   sphinx-quickstart on Thu Mar 14 13:14:16 2019.\n   You can adapt this file completely to your liking, but it should at least\n   contain the root `toctree` directive.\n\nWelcome to demosaicnet's documentation!\n=======================================\n\n.. toctree::\n   :maxdepth: 2\n   :caption: Contents:\n\n   models\n   dataset\n   helpers\n\n.. automodule:: demosaicnet\n   :members:\n\n\nIndices and tables\n==================\n\n* :ref:`genindex`\n* :ref:`modindex`\n* :ref:`search`\n"
  },
  {
    "path": "docs/source/models.rst",
    "content": "Models\n======\n\n.. automodule:: demosaicnet.modules\n   :members:\n"
  },
  {
    "path": "requirements.txt",
    "content": "imageio==2.31.1\nnumpy==1.25.0\nsetuptools==49.2.1\ntorch==2.0.1\ntqdm==4.65.0\nwget\n"
  },
  {
    "path": "scripts/demosaicnet_demo.py",
    "content": "#!/usr/bin/env python\n\"\"\"Demo script on using demosaicnet for inference.\"\"\"\n\nimport os\nfrom pkg_resources import resource_filename\n\nimport argparse\nimport numpy as np\nimport torch as th\nimport imageio\n\nimport demosaicnet\n\n_TEST_INPUT = resource_filename(\"demosaicnet\", 'data/test_input.png')\n\ndef main(args):\n  print(\"Running demosaicnet demo on {}, outputing to {}\".format(_TEST_INPUT, args.output))\n  bayer = demosaicnet.BayerDemosaick()\n  xtrans = demosaicnet.XTransDemosaick()\n\n  # Load some ground-truth image\n  gt = imageio.imread(args.input).astype(np.float32) / 255.0\n  gt = np.array(gt)\n\n  h, w, _ = gt.shape\n\n  # Make the image size a multiple of 6 (for xtrans pattern)\n  gt = gt[:6*(h//6), :6*(w//6)]\n\n\n  # Network expects channel first\n  gt = np.transpose(gt, [2, 0, 1])\n  mosaicked = demosaicnet.bayer(gt)\n  xmosaicked = demosaicnet.xtrans(gt)\n\n  # Run the model (expects batch as first dimension)\n  mosaicked = th.from_numpy(mosaicked).unsqueeze(0)\n  xmosaicked = th.from_numpy(xmosaicked).unsqueeze(0)\n  with th.no_grad():  # inference only\n    out = bayer(mosaicked).squeeze(0).cpu().numpy()\n    out = np.clip(out, 0, 1)\n    xout = xtrans(xmosaicked).squeeze(0).cpu().numpy()\n    xout = np.clip(xout, 0, 1)\n  print(\"done\")\n\n  os.makedirs(args.output, exist_ok=True)\n  output = args.output\n\n  imageio.imsave(os.path.join(output, \"bayer_mosaick.tif\"), mosaicked.squeeze(0).permute([1, 2, 0]))\n  imageio.imsave(os.path.join(output, \"bayer_result.tif\"), np.transpose(out, [1, 2, 0]))\n  imageio.imsave(os.path.join(output, \"xtrans_mosaick.tif\"), xmosaicked.squeeze(0).permute([1, 2, 0]))\n  imageio.imsave(os.path.join(output, \"xtrans_result.tif\"), np.transpose(xout, [1, 2, 0]))\n\n  \nif __name__ == \"__main__\":\n  parser = argparse.ArgumentParser()\n  parser.add_argument(\"output\", help=\"output directory\")\n  parser.add_argument(\"--input\", default=_TEST_INPUT, help=\"test input, uses the default test input provided if no argument.\")\n  args = parser.parse_args()\n  main(args)\n  \n"
  },
  {
    "path": "scripts/eval.py",
    "content": "#!/bin/env python\n\"\"\"Evaluate a demosaicking model.\"\"\"\nimport argparse\nimport logging\n\nimport torch as th\nfrom torch.utils.data import DataLoader\n\nimport demosaicnet\n\n\nlog = logging.getLogger(__name__)\n\nclass PSNR(th.nn.Module):\n    def __init__(self):\n        super(PSNR, self).__init__()\n        self.mse = th.nn.MSELoss()\n    def forward(self, out, ref):\n        mse = self.mse(out, ref)\n        return -10*th.log10(mse)\n\ndef main(args):\n    \"\"\"Entrypoint to the training.\"\"\"\n\n    # Load model parameters from checkpoint, if any\n    meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)\n    if meta is None:\n        log.warning(\"No checkpoint found at %s, aborting.\", args.checkpoint_dir)\n        return\n\n    data = demosaicnet.Dataset(args.data, download=False,\n                               mode=meta[\"mode\"],\n                               subset=demosaicnet.TEST_SUBSET)\n    dataloader = DataLoader(\n        data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)\n\n    if meta[\"mode\"] == demosaicnet.BAYER_MODE:\n        model = demosaicnet.BayerDemosaick(depth=meta[\"depth\"],\n                                           width=meta[\"width\"],\n                                           pretrained=True,\n                                           pad=False)\n    elif meta[\"mode\"] == demosaicnet.XTRANS_MODE:\n        model = demosaicnet.XTransDemosaick(depth=meta[\"depth\"],\n                                            width=meta[\"width\"],\n                                            pretrained=True,\n                                            pad=False)\n\n    checkpointer = demosaicnet.utils.Checkpointer(args.checkpoint_dir, model, meta=meta)\n    checkpointer.load_latest()  # Resume from checkpoint, if any.\n\n    # No need for gradients\n    for p in model.parameters():\n        p.requires_grad = False\n\n    mse_fn = th.nn.MSELoss()\n    psnr_fn = PSNR()\n\n    device = \"cpu\"\n    if th.cuda.is_available():\n        device = \"cuda\"\n        log.info(\"Using CUDA\")\n\n    count = 0\n    mse = 0.0\n    psnr = 0.0\n    for idx, batch in enumerate(dataloader):\n        mosaic = batch[0].to(device)\n        target = batch[1].to(device)\n        output = model(mosaic)\n\n        target = demosaicnet.utils.crop_like(target, output)\n\n        output = th.clamp(output, 0, 1)\n\n        psnr_ = psnr_fn(output, target).item()\n        mse_ = mse_fn(output, target).item()\n\n        psnr += psnr_\n        mse += mse_\n        count += 1\n\n        log.info(\"Image %04d, PSNR = %.1f dB, MSE = %.5f\", idx, psnr_, mse_)\n\n    mse /= count\n    psnr /= count\n\n    log.info(\"-----------------------------------\")\n    log.info(\"Average, PSNR = %.1f dB, MSE = %.5f\", psnr, mse)\n\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"data\", help=\"root directory for the demosaicnet dataset.\")\n    parser.add_argument(\"checkpoint_dir\", help=\"directory with the model checkpoints.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "scripts/train.py",
    "content": "#!/bin/env python\n\"\"\"Train a demosaicking model.\"\"\"\nimport logging\n\nimport torch as th\nfrom torch.utils.data import DataLoader\n\nimport demosaicnet\n\n\nlog = logging.getLogger(__name__)\n\n\nclass PSNR(th.nn.Module):\n    def __init__(self):\n        super(PSNR, self).__init__()\n        self.mse = th.nn.MSELoss()\n    def forward(self, out, ref):\n        mse = self.mse(out, ref)\n        return -10*th.log10(mse+1e-12)\n\n\nclass DemosaicnetInterface(demosaicnet.utils.ModelInterface):\n    \"\"\"Training and validation interface.\n\n    Args:\n        model(th.nn.Module): model to train.\n        lr(float): learning rate for the optimizer.\n        cuda(bool): whether to use CPU or GPU for training.\n    \"\"\"\n    def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()):\n        self.model = model\n        self.device = \"cpu\"\n        if cuda:\n            self.device = \"cuda\"\n        self.model.to(self.device)\n        self.opt = th.optim.Adam(self.model.parameters(), lr=lr)\n        self.loss = th.nn.MSELoss()\n        self.psnr = PSNR()\n\n    def training_step(self, batch):\n        fwd_data = self.forward(batch)\n        bwd_data = self.backward(batch, fwd_data)\n        return bwd_data\n\n    def forward(self, batch):\n        mosaic = batch[0]\n        mosaic = mosaic.to(self.device)\n        output = self.model(mosaic)\n        return output\n\n    def backward(self, batch, fwd_output):\n        target = batch[1].to(self.device)\n\n        # remove boundaries to match output size\n        target = demosaicnet.utils.crop_like(target, fwd_output)\n\n        loss = self.loss(fwd_output, target)\n\n        self.opt.zero_grad()\n        loss.backward()\n        self.opt.step()\n\n        with th.no_grad():\n            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)\n\n        return {\"loss\": loss.item(), \"psnr\": psnr.item()}\n\n    def init_validation(self):\n        return {\"count\": 0, \"psnr\": 0}\n\n    def update_validation(self, batch, fwd_output, running_data):\n        target = batch[1].to(self.device)\n\n        # remove boundaries to match output size\n        target = demosaicnet.utils.crop_like(target, fwd_output)\n\n        with th.no_grad():\n            psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)\n            n = target.shape[0]\n\n        return {\n            \"psnr\": running_data[\"psnr\"] + psnr.item()*n,\n            \"count\": running_data[\"count\"] + n\n        }\n\n    def finalize_validation(self, running_data):\n        return {\n            \"psnr\": running_data[\"psnr\"] / running_data[\"count\"]\n        }\n\n\ndef main(args):\n    \"\"\"Entrypoint to the training.\"\"\"\n\n    # Load model parameters from checkpoint, if any\n    meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)\n    if meta is None:\n        log.info(\"No metadata or checkpoint, \"\n                 \"parsing model parameters from command line.\")\n        meta = {\n            \"depth\": args.depth,\n            \"width\": args.width,\n            \"mode\": args.mode,\n        }\n\n    data = demosaicnet.Dataset(args.data, download=False,\n                               mode=meta[\"mode\"],\n                               subset=demosaicnet.TRAIN_SUBSET)\n    dataloader = DataLoader(\n        data, batch_size=args.bs, num_workers=args.num_worker_threads,\n        pin_memory=True, shuffle=True)\n\n    val_dataloader = None\n    if args.val_data:\n        val_data = demosaicnet.Dataset(args.data, download=False,\n                                       mode=meta[\"mode\"],\n                                       subset=demosaicnet.VAL_SUBSET)\n        val_dataloader = DataLoader(\n            val_data, batch_size=args.bs, num_workers=1,\n            pin_memory=True, shuffle=False)\n\n    if meta[\"mode\"] == demosaicnet.BAYER_MODE:\n        model = demosaicnet.BayerDemosaick(depth=meta[\"depth\"],\n                                           width=meta[\"width\"],\n                                           pretrained=True,\n                                           pad=False)\n    elif meta[\"mode\"] == demosaicnet.XTRANS_MODE:\n        model = demosaicnet.XTransDemosaick(depth=meta[\"depth\"],\n                                            width=meta[\"width\"],\n                                            pretrained=True,\n                                            pad=False)\n    checkpointer = demosaicnet.utils.Checkpointer(\n        args.checkpoint_dir, model, meta=meta)\n\n    interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda)\n\n    checkpointer.load_latest()  # Resume from checkpoint, if any.\n\n    trainer = demosaicnet.utils.Trainer(interface)\n\n    keys = [\"loss\", \"psnr\"]\n    val_keys = [\"psnr\"]\n\n    trainer.add_callback(demosaicnet.utils.ProgressBarCallback(\n        keys=keys, val_keys=val_keys))\n    trainer.add_callback(demosaicnet.utils.CheckpointingCallback(\n        checkpointer, max_files=8, interval=3600, max_epochs=10))\n\n    if args.cuda:\n        log.info(\"Training with CUDA enabled\")\n    else:\n        log.info(\"Training on CPU\")\n\n    trainer.train(\n        dataloader, num_epochs=args.num_epochs,\n        val_dataloader=val_dataloader)\n\n\nif __name__ == \"__main__\":\n    parser = demosaicnet.utils.BasicArgumentParser()\n    parser.add_argument(\"--depth\", default=15,\n                        help=\"number of net layers.\")\n    parser.add_argument(\"--width\", default=64,\n                        help=\"number of features per layer.\")\n    parser.add_argument(\"--mode\", default=demosaicnet.BAYER_MODE,\n                        choices=[demosaicnet.BAYER_MODE,\n                                 demosaicnet.XTRANS_MODE],\n                        help=\"number of features per layer.\")\n    args = parser.parse_args()\n    main(args)\n"
  },
  {
    "path": "setup.py",
    "content": "import re\nimport setuptools\n\n\nwith open('demosaicnet/version.py') as fid:\n    try:\n        __version__, = re.findall( '__version__ = \"(.*)\"', fid.read() )\n    except:\n        raise ValueError(\"could not find version number\")\n\n\nwith open(\"README.md\", \"r\") as fh:\n    long_description = fh.read()\n\n\nsetuptools.setup(\n    name='demosaicnet',\n    version=__version__,\n    scripts=[\"scripts/demosaicnet_demo.py\"],\n    author=\"Michaël Gharbi\",\n    author_email=\"gharbi@csail.mit.edu\",\n    description=\"Minimal implementation of Deep Joint Demosaicking and Denoising [Gharbi2016]\",\n    long_description=long_description,\n    url=\"https://github.com/mgharbi/\",\n    packages = setuptools.find_packages(exclude=[\"tests\"]),\n    include_package_data=True,\n    install_requires=[\"wget\", \"tqdm\", \"torch\", \"imageio\", \"numpy\"],\n    classifiers=[\n      \"Programming Language :: Python :: 3\",\n      \"License :: OSI Approved :: MIT License\",\n      \"Operating System :: MacOS :: MacOS X\",\n      \"Operating System :: POSIX\",\n    ],\n)\n"
  }
]