Repository: mgharbi/demosaicnet Branch: master Commit: 959e9d163097 Files: 24 Total size: 82.4 KB Directory structure: gitextract_p3udv4rd/ ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── demosaicnet/ │ ├── .gitignore │ ├── __init__.py │ ├── dataset.py │ ├── modules.py │ ├── mosaic.py │ ├── utils.py │ └── version.py ├── docs/ │ ├── .gitignore │ ├── Makefile │ └── source/ │ ├── conf.py │ ├── dataset.rst │ ├── helpers.rst │ ├── index.rst │ └── models.rst ├── requirements.txt ├── scripts/ │ ├── demosaicnet_demo.py │ ├── eval.py │ └── train.py └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ output/ dist/ demosaicnet.egg-info build/ data .DS_Store ================================================ FILE: LICENSE ================================================ MIT License Deep Joint Demosaicking and Denoising Siggraph Asia 2016 Michael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand Copyright (c) 2016 Michael Gharbi Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ include demosaicnet/data/bayer.pth include demosaicnet/data/xtrans.pth include demosaicnet/data/test_input.png ================================================ FILE: Makefile ================================================ test: py.test tests .PHONY: docs docs: $(MAKE) -C docs html clean: python setup.py clean rm -rf build demosaicnet.egg-info dist .pytest_cache distribution: python setup.py sdist bdist_wheel twine check dist/* test_upload: twine upload --repository-url https://test.pypi.org/legacy/ dist/* upload_distribution: twine upload dist/* ================================================ FILE: README.md ================================================ # Deep Joint Demosaicking and Denoising SiGGRAPH Asia 2016 Michaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris Frédo Durand A minimal pytorch implementation of "Deep Joint Demosaicking and Denoising" [Gharbi2016] # Installation From this repo: ```shell python setup.py install ``` Using pip: ```shell pip install demosaicnet ``` Then run the demo script with: ```shell python scripts/demosaicnet_demo.py output ``` To train a dummy model on the demo dataset provided, run: ```shell python scripts/train.py --data demosaicnet/data/dummy_dataset --checkpoint_dir ckpt ``` To build and update the whee: ```shell pip install wheel twine make distribution make upload_distribution ``` # FAQ - **How is noise handled? Where is the pretrained model?** The noise-aware model is not implementation, see the earlier Caffe implementation for that - **How do I train this?** The script `scripts/train.py` is a good start to setup your training job, but I haven't tested it yet, I recommend rolling your own. ================================================ FILE: demosaicnet/.gitignore ================================================ __pycache__ ================================================ FILE: demosaicnet/__init__.py ================================================ from .modules import BayerDemosaick from .modules import XTransDemosaick from .mosaic import xtrans from .mosaic import bayer from .mosaic import xtrans_cell from .dataset import * from . import utils ================================================ FILE: demosaicnet/dataset.py ================================================ """Dataset loader for demosaicnet.""" import os import subprocess import shutil import hashlib import logging import numpy as np from imageio import imread from torch.utils.data import Dataset as TorchDataset import wget from .mosaic import bayer, xtrans __all__ = ["BAYER_MODE", "XTRANS_MODE", "Dataset", "TRAIN_SUBSET", "VAL_SUBSET", "TEST_SUBSET"] log = logging.getLogger(__name__) BAYER_MODE = "bayer" """Applies a Bayer mosaic pattern.""" XTRANS_MODE = "xtrans" """Applies an X-Trans mosaic pattern.""" TRAIN_SUBSET = "train" """Loads the 'train' subset of the data.""" VAL_SUBSET = "val" """Loads the 'val' subset of the data.""" TEST_SUBSET = "test" """Loads the 'test' subset of the data.""" class Dataset(TorchDataset): """Dataset of challenging image patches for demosaicking. Args: download(bool): if True, automatically download the dataset. mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data. subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load. """ def __init__(self, root, download=False, mode=BAYER_MODE, subset="train"): super(Dataset, self).__init__() self.root = os.path.abspath(root) if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]: raise ValueError("Dataset subet should be '%s', '%s' or '%s', got" " %s" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET, subset)) if mode not in [BAYER_MODE, XTRANS_MODE]: raise ValueError("Dataset mode should be '%s' or '%s', got" " %s" % (BAYER_MODE, XTRANS_MODE, mode)) self.mode = mode listfile = os.path.join(self.root, subset, "filelist.txt") log.debug("Reading image list from %s", listfile) if not os.path.exists(listfile): if download: _download(self.root) else: log.error("Filelist %s not found", listfile) raise ValueError("Filelist %s not found" % listfile) else: log.debug("No need no download the data, filelist exists.") self.files = [] with open(listfile, "r") as fid: for fname in fid.readlines(): self.files.append(os.path.join(self.root, subset, fname.strip())) def __len__(self): return len(self.files) def __getitem__(self, idx): """Fetches a mosaic / demosaicked pair of images. Returns mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels. img(np.array): with size [3, h, w] the groundtruth image. """ fname = self.files[idx] img = np.array(imread(fname)).astype(np.float32) / (2**8-1) img = np.transpose(img, [2, 0, 1]) if self.mode == BAYER_MODE: mosaic = bayer(img) else: mosaic = xtrans(img) return mosaic, img CHECKSUMS = { 'datasets.z01': 'da46277afe85d3a91c065e4751fb8175', 'datasets.z02': 'e274a9646323d954b00094ea424e4e4c', 'datasets.z03': 'e071cc595a99a5aa4545d06350e5165f', 'datasets.z04': 'c3d2f229834569cd5ae6e2d1467c4a95', 'datasets.z05': 'daf90136c7b1ee4bb4653e9b6bf4b67d', 'datasets.z06': '87e85d2854d40116e066b28e5a8750cc', 'datasets.z07': 'a0b2854bf025c87c0bfdf83ce9aa9055', 'datasets.z08': '62125ccf29cd4b182dd81a4bb82f94c4', 'datasets.z09': 'f990f8a5090d586f2f31e61b5e6434bd', 'datasets.z10': '41ecf8d8b7d981604d661d258bf988db', 'datasets.z100': '923a536ece64cd036eec4a13156531c8', 'datasets.z101': '44a936558af2e830fdf65d9acb3960ab', 'datasets.z102': 'b24870482b41200ab7b91f0bcd3ed718', 'datasets.z103': 'a85521c1fe0b8d2d1a074b0b52bf9db1', 'datasets.z104': 'aacc7a81ec9e9a7849e3a45b1cb12f7c', 'datasets.z105': '19b62c0f0ae008b77df6465182f43dc4', 'datasets.z106': '4b0c414ce5825a9e2249e5810f0e55f0', 'datasets.z107': '7f6df7fea899a656fcde898225890daf', 'datasets.z108': '16a877c357f112367200a2534b5e54f3', 'datasets.z109': '9180129bf9c204184f729bdf1c284c9c', 'datasets.z11': 'cff5b0e9950933fa9dd6ced8ffb9528f', 'datasets.z110': 'a95a6fbfd32d90058b9e4b9f0645c646', 'datasets.z111': 'f3c7894a7d04178ca417dd5ed3a9e649', 'datasets.z112': 'd46a73703a72c07137424cad90c9c0bd', 'datasets.z113': '04b3421e465c5ef8bd64fc23730b58f7', 'datasets.z114': '9c31d2d94bd6f1ba321039b18c462175', 'datasets.z115': '427a3a6f3f936b0ba435da35be3e4bc3', 'datasets.z116': 'c633a66e7644d7d8e8148d651b76d93f', 'datasets.z117': 'cf316d3acaf301fc7b2b7e250ef734dd', 'datasets.z118': 'c3b53a604492499930dfd000d7d33fa4', 'datasets.z119': 'e64ede2179c589abd9e587347d9ba3b0', 'datasets.z12': '17b70298245ae7965f4e4b4fb01f19cc', 'datasets.z120': 'ece09ee0bc30eab71f06716cca393029', 'datasets.z121': '79db008803bf58593df6c32db8c0b3d0', 'datasets.z122': '647b151eb30a44d123ce9ddfbb380094', 'datasets.z123': '5065f755fe61c666d6ed28096e4047a3', 'datasets.z124': '6306215855e30112c495291d5928e0b7', 'datasets.z125': 'a55c6e31e7ad42170016a15791e25134', 'datasets.z126': '582eb81f0251840507ca0b53e624b1e0', 'datasets.z127': '3beefb769a01481a8bc7ae39bc2f539b', 'datasets.z128': 'fac96b38f96ea364ea51020386597b5d', 'datasets.z129': '3aaca8ce2b67d2c1fe764ae2b306d17f', 'datasets.z13': '7ec2aa595441d9f698a46d707f299e8c', 'datasets.z130': '9f94105dcc39cf5421b2df2532c06ec9', 'datasets.z131': '5de2901388a2e531d6874cdaf23bf15e', 'datasets.z132': '6851ed45004ae6864892532d1ad44b20', 'datasets.z133': '0f6b417d57f9bdc9fa91d85ff5e3378d', 'datasets.z134': '0602b14c8828f7a9fed92713047c695f', 'datasets.z135': '3ab79fd4da5c4c5b5a7896a189ec43a7', 'datasets.z136': 'dd05152db786d8189cdb419ac5d0018a', 'datasets.z137': 'b4e97abef22ee8b81ac232760d9f539d', 'datasets.z138': '4ce939f1fa4f3e110e989db07d53d33f', 'datasets.z139': 'a096add471cc5e074852c063eec3863a', 'datasets.z14': '08571b6629b8856813fb35b45bbc082c', 'datasets.z140': '3d84e8cc84ab26969c5239be60222ad3', 'datasets.z141': 'd77062f59ca9957d33c9a671657fd795', 'datasets.z142': '82901d01917006348deb89aa37fe3629', 'datasets.z143': '736f77856f0854b26fa951479691df8f', 'datasets.z144': '55c44320975f4278a8837085c5e02eda', 'datasets.z145': '087d3b7634bf4720a916767d5c6b7d70', 'datasets.z146': '5659d6f0495dcdc5f5d98bf2efaaa09b', 'datasets.z147': '66dd69b2f9348e3c0d0c93c3e61416dd', 'datasets.z148': 'f3fc8f15aceb0f9bf04d786b894caa44', 'datasets.z149': '3863be1d2b130f79399432cfc1281c2e', 'datasets.z15': 'cc57e0c4466575436f670ac3e07ad2f3', 'datasets.z150': '09750f2019da9ff7132b904b8bcbd895', 'datasets.z151': 'b1573f086c0f7d1fdf249a8e3a9bb178', 'datasets.z152': '1a2d4374aea1e22c0b676a6a7eac49ec', 'datasets.z153': 'b24320708d2019ed71ab16055e971b1d', 'datasets.z154': '7ba27e1946afa610e131f3afefe78326', 'datasets.z155': '2f02c8b5470be4cb6b53e4c9e512394e', 'datasets.z156': 'fa4f0977409f181820bf78174257d657', 'datasets.z157': '6736b97a29d1393ec65ddc9376a06369', 'datasets.z158': 'b4b72842b13ec3877bc530ca2470a0db', 'datasets.z159': 'fa0dfa57c9d299175719bbbdf319c935', 'datasets.z16': '02c213b708e2ee7ebd68464dfb2279fc', 'datasets.z160': '6edbb9dc7fa6d12d2e21631ae14eaa8c', 'datasets.z161': '4ab093ce5af2726e9ee71fdf1943e8e2', 'datasets.z162': '4e21401db9f9884d953df20381c5fd97', 'datasets.z163': '8d39f0ed1a1d9b5de22b583d00081522', 'datasets.z164': '40fe425e2c5e89b87b44a6e9735590d5', 'datasets.z165': '9552ff9b03e2dbb45befc9d1cc99ad81', 'datasets.z166': '6c0098a36d6827aea846d8522c578751', 'datasets.z167': 'ce6b8b981d92f5a61f2ec40089a400a2', 'datasets.z168': '60f6e16a3e5e409a3fc89edd3e0034d5', 'datasets.z169': '75eb975a10d5cbf136796651a1789b42', 'datasets.z17': 'c92ae62205eaef02db27996f0dc6c282', 'datasets.z170': '287966840fff015ed36da3a08a18ebfb', 'datasets.z171': 'e08193b722af492a78ac36a3125ac8d9', 'datasets.z172': '32c795461f194c38b25047faeb46fdb5', 'datasets.z173': '58fbb396dbbb902ac2f2c43722573200', 'datasets.z174': '506fee2b982ee81689f3fe4d89133cac', 'datasets.z175': '9d553e31b07b23e30c427800168eec6b', 'datasets.z176': '0f6e3048824ead093d3127434ac83a72', 'datasets.z177': 'ece15c004fa708849295987b8b1aba9c', 'datasets.z178': 'd9db14d92d56ae2970798417030b5bb4', 'datasets.z179': '25ddbe866d0a6b9cebe8d90f7b801fa6', 'datasets.z18': 'f55ddc31cf203f495e352182a5bbadc3', 'datasets.z180': 'd2fc49c68c77d1da592aff4ab90c0915', 'datasets.z181': '5a4a635b1f3535311c6caecf4ab3ba80', 'datasets.z182': 'df51725daff3edfc377a5f6bc158ec3a', 'datasets.z183': '626add199ec4f263ff278d5392f41c9c', 'datasets.z184': '5069483fd064ee5e8c24a240e6ee7736', 'datasets.z185': '589249e98db0a4ded1d3e4acefd07509', 'datasets.z186': 'e4415c64463ef16bceb9d2e2fa934d71', 'datasets.z187': 'e070cbaaf88a1085964244f6505c713c', 'datasets.z188': '71b38eb51edff8b049a302bacbe344d6', 'datasets.z189': '8fa7a8b58c9e7cb9e86bfd0ca5f6d2ea', 'datasets.z19': '6c34cd0e39a33737983ebf89f6cabf5f', 'datasets.z190': 'daad0ea7c87d0935e014a370c38cc926', 'datasets.z191': 'c355e9ee9d0afa67faa34739b7f7cf79', 'datasets.z192': 'c97d5a784625795cdf3c36c337986afe', 'datasets.z193': '5f3b8425e215798c9e454cdbe586db90', 'datasets.z194': '31b5d74c1cbbabbf58ee470467b40d12', 'datasets.z195': '4c65958343bc2ea1a28e779ee7e5e498', 'datasets.z196': '26ab3664e62c7fd5d0be673c45dd0d93', 'datasets.z197': '32d690086b6e9f05e3ced3a126af870c', 'datasets.z198': '323071827db89626c9f186455fbb38c9', 'datasets.z199': '72475a8500be1ff21407a66f0e2e91a6', 'datasets.z20': '4161d6eda0ca5ed9500f953f789a25b2', 'datasets.z200': '44a863b9c9760cb87a23f1422f242c0d', 'datasets.z201': 'dc38b455fa45e3ef0d5f06397507982e', 'datasets.z202': 'b9ba231b317b008602f9472325b40e65', 'datasets.z203': '36f4afc46258d80be626040956550028', 'datasets.z204': '5c522fbdfe1f9d449c9189173f2ed2b2', 'datasets.z205': '58d39995017eed2c4abcf9fcfd07e695', 'datasets.z206': '2efdfc2abc834f0f0f1cabe10423f865', 'datasets.z207': '04ead7536e5c13936c724f644ab1cb3a', 'datasets.z208': '9e3e0a02a07bebcc7cdb62a0ad047946', 'datasets.z209': 'c1fc44cd8b6f50955c8b3b317155ecb6', 'datasets.z21': '8669b8bb9fa90628d4423c45648868b2', 'datasets.z210': '210df79f8434bdb4e2a7d12c4078d972', 'datasets.z211': 'f078d2d8a14b6c59f58c67865bbc3334', 'datasets.z212': 'ef7d08a6cc39f6cb96b631ca61b440a0', 'datasets.z213': '723057f7619d8820f142944f55f9542b', 'datasets.z214': '2ba38ca8561b51710f660c03f84c0eb1', 'datasets.z215': '92a6e97dfaff295110ddead242ebe932', 'datasets.z216': '05f40901ae70f73b3c099fcdd4ca945e', 'datasets.z217': 'b690ed3e8c6ba9f8bba9154d7e8f7ece', 'datasets.z218': 'e290ca6f5573579df9f3aa7c5158891e', 'datasets.z219': 'a8e4626968f089163179f30066ce732c', 'datasets.z22': '0f90463abdc8f0f81d81249302cf2d09', 'datasets.z220': '3ecd2c0c855505d2957046d784944fce', 'datasets.z221': '6c9c28287fbadcab2ce777ef3134e5d6', 'datasets.z222': '29361a77f05e5e68113fc23e11b54b4d', 'datasets.z223': '9214257b9a87c0037e88709addba8948', 'datasets.z224': '7596a516fb7e308f33a81c5b3c36810a', 'datasets.z225': 'c1dc079f5261a976b1bd7f5c05cd4a02', 'datasets.z226': '5b87815b0ccc5cacec83a399a52874aa', 'datasets.z227': '6bed353dc50263b2c720af663c833bbc', 'datasets.z228': '37ed3574bf978bccd6e2db9be00bae94', 'datasets.z229': '8fd57367808fd77581f998850e5f935a', 'datasets.z23': '90ae2cdcdc1663b80c20e080e5c0e038', 'datasets.z230': '3f5ef3234da0236d2fbfaf7366407d70', 'datasets.z231': 'f67d8320028620c8bdd9a800a78afa27', 'datasets.z232': '5f831f25f8e8557168b38a7a28f8e7f9', 'datasets.z233': '56ee8c4b01825ba7f12340ae8b990db3', 'datasets.z234': '5428e98487b0e077cf9c24dc60599286', 'datasets.z235': '883c0ea97facca4d57d5c9c54922e8be', 'datasets.z236': 'e23fb6f610a3b528d5b310df4e452256', 'datasets.z237': 'bd858e84a47668edc851dab131239ae4', 'datasets.z238': 'db51bed3e3e5c6a40881f22532618533', 'datasets.z239': 'c1be852117739fc63227a503b08a8436', 'datasets.z24': '00dcb2e2a72b15a9aa9a646ecaea0019', 'datasets.z240': 'f4e6da02349b03b4f433b3399dcf8b3c', 'datasets.z241': '4f29898105aaf9f1a753a1c639947c2f', 'datasets.z242': '168a15bd8367f5d5f3e5e8cf4d0da6af', 'datasets.z243': '2930bced33bd1ecabce070fc831567e7', 'datasets.z244': 'aed9c5ec05f57e3fa9e7b224d47fa7b0', 'datasets.z245': '5aa83729fec805c166e48e5ec21530a5', 'datasets.z246': '646ceabfae028d631568930b4056227a', 'datasets.z247': '7381491175c1a63cc04ecac81148925b', 'datasets.z248': '4064df81449c1980d0abaa8c7262b315', 'datasets.z249': '7ae84d1dde2e935d86138d1e7b077df8', 'datasets.z25': '5ae383bcd01d4ce22387680e28833f06', 'datasets.z250': 'c018b41fbc4982a561b07cf0d52137f2', 'datasets.z251': '9c9dc7a889d537fd1e02f4549529a5f1', 'datasets.z252': 'b74df0680e7a62794902186fe1e3fec2', 'datasets.z253': '894cbb618ddffe65ce2ada0f250ad79c', 'datasets.z254': 'd5d8bc590d109c7d592d4df183c495a8', 'datasets.z255': 'faf313a3edd70129c212d3dbd1de5042', 'datasets.z256': '0fefcb84b66518df03a93ae53079409f', 'datasets.z257': '887e8e8abaf09682903b9b1060fb8153', 'datasets.z258': 'a700ae13abb7707032123468fab1bf55', 'datasets.z259': 'cad7c974b832d27d1cfb4ae0e4dc6c3c', 'datasets.z26': '5e0cdf281eeca969a4e0adfc44e11dcb', 'datasets.z260': '5e1c11df40440e4e84354d40efe7940e', 'datasets.z261': '2a75579c238569356855244d9fedf50b', 'datasets.z262': '51d508271fef5558df387542b3561b67', 'datasets.z263': 'e11b54dcd5069e838a34dcc2daebf4b5', 'datasets.z264': '8a5a3b288d21a3c2ef641370d436703e', 'datasets.z265': '56954047cba7c8732f0490323540af43', 'datasets.z266': 'cad0325e494cc720c385ac7420acd2d7', 'datasets.z267': '25de522b499ec7af12f583dd89a31769', 'datasets.z268': '0e9882f392ca679e8c47276371384efd', 'datasets.z269': '28487bc8eb731d4913254a9d63bb13ae', 'datasets.z27': 'c78804dfac1e395156abd235ca416b33', 'datasets.z270': '276fde25d412d8a1197e3dda307580d7', 'datasets.z271': 'ba6f46558aa9ebc64efca2485b4f18ca', 'datasets.z272': '5acceb08940d937d3023d98e745b8197', 'datasets.z273': '57201a53390fe9c6a8c069397dcb81b8', 'datasets.z274': '92491b6786ccea7b6ccccfb4e09c6d75', 'datasets.z275': '349136c52b8554f03967d7083e5cd95c', 'datasets.z276': '1525075dbf4d5d101d63cdade8bed9e7', 'datasets.z277': 'eeb6365ef482cd2c6bac20aab8181081', 'datasets.z278': 'ab1b04860b27fc11b7f57074a0815877', 'datasets.z279': 'dc1cee0d4b69da9fd7aeb47f91768589', 'datasets.z28': '0257b938256a2b7b55637970ebb3edcc', 'datasets.z280': 'd168adf5e7c223a1d8dddfc663ebaeb0', 'datasets.z281': '540fc1de91a90e9bd91b3f2b590ddbf6', 'datasets.z282': 'f3fd6fbe05cfd53eb4a4c2e41bf75cc7', 'datasets.z283': '20e89e60dd6a582bcb98b74df82699c3', 'datasets.z284': '6e6ba6077437285881609999ead45463', 'datasets.z285': 'fdc4bad8adb36b3d6653a438f1fc000f', 'datasets.z286': '39c0f6fc7aace7e30a33e7e73afdb6ae', 'datasets.z287': 'e60818eefd7426f3de0cc0746550be7f', 'datasets.z288': 'd9db9107b8e92c0bf6a311a219363554', 'datasets.z289': 'a250aa672d1f9165981cfaf1c6c8fff6', 'datasets.z29': 'e1240710e1c4dd506aec03c02caf5606', 'datasets.z290': '6dcb5a7674c927ec4e965a42d04a0ccd', 'datasets.z291': 'c058fca515b7f714338816b72672ce20', 'datasets.z292': '3fa254ed46ad6a6f7686836fe6fb7991', 'datasets.z293': '782344c6620582d6b1681e142853a61e', 'datasets.z294': '5263c5eb50cfec20adf82a89973a7547', 'datasets.z295': '0c217819aa7308ce8f744511e572a632', 'datasets.z296': 'cf8e7f1d3503ad6371ebcc9f827a29f8', 'datasets.z297': 'f16861471be342b291e55991654c882b', 'datasets.z298': 'ceb8b1170f1f2ec8113ef1c004df236c', 'datasets.z299': '4f66a50c5dfb03143a6456c7fd925ec1', 'datasets.z30': '75b0444187fc6c3df7bd3182108bc647', 'datasets.z300': 'becbeea47cef192a1a13b35911c4795a', 'datasets.z301': '92a67108c6bde11111f3b94f690f4b42', 'datasets.z302': '06292210e05575cf1632a099648c16af', 'datasets.z303': 'ecb7a48777719322c957ecc99340f04a', 'datasets.z304': 'b318fc0b7d645d16467b78bbce95befa', 'datasets.z305': '5f83e6a17e5977c2b99fc29893a1f479', 'datasets.z306': 'dde2ec22363740081f54032a3add00e0', 'datasets.z307': 'e0b7f9a7eddf2117a6127542883eb767', 'datasets.z308': 'fff180911deaf4b476f42e6e47d78e6e', 'datasets.z309': 'c55a8ac017f7ea69e77519fbcc617301', 'datasets.z31': '04bee9374c7436d66f560bbbfc22299d', 'datasets.z310': '3b4282fd1ab2b4f885df196a83726d34', 'datasets.z311': '08a7c41c5d290750d9ebf6266a86bec8', 'datasets.z312': 'b0f844ffd0ae6785e077ef8871dfc5da', 'datasets.z313': 'bd03ba03a63877b17274e21ddb828218', 'datasets.z314': '392f4528f9c345355434e7448a80b28b', 'datasets.z315': 'e499cdf8acf1561720d4eb8f9ad9daad', 'datasets.z316': '2575496c4d9c5082c6c4ef2f0cedeb69', 'datasets.z317': '6f387284dbef478e02f7a5954da66015', 'datasets.z318': '2ea47b6b7c9790bbfc2d3c238ffba391', 'datasets.z319': '8e1e48b2aa1d6c7ae840a23360a3a8f8', 'datasets.z32': 'ab930ee97741da82bfc778474f528a28', 'datasets.z320': 'a785b37410ccf0366f97c0d221faa629', 'datasets.z321': '16c73037fa0ca7704a3c4ceddbd7d599', 'datasets.z322': 'af97027070cf5d057934dfd6bd819e61', 'datasets.z323': '79078baf8d5fee935620f48aed2980a6', 'datasets.z324': '987b626c8731b0092df593f0eaddb32a', 'datasets.z325': 'fcc9a289015b40044c55ca96bc3dbe1e', 'datasets.z326': 'f2921b782a23f2729e1a6641e8c99954', 'datasets.z327': '6d1b488d88cf0fcaf5105b6544e5ea66', 'datasets.z328': '268885eaeb6290be4ddfaefb34943985', 'datasets.z329': '3951d3dd0527b54c5302720ea037cb12', 'datasets.z33': 'fff4585bf34e5a7ed8f369b337aac901', 'datasets.z330': '4b546d783366442d95eed64f882ae9f6', 'datasets.z331': '4605615511ea901c18256306263b2226', 'datasets.z332': '4127b2a8dd549b02ce3e465cfcfcc0a3', 'datasets.z333': '711d282247b4fd56758c96c13bfe1b8e', 'datasets.z34': '39b5d53f995fa8c223ed0d7a2de34652', 'datasets.z35': 'ae351c1e2961f99a6d3ac37fbae27548', 'datasets.z36': 'c2102c4a984d03c32c7c99378df953dc', 'datasets.z37': '76ba463895984049a1814654b7290890', 'datasets.z38': '5481901e75ee9d55066ba2731b2f36b7', 'datasets.z39': '59815df91b2532d1e300bc71c976da12', 'datasets.z40': '321ebe9b7812c14ee185fe2d6f16300c', 'datasets.z41': '862bc3fcb9fc1df2bef61561bcca8090', 'datasets.z42': '10303d540fea2150a7574cefcec92977', 'datasets.z43': '2269db3212f2d2db86982408a2b24948', 'datasets.z44': '80673bdbd722d02d97febeca00e57cf1', 'datasets.z45': '42329dbb5c5165788902f33265db66a3', 'datasets.z46': 'f81999194d418c515bb5df32c278d7f7', 'datasets.z47': '2e3d3520636b5f3eb0cd2d649b6b4dab', 'datasets.z48': 'e57e7aee104e80deaf068fbdd3292410', 'datasets.z49': '36146291e4af0eff44e763e7e1facb4f', 'datasets.z50': '60ef115c5c621c757b9b7075d5590c20', 'datasets.z51': '02326e4392d176c2aa6d479cf43f29f8', 'datasets.z52': 'b454f81ef7cda24cb13b8176c641d7ef', 'datasets.z53': '95a4857fe7bc6230e6a3cc379085e989', 'datasets.z54': '8ff35d1e9d738eb8b2afde04699ba73f', 'datasets.z55': '75d19ea4a9a283b6d236e3353063d82d', 'datasets.z56': '2ec9121fb364cc76eaf6fe7ff947dd04', 'datasets.z57': 'f5d0b9b2b0a91f82dea784023f71cf3c', 'datasets.z58': '79e318093d97bd573cc7aab4ed68dd6d', 'datasets.z59': '68350ada72cfc22d2d6fc8ca636ffef2', 'datasets.z60': '1ff9d36ec2b3723253503d524c90c9df', 'datasets.z61': '9fb963355e8c895c1a0a02fa45eb6fb2', 'datasets.z62': '5900cdc6c3d5f4b0c05825b47946096a', 'datasets.z63': '9d039ddf86aef563bf7834faa2766e84', 'datasets.z64': '15decde2fb2b4f869684f88f067378c1', 'datasets.z65': '7cf283ebffd79c38bc09835caba483e1', 'datasets.z66': 'e3b65f4facef36a355b444bd6b4ec73b', 'datasets.z67': 'b007466dd2ac3388a5902fdf92979b2d', 'datasets.z68': 'c07cd3fbc9a99e8f4d3cc9bbc55682e0', 'datasets.z69': 'cff012a6a3af8a9f286cf68baf37ac73', 'datasets.z70': '032ad7661e468723242b995357870a05', 'datasets.z71': 'fd7aa78706b8ae7ff50eb2f312a14ea5', 'datasets.z72': 'c2b90b8e8f75e36c927d523138d8ae93', 'datasets.z73': '8c001e26b5f4ca0e061ff0b685d509c6', 'datasets.z74': '484dcaf9311cc4315175e114acf01f32', 'datasets.z75': '19f6896b426b99f34b99c88f3b68209f', 'datasets.z76': '4a64c040ba9aa6e455fee774a08873e5', 'datasets.z77': 'e96b1d94e51741ca916580c5f8287ef5', 'datasets.z78': 'e0f5c9d93a3a5cfd95eeb61ee322650d', 'datasets.z79': '75fbea4391cd38d85c583c0f504325f1', 'datasets.z80': '7a99d15beff97f2dc620300f5ea82506', 'datasets.z81': '1e303b839a0349a8497eb6c49e3ab4c3', 'datasets.z82': '48c23365901b8d9cb51a20f7c71fa0a7', 'datasets.z83': 'fed6445cb47f0b6cd6ed6f5482e38be7', 'datasets.z84': '3d052276fc186ec151b5c237c5774e66', 'datasets.z85': 'd84dccdf4fd4e6f6126b746d0519d19d', 'datasets.z86': '772ab63f1ae266761db99124a152fc29', 'datasets.z87': '44208b663adf97563018a416560cce6b', 'datasets.z88': '6be7aaff3349239507b103c928afeaf5', 'datasets.z89': '515daa16bcaf83b0be91dcc32e0e4985', 'datasets.z90': 'f0fe0aa02bf0f19c0a7e301b532afa4c', 'datasets.z91': '87fa807089240e01886389a6c4c77641', 'datasets.z92': '1fd3154b3d8ec321f58f7685aedc5441', 'datasets.z93': 'b24ddb4b3e2c0c1c96d7650a8320e446', 'datasets.z94': 'b5e98cd90a629c4abb70d66a6976e0c7', 'datasets.z95': '87c8d5d7f26e0fbf7422fda2194bef97', 'datasets.z96': 'bb18c027a71f326a33b0a8fbe2d3a11f', 'datasets.z97': '0489ed1c8b4fb7d5965d1565076d5c6a', 'datasets.z98': 'a85303e08ce59fdf28f36ae9d0f20dcb', 'datasets.z99': '992e63e126f05eca9fa9a84bbf66165c', 'datasets.zip': '3434f60f5e9b263ef78e207b54e9debe', } def _download(dst): dst = os.path.abspath(dst) files = CHECKSUMS.keys() fullzip = os.path.join(dst, "datasets.zip") joinedzip = os.path.join(dst, "joined.zip") URL_ROOT = "https://data.csail.mit.edu/graphics/demosaicnet" if not os.path.exists(joinedzip): log.info("Dowloading %d files to %s (This will take a while, and ~80GB)", len( files), dst) os.makedirs(dst, exist_ok=True) for f in files: fname = os.path.join(dst, f) url = os.path.join(URL_ROOT, f) do_download = True if os.path.exists(fname): checksum = md5sum(fname) if checksum == CHECKSUMS[f]: # File is is and correct log.info('%s already downloaded, with correct checksum', f) do_download = False else: log.warning('%s checksums do not match, got %s, should be %s', f, checksum, CHECKSUMS[f]) try: os.remove(fname) except OSError as e: log.error("Could not delete broken part %s: %s", f, e) raise ValueError if do_download: log.info('Downloading %s', f) wget.download(url, fname) checksum = md5sum(fname) if checksum == CHECKSUMS[f]: log.info("%s MD5 correct", f) else: log.error('%s checksums do not match, got %s, should be %s. Downloading failed', f, checksum, CHECKSUMS[f]) log.info("Joining zip files") cmd = " ".join(["zip", "-FF", fullzip, "--out", joinedzip]) subprocess.check_call(cmd, shell=True) # Cleanup the parts for f in files: fname = os.path.join(dst, f) try: os.remove(fname) except OSError as e: log.warning("Could not delete file %s", f) # Extract wd = os.path.abspath(os.curdir) os.chdir(dst) log.info("Extracting files from %s", joinedzip) cmd = " ".join(["unzip", joinedzip]) subprocess.check_call(cmd, shell=True) try: os.remove(joinedzip) except OSError as e: log.warning("Could not delete file %s", f) log.info("Moving subfolders") for k in ["train", "test", "val"]: shutil.move(os.path.join(dst, "images", k), os.path.join(dst, k)) images = os.path.join(dst, "images") log.info("removing '%s' folder", images) shutil.rmtree(images) def md5sum(filename, blocksize=65536): hash = hashlib.md5() with open(filename, "rb") as f: for block in iter(lambda: f.read(blocksize), b""): hash.update(block) return hash.hexdigest() ================================================ FILE: demosaicnet/modules.py ================================================ """Models for [Gharbi2016] Deep Joint demosaicking and denoising.""" import os from collections import OrderedDict from pkg_resources import resource_filename import numpy as np import torch as th import torch.nn as nn __all__ = ["BayerDemosaick", "XTransDemosaick"] _BAYER_WEIGHTS = resource_filename(__name__, 'data/bayer.pth') _XTRANS_WEIGHTS = resource_filename(__name__, 'data/xtrans.pth') class BayerDemosaick(nn.Module): """Released version of the network, best quality. This model differs from the published description. It has a mask/filter split towards the end of the processing. Masks and filters are multiplied with each other. This is not key to performance and can be ignored when training new models from scratch. """ def __init__(self, depth=15, width=64, pretrained=True, pad=False): super(BayerDemosaick, self).__init__() self.depth = depth self.width = width if pad: pad = 1 else: pad = 0 layers = OrderedDict([ ("pack_mosaic", nn.Conv2d(3, 4, 2, stride=2)), # Downsample 2x2 to re-establish translation invariance ]) for i in range(depth): n_out = width n_in = width if i == 0: n_in = 4 if i == depth-1: n_out = 2*width layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad) layers["relu{}".format(i+1)] = nn.ReLU(inplace=True) self.main_processor = nn.Sequential(layers) self.residual_predictor = nn.Conv2d(width, 12, 1) self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3) self.fullres_processor = nn.Sequential(OrderedDict([ ("post_conv", nn.Conv2d(6, width, 3, padding=pad)), ("post_relu", nn.ReLU(inplace=True)), ("output", nn.Conv2d(width, 3, 1)), ])) # Load weights if pretrained: assert depth == 15, "pretrained bayer model has depth=15." assert width == 64, "pretrained bayer model has width=64." state_dict = th.load(_BAYER_WEIGHTS) self.load_state_dict(state_dict) def forward(self, mosaic): """Demosaicks a Bayer image. Args: mosaic (th.Tensor): input Bayer mosaic Returns: th.Tensor: the demosaicked image """ # 1/4 resolution features features = self.main_processor(mosaic) filters, masks = features[:, 0:self.width], features[:, self.width:2*self.width] filtered = filters * masks residual = self.residual_predictor(filtered) # Match mosaic and residual upsampled = self.upsampler(residual) cropped = _crop_like(mosaic, upsampled) packed = th.cat([cropped, upsampled], 1) # skip connection output = self.fullres_processor(packed) return output class XTransDemosaick(nn.Module): """Released version of the network. There is no downsampling here. """ def __init__(self, depth=11, width=64, pretrained=True, pad=False): super(XTransDemosaick, self).__init__() self.depth = depth self.width = width if pad: pad = 1 else: pad = 0 layers = OrderedDict([]) for i in range(depth): n_in = width n_out = width if i == 0: n_in = 3 layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad) layers["relu{}".format(i+1)] = nn.ReLU(inplace=True) self.main_processor = nn.Sequential(layers) self.fullres_processor = nn.Sequential(OrderedDict([ ("post_conv", nn.Conv2d(3+width, width, 3, padding=pad)), ("post_relu", nn.ReLU(inplace=True)), ("output", nn.Conv2d(width, 3, 1)), ])) # Load weights if pretrained: assert depth == 11, "pretrained xtrans model has depth=11." assert width == 64, "pretrained xtrans model has width=64." state_dict = th.load(_XTRANS_WEIGHTS) self.load_state_dict(state_dict) def forward(self, mosaic): """Demosaicks an XTrans image. Args: mosaic (th.Tensor): input XTrans mosaic Returns: th.Tensor: the demosaicked image """ features = self.main_processor(mosaic) cropped = _crop_like(mosaic, features) # Match mosaic and residual packed = th.cat([cropped, features], 1) # skip connection output = self.fullres_processor(packed) return output def _crop_like(src, tgt): """Crop a source image to match the spatial dimensions of a target. Args: src (th.Tensor or np.ndarray): image to be cropped tgt (th.Tensor or np.ndarray): reference image """ src_sz = np.array(src.shape) tgt_sz = np.array(tgt.shape) # Assumes the spatial dimensions are the last two crop = (src_sz[-2:]-tgt_sz[-2:]) crop_t = crop[0] // 2 crop_b = crop[0] - crop_t crop_l = crop[1] // 2 crop_r = crop[1] - crop_l crop //= 2 if (np.array([crop_t, crop_b, crop_r, crop_l])> 0).any(): return src[..., crop_t:src_sz[-2]-crop_b, crop_l:src_sz[-1]-crop_r] else: return src ================================================ FILE: demosaicnet/mosaic.py ================================================ """Utilities to make a mosaic mask and apply it to an image.""" import numpy as np import torch as th __all__ = ["bayer", "xtrans"] def bayer(im, return_mask=False): """Bayer mosaic. The patterned assumed is:: G r b G Args: im (np.array): image to mosaic. Dimensions are [c, h, w] return_mask (bool): if true return the binary mosaic mask, instead of the mosaic image. Returns: np.array: mosaicked image (if return_mask==False), or binary mask if (return_mask==True) """ numpy = False if type(im) == np.ndarray: numpy = True if type(im) == np.ndarray: mask = np.ones_like(im) else: mask = th.ones_like(im) # red mask[..., 0, ::2, 0::2] = 0 mask[..., 0, 1::2, :] = 0 # green mask[..., 1, ::2, 1::2] = 0 mask[..., 1, 1::2, ::2] = 0 # blue mask[..., 2, 0::2, :] = 0 mask[..., 2, 1::2, 1::2] = 0 if not numpy: # make it a constant for ONNX conversion mask = th.from_numpy(mask.cpu().detach().numpy()).to(im.device) if mask.shape[0] == 1: mask = mask.squeeze(0) # coreml hack if return_mask: return mask return im*mask def xtrans_cell(torch=False): g_pos = [(0,0), (0,2), (0,3), (0,5), (1,1), (1,4), (2,0), (2,2), (2,3), (2,5), (3,0), (3,2), (3,3), (3,5), (4,1), (4,4), (5,0), (5,2), (5,3), (5,5)] r_pos = [(0,4), (1,0), (1,2), (2,4), (3,1), (4,3), (4,5), (5,1)] b_pos = [(0,1), (1,3), (1,5), (2,1), (3,4), (4,0), (4,2), (5,4)] if torch: mask = th.zeros(3, 6, 6) else: mask = np.zeros((3, 6, 6), dtype=np.float32) for idx, coord in enumerate([r_pos, g_pos, b_pos]): for y, x in coord: mask[..., idx, y, x] = 1 return mask def xtrans(im, return_mask=False): """XTrans Mosaick. The patterned assumed is:: G b G G r G r G r b G b G b G G r G G r G G b G b G b r G r G r G G b G Args: im(np.array, th.Tensor): image to mosaic. Dimensions are [c, h, w] mask(bool): if true return the binary mosaic mask, instead of the mosaic image. Returns: np.array: mosaicked image (if mask==False), or binary mask if (mask==True) """ numpy = False if type(im) == np.ndarray: numpy = True mask = xtrans_cell(torch=False) # mask = np.zeros((3, 6, 6), dtype=np.float32) else: # mask = th.zeros(3, 6, 6).to(im.device) mask = xtrans_cell(torch=True).to(im.device) if len(im.shape) == 4: mask = mask.unsqueeze(0).repeat(im.shape[0], 1, 1, 1) h, w = im.shape[-2:] h = int(h) w = int(w) new_sz = [np.ceil(h / 6).astype(np.int32), np.ceil(w / 6).astype(np.int32)] sz = np.array(mask.shape) sz[:-2] = 1 sz[-2:] = new_sz sz = list(sz) if numpy: mask = np.tile(mask, sz) else: mask = mask.repeat(*sz) if return_mask: return mask return mask*im ================================================ FILE: demosaicnet/utils.py ================================================ """Helper functions.""" from abc import ABCMeta, abstractmethod import argparse import logging import os import re import signal import time import torch as th import numpy as np import torch as th from tqdm import tqdm log = logging.getLogger(__name__) def crop_like(src, tgt): """Crop a source image to match the spatial dimensions of a target. Assumes sizes are even. Args: src (th.Tensor or np.ndarray): image to be cropped tgt (th.Tensor or np.ndarray): reference image """ src_sz = np.array(src.shape) tgt_sz = np.array(tgt.shape) # Assumes the spatial dimensions are the last two delta = (src_sz[2:4]-tgt_sz[2:4]) crop = np.maximum(delta // 2, 0) # no negative crop crop2 = delta - crop if (crop > 0).any() or (crop2 > 0).any(): # NOTE: convert to ints to enable static slicing in ONNX conversion src_sz = [int(x) for x in src_sz] crop = [int(x) for x in crop] crop2 = [int(x) for x in crop2] return src[..., crop[0]:src_sz[-2]-crop2[0], crop[1]:src_sz[-1]-crop2[1]] else: return src class ExponentialMovingAverage(object): """Keyed tracker that maintains an exponential moving average for each key. Args: keys(list of str): keys to track. alpha(float): exponential smoothing factor (higher = smoother). """ def __init__(self, keys, alpha=0.999): self._is_first_update = {k: True for k in keys} self._alpha = alpha self._values = {k: 0 for k in keys} def __getitem__(self, key): return self._values[key] def update(self, key, value): if value is None: return if self._is_first_update[key]: self._values[key] = value self._is_first_update[key] = False else: self._values[key] = self._values[key] * \ self._alpha + value*(1.0-self._alpha) class BasicArgumentParser(argparse.ArgumentParser): """A basic argument parser with commonly used training options.""" def __init__(self, *args, **kwargs): super(BasicArgumentParser, self).__init__(*args, **kwargs) self.add_argument("--data", required=True, help="path to the training data.") self.add_argument("--val_data", help="path to the validation data.") self.add_argument("--config", help="path to a config file.") self.add_argument("--checkpoint_dir", required=True, help="Output directory where checkpoints are saved") self.add_argument("--init_from", help="path to a checkpoint from which to try and initialize the weights.") self.add_argument("--lr", type=float, default=1e-4, help="Learning rate for the optimizer") self.add_argument("--bs", type=int, default=4, help="Batch size") self.add_argument("--num_epochs", type=int, help="Number of epochs to train for") self.add_argument("--num_worker_threads", default=4, type=int, help="Number of threads that load data") # self.add_argument("--experiment_log", # help="csv file in which we log our experiments") self.add_argument("--cuda", action="store_true", dest="cuda", help="Force GPU") self.add_argument("--no-cuda", action="store_false", dest="cuda", help="Force CPU") self.add_argument("--server", help="Visdom server url") self.add_argument("--base_url", default="/", help="Visdom base url") self.add_argument("--env", default="main", help="Visdom environment") self.add_argument("--port", default=8097, type=int, help="Visdom server port") self.add_argument('--debug', dest="debug", action="store_true") self.set_defaults(cuda=th.cuda.is_available(), debug=False) class ModelInterface(metaclass=ABCMeta): """An adapter to run or train a model.""" def __init__(self): pass @abstractmethod def training_step(self, batch): """Training step given a batch of data. This should implement a forward pass of the model, compute gradients, take an optimizer step and return useful metrics and tensors for visualization and training callbacks. Args: batch (dict): batch of data provided by a data pipeline. Returns: train_step_data (dict): a dictionary of outputs. """ return {} def init_validation(self): """Initializes the quantities to be reported during validation. The default implementation is a no-op Returns: data (dict): initialized values """ log.warning("Running a ModelInterface validation initialization that was not overriden: this is a no-op.") data = {} return data def validation_step(self, batch, running_val_data): """Updates the running validataion with the current batch's results. The default implementation is a no-op Args: batch (dict): batch of data provided by a data pipeline. running_val_data (dict): current aggregates of the validation loop. Returns: updated_data (dict): new updated value for the running_val_data. """ log.warning("Running a ModelInterface validation step that was not overriden: this is a no-op.") return {} def __repr__(self): return self.__class__.__name__ class Checkpointer(object): """Save and restore model and optimizer variables. Args: root (string): path to the root directory where the files are stored. model (torch.nn.Module): meta (dict): a dictionary of training or configuration parameters useful to initialize the model upon loading the checkpoint again. optimizers (single or list of torch.optimizer): optimizers whose parameters will be checkpointed together with the model. schedulers (single or list of torch.optim.lr_scheduler): schedulers whose parameters will be checkpointed together with the model. prefix (str): unique prefix name in case several models are stored in the same folder. """ EXTENSION = ".pth" def __init__(self, root, model=None, meta=None, optimizers=None, schedulers=None, prefix=None): self.root = root self.model = model self.meta = meta # TODO(mgharbi): verify the prefixes are unique. if optimizers is None: log.info("No optimizer state will be stored in the " "checkpointer") else: # if we have only one optimizer, make it a list if not isinstance(optimizers, list): optimizers = [optimizers] self.optimizers = optimizers if schedulers is not None: if not isinstance(schedulers, list): schedulers = [schedulers] self.schedulers = schedulers log.debug(self) self.prefix = "" if prefix is not None: self.prefix = prefix def __repr__(self): return "Checkpointer with root at \"{}\"".format(self.root) def __path(self, path, prefix=None): if prefix is None: prefix = "" return os.path.join(self.root, prefix+os.path.splitext(path)[0] + ".pth") def save(self, path, extras=None): """Save model, metaparams and extras to relative path. Args: path (string): relative path to the file being saved (without extension). extras (dict): extra user-provided information to be saved with the model. """ if self.model is None: model_state = None else: log.debug("Saving model state dict") model_state = self.model.state_dict() opt_dicts = [] if self.optimizers is not None: for opt in self.optimizers: opt_dicts.append(opt.state_dict()) sched_dicts = [] if self.schedulers is not None: for s in self.schedulers: sched_dicts.append(s.state_dict()) filename = self.__path(path, prefix=self.prefix) os.makedirs(self.root, exist_ok=True) th.save({'model': model_state, 'meta': self.meta, 'extras': extras, 'optimizers': opt_dicts, 'schedulers': sched_dicts, }, filename) log.debug("Checkpoint saved to \"{}\"".format(filename)) def try_and_init_from(self, path): """Try to initialize the models's weights from an external checkpoint. Args: path(str): full path to the checkpoints to load model parameters from. """ log.info("Loading weights from foreign checkpoint {}".format(path)) if not os.path.exists(path): raise ValueError("Checkpoint {} does not exist".format(path)) chkpt = th.load(path, map_location=th.device("cpu")) if "model" not in chkpt.keys() or chkpt["model"] is None: raise ValueError("{} has no model saved".format(path)) mdl = chkpt["model"] for n, p in self.model.named_parameters(): if n in mdl: p2 = mdl[n] if p2.shape != p.shape: log.warning("Parameter {} ignored, checkpoint size does not match: {}, should be {}".format(n, p2.shape, p.shape)) continue log.debug("Parameter {} copied".format(n)) p.data.copy_(p2) else: log.warning("Parameter {} ignored, not found in source checkpoint.".format(n)) log.info("Weights loaded from foreign checkpoint {}".format(path)) def load(self, path): """Loads a checkpoint, updates the model and returns extra data. Args: path (string): path to the checkpoint file, relative to the root dir. Returns: extras (dict): extra information passed by the user at save time. meta (dict): metaparameters of the model passed at save time. """ filename = self.__path(path, prefix=None) chkpt = th.load(filename, map_location="cpu") # TODO: check behavior if self.model is not None and chkpt["model"] is not None: log.debug("Loading model state dict") self.model.load_state_dict(chkpt["model"]) if "optimizers" in chkpt.keys(): if self.optimizers is not None and chkpt["optimizers"] is not None: try: for opt, state in zip(self.optimizers, chkpt["optimizers"]): log.debug("Loading optimizers state dict for %s", opt) opt.load_state_dict(state) except: # We do not raise an error here, e.g. in case the user simply # changes optimizer log.warning("Could not load optimizer state dicts, " "starting from scratch") if "schedulers" in chkpt.keys(): if self.schedulers is not None and chkpt["schedulers"] is not None: try: for s, state in zip(self.schedulers, chkpt["schedulers"]): log.debug("Loading scheduler state dict for %s", s) s.load_state_dict(state) except: log.warning("Could not load scheduler state dicts, " "starting from scratch") log.debug("Loaded checkpoint \"{}\"".format(filename)) return tuple(chkpt[k] for k in ["extras", "meta"]) def load_latest(self): """Try to load the most recent checkpoint, skip failing files. Returns: extras (dict): extra user-defined information that was saved in the checkpoint. meta (dict): metaparameters of the model passed at save time. """ all_checkpoints = self.sorted_checkpoints() extras = None meta = None for f in all_checkpoints: try: extras, meta = self.load(f) return extras, meta except Exception as e: log.debug( "Could not load checkpoint \"{}\", moving on ({}).".format(f, e)) log.debug("No checkpoint found to load.") return extras, meta def sorted_checkpoints(self): """Get list of all checkpoints in root directory, sorted by creation date. Returns: chkpts (list of str): sorted checkpoints in the root folder. """ reg = re.compile(r"{}.*\{}".format(self.prefix, Checkpointer.EXTENSION)) if not os.path.exists(self.root): all_checkpoints = [] else: all_checkpoints = [f for f in os.listdir( self.root) if reg.match(f)] mtimes = [] for f in all_checkpoints: mtimes.append(os.path.getmtime(os.path.join(self.root, f))) mf = sorted(zip(mtimes, all_checkpoints)) chkpts = [m[1] for m in reversed(mf)] log.debug("Sorted checkpoints {}".format(chkpts)) return chkpts def delete(self, path): """Delete checkpoint at path. Args: path(str): full path to the checkpoint to delete. """ if path in self.sorted_checkpoints(): os.remove(os.path.join(self.root, path)) else: log.warning("Trying to delete a checkpoint that does not exists.") @staticmethod def load_meta(root, prefix=None): """Fetch model metadata without touching the saved parameters. This loads the metadata from the most recent checkpoint in the root directory. Args: root(str): path to the root directory containing the checkpoints prefix(str): unique prefix for the checkpoint to be loaded (e.g. if multiple models are saved in the same location) """ chkptr = Checkpointer(root, model=None, meta=None, prefix=prefix, optimizers=[]) log.debug("checkpoints: %s", chkptr.sorted_checkpoints()) _, meta = chkptr.load_latest() return meta class Trainer(object): """Implements a simple training loop with hooks for callbacks. Args: interface (ModelInterface): adapter to run forward and backward pass on the model being trained. Attributes: callbacks (list of Callbacks): hooks that will be called while training progresses. """ def __init__(self, interface): super(Trainer, self).__init__() self.callbacks = [] self.interface = interface log.debug("Creating {}".format(self)) signal.signal(signal.SIGINT, self.interrupt_handler) self._keep_running = True def interrupt_handler(self, signo, frame): """Stop the training process upon receiving a SIGINT (Ctrl+C).""" log.debug("interrupting run") self._keep_running = False def _stop(self): # Reset the run flag self._keep_running = True self.__training_end() def add_callback(self, callback): """Adds a callback to the list of training hooks. Args: callback(ttools.Callback): callback to add. """ log.debug("Adding callback {}".format(callback)) # pass an interface reference to the callback callback.model_interface = self.interface self.callbacks.append(callback) def train(self, dataloader, starting_epoch=None, num_epochs=None, val_dataloader=None): """Main training loop. This starts the training procedure. Args: dataloader (DataLoader): loader that yields training batches. starting_epoch (int, optional): index of the epoch we are starting from. num_epochs (int, optional): max number of epochs to run. val_dataloader (DataLoader, optional): loader that yields validation batches """ self.__training_start(dataloader) if starting_epoch is None: starting_epoch = 0 log.info("Starting taining from epoch %d", starting_epoch) epoch = starting_epoch while num_epochs is None or epoch < starting_epoch + num_epochs: self.__epoch_start(epoch) for batch_idx, batch in enumerate(dataloader): if not self._keep_running: self._stop() return self.__batch_start(batch_idx, batch) train_step_data = self.__training_step(batch) self.__batch_end(batch, train_step_data) self.__epoch_end() # TODO: allow validation at intermediate steps during one epoch # Validate if val_dataloader: with th.no_grad(): running_val_data = self.__validation_start(val_dataloader) for batch_idx, batch in enumerate(val_dataloader): if not self._keep_running: self._stop() return self.__val_batch_start(batch_idx, batch) running_val_data = self.__validation_step(batch, running_val_data) self.__val_batch_end(batch, running_val_data) self.__validation_end(running_val_data) epoch += 1 if not self._keep_running: self._stop() return self._stop() def __repr__(self): return "Trainer({}, {} callbacks)".format( self.interface, len(self.callbacks)) def __training_start(self, dataloader): for cb in self.callbacks: cb.training_start(dataloader) def __training_end(self): for cb in self.callbacks: cb.training_end() def __epoch_start(self, epoch_idx): for cb in self.callbacks: cb.epoch_start(epoch_idx) def __epoch_end(self): for cb in self.callbacks: cb.epoch_end() def __batch_start(self, batch_idx, batch): for cb in self.callbacks: cb.batch_start(batch_idx, batch) def __batch_end(self, batch, train_step_data): for cb in self.callbacks: cb.batch_end(batch, train_step_data) def __val_batch_start(self, batch_idx, batch): for cb in self.callbacks: cb.val_batch_start(batch_idx, batch) def __val_batch_end(self, batch, running_val_data): for cb in self.callbacks: cb.val_batch_end(batch, running_val_data) def __validation_start(self, dataloader): for cb in self.callbacks: cb.validation_start(dataloader) return self.interface.init_validation() def __validation_end(self, running_val_data): for cb in self.callbacks: cb.validation_end(running_val_data) def __training_step(self, batch): return self.interface.training_step(batch) def __validation_step(self, batch, running_val_data): return self.interface.validation_step(batch, running_val_data) class Callback(object): """Base class for all training callbacks. Attributes: epoch(int): current epoch index. batch(int): current batch index. datasize(int): number of batches in the training dataset. val_datasize(int): number of batches in the validation dataset. model_interface(ttools.ModelInterface): parent interface driving the training. """ def __repr__(self): return self.__class__.__name__ def __init__(self): super(Callback, self).__init__() self.epoch = 0 self.batch = 0 self.val_batch = 0 self.datasize = 0 self.val_datasize = 0 self.model_interface = None def training_start(self, dataloader): """Hook to execute code when the training begins. Args: dataloader(th.utils.data.Dataloader): a data loading class that provides batches of data for training. """ self.datasize = len(dataloader) def training_end(self): """Hook to execute code when the training ends.""" pass def epoch_start(self, epoch): """Hook to execute code when a new epoch starts. Args: epoch(int): index of the current epoch. Note: self.epoch is never incremented. Instead, it should be set by the caller. """ self.epoch = epoch def epoch_end(self): """Hook to execute code when an epoch ends. NOTE: self.epoch is not incremented. Instead it is set externally in the `epoch_start` method. """ pass def validation_start(self, dataloader): """Hook to execute code when a validation run starts. Args: dataloader(th.utils.data.Dataloader): a data loading class that provides batches of data for evaluation. """ self.val_datasize = len(dataloader) def validation_end(self, val_data): """Hook to execute code when a validation run ends.""" pass def batch_start(self, batch_idx, batch_data): """Hook to execute code when a training step starts. Args: batch_idx(int): index of the current batch. batch_data: a Tensor, tuple of dict with the current batch of data. """ self.batch = batch_idx def batch_end(self, batch_data, train_step_data): """Hook to execute code when a training step ends. Args: batch_data: a Tensor, tuple of dict with the current batch of data. train_setp_data(dict): outputs from the `training_step` of a ModelInterface. """ pass def val_batch_start(self, batch_idx, batch_data): """Hook to execute code when a validation step starts. Args: batch_idx(int): index of the current batch. batch_data: a Tensor, tuple of dict with the current batch of data. """ self.val_batch = batch_idx def val_batch_end(self, batch_data, running_val_data): """Hook to execute code when a validation step ends. Args: batch_data: a Tensor, tuple of dict with the current batch of data. train_setp_data(dict): running outputs produced by the `validation_step` of a ModelInterface. """ pass class CheckpointingCallback(Callback): """A callback that periodically saves model checkpoints to disk. Args: checkpointer (Checkpointer): actual checkpointer responsible for the I/O. interval (int, optional): minimum time in seconds between periodic checkpoints (within an epoch). There is not periodic checkpoint if this value is None. max_files (int, optional): maximum number of periodic checkpoints to keep on disk. max_epochs (int, optional): maximum number of epoch checkpoints to keep on disk. """ PERIODIC_PREFIX = "periodic_" EPOCH_PREFIX = "epoch_" def __init__(self, checkpointer, interval=600, max_files=5, max_epochs=10): super(CheckpointingCallback, self).__init__() self.checkpointer = checkpointer self.interval = interval self.max_files = max_files self.max_epochs = max_epochs self.last_checkpoint_time = time.time() def epoch_end(self): """Save a checkpoint at the end of each epoch.""" super(CheckpointingCallback, self).epoch_end() path = "{}{}".format(CheckpointingCallback.EPOCH_PREFIX, self.epoch) self.checkpointer.save(path, extras={"epoch": self.epoch + 1}) self.__purge_old_files() def training_end(self): super(CheckpointingCallback, self).training_end() self.checkpointer.save("training_end", extras={"epoch": self.epoch + 1}) def batch_end(self, batch_data, train_step_data): """Save a periodic checkpoint if requested.""" super(CheckpointingCallback, self).batch_end( batch_data, train_step_data) if self.interval is None: # We skip periodic checkpoints return now = time.time() delta = now - self.last_checkpoint_time if delta < self.interval: # last checkpoint is too recent return log.debug("Periodic checkpoint") self.last_checkpoint_time = now filename = "{}{}".format(CheckpointingCallback.PERIODIC_PREFIX, time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())) self.checkpointer.save(filename, extras={"epoch": self.epoch}) self.__purge_old_files() def __purge_old_files(self): """Delete checkpoints that are beyond the max to keep.""" chkpts = self.checkpointer.sorted_checkpoints() p_chkpts = [] e_chkpts = [] for c in chkpts: if c.startswith(self.checkpointer.prefix + CheckpointingCallback.PERIODIC_PREFIX): p_chkpts.append(c) if c.startswith(self.checkpointer.prefix + CheckpointingCallback.EPOCH_PREFIX): e_chkpts.append(c) # Delete periodic checkpoints if self.max_files is not None and len(p_chkpts) > self.max_files: for c in p_chkpts[self.max_files:]: log.debug("CheckpointingCallback deleting {}".format(c)) self.checkpointer.delete(c) # Delete older epochs if self.max_epochs is not None and len(e_chkpts) > self.max_epochs: for c in e_chkpts[self.max_epochs:]: log.debug("CheckpointingCallback deleting (epoch) {}".format(c)) self.checkpointer.delete(c) class KeyedCallback(Callback): """An abstract Callback that performs the same action for all keys in a list. The keys (resp. val_keys) are used to access the backward_data (resp. validation_data) produced by a ModelInterface. Args: keys (list of str or None): list of keys whose values will be logged during training. val_keys (list of str or None): list of keys whose values will be logged during validation """ def __init__(self, keys=None, val_keys=None, smoothing=0.999): super(KeyedCallback, self).__init__() if keys is None and val_keys is None: log.warning("Logger has no keys, nor val_keys") if keys is None: self.keys = [] else: self.keys = keys if val_keys is None: self.val_keys = [] else: self.val_keys = val_keys # Only smooth the training keys self.ema = ExponentialMovingAverage(self.keys, alpha=smoothing) def batch_end(self, batch_data, train_step_data): for k in self.keys: self.ema.update(k, train_step_data[k]) class ProgressBarCallback(KeyedCallback): """A progress bar optimization logger. Args: label(str): a prefix label to identify the experiment currently running. """ def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None): super(ProgressBarCallback, self).__init__( keys=keys, val_keys=val_keys, smoothing=smoothing) self.pbar = None if label is None: self.label = "" else: self.label = label def training_start(self, dataloader): super(ProgressBarCallback, self).training_start(dataloader) print("Training start") def training_end(self): super(ProgressBarCallback, self).training_end() print("Training ends") def epoch_start(self, epoch): super(ProgressBarCallback, self).epoch_start(epoch) desc = "Epoch {}".format(self.epoch) if self.label is not None: desc = "%s | " % self.label + desc self.pbar = tqdm(total=self.datasize, unit=" batches", desc=desc) def epoch_end(self): super(ProgressBarCallback, self).epoch_end() self.pbar.close() self.pbar = None def validation_start(self, dataloader): super(ProgressBarCallback, self).validation_start(dataloader) print("Running validation...") self.pbar = tqdm(total=len(dataloader), unit=" batches", desc="Validation {}".format(self.epoch)) def val_batch_end(self, batch, running_val_data): self.pbar.update(1) def validation_end(self, val_data): super(ProgressBarCallback, self).validation_end(val_data) self.pbar.close() self.pbar = None s = " "*ProgressBarCallback.TABSTOPS + "Validation {} | ".format( self.epoch) for k in self.val_keys: s += "{} = {:.2f} ".format(k, val_data[k]) print(s) def batch_end(self, batch_data, train_step_data): super(ProgressBarCallback, self).batch_end(batch_data, train_step_data) d = {} for k in self.keys: d[k] = self.ema[k] self.pbar.update(1) self.pbar.set_postfix(d) TABSTOPS = 2 ================================================ FILE: demosaicnet/version.py ================================================ __version__ = "0.0.14" ================================================ FILE: docs/.gitignore ================================================ build ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/source/conf.py ================================================ # -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys dirname = os.path.dirname rootdir = dirname(dirname(dirname(os.path.abspath(__file__)))) sys.path.insert(0, rootdir) # -- Project information ----------------------------------------------------- project = 'demosaicnet' copyright = '2019, Michael Gharbi' author = 'Michael Gharbi' import re with open(os.path.join(rootdir, "demosaicnet", "version.py")) as fid: try: __version__, = re.findall( '__version__ = "(.*)"', fid.read() ) except: raise ValueError("could not find version number") # The full version, including alpha/beta/rc tags release = __version__ # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. autodoc_mock_imports = ["torch", "numpy", "imageio", "torchvision", "wget"] # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.doctest', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.viewcode', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['ytemplates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['ystatic'] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'demosaicnetdoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'demosaicnet.tex', 'demosaicnet Documentation', 'Michael Gharbi', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'demosaicnet', 'demosaicnet Documentation', [author], 1) ] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'demosaicnet', 'demosaicnet Documentation', author, 'demosaicnet', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] # -- Extension configuration ------------------------------------------------- # -- Options for todo extension ---------------------------------------------- # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True ================================================ FILE: docs/source/dataset.rst ================================================ Dataset ======= .. automodule:: demosaicnet.dataset :members: ================================================ FILE: docs/source/helpers.rst ================================================ Helpers ======= .. automodule:: demosaicnet.mosaic :members: ================================================ FILE: docs/source/index.rst ================================================ .. demosaicnet documentation master file, created by sphinx-quickstart on Thu Mar 14 13:14:16 2019. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to demosaicnet's documentation! ======================================= .. toctree:: :maxdepth: 2 :caption: Contents: models dataset helpers .. automodule:: demosaicnet :members: Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/source/models.rst ================================================ Models ====== .. automodule:: demosaicnet.modules :members: ================================================ FILE: requirements.txt ================================================ imageio==2.31.1 numpy==1.25.0 setuptools==49.2.1 torch==2.0.1 tqdm==4.65.0 wget ================================================ FILE: scripts/demosaicnet_demo.py ================================================ #!/usr/bin/env python """Demo script on using demosaicnet for inference.""" import os from pkg_resources import resource_filename import argparse import numpy as np import torch as th import imageio import demosaicnet _TEST_INPUT = resource_filename("demosaicnet", 'data/test_input.png') def main(args): print("Running demosaicnet demo on {}, outputing to {}".format(_TEST_INPUT, args.output)) bayer = demosaicnet.BayerDemosaick() xtrans = demosaicnet.XTransDemosaick() # Load some ground-truth image gt = imageio.imread(args.input).astype(np.float32) / 255.0 gt = np.array(gt) h, w, _ = gt.shape # Make the image size a multiple of 6 (for xtrans pattern) gt = gt[:6*(h//6), :6*(w//6)] # Network expects channel first gt = np.transpose(gt, [2, 0, 1]) mosaicked = demosaicnet.bayer(gt) xmosaicked = demosaicnet.xtrans(gt) # Run the model (expects batch as first dimension) mosaicked = th.from_numpy(mosaicked).unsqueeze(0) xmosaicked = th.from_numpy(xmosaicked).unsqueeze(0) with th.no_grad(): # inference only out = bayer(mosaicked).squeeze(0).cpu().numpy() out = np.clip(out, 0, 1) xout = xtrans(xmosaicked).squeeze(0).cpu().numpy() xout = np.clip(xout, 0, 1) print("done") os.makedirs(args.output, exist_ok=True) output = args.output imageio.imsave(os.path.join(output, "bayer_mosaick.tif"), mosaicked.squeeze(0).permute([1, 2, 0])) imageio.imsave(os.path.join(output, "bayer_result.tif"), np.transpose(out, [1, 2, 0])) imageio.imsave(os.path.join(output, "xtrans_mosaick.tif"), xmosaicked.squeeze(0).permute([1, 2, 0])) imageio.imsave(os.path.join(output, "xtrans_result.tif"), np.transpose(xout, [1, 2, 0])) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("output", help="output directory") parser.add_argument("--input", default=_TEST_INPUT, help="test input, uses the default test input provided if no argument.") args = parser.parse_args() main(args) ================================================ FILE: scripts/eval.py ================================================ #!/bin/env python """Evaluate a demosaicking model.""" import argparse import logging import torch as th from torch.utils.data import DataLoader import demosaicnet log = logging.getLogger(__name__) class PSNR(th.nn.Module): def __init__(self): super(PSNR, self).__init__() self.mse = th.nn.MSELoss() def forward(self, out, ref): mse = self.mse(out, ref) return -10*th.log10(mse) def main(args): """Entrypoint to the training.""" # Load model parameters from checkpoint, if any meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir) if meta is None: log.warning("No checkpoint found at %s, aborting.", args.checkpoint_dir) return data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.TEST_SUBSET) dataloader = DataLoader( data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True) if meta["mode"] == demosaicnet.BAYER_MODE: model = demosaicnet.BayerDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) elif meta["mode"] == demosaicnet.XTRANS_MODE: model = demosaicnet.XTransDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) checkpointer = demosaicnet.utils.Checkpointer(args.checkpoint_dir, model, meta=meta) checkpointer.load_latest() # Resume from checkpoint, if any. # No need for gradients for p in model.parameters(): p.requires_grad = False mse_fn = th.nn.MSELoss() psnr_fn = PSNR() device = "cpu" if th.cuda.is_available(): device = "cuda" log.info("Using CUDA") count = 0 mse = 0.0 psnr = 0.0 for idx, batch in enumerate(dataloader): mosaic = batch[0].to(device) target = batch[1].to(device) output = model(mosaic) target = demosaicnet.utils.crop_like(target, output) output = th.clamp(output, 0, 1) psnr_ = psnr_fn(output, target).item() mse_ = mse_fn(output, target).item() psnr += psnr_ mse += mse_ count += 1 log.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_) mse /= count psnr /= count log.info("-----------------------------------") log.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("data", help="root directory for the demosaicnet dataset.") parser.add_argument("checkpoint_dir", help="directory with the model checkpoints.") args = parser.parse_args() main(args) ================================================ FILE: scripts/train.py ================================================ #!/bin/env python """Train a demosaicking model.""" import logging import torch as th from torch.utils.data import DataLoader import demosaicnet log = logging.getLogger(__name__) class PSNR(th.nn.Module): def __init__(self): super(PSNR, self).__init__() self.mse = th.nn.MSELoss() def forward(self, out, ref): mse = self.mse(out, ref) return -10*th.log10(mse+1e-12) class DemosaicnetInterface(demosaicnet.utils.ModelInterface): """Training and validation interface. Args: model(th.nn.Module): model to train. lr(float): learning rate for the optimizer. cuda(bool): whether to use CPU or GPU for training. """ def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()): self.model = model self.device = "cpu" if cuda: self.device = "cuda" self.model.to(self.device) self.opt = th.optim.Adam(self.model.parameters(), lr=lr) self.loss = th.nn.MSELoss() self.psnr = PSNR() def training_step(self, batch): fwd_data = self.forward(batch) bwd_data = self.backward(batch, fwd_data) return bwd_data def forward(self, batch): mosaic = batch[0] mosaic = mosaic.to(self.device) output = self.model(mosaic) return output def backward(self, batch, fwd_output): target = batch[1].to(self.device) # remove boundaries to match output size target = demosaicnet.utils.crop_like(target, fwd_output) loss = self.loss(fwd_output, target) self.opt.zero_grad() loss.backward() self.opt.step() with th.no_grad(): psnr = self.psnr(th.clamp(fwd_output, 0, 1), target) return {"loss": loss.item(), "psnr": psnr.item()} def init_validation(self): return {"count": 0, "psnr": 0} def update_validation(self, batch, fwd_output, running_data): target = batch[1].to(self.device) # remove boundaries to match output size target = demosaicnet.utils.crop_like(target, fwd_output) with th.no_grad(): psnr = self.psnr(th.clamp(fwd_output, 0, 1), target) n = target.shape[0] return { "psnr": running_data["psnr"] + psnr.item()*n, "count": running_data["count"] + n } def finalize_validation(self, running_data): return { "psnr": running_data["psnr"] / running_data["count"] } def main(args): """Entrypoint to the training.""" # Load model parameters from checkpoint, if any meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir) if meta is None: log.info("No metadata or checkpoint, " "parsing model parameters from command line.") meta = { "depth": args.depth, "width": args.width, "mode": args.mode, } data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.TRAIN_SUBSET) dataloader = DataLoader( data, batch_size=args.bs, num_workers=args.num_worker_threads, pin_memory=True, shuffle=True) val_dataloader = None if args.val_data: val_data = demosaicnet.Dataset(args.data, download=False, mode=meta["mode"], subset=demosaicnet.VAL_SUBSET) val_dataloader = DataLoader( val_data, batch_size=args.bs, num_workers=1, pin_memory=True, shuffle=False) if meta["mode"] == demosaicnet.BAYER_MODE: model = demosaicnet.BayerDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) elif meta["mode"] == demosaicnet.XTRANS_MODE: model = demosaicnet.XTransDemosaick(depth=meta["depth"], width=meta["width"], pretrained=True, pad=False) checkpointer = demosaicnet.utils.Checkpointer( args.checkpoint_dir, model, meta=meta) interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda) checkpointer.load_latest() # Resume from checkpoint, if any. trainer = demosaicnet.utils.Trainer(interface) keys = ["loss", "psnr"] val_keys = ["psnr"] trainer.add_callback(demosaicnet.utils.ProgressBarCallback( keys=keys, val_keys=val_keys)) trainer.add_callback(demosaicnet.utils.CheckpointingCallback( checkpointer, max_files=8, interval=3600, max_epochs=10)) if args.cuda: log.info("Training with CUDA enabled") else: log.info("Training on CPU") trainer.train( dataloader, num_epochs=args.num_epochs, val_dataloader=val_dataloader) if __name__ == "__main__": parser = demosaicnet.utils.BasicArgumentParser() parser.add_argument("--depth", default=15, help="number of net layers.") parser.add_argument("--width", default=64, help="number of features per layer.") parser.add_argument("--mode", default=demosaicnet.BAYER_MODE, choices=[demosaicnet.BAYER_MODE, demosaicnet.XTRANS_MODE], help="number of features per layer.") args = parser.parse_args() main(args) ================================================ FILE: setup.py ================================================ import re import setuptools with open('demosaicnet/version.py') as fid: try: __version__, = re.findall( '__version__ = "(.*)"', fid.read() ) except: raise ValueError("could not find version number") with open("README.md", "r") as fh: long_description = fh.read() setuptools.setup( name='demosaicnet', version=__version__, scripts=["scripts/demosaicnet_demo.py"], author="Michaël Gharbi", author_email="gharbi@csail.mit.edu", description="Minimal implementation of Deep Joint Demosaicking and Denoising [Gharbi2016]", long_description=long_description, url="https://github.com/mgharbi/", packages = setuptools.find_packages(exclude=["tests"]), include_package_data=True, install_requires=["wget", "tqdm", "torch", "imageio", "numpy"], classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX", ], )