Repository: mgharbi/demosaicnet
Branch: master
Commit: 959e9d163097
Files: 24
Total size: 82.4 KB
Directory structure:
gitextract_p3udv4rd/
├── .gitignore
├── LICENSE
├── MANIFEST.in
├── Makefile
├── README.md
├── demosaicnet/
│ ├── .gitignore
│ ├── __init__.py
│ ├── dataset.py
│ ├── modules.py
│ ├── mosaic.py
│ ├── utils.py
│ └── version.py
├── docs/
│ ├── .gitignore
│ ├── Makefile
│ └── source/
│ ├── conf.py
│ ├── dataset.rst
│ ├── helpers.rst
│ ├── index.rst
│ └── models.rst
├── requirements.txt
├── scripts/
│ ├── demosaicnet_demo.py
│ ├── eval.py
│ └── train.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
output/
dist/
demosaicnet.egg-info
build/
data
.DS_Store
================================================
FILE: LICENSE
================================================
MIT License
Deep Joint Demosaicking and Denoising
Siggraph Asia 2016
Michael Gharbi, Gaurav Chaurasia, Sylvain Paris, Fredo Durand
Copyright (c) 2016 Michael Gharbi
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
include demosaicnet/data/bayer.pth
include demosaicnet/data/xtrans.pth
include demosaicnet/data/test_input.png
================================================
FILE: Makefile
================================================
test:
py.test tests
.PHONY: docs
docs:
$(MAKE) -C docs html
clean:
python setup.py clean
rm -rf build demosaicnet.egg-info dist .pytest_cache
distribution:
python setup.py sdist bdist_wheel
twine check dist/*
test_upload:
twine upload --repository-url https://test.pypi.org/legacy/ dist/*
upload_distribution:
twine upload dist/*
================================================
FILE: README.md
================================================
# Deep Joint Demosaicking and Denoising
SiGGRAPH Asia 2016
Michaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris Frédo Durand
A minimal pytorch implementation of "Deep Joint Demosaicking and Denoising" [Gharbi2016]
# Installation
From this repo:
```shell
python setup.py install
```
Using pip:
```shell
pip install demosaicnet
```
Then run the demo script with:
```shell
python scripts/demosaicnet_demo.py output
```
To train a dummy model on the demo dataset provided, run:
```shell
python scripts/train.py --data demosaicnet/data/dummy_dataset --checkpoint_dir ckpt
```
To build and update the whee:
```shell
pip install wheel twine
make distribution
make upload_distribution
```
# FAQ
- **How is noise handled? Where is the pretrained model?** The noise-aware model is not implementation, see the earlier Caffe implementation for that <https://github.com/mgharbi/demosaicnet_caffe>
- **How do I train this?** The script `scripts/train.py` is a good start to setup your training job, but I haven't tested it yet, I recommend rolling your own.
================================================
FILE: demosaicnet/.gitignore
================================================
__pycache__
================================================
FILE: demosaicnet/__init__.py
================================================
from .modules import BayerDemosaick
from .modules import XTransDemosaick
from .mosaic import xtrans
from .mosaic import bayer
from .mosaic import xtrans_cell
from .dataset import *
from . import utils
================================================
FILE: demosaicnet/dataset.py
================================================
"""Dataset loader for demosaicnet."""
import os
import subprocess
import shutil
import hashlib
import logging
import numpy as np
from imageio import imread
from torch.utils.data import Dataset as TorchDataset
import wget
from .mosaic import bayer, xtrans
__all__ = ["BAYER_MODE", "XTRANS_MODE", "Dataset",
"TRAIN_SUBSET", "VAL_SUBSET", "TEST_SUBSET"]
log = logging.getLogger(__name__)
BAYER_MODE = "bayer"
"""Applies a Bayer mosaic pattern."""
XTRANS_MODE = "xtrans"
"""Applies an X-Trans mosaic pattern."""
TRAIN_SUBSET = "train"
"""Loads the 'train' subset of the data."""
VAL_SUBSET = "val"
"""Loads the 'val' subset of the data."""
TEST_SUBSET = "test"
"""Loads the 'test' subset of the data."""
class Dataset(TorchDataset):
"""Dataset of challenging image patches for demosaicking.
Args:
download(bool): if True, automatically download the dataset.
mode(:class:`BAYER_MODE` or :class:`XTRANS_MODE`): mosaic pattern to apply to the data.
subset(:class:`TRAIN_SUBET`, :class:`VAL_SUBSET` or :class:`TEST_SUBSET`): subset of the data to load.
"""
def __init__(self, root, download=False,
mode=BAYER_MODE, subset="train"):
super(Dataset, self).__init__()
self.root = os.path.abspath(root)
if subset not in [TRAIN_SUBSET, VAL_SUBSET, TEST_SUBSET]:
raise ValueError("Dataset subet should be '%s', '%s' or '%s', got"
" %s" % (TRAIN_SUBSET, TEST_SUBSET, VAL_SUBSET,
subset))
if mode not in [BAYER_MODE, XTRANS_MODE]:
raise ValueError("Dataset mode should be '%s' or '%s', got"
" %s" % (BAYER_MODE, XTRANS_MODE, mode))
self.mode = mode
listfile = os.path.join(self.root, subset, "filelist.txt")
log.debug("Reading image list from %s", listfile)
if not os.path.exists(listfile):
if download:
_download(self.root)
else:
log.error("Filelist %s not found", listfile)
raise ValueError("Filelist %s not found" % listfile)
else:
log.debug("No need no download the data, filelist exists.")
self.files = []
with open(listfile, "r") as fid:
for fname in fid.readlines():
self.files.append(os.path.join(self.root, subset, fname.strip()))
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
"""Fetches a mosaic / demosaicked pair of images.
Returns
mosaic(np.array): with size [3, h, w] the mosaic data with separated color channels.
img(np.array): with size [3, h, w] the groundtruth image.
"""
fname = self.files[idx]
img = np.array(imread(fname)).astype(np.float32) / (2**8-1)
img = np.transpose(img, [2, 0, 1])
if self.mode == BAYER_MODE:
mosaic = bayer(img)
else:
mosaic = xtrans(img)
return mosaic, img
CHECKSUMS = {
'datasets.z01': 'da46277afe85d3a91c065e4751fb8175',
'datasets.z02': 'e274a9646323d954b00094ea424e4e4c',
'datasets.z03': 'e071cc595a99a5aa4545d06350e5165f',
'datasets.z04': 'c3d2f229834569cd5ae6e2d1467c4a95',
'datasets.z05': 'daf90136c7b1ee4bb4653e9b6bf4b67d',
'datasets.z06': '87e85d2854d40116e066b28e5a8750cc',
'datasets.z07': 'a0b2854bf025c87c0bfdf83ce9aa9055',
'datasets.z08': '62125ccf29cd4b182dd81a4bb82f94c4',
'datasets.z09': 'f990f8a5090d586f2f31e61b5e6434bd',
'datasets.z10': '41ecf8d8b7d981604d661d258bf988db',
'datasets.z100': '923a536ece64cd036eec4a13156531c8',
'datasets.z101': '44a936558af2e830fdf65d9acb3960ab',
'datasets.z102': 'b24870482b41200ab7b91f0bcd3ed718',
'datasets.z103': 'a85521c1fe0b8d2d1a074b0b52bf9db1',
'datasets.z104': 'aacc7a81ec9e9a7849e3a45b1cb12f7c',
'datasets.z105': '19b62c0f0ae008b77df6465182f43dc4',
'datasets.z106': '4b0c414ce5825a9e2249e5810f0e55f0',
'datasets.z107': '7f6df7fea899a656fcde898225890daf',
'datasets.z108': '16a877c357f112367200a2534b5e54f3',
'datasets.z109': '9180129bf9c204184f729bdf1c284c9c',
'datasets.z11': 'cff5b0e9950933fa9dd6ced8ffb9528f',
'datasets.z110': 'a95a6fbfd32d90058b9e4b9f0645c646',
'datasets.z111': 'f3c7894a7d04178ca417dd5ed3a9e649',
'datasets.z112': 'd46a73703a72c07137424cad90c9c0bd',
'datasets.z113': '04b3421e465c5ef8bd64fc23730b58f7',
'datasets.z114': '9c31d2d94bd6f1ba321039b18c462175',
'datasets.z115': '427a3a6f3f936b0ba435da35be3e4bc3',
'datasets.z116': 'c633a66e7644d7d8e8148d651b76d93f',
'datasets.z117': 'cf316d3acaf301fc7b2b7e250ef734dd',
'datasets.z118': 'c3b53a604492499930dfd000d7d33fa4',
'datasets.z119': 'e64ede2179c589abd9e587347d9ba3b0',
'datasets.z12': '17b70298245ae7965f4e4b4fb01f19cc',
'datasets.z120': 'ece09ee0bc30eab71f06716cca393029',
'datasets.z121': '79db008803bf58593df6c32db8c0b3d0',
'datasets.z122': '647b151eb30a44d123ce9ddfbb380094',
'datasets.z123': '5065f755fe61c666d6ed28096e4047a3',
'datasets.z124': '6306215855e30112c495291d5928e0b7',
'datasets.z125': 'a55c6e31e7ad42170016a15791e25134',
'datasets.z126': '582eb81f0251840507ca0b53e624b1e0',
'datasets.z127': '3beefb769a01481a8bc7ae39bc2f539b',
'datasets.z128': 'fac96b38f96ea364ea51020386597b5d',
'datasets.z129': '3aaca8ce2b67d2c1fe764ae2b306d17f',
'datasets.z13': '7ec2aa595441d9f698a46d707f299e8c',
'datasets.z130': '9f94105dcc39cf5421b2df2532c06ec9',
'datasets.z131': '5de2901388a2e531d6874cdaf23bf15e',
'datasets.z132': '6851ed45004ae6864892532d1ad44b20',
'datasets.z133': '0f6b417d57f9bdc9fa91d85ff5e3378d',
'datasets.z134': '0602b14c8828f7a9fed92713047c695f',
'datasets.z135': '3ab79fd4da5c4c5b5a7896a189ec43a7',
'datasets.z136': 'dd05152db786d8189cdb419ac5d0018a',
'datasets.z137': 'b4e97abef22ee8b81ac232760d9f539d',
'datasets.z138': '4ce939f1fa4f3e110e989db07d53d33f',
'datasets.z139': 'a096add471cc5e074852c063eec3863a',
'datasets.z14': '08571b6629b8856813fb35b45bbc082c',
'datasets.z140': '3d84e8cc84ab26969c5239be60222ad3',
'datasets.z141': 'd77062f59ca9957d33c9a671657fd795',
'datasets.z142': '82901d01917006348deb89aa37fe3629',
'datasets.z143': '736f77856f0854b26fa951479691df8f',
'datasets.z144': '55c44320975f4278a8837085c5e02eda',
'datasets.z145': '087d3b7634bf4720a916767d5c6b7d70',
'datasets.z146': '5659d6f0495dcdc5f5d98bf2efaaa09b',
'datasets.z147': '66dd69b2f9348e3c0d0c93c3e61416dd',
'datasets.z148': 'f3fc8f15aceb0f9bf04d786b894caa44',
'datasets.z149': '3863be1d2b130f79399432cfc1281c2e',
'datasets.z15': 'cc57e0c4466575436f670ac3e07ad2f3',
'datasets.z150': '09750f2019da9ff7132b904b8bcbd895',
'datasets.z151': 'b1573f086c0f7d1fdf249a8e3a9bb178',
'datasets.z152': '1a2d4374aea1e22c0b676a6a7eac49ec',
'datasets.z153': 'b24320708d2019ed71ab16055e971b1d',
'datasets.z154': '7ba27e1946afa610e131f3afefe78326',
'datasets.z155': '2f02c8b5470be4cb6b53e4c9e512394e',
'datasets.z156': 'fa4f0977409f181820bf78174257d657',
'datasets.z157': '6736b97a29d1393ec65ddc9376a06369',
'datasets.z158': 'b4b72842b13ec3877bc530ca2470a0db',
'datasets.z159': 'fa0dfa57c9d299175719bbbdf319c935',
'datasets.z16': '02c213b708e2ee7ebd68464dfb2279fc',
'datasets.z160': '6edbb9dc7fa6d12d2e21631ae14eaa8c',
'datasets.z161': '4ab093ce5af2726e9ee71fdf1943e8e2',
'datasets.z162': '4e21401db9f9884d953df20381c5fd97',
'datasets.z163': '8d39f0ed1a1d9b5de22b583d00081522',
'datasets.z164': '40fe425e2c5e89b87b44a6e9735590d5',
'datasets.z165': '9552ff9b03e2dbb45befc9d1cc99ad81',
'datasets.z166': '6c0098a36d6827aea846d8522c578751',
'datasets.z167': 'ce6b8b981d92f5a61f2ec40089a400a2',
'datasets.z168': '60f6e16a3e5e409a3fc89edd3e0034d5',
'datasets.z169': '75eb975a10d5cbf136796651a1789b42',
'datasets.z17': 'c92ae62205eaef02db27996f0dc6c282',
'datasets.z170': '287966840fff015ed36da3a08a18ebfb',
'datasets.z171': 'e08193b722af492a78ac36a3125ac8d9',
'datasets.z172': '32c795461f194c38b25047faeb46fdb5',
'datasets.z173': '58fbb396dbbb902ac2f2c43722573200',
'datasets.z174': '506fee2b982ee81689f3fe4d89133cac',
'datasets.z175': '9d553e31b07b23e30c427800168eec6b',
'datasets.z176': '0f6e3048824ead093d3127434ac83a72',
'datasets.z177': 'ece15c004fa708849295987b8b1aba9c',
'datasets.z178': 'd9db14d92d56ae2970798417030b5bb4',
'datasets.z179': '25ddbe866d0a6b9cebe8d90f7b801fa6',
'datasets.z18': 'f55ddc31cf203f495e352182a5bbadc3',
'datasets.z180': 'd2fc49c68c77d1da592aff4ab90c0915',
'datasets.z181': '5a4a635b1f3535311c6caecf4ab3ba80',
'datasets.z182': 'df51725daff3edfc377a5f6bc158ec3a',
'datasets.z183': '626add199ec4f263ff278d5392f41c9c',
'datasets.z184': '5069483fd064ee5e8c24a240e6ee7736',
'datasets.z185': '589249e98db0a4ded1d3e4acefd07509',
'datasets.z186': 'e4415c64463ef16bceb9d2e2fa934d71',
'datasets.z187': 'e070cbaaf88a1085964244f6505c713c',
'datasets.z188': '71b38eb51edff8b049a302bacbe344d6',
'datasets.z189': '8fa7a8b58c9e7cb9e86bfd0ca5f6d2ea',
'datasets.z19': '6c34cd0e39a33737983ebf89f6cabf5f',
'datasets.z190': 'daad0ea7c87d0935e014a370c38cc926',
'datasets.z191': 'c355e9ee9d0afa67faa34739b7f7cf79',
'datasets.z192': 'c97d5a784625795cdf3c36c337986afe',
'datasets.z193': '5f3b8425e215798c9e454cdbe586db90',
'datasets.z194': '31b5d74c1cbbabbf58ee470467b40d12',
'datasets.z195': '4c65958343bc2ea1a28e779ee7e5e498',
'datasets.z196': '26ab3664e62c7fd5d0be673c45dd0d93',
'datasets.z197': '32d690086b6e9f05e3ced3a126af870c',
'datasets.z198': '323071827db89626c9f186455fbb38c9',
'datasets.z199': '72475a8500be1ff21407a66f0e2e91a6',
'datasets.z20': '4161d6eda0ca5ed9500f953f789a25b2',
'datasets.z200': '44a863b9c9760cb87a23f1422f242c0d',
'datasets.z201': 'dc38b455fa45e3ef0d5f06397507982e',
'datasets.z202': 'b9ba231b317b008602f9472325b40e65',
'datasets.z203': '36f4afc46258d80be626040956550028',
'datasets.z204': '5c522fbdfe1f9d449c9189173f2ed2b2',
'datasets.z205': '58d39995017eed2c4abcf9fcfd07e695',
'datasets.z206': '2efdfc2abc834f0f0f1cabe10423f865',
'datasets.z207': '04ead7536e5c13936c724f644ab1cb3a',
'datasets.z208': '9e3e0a02a07bebcc7cdb62a0ad047946',
'datasets.z209': 'c1fc44cd8b6f50955c8b3b317155ecb6',
'datasets.z21': '8669b8bb9fa90628d4423c45648868b2',
'datasets.z210': '210df79f8434bdb4e2a7d12c4078d972',
'datasets.z211': 'f078d2d8a14b6c59f58c67865bbc3334',
'datasets.z212': 'ef7d08a6cc39f6cb96b631ca61b440a0',
'datasets.z213': '723057f7619d8820f142944f55f9542b',
'datasets.z214': '2ba38ca8561b51710f660c03f84c0eb1',
'datasets.z215': '92a6e97dfaff295110ddead242ebe932',
'datasets.z216': '05f40901ae70f73b3c099fcdd4ca945e',
'datasets.z217': 'b690ed3e8c6ba9f8bba9154d7e8f7ece',
'datasets.z218': 'e290ca6f5573579df9f3aa7c5158891e',
'datasets.z219': 'a8e4626968f089163179f30066ce732c',
'datasets.z22': '0f90463abdc8f0f81d81249302cf2d09',
'datasets.z220': '3ecd2c0c855505d2957046d784944fce',
'datasets.z221': '6c9c28287fbadcab2ce777ef3134e5d6',
'datasets.z222': '29361a77f05e5e68113fc23e11b54b4d',
'datasets.z223': '9214257b9a87c0037e88709addba8948',
'datasets.z224': '7596a516fb7e308f33a81c5b3c36810a',
'datasets.z225': 'c1dc079f5261a976b1bd7f5c05cd4a02',
'datasets.z226': '5b87815b0ccc5cacec83a399a52874aa',
'datasets.z227': '6bed353dc50263b2c720af663c833bbc',
'datasets.z228': '37ed3574bf978bccd6e2db9be00bae94',
'datasets.z229': '8fd57367808fd77581f998850e5f935a',
'datasets.z23': '90ae2cdcdc1663b80c20e080e5c0e038',
'datasets.z230': '3f5ef3234da0236d2fbfaf7366407d70',
'datasets.z231': 'f67d8320028620c8bdd9a800a78afa27',
'datasets.z232': '5f831f25f8e8557168b38a7a28f8e7f9',
'datasets.z233': '56ee8c4b01825ba7f12340ae8b990db3',
'datasets.z234': '5428e98487b0e077cf9c24dc60599286',
'datasets.z235': '883c0ea97facca4d57d5c9c54922e8be',
'datasets.z236': 'e23fb6f610a3b528d5b310df4e452256',
'datasets.z237': 'bd858e84a47668edc851dab131239ae4',
'datasets.z238': 'db51bed3e3e5c6a40881f22532618533',
'datasets.z239': 'c1be852117739fc63227a503b08a8436',
'datasets.z24': '00dcb2e2a72b15a9aa9a646ecaea0019',
'datasets.z240': 'f4e6da02349b03b4f433b3399dcf8b3c',
'datasets.z241': '4f29898105aaf9f1a753a1c639947c2f',
'datasets.z242': '168a15bd8367f5d5f3e5e8cf4d0da6af',
'datasets.z243': '2930bced33bd1ecabce070fc831567e7',
'datasets.z244': 'aed9c5ec05f57e3fa9e7b224d47fa7b0',
'datasets.z245': '5aa83729fec805c166e48e5ec21530a5',
'datasets.z246': '646ceabfae028d631568930b4056227a',
'datasets.z247': '7381491175c1a63cc04ecac81148925b',
'datasets.z248': '4064df81449c1980d0abaa8c7262b315',
'datasets.z249': '7ae84d1dde2e935d86138d1e7b077df8',
'datasets.z25': '5ae383bcd01d4ce22387680e28833f06',
'datasets.z250': 'c018b41fbc4982a561b07cf0d52137f2',
'datasets.z251': '9c9dc7a889d537fd1e02f4549529a5f1',
'datasets.z252': 'b74df0680e7a62794902186fe1e3fec2',
'datasets.z253': '894cbb618ddffe65ce2ada0f250ad79c',
'datasets.z254': 'd5d8bc590d109c7d592d4df183c495a8',
'datasets.z255': 'faf313a3edd70129c212d3dbd1de5042',
'datasets.z256': '0fefcb84b66518df03a93ae53079409f',
'datasets.z257': '887e8e8abaf09682903b9b1060fb8153',
'datasets.z258': 'a700ae13abb7707032123468fab1bf55',
'datasets.z259': 'cad7c974b832d27d1cfb4ae0e4dc6c3c',
'datasets.z26': '5e0cdf281eeca969a4e0adfc44e11dcb',
'datasets.z260': '5e1c11df40440e4e84354d40efe7940e',
'datasets.z261': '2a75579c238569356855244d9fedf50b',
'datasets.z262': '51d508271fef5558df387542b3561b67',
'datasets.z263': 'e11b54dcd5069e838a34dcc2daebf4b5',
'datasets.z264': '8a5a3b288d21a3c2ef641370d436703e',
'datasets.z265': '56954047cba7c8732f0490323540af43',
'datasets.z266': 'cad0325e494cc720c385ac7420acd2d7',
'datasets.z267': '25de522b499ec7af12f583dd89a31769',
'datasets.z268': '0e9882f392ca679e8c47276371384efd',
'datasets.z269': '28487bc8eb731d4913254a9d63bb13ae',
'datasets.z27': 'c78804dfac1e395156abd235ca416b33',
'datasets.z270': '276fde25d412d8a1197e3dda307580d7',
'datasets.z271': 'ba6f46558aa9ebc64efca2485b4f18ca',
'datasets.z272': '5acceb08940d937d3023d98e745b8197',
'datasets.z273': '57201a53390fe9c6a8c069397dcb81b8',
'datasets.z274': '92491b6786ccea7b6ccccfb4e09c6d75',
'datasets.z275': '349136c52b8554f03967d7083e5cd95c',
'datasets.z276': '1525075dbf4d5d101d63cdade8bed9e7',
'datasets.z277': 'eeb6365ef482cd2c6bac20aab8181081',
'datasets.z278': 'ab1b04860b27fc11b7f57074a0815877',
'datasets.z279': 'dc1cee0d4b69da9fd7aeb47f91768589',
'datasets.z28': '0257b938256a2b7b55637970ebb3edcc',
'datasets.z280': 'd168adf5e7c223a1d8dddfc663ebaeb0',
'datasets.z281': '540fc1de91a90e9bd91b3f2b590ddbf6',
'datasets.z282': 'f3fd6fbe05cfd53eb4a4c2e41bf75cc7',
'datasets.z283': '20e89e60dd6a582bcb98b74df82699c3',
'datasets.z284': '6e6ba6077437285881609999ead45463',
'datasets.z285': 'fdc4bad8adb36b3d6653a438f1fc000f',
'datasets.z286': '39c0f6fc7aace7e30a33e7e73afdb6ae',
'datasets.z287': 'e60818eefd7426f3de0cc0746550be7f',
'datasets.z288': 'd9db9107b8e92c0bf6a311a219363554',
'datasets.z289': 'a250aa672d1f9165981cfaf1c6c8fff6',
'datasets.z29': 'e1240710e1c4dd506aec03c02caf5606',
'datasets.z290': '6dcb5a7674c927ec4e965a42d04a0ccd',
'datasets.z291': 'c058fca515b7f714338816b72672ce20',
'datasets.z292': '3fa254ed46ad6a6f7686836fe6fb7991',
'datasets.z293': '782344c6620582d6b1681e142853a61e',
'datasets.z294': '5263c5eb50cfec20adf82a89973a7547',
'datasets.z295': '0c217819aa7308ce8f744511e572a632',
'datasets.z296': 'cf8e7f1d3503ad6371ebcc9f827a29f8',
'datasets.z297': 'f16861471be342b291e55991654c882b',
'datasets.z298': 'ceb8b1170f1f2ec8113ef1c004df236c',
'datasets.z299': '4f66a50c5dfb03143a6456c7fd925ec1',
'datasets.z30': '75b0444187fc6c3df7bd3182108bc647',
'datasets.z300': 'becbeea47cef192a1a13b35911c4795a',
'datasets.z301': '92a67108c6bde11111f3b94f690f4b42',
'datasets.z302': '06292210e05575cf1632a099648c16af',
'datasets.z303': 'ecb7a48777719322c957ecc99340f04a',
'datasets.z304': 'b318fc0b7d645d16467b78bbce95befa',
'datasets.z305': '5f83e6a17e5977c2b99fc29893a1f479',
'datasets.z306': 'dde2ec22363740081f54032a3add00e0',
'datasets.z307': 'e0b7f9a7eddf2117a6127542883eb767',
'datasets.z308': 'fff180911deaf4b476f42e6e47d78e6e',
'datasets.z309': 'c55a8ac017f7ea69e77519fbcc617301',
'datasets.z31': '04bee9374c7436d66f560bbbfc22299d',
'datasets.z310': '3b4282fd1ab2b4f885df196a83726d34',
'datasets.z311': '08a7c41c5d290750d9ebf6266a86bec8',
'datasets.z312': 'b0f844ffd0ae6785e077ef8871dfc5da',
'datasets.z313': 'bd03ba03a63877b17274e21ddb828218',
'datasets.z314': '392f4528f9c345355434e7448a80b28b',
'datasets.z315': 'e499cdf8acf1561720d4eb8f9ad9daad',
'datasets.z316': '2575496c4d9c5082c6c4ef2f0cedeb69',
'datasets.z317': '6f387284dbef478e02f7a5954da66015',
'datasets.z318': '2ea47b6b7c9790bbfc2d3c238ffba391',
'datasets.z319': '8e1e48b2aa1d6c7ae840a23360a3a8f8',
'datasets.z32': 'ab930ee97741da82bfc778474f528a28',
'datasets.z320': 'a785b37410ccf0366f97c0d221faa629',
'datasets.z321': '16c73037fa0ca7704a3c4ceddbd7d599',
'datasets.z322': 'af97027070cf5d057934dfd6bd819e61',
'datasets.z323': '79078baf8d5fee935620f48aed2980a6',
'datasets.z324': '987b626c8731b0092df593f0eaddb32a',
'datasets.z325': 'fcc9a289015b40044c55ca96bc3dbe1e',
'datasets.z326': 'f2921b782a23f2729e1a6641e8c99954',
'datasets.z327': '6d1b488d88cf0fcaf5105b6544e5ea66',
'datasets.z328': '268885eaeb6290be4ddfaefb34943985',
'datasets.z329': '3951d3dd0527b54c5302720ea037cb12',
'datasets.z33': 'fff4585bf34e5a7ed8f369b337aac901',
'datasets.z330': '4b546d783366442d95eed64f882ae9f6',
'datasets.z331': '4605615511ea901c18256306263b2226',
'datasets.z332': '4127b2a8dd549b02ce3e465cfcfcc0a3',
'datasets.z333': '711d282247b4fd56758c96c13bfe1b8e',
'datasets.z34': '39b5d53f995fa8c223ed0d7a2de34652',
'datasets.z35': 'ae351c1e2961f99a6d3ac37fbae27548',
'datasets.z36': 'c2102c4a984d03c32c7c99378df953dc',
'datasets.z37': '76ba463895984049a1814654b7290890',
'datasets.z38': '5481901e75ee9d55066ba2731b2f36b7',
'datasets.z39': '59815df91b2532d1e300bc71c976da12',
'datasets.z40': '321ebe9b7812c14ee185fe2d6f16300c',
'datasets.z41': '862bc3fcb9fc1df2bef61561bcca8090',
'datasets.z42': '10303d540fea2150a7574cefcec92977',
'datasets.z43': '2269db3212f2d2db86982408a2b24948',
'datasets.z44': '80673bdbd722d02d97febeca00e57cf1',
'datasets.z45': '42329dbb5c5165788902f33265db66a3',
'datasets.z46': 'f81999194d418c515bb5df32c278d7f7',
'datasets.z47': '2e3d3520636b5f3eb0cd2d649b6b4dab',
'datasets.z48': 'e57e7aee104e80deaf068fbdd3292410',
'datasets.z49': '36146291e4af0eff44e763e7e1facb4f',
'datasets.z50': '60ef115c5c621c757b9b7075d5590c20',
'datasets.z51': '02326e4392d176c2aa6d479cf43f29f8',
'datasets.z52': 'b454f81ef7cda24cb13b8176c641d7ef',
'datasets.z53': '95a4857fe7bc6230e6a3cc379085e989',
'datasets.z54': '8ff35d1e9d738eb8b2afde04699ba73f',
'datasets.z55': '75d19ea4a9a283b6d236e3353063d82d',
'datasets.z56': '2ec9121fb364cc76eaf6fe7ff947dd04',
'datasets.z57': 'f5d0b9b2b0a91f82dea784023f71cf3c',
'datasets.z58': '79e318093d97bd573cc7aab4ed68dd6d',
'datasets.z59': '68350ada72cfc22d2d6fc8ca636ffef2',
'datasets.z60': '1ff9d36ec2b3723253503d524c90c9df',
'datasets.z61': '9fb963355e8c895c1a0a02fa45eb6fb2',
'datasets.z62': '5900cdc6c3d5f4b0c05825b47946096a',
'datasets.z63': '9d039ddf86aef563bf7834faa2766e84',
'datasets.z64': '15decde2fb2b4f869684f88f067378c1',
'datasets.z65': '7cf283ebffd79c38bc09835caba483e1',
'datasets.z66': 'e3b65f4facef36a355b444bd6b4ec73b',
'datasets.z67': 'b007466dd2ac3388a5902fdf92979b2d',
'datasets.z68': 'c07cd3fbc9a99e8f4d3cc9bbc55682e0',
'datasets.z69': 'cff012a6a3af8a9f286cf68baf37ac73',
'datasets.z70': '032ad7661e468723242b995357870a05',
'datasets.z71': 'fd7aa78706b8ae7ff50eb2f312a14ea5',
'datasets.z72': 'c2b90b8e8f75e36c927d523138d8ae93',
'datasets.z73': '8c001e26b5f4ca0e061ff0b685d509c6',
'datasets.z74': '484dcaf9311cc4315175e114acf01f32',
'datasets.z75': '19f6896b426b99f34b99c88f3b68209f',
'datasets.z76': '4a64c040ba9aa6e455fee774a08873e5',
'datasets.z77': 'e96b1d94e51741ca916580c5f8287ef5',
'datasets.z78': 'e0f5c9d93a3a5cfd95eeb61ee322650d',
'datasets.z79': '75fbea4391cd38d85c583c0f504325f1',
'datasets.z80': '7a99d15beff97f2dc620300f5ea82506',
'datasets.z81': '1e303b839a0349a8497eb6c49e3ab4c3',
'datasets.z82': '48c23365901b8d9cb51a20f7c71fa0a7',
'datasets.z83': 'fed6445cb47f0b6cd6ed6f5482e38be7',
'datasets.z84': '3d052276fc186ec151b5c237c5774e66',
'datasets.z85': 'd84dccdf4fd4e6f6126b746d0519d19d',
'datasets.z86': '772ab63f1ae266761db99124a152fc29',
'datasets.z87': '44208b663adf97563018a416560cce6b',
'datasets.z88': '6be7aaff3349239507b103c928afeaf5',
'datasets.z89': '515daa16bcaf83b0be91dcc32e0e4985',
'datasets.z90': 'f0fe0aa02bf0f19c0a7e301b532afa4c',
'datasets.z91': '87fa807089240e01886389a6c4c77641',
'datasets.z92': '1fd3154b3d8ec321f58f7685aedc5441',
'datasets.z93': 'b24ddb4b3e2c0c1c96d7650a8320e446',
'datasets.z94': 'b5e98cd90a629c4abb70d66a6976e0c7',
'datasets.z95': '87c8d5d7f26e0fbf7422fda2194bef97',
'datasets.z96': 'bb18c027a71f326a33b0a8fbe2d3a11f',
'datasets.z97': '0489ed1c8b4fb7d5965d1565076d5c6a',
'datasets.z98': 'a85303e08ce59fdf28f36ae9d0f20dcb',
'datasets.z99': '992e63e126f05eca9fa9a84bbf66165c',
'datasets.zip': '3434f60f5e9b263ef78e207b54e9debe',
}
def _download(dst):
dst = os.path.abspath(dst)
files = CHECKSUMS.keys()
fullzip = os.path.join(dst, "datasets.zip")
joinedzip = os.path.join(dst, "joined.zip")
URL_ROOT = "https://data.csail.mit.edu/graphics/demosaicnet"
if not os.path.exists(joinedzip):
log.info("Dowloading %d files to %s (This will take a while, and ~80GB)", len(
files), dst)
os.makedirs(dst, exist_ok=True)
for f in files:
fname = os.path.join(dst, f)
url = os.path.join(URL_ROOT, f)
do_download = True
if os.path.exists(fname):
checksum = md5sum(fname)
if checksum == CHECKSUMS[f]: # File is is and correct
log.info('%s already downloaded, with correct checksum', f)
do_download = False
else:
log.warning('%s checksums do not match, got %s, should be %s',
f, checksum, CHECKSUMS[f])
try:
os.remove(fname)
except OSError as e:
log.error("Could not delete broken part %s: %s", f, e)
raise ValueError
if do_download:
log.info('Downloading %s', f)
wget.download(url, fname)
checksum = md5sum(fname)
if checksum == CHECKSUMS[f]:
log.info("%s MD5 correct", f)
else:
log.error('%s checksums do not match, got %s, should be %s. Downloading failed',
f, checksum, CHECKSUMS[f])
log.info("Joining zip files")
cmd = " ".join(["zip", "-FF", fullzip, "--out", joinedzip])
subprocess.check_call(cmd, shell=True)
# Cleanup the parts
for f in files:
fname = os.path.join(dst, f)
try:
os.remove(fname)
except OSError as e:
log.warning("Could not delete file %s", f)
# Extract
wd = os.path.abspath(os.curdir)
os.chdir(dst)
log.info("Extracting files from %s", joinedzip)
cmd = " ".join(["unzip", joinedzip])
subprocess.check_call(cmd, shell=True)
try:
os.remove(joinedzip)
except OSError as e:
log.warning("Could not delete file %s", f)
log.info("Moving subfolders")
for k in ["train", "test", "val"]:
shutil.move(os.path.join(dst, "images", k), os.path.join(dst, k))
images = os.path.join(dst, "images")
log.info("removing '%s' folder", images)
shutil.rmtree(images)
def md5sum(filename, blocksize=65536):
hash = hashlib.md5()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(blocksize), b""):
hash.update(block)
return hash.hexdigest()
================================================
FILE: demosaicnet/modules.py
================================================
"""Models for [Gharbi2016] Deep Joint demosaicking and denoising."""
import os
from collections import OrderedDict
from pkg_resources import resource_filename
import numpy as np
import torch as th
import torch.nn as nn
__all__ = ["BayerDemosaick", "XTransDemosaick"]
_BAYER_WEIGHTS = resource_filename(__name__, 'data/bayer.pth')
_XTRANS_WEIGHTS = resource_filename(__name__, 'data/xtrans.pth')
class BayerDemosaick(nn.Module):
"""Released version of the network, best quality.
This model differs from the published description. It has a mask/filter split
towards the end of the processing. Masks and filters are multiplied with each
other. This is not key to performance and can be ignored when training new
models from scratch.
"""
def __init__(self, depth=15, width=64, pretrained=True, pad=False):
super(BayerDemosaick, self).__init__()
self.depth = depth
self.width = width
if pad:
pad = 1
else:
pad = 0
layers = OrderedDict([
("pack_mosaic", nn.Conv2d(3, 4, 2, stride=2)), # Downsample 2x2 to re-establish translation invariance
])
for i in range(depth):
n_out = width
n_in = width
if i == 0:
n_in = 4
if i == depth-1:
n_out = 2*width
layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)
layers["relu{}".format(i+1)] = nn.ReLU(inplace=True)
self.main_processor = nn.Sequential(layers)
self.residual_predictor = nn.Conv2d(width, 12, 1)
self.upsampler = nn.ConvTranspose2d(12, 3, 2, stride=2, groups=3)
self.fullres_processor = nn.Sequential(OrderedDict([
("post_conv", nn.Conv2d(6, width, 3, padding=pad)),
("post_relu", nn.ReLU(inplace=True)),
("output", nn.Conv2d(width, 3, 1)),
]))
# Load weights
if pretrained:
assert depth == 15, "pretrained bayer model has depth=15."
assert width == 64, "pretrained bayer model has width=64."
state_dict = th.load(_BAYER_WEIGHTS)
self.load_state_dict(state_dict)
def forward(self, mosaic):
"""Demosaicks a Bayer image.
Args:
mosaic (th.Tensor): input Bayer mosaic
Returns:
th.Tensor: the demosaicked image
"""
# 1/4 resolution features
features = self.main_processor(mosaic)
filters, masks = features[:, 0:self.width], features[:, self.width:2*self.width]
filtered = filters * masks
residual = self.residual_predictor(filtered)
# Match mosaic and residual
upsampled = self.upsampler(residual)
cropped = _crop_like(mosaic, upsampled)
packed = th.cat([cropped, upsampled], 1) # skip connection
output = self.fullres_processor(packed)
return output
class XTransDemosaick(nn.Module):
"""Released version of the network.
There is no downsampling here.
"""
def __init__(self, depth=11, width=64, pretrained=True, pad=False):
super(XTransDemosaick, self).__init__()
self.depth = depth
self.width = width
if pad:
pad = 1
else:
pad = 0
layers = OrderedDict([])
for i in range(depth):
n_in = width
n_out = width
if i == 0:
n_in = 3
layers["conv{}".format(i+1)] = nn.Conv2d(n_in, n_out, 3, padding=pad)
layers["relu{}".format(i+1)] = nn.ReLU(inplace=True)
self.main_processor = nn.Sequential(layers)
self.fullres_processor = nn.Sequential(OrderedDict([
("post_conv", nn.Conv2d(3+width, width, 3, padding=pad)),
("post_relu", nn.ReLU(inplace=True)),
("output", nn.Conv2d(width, 3, 1)),
]))
# Load weights
if pretrained:
assert depth == 11, "pretrained xtrans model has depth=11."
assert width == 64, "pretrained xtrans model has width=64."
state_dict = th.load(_XTRANS_WEIGHTS)
self.load_state_dict(state_dict)
def forward(self, mosaic):
"""Demosaicks an XTrans image.
Args:
mosaic (th.Tensor): input XTrans mosaic
Returns:
th.Tensor: the demosaicked image
"""
features = self.main_processor(mosaic)
cropped = _crop_like(mosaic, features) # Match mosaic and residual
packed = th.cat([cropped, features], 1) # skip connection
output = self.fullres_processor(packed)
return output
def _crop_like(src, tgt):
"""Crop a source image to match the spatial dimensions of a target.
Args:
src (th.Tensor or np.ndarray): image to be cropped
tgt (th.Tensor or np.ndarray): reference image
"""
src_sz = np.array(src.shape)
tgt_sz = np.array(tgt.shape)
# Assumes the spatial dimensions are the last two
crop = (src_sz[-2:]-tgt_sz[-2:])
crop_t = crop[0] // 2
crop_b = crop[0] - crop_t
crop_l = crop[1] // 2
crop_r = crop[1] - crop_l
crop //= 2
if (np.array([crop_t, crop_b, crop_r, crop_l])> 0).any():
return src[..., crop_t:src_sz[-2]-crop_b, crop_l:src_sz[-1]-crop_r]
else:
return src
================================================
FILE: demosaicnet/mosaic.py
================================================
"""Utilities to make a mosaic mask and apply it to an image."""
import numpy as np
import torch as th
__all__ = ["bayer", "xtrans"]
def bayer(im, return_mask=False):
"""Bayer mosaic.
The patterned assumed is::
G r
b G
Args:
im (np.array): image to mosaic. Dimensions are [c, h, w]
return_mask (bool): if true return the binary mosaic mask, instead of the mosaic image.
Returns:
np.array: mosaicked image (if return_mask==False), or binary mask if (return_mask==True)
"""
numpy = False
if type(im) == np.ndarray:
numpy = True
if type(im) == np.ndarray:
mask = np.ones_like(im)
else:
mask = th.ones_like(im)
# red
mask[..., 0, ::2, 0::2] = 0
mask[..., 0, 1::2, :] = 0
# green
mask[..., 1, ::2, 1::2] = 0
mask[..., 1, 1::2, ::2] = 0
# blue
mask[..., 2, 0::2, :] = 0
mask[..., 2, 1::2, 1::2] = 0
if not numpy: # make it a constant for ONNX conversion
mask = th.from_numpy(mask.cpu().detach().numpy()).to(im.device)
if mask.shape[0] == 1:
mask = mask.squeeze(0) # coreml hack
if return_mask:
return mask
return im*mask
def xtrans_cell(torch=False):
g_pos = [(0,0), (0,2), (0,3), (0,5),
(1,1), (1,4),
(2,0), (2,2), (2,3), (2,5),
(3,0), (3,2), (3,3), (3,5),
(4,1), (4,4),
(5,0), (5,2), (5,3), (5,5)]
r_pos = [(0,4),
(1,0), (1,2),
(2,4),
(3,1),
(4,3), (4,5),
(5,1)]
b_pos = [(0,1),
(1,3), (1,5),
(2,1),
(3,4),
(4,0), (4,2),
(5,4)]
if torch:
mask = th.zeros(3, 6, 6)
else:
mask = np.zeros((3, 6, 6), dtype=np.float32)
for idx, coord in enumerate([r_pos, g_pos, b_pos]):
for y, x in coord:
mask[..., idx, y, x] = 1
return mask
def xtrans(im, return_mask=False):
"""XTrans Mosaick.
The patterned assumed is::
G b G G r G
r G r b G b
G b G G r G
G r G G b G
b G b r G r
G r G G b G
Args:
im(np.array, th.Tensor): image to mosaic. Dimensions are [c, h, w]
mask(bool): if true return the binary mosaic mask, instead of the mosaic image.
Returns:
np.array: mosaicked image (if mask==False), or binary mask if (mask==True)
"""
numpy = False
if type(im) == np.ndarray:
numpy = True
mask = xtrans_cell(torch=False)
# mask = np.zeros((3, 6, 6), dtype=np.float32)
else:
# mask = th.zeros(3, 6, 6).to(im.device)
mask = xtrans_cell(torch=True).to(im.device)
if len(im.shape) == 4:
mask = mask.unsqueeze(0).repeat(im.shape[0], 1, 1, 1)
h, w = im.shape[-2:]
h = int(h)
w = int(w)
new_sz = [np.ceil(h / 6).astype(np.int32), np.ceil(w / 6).astype(np.int32)]
sz = np.array(mask.shape)
sz[:-2] = 1
sz[-2:] = new_sz
sz = list(sz)
if numpy:
mask = np.tile(mask, sz)
else:
mask = mask.repeat(*sz)
if return_mask:
return mask
return mask*im
================================================
FILE: demosaicnet/utils.py
================================================
"""Helper functions."""
from abc import ABCMeta, abstractmethod
import argparse
import logging
import os
import re
import signal
import time
import torch as th
import numpy as np
import torch as th
from tqdm import tqdm
log = logging.getLogger(__name__)
def crop_like(src, tgt):
"""Crop a source image to match the spatial dimensions of a target.
Assumes sizes are even.
Args:
src (th.Tensor or np.ndarray): image to be cropped
tgt (th.Tensor or np.ndarray): reference image
"""
src_sz = np.array(src.shape)
tgt_sz = np.array(tgt.shape)
# Assumes the spatial dimensions are the last two
delta = (src_sz[2:4]-tgt_sz[2:4])
crop = np.maximum(delta // 2, 0) # no negative crop
crop2 = delta - crop
if (crop > 0).any() or (crop2 > 0).any():
# NOTE: convert to ints to enable static slicing in ONNX conversion
src_sz = [int(x) for x in src_sz]
crop = [int(x) for x in crop]
crop2 = [int(x) for x in crop2]
return src[..., crop[0]:src_sz[-2]-crop2[0],
crop[1]:src_sz[-1]-crop2[1]]
else:
return src
class ExponentialMovingAverage(object):
"""Keyed tracker that maintains an exponential moving average for each key.
Args:
keys(list of str): keys to track.
alpha(float): exponential smoothing factor (higher = smoother).
"""
def __init__(self, keys, alpha=0.999):
self._is_first_update = {k: True for k in keys}
self._alpha = alpha
self._values = {k: 0 for k in keys}
def __getitem__(self, key):
return self._values[key]
def update(self, key, value):
if value is None:
return
if self._is_first_update[key]:
self._values[key] = value
self._is_first_update[key] = False
else:
self._values[key] = self._values[key] * \
self._alpha + value*(1.0-self._alpha)
class BasicArgumentParser(argparse.ArgumentParser):
"""A basic argument parser with commonly used training options."""
def __init__(self, *args, **kwargs):
super(BasicArgumentParser, self).__init__(*args, **kwargs)
self.add_argument("--data", required=True, help="path to the training data.")
self.add_argument("--val_data", help="path to the validation data.")
self.add_argument("--config", help="path to a config file.")
self.add_argument("--checkpoint_dir", required=True,
help="Output directory where checkpoints are saved")
self.add_argument("--init_from", help="path to a checkpoint from which to try and initialize the weights.")
self.add_argument("--lr", type=float, default=1e-4,
help="Learning rate for the optimizer")
self.add_argument("--bs", type=int, default=4, help="Batch size")
self.add_argument("--num_epochs", type=int,
help="Number of epochs to train for")
self.add_argument("--num_worker_threads", default=4, type=int,
help="Number of threads that load data")
# self.add_argument("--experiment_log",
# help="csv file in which we log our experiments")
self.add_argument("--cuda", action="store_true",
dest="cuda", help="Force GPU")
self.add_argument("--no-cuda", action="store_false",
dest="cuda", help="Force CPU")
self.add_argument("--server", help="Visdom server url")
self.add_argument("--base_url", default="/", help="Visdom base url")
self.add_argument("--env", default="main", help="Visdom environment")
self.add_argument("--port", default=8097, type=int,
help="Visdom server port")
self.add_argument('--debug', dest="debug", action="store_true")
self.set_defaults(cuda=th.cuda.is_available(), debug=False)
class ModelInterface(metaclass=ABCMeta):
"""An adapter to run or train a model."""
def __init__(self):
pass
@abstractmethod
def training_step(self, batch):
"""Training step given a batch of data.
This should implement a forward pass of the model, compute gradients,
take an optimizer step and return useful metrics and tensors for
visualization and training callbacks.
Args:
batch (dict): batch of data provided by a data pipeline.
Returns:
train_step_data (dict): a dictionary of outputs.
"""
return {}
def init_validation(self):
"""Initializes the quantities to be reported during validation.
The default implementation is a no-op
Returns:
data (dict): initialized values
"""
log.warning("Running a ModelInterface validation initialization that was not overriden: this is a no-op.")
data = {}
return data
def validation_step(self, batch, running_val_data):
"""Updates the running validataion with the current batch's results.
The default implementation is a no-op
Args:
batch (dict): batch of data provided by a data pipeline.
running_val_data (dict): current aggregates of the validation loop.
Returns:
updated_data (dict): new updated value for the running_val_data.
"""
log.warning("Running a ModelInterface validation step that was not overriden: this is a no-op.")
return {}
def __repr__(self):
return self.__class__.__name__
class Checkpointer(object):
"""Save and restore model and optimizer variables.
Args:
root (string): path to the root directory where the files are stored.
model (torch.nn.Module):
meta (dict): a dictionary of training or configuration parameters useful
to initialize the model upon loading the checkpoint again.
optimizers (single or list of torch.optimizer): optimizers whose parameters will
be checkpointed together with the model.
schedulers (single or list of
torch.optim.lr_scheduler): schedulers whose
parameters will be checkpointed together with
the model.
prefix (str): unique prefix name in case several models are stored in the
same folder.
"""
EXTENSION = ".pth"
def __init__(self, root, model=None, meta=None, optimizers=None,
schedulers=None, prefix=None):
self.root = root
self.model = model
self.meta = meta
# TODO(mgharbi): verify the prefixes are unique.
if optimizers is None:
log.info("No optimizer state will be stored in the "
"checkpointer")
else:
# if we have only one optimizer, make it a list
if not isinstance(optimizers, list):
optimizers = [optimizers]
self.optimizers = optimizers
if schedulers is not None:
if not isinstance(schedulers, list):
schedulers = [schedulers]
self.schedulers = schedulers
log.debug(self)
self.prefix = ""
if prefix is not None:
self.prefix = prefix
def __repr__(self):
return "Checkpointer with root at \"{}\"".format(self.root)
def __path(self, path, prefix=None):
if prefix is None:
prefix = ""
return os.path.join(self.root, prefix+os.path.splitext(path)[0] + ".pth")
def save(self, path, extras=None):
"""Save model, metaparams and extras to relative path.
Args:
path (string): relative path to the file being saved (without extension).
extras (dict): extra user-provided information to be saved with the model.
"""
if self.model is None:
model_state = None
else:
log.debug("Saving model state dict")
model_state = self.model.state_dict()
opt_dicts = []
if self.optimizers is not None:
for opt in self.optimizers:
opt_dicts.append(opt.state_dict())
sched_dicts = []
if self.schedulers is not None:
for s in self.schedulers:
sched_dicts.append(s.state_dict())
filename = self.__path(path, prefix=self.prefix)
os.makedirs(self.root, exist_ok=True)
th.save({'model': model_state,
'meta': self.meta,
'extras': extras,
'optimizers': opt_dicts,
'schedulers': sched_dicts,
}, filename)
log.debug("Checkpoint saved to \"{}\"".format(filename))
def try_and_init_from(self, path):
"""Try to initialize the models's weights from an external checkpoint.
Args:
path(str): full path to the checkpoints to load model parameters
from.
"""
log.info("Loading weights from foreign checkpoint {}".format(path))
if not os.path.exists(path):
raise ValueError("Checkpoint {} does not exist".format(path))
chkpt = th.load(path, map_location=th.device("cpu"))
if "model" not in chkpt.keys() or chkpt["model"] is None:
raise ValueError("{} has no model saved".format(path))
mdl = chkpt["model"]
for n, p in self.model.named_parameters():
if n in mdl:
p2 = mdl[n]
if p2.shape != p.shape:
log.warning("Parameter {} ignored, checkpoint size does not match: {}, should be {}".format(n, p2.shape, p.shape))
continue
log.debug("Parameter {} copied".format(n))
p.data.copy_(p2)
else:
log.warning("Parameter {} ignored, not found in source checkpoint.".format(n))
log.info("Weights loaded from foreign checkpoint {}".format(path))
def load(self, path):
"""Loads a checkpoint, updates the model and returns extra data.
Args:
path (string): path to the checkpoint file, relative to the root dir.
Returns:
extras (dict): extra information passed by the user at save time.
meta (dict): metaparameters of the model passed at save time.
"""
filename = self.__path(path, prefix=None)
chkpt = th.load(filename, map_location="cpu") # TODO: check behavior
if self.model is not None and chkpt["model"] is not None:
log.debug("Loading model state dict")
self.model.load_state_dict(chkpt["model"])
if "optimizers" in chkpt.keys():
if self.optimizers is not None and chkpt["optimizers"] is not None:
try:
for opt, state in zip(self.optimizers,
chkpt["optimizers"]):
log.debug("Loading optimizers state dict for %s", opt)
opt.load_state_dict(state)
except:
# We do not raise an error here, e.g. in case the user simply
# changes optimizer
log.warning("Could not load optimizer state dicts, "
"starting from scratch")
if "schedulers" in chkpt.keys():
if self.schedulers is not None and chkpt["schedulers"] is not None:
try:
for s, state in zip(self.schedulers,
chkpt["schedulers"]):
log.debug("Loading scheduler state dict for %s", s)
s.load_state_dict(state)
except:
log.warning("Could not load scheduler state dicts, "
"starting from scratch")
log.debug("Loaded checkpoint \"{}\"".format(filename))
return tuple(chkpt[k] for k in ["extras", "meta"])
def load_latest(self):
"""Try to load the most recent checkpoint, skip failing files.
Returns:
extras (dict): extra user-defined information that was saved in the
checkpoint.
meta (dict): metaparameters of the model passed at save time.
"""
all_checkpoints = self.sorted_checkpoints()
extras = None
meta = None
for f in all_checkpoints:
try:
extras, meta = self.load(f)
return extras, meta
except Exception as e:
log.debug(
"Could not load checkpoint \"{}\", moving on ({}).".format(f, e))
log.debug("No checkpoint found to load.")
return extras, meta
def sorted_checkpoints(self):
"""Get list of all checkpoints in root directory, sorted by creation date.
Returns:
chkpts (list of str): sorted checkpoints in the root folder.
"""
reg = re.compile(r"{}.*\{}".format(self.prefix, Checkpointer.EXTENSION))
if not os.path.exists(self.root):
all_checkpoints = []
else:
all_checkpoints = [f for f in os.listdir(
self.root) if reg.match(f)]
mtimes = []
for f in all_checkpoints:
mtimes.append(os.path.getmtime(os.path.join(self.root, f)))
mf = sorted(zip(mtimes, all_checkpoints))
chkpts = [m[1] for m in reversed(mf)]
log.debug("Sorted checkpoints {}".format(chkpts))
return chkpts
def delete(self, path):
"""Delete checkpoint at path.
Args:
path(str): full path to the checkpoint to delete.
"""
if path in self.sorted_checkpoints():
os.remove(os.path.join(self.root, path))
else:
log.warning("Trying to delete a checkpoint that does not exists.")
@staticmethod
def load_meta(root, prefix=None):
"""Fetch model metadata without touching the saved parameters.
This loads the metadata from the most recent checkpoint in the root
directory.
Args:
root(str): path to the root directory containing the checkpoints
prefix(str): unique prefix for the checkpoint to be loaded (e.g. if
multiple models are saved in the same location)
"""
chkptr = Checkpointer(root, model=None, meta=None, prefix=prefix,
optimizers=[])
log.debug("checkpoints: %s", chkptr.sorted_checkpoints())
_, meta = chkptr.load_latest()
return meta
class Trainer(object):
"""Implements a simple training loop with hooks for callbacks.
Args:
interface (ModelInterface): adapter to run forward and backward
pass on the model being trained.
Attributes:
callbacks (list of Callbacks): hooks that will be called while training
progresses.
"""
def __init__(self, interface):
super(Trainer, self).__init__()
self.callbacks = []
self.interface = interface
log.debug("Creating {}".format(self))
signal.signal(signal.SIGINT, self.interrupt_handler)
self._keep_running = True
def interrupt_handler(self, signo, frame):
"""Stop the training process upon receiving a SIGINT (Ctrl+C)."""
log.debug("interrupting run")
self._keep_running = False
def _stop(self):
# Reset the run flag
self._keep_running = True
self.__training_end()
def add_callback(self, callback):
"""Adds a callback to the list of training hooks.
Args:
callback(ttools.Callback): callback to add.
"""
log.debug("Adding callback {}".format(callback))
# pass an interface reference to the callback
callback.model_interface = self.interface
self.callbacks.append(callback)
def train(self, dataloader, starting_epoch=None, num_epochs=None,
val_dataloader=None):
"""Main training loop. This starts the training procedure.
Args:
dataloader (DataLoader): loader that yields training batches.
starting_epoch (int, optional): index of the epoch we are starting from.
num_epochs (int, optional): max number of epochs to run.
val_dataloader (DataLoader, optional): loader that yields validation
batches
"""
self.__training_start(dataloader)
if starting_epoch is None:
starting_epoch = 0
log.info("Starting taining from epoch %d", starting_epoch)
epoch = starting_epoch
while num_epochs is None or epoch < starting_epoch + num_epochs:
self.__epoch_start(epoch)
for batch_idx, batch in enumerate(dataloader):
if not self._keep_running:
self._stop()
return
self.__batch_start(batch_idx, batch)
train_step_data = self.__training_step(batch)
self.__batch_end(batch, train_step_data)
self.__epoch_end()
# TODO: allow validation at intermediate steps during one epoch
# Validate
if val_dataloader:
with th.no_grad():
running_val_data = self.__validation_start(val_dataloader)
for batch_idx, batch in enumerate(val_dataloader):
if not self._keep_running:
self._stop()
return
self.__val_batch_start(batch_idx, batch)
running_val_data = self.__validation_step(batch, running_val_data)
self.__val_batch_end(batch, running_val_data)
self.__validation_end(running_val_data)
epoch += 1
if not self._keep_running:
self._stop()
return
self._stop()
def __repr__(self):
return "Trainer({}, {} callbacks)".format(
self.interface, len(self.callbacks))
def __training_start(self, dataloader):
for cb in self.callbacks:
cb.training_start(dataloader)
def __training_end(self):
for cb in self.callbacks:
cb.training_end()
def __epoch_start(self, epoch_idx):
for cb in self.callbacks:
cb.epoch_start(epoch_idx)
def __epoch_end(self):
for cb in self.callbacks:
cb.epoch_end()
def __batch_start(self, batch_idx, batch):
for cb in self.callbacks:
cb.batch_start(batch_idx, batch)
def __batch_end(self, batch, train_step_data):
for cb in self.callbacks:
cb.batch_end(batch, train_step_data)
def __val_batch_start(self, batch_idx, batch):
for cb in self.callbacks:
cb.val_batch_start(batch_idx, batch)
def __val_batch_end(self, batch, running_val_data):
for cb in self.callbacks:
cb.val_batch_end(batch, running_val_data)
def __validation_start(self, dataloader):
for cb in self.callbacks:
cb.validation_start(dataloader)
return self.interface.init_validation()
def __validation_end(self, running_val_data):
for cb in self.callbacks:
cb.validation_end(running_val_data)
def __training_step(self, batch):
return self.interface.training_step(batch)
def __validation_step(self, batch, running_val_data):
return self.interface.validation_step(batch, running_val_data)
class Callback(object):
"""Base class for all training callbacks.
Attributes:
epoch(int): current epoch index.
batch(int): current batch index.
datasize(int): number of batches in the training dataset.
val_datasize(int): number of batches in the validation dataset.
model_interface(ttools.ModelInterface): parent interface driving the training.
"""
def __repr__(self):
return self.__class__.__name__
def __init__(self):
super(Callback, self).__init__()
self.epoch = 0
self.batch = 0
self.val_batch = 0
self.datasize = 0
self.val_datasize = 0
self.model_interface = None
def training_start(self, dataloader):
"""Hook to execute code when the training begins.
Args:
dataloader(th.utils.data.Dataloader): a data loading class that
provides batches of data for training.
"""
self.datasize = len(dataloader)
def training_end(self):
"""Hook to execute code when the training ends."""
pass
def epoch_start(self, epoch):
"""Hook to execute code when a new epoch starts.
Args:
epoch(int): index of the current epoch.
Note: self.epoch is never incremented. Instead, it should be set by the
caller.
"""
self.epoch = epoch
def epoch_end(self):
"""Hook to execute code when an epoch ends.
NOTE: self.epoch is not incremented. Instead it is set externally in
the `epoch_start` method.
"""
pass
def validation_start(self, dataloader):
"""Hook to execute code when a validation run starts.
Args:
dataloader(th.utils.data.Dataloader): a data loading class that
provides batches of data for evaluation.
"""
self.val_datasize = len(dataloader)
def validation_end(self, val_data):
"""Hook to execute code when a validation run ends."""
pass
def batch_start(self, batch_idx, batch_data):
"""Hook to execute code when a training step starts.
Args:
batch_idx(int): index of the current batch.
batch_data: a Tensor, tuple of dict with the current batch of data.
"""
self.batch = batch_idx
def batch_end(self, batch_data, train_step_data):
"""Hook to execute code when a training step ends.
Args:
batch_data: a Tensor, tuple of dict with the current batch of data.
train_setp_data(dict): outputs from the `training_step` of a
ModelInterface.
"""
pass
def val_batch_start(self, batch_idx, batch_data):
"""Hook to execute code when a validation step starts.
Args:
batch_idx(int): index of the current batch.
batch_data: a Tensor, tuple of dict with the current batch of data.
"""
self.val_batch = batch_idx
def val_batch_end(self, batch_data, running_val_data):
"""Hook to execute code when a validation step ends.
Args:
batch_data: a Tensor, tuple of dict with the current batch of data.
train_setp_data(dict): running outputs produced by the `validation_step` of a
ModelInterface.
"""
pass
class CheckpointingCallback(Callback):
"""A callback that periodically saves model checkpoints to disk.
Args:
checkpointer (Checkpointer): actual checkpointer responsible for the I/O.
interval (int, optional): minimum time in seconds between periodic
checkpoints (within an epoch). There is not periodic checkpoint if
this value is None.
max_files (int, optional): maximum number of periodic checkpoints to keep
on disk.
max_epochs (int, optional): maximum number of epoch checkpoints to keep
on disk.
"""
PERIODIC_PREFIX = "periodic_"
EPOCH_PREFIX = "epoch_"
def __init__(self, checkpointer, interval=600,
max_files=5, max_epochs=10):
super(CheckpointingCallback, self).__init__()
self.checkpointer = checkpointer
self.interval = interval
self.max_files = max_files
self.max_epochs = max_epochs
self.last_checkpoint_time = time.time()
def epoch_end(self):
"""Save a checkpoint at the end of each epoch."""
super(CheckpointingCallback, self).epoch_end()
path = "{}{}".format(CheckpointingCallback.EPOCH_PREFIX, self.epoch)
self.checkpointer.save(path, extras={"epoch": self.epoch + 1})
self.__purge_old_files()
def training_end(self):
super(CheckpointingCallback, self).training_end()
self.checkpointer.save("training_end", extras={"epoch": self.epoch + 1})
def batch_end(self, batch_data, train_step_data):
"""Save a periodic checkpoint if requested."""
super(CheckpointingCallback, self).batch_end(
batch_data, train_step_data)
if self.interval is None: # We skip periodic checkpoints
return
now = time.time()
delta = now - self.last_checkpoint_time
if delta < self.interval: # last checkpoint is too recent
return
log.debug("Periodic checkpoint")
self.last_checkpoint_time = now
filename = "{}{}".format(CheckpointingCallback.PERIODIC_PREFIX,
time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()))
self.checkpointer.save(filename, extras={"epoch": self.epoch})
self.__purge_old_files()
def __purge_old_files(self):
"""Delete checkpoints that are beyond the max to keep."""
chkpts = self.checkpointer.sorted_checkpoints()
p_chkpts = []
e_chkpts = []
for c in chkpts:
if c.startswith(self.checkpointer.prefix + CheckpointingCallback.PERIODIC_PREFIX):
p_chkpts.append(c)
if c.startswith(self.checkpointer.prefix + CheckpointingCallback.EPOCH_PREFIX):
e_chkpts.append(c)
# Delete periodic checkpoints
if self.max_files is not None and len(p_chkpts) > self.max_files:
for c in p_chkpts[self.max_files:]:
log.debug("CheckpointingCallback deleting {}".format(c))
self.checkpointer.delete(c)
# Delete older epochs
if self.max_epochs is not None and len(e_chkpts) > self.max_epochs:
for c in e_chkpts[self.max_epochs:]:
log.debug("CheckpointingCallback deleting (epoch) {}".format(c))
self.checkpointer.delete(c)
class KeyedCallback(Callback):
"""An abstract Callback that performs the same action for all keys in a list.
The keys (resp. val_keys) are used to access the backward_data (resp.
validation_data) produced by a ModelInterface.
Args:
keys (list of str or None): list of keys whose values will be logged during
training.
val_keys (list of str or None): list of keys whose values will be logged during
validation
"""
def __init__(self, keys=None, val_keys=None, smoothing=0.999):
super(KeyedCallback, self).__init__()
if keys is None and val_keys is None:
log.warning("Logger has no keys, nor val_keys")
if keys is None:
self.keys = []
else:
self.keys = keys
if val_keys is None:
self.val_keys = []
else:
self.val_keys = val_keys
# Only smooth the training keys
self.ema = ExponentialMovingAverage(self.keys, alpha=smoothing)
def batch_end(self, batch_data, train_step_data):
for k in self.keys:
self.ema.update(k, train_step_data[k])
class ProgressBarCallback(KeyedCallback):
"""A progress bar optimization logger.
Args:
label(str): a prefix label to identify the experiment currently
running.
"""
def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None):
super(ProgressBarCallback, self).__init__(
keys=keys, val_keys=val_keys, smoothing=smoothing)
self.pbar = None
if label is None:
self.label = ""
else:
self.label = label
def training_start(self, dataloader):
super(ProgressBarCallback, self).training_start(dataloader)
print("Training start")
def training_end(self):
super(ProgressBarCallback, self).training_end()
print("Training ends")
def epoch_start(self, epoch):
super(ProgressBarCallback, self).epoch_start(epoch)
desc = "Epoch {}".format(self.epoch)
if self.label is not None:
desc = "%s | " % self.label + desc
self.pbar = tqdm(total=self.datasize, unit=" batches",
desc=desc)
def epoch_end(self):
super(ProgressBarCallback, self).epoch_end()
self.pbar.close()
self.pbar = None
def validation_start(self, dataloader):
super(ProgressBarCallback, self).validation_start(dataloader)
print("Running validation...")
self.pbar = tqdm(total=len(dataloader), unit=" batches",
desc="Validation {}".format(self.epoch))
def val_batch_end(self, batch, running_val_data):
self.pbar.update(1)
def validation_end(self, val_data):
super(ProgressBarCallback, self).validation_end(val_data)
self.pbar.close()
self.pbar = None
s = " "*ProgressBarCallback.TABSTOPS + "Validation {} | ".format(
self.epoch)
for k in self.val_keys:
s += "{} = {:.2f} ".format(k, val_data[k])
print(s)
def batch_end(self, batch_data, train_step_data):
super(ProgressBarCallback, self).batch_end(batch_data, train_step_data)
d = {}
for k in self.keys:
d[k] = self.ema[k]
self.pbar.update(1)
self.pbar.set_postfix(d)
TABSTOPS = 2
================================================
FILE: demosaicnet/version.py
================================================
__version__ = "0.0.14"
================================================
FILE: docs/.gitignore
================================================
build
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/source/conf.py
================================================
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
dirname = os.path.dirname
rootdir = dirname(dirname(dirname(os.path.abspath(__file__))))
sys.path.insert(0, rootdir)
# -- Project information -----------------------------------------------------
project = 'demosaicnet'
copyright = '2019, Michael Gharbi'
author = 'Michael Gharbi'
import re
with open(os.path.join(rootdir, "demosaicnet", "version.py")) as fid:
try:
__version__, = re.findall( '__version__ = "(.*)"', fid.read() )
except:
raise ValueError("could not find version number")
# The full version, including alpha/beta/rc tags
release = __version__
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
autodoc_mock_imports = ["torch", "numpy", "imageio", "torchvision", "wget"]
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.doctest',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['ytemplates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['ystatic']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'demosaicnetdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'demosaicnet.tex', 'demosaicnet Documentation',
'Michael Gharbi', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'demosaicnet', 'demosaicnet Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'demosaicnet', 'demosaicnet Documentation',
author, 'demosaicnet', 'One line description of project.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------
# -- Options for todo extension ----------------------------------------------
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
================================================
FILE: docs/source/dataset.rst
================================================
Dataset
=======
.. automodule:: demosaicnet.dataset
:members:
================================================
FILE: docs/source/helpers.rst
================================================
Helpers
=======
.. automodule:: demosaicnet.mosaic
:members:
================================================
FILE: docs/source/index.rst
================================================
.. demosaicnet documentation master file, created by
sphinx-quickstart on Thu Mar 14 13:14:16 2019.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to demosaicnet's documentation!
=======================================
.. toctree::
:maxdepth: 2
:caption: Contents:
models
dataset
helpers
.. automodule:: demosaicnet
:members:
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
================================================
FILE: docs/source/models.rst
================================================
Models
======
.. automodule:: demosaicnet.modules
:members:
================================================
FILE: requirements.txt
================================================
imageio==2.31.1
numpy==1.25.0
setuptools==49.2.1
torch==2.0.1
tqdm==4.65.0
wget
================================================
FILE: scripts/demosaicnet_demo.py
================================================
#!/usr/bin/env python
"""Demo script on using demosaicnet for inference."""
import os
from pkg_resources import resource_filename
import argparse
import numpy as np
import torch as th
import imageio
import demosaicnet
_TEST_INPUT = resource_filename("demosaicnet", 'data/test_input.png')
def main(args):
print("Running demosaicnet demo on {}, outputing to {}".format(_TEST_INPUT, args.output))
bayer = demosaicnet.BayerDemosaick()
xtrans = demosaicnet.XTransDemosaick()
# Load some ground-truth image
gt = imageio.imread(args.input).astype(np.float32) / 255.0
gt = np.array(gt)
h, w, _ = gt.shape
# Make the image size a multiple of 6 (for xtrans pattern)
gt = gt[:6*(h//6), :6*(w//6)]
# Network expects channel first
gt = np.transpose(gt, [2, 0, 1])
mosaicked = demosaicnet.bayer(gt)
xmosaicked = demosaicnet.xtrans(gt)
# Run the model (expects batch as first dimension)
mosaicked = th.from_numpy(mosaicked).unsqueeze(0)
xmosaicked = th.from_numpy(xmosaicked).unsqueeze(0)
with th.no_grad(): # inference only
out = bayer(mosaicked).squeeze(0).cpu().numpy()
out = np.clip(out, 0, 1)
xout = xtrans(xmosaicked).squeeze(0).cpu().numpy()
xout = np.clip(xout, 0, 1)
print("done")
os.makedirs(args.output, exist_ok=True)
output = args.output
imageio.imsave(os.path.join(output, "bayer_mosaick.tif"), mosaicked.squeeze(0).permute([1, 2, 0]))
imageio.imsave(os.path.join(output, "bayer_result.tif"), np.transpose(out, [1, 2, 0]))
imageio.imsave(os.path.join(output, "xtrans_mosaick.tif"), xmosaicked.squeeze(0).permute([1, 2, 0]))
imageio.imsave(os.path.join(output, "xtrans_result.tif"), np.transpose(xout, [1, 2, 0]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("output", help="output directory")
parser.add_argument("--input", default=_TEST_INPUT, help="test input, uses the default test input provided if no argument.")
args = parser.parse_args()
main(args)
================================================
FILE: scripts/eval.py
================================================
#!/bin/env python
"""Evaluate a demosaicking model."""
import argparse
import logging
import torch as th
from torch.utils.data import DataLoader
import demosaicnet
log = logging.getLogger(__name__)
class PSNR(th.nn.Module):
def __init__(self):
super(PSNR, self).__init__()
self.mse = th.nn.MSELoss()
def forward(self, out, ref):
mse = self.mse(out, ref)
return -10*th.log10(mse)
def main(args):
"""Entrypoint to the training."""
# Load model parameters from checkpoint, if any
meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)
if meta is None:
log.warning("No checkpoint found at %s, aborting.", args.checkpoint_dir)
return
data = demosaicnet.Dataset(args.data, download=False,
mode=meta["mode"],
subset=demosaicnet.TEST_SUBSET)
dataloader = DataLoader(
data, batch_size=1, num_workers=4, pin_memory=True, shuffle=True)
if meta["mode"] == demosaicnet.BAYER_MODE:
model = demosaicnet.BayerDemosaick(depth=meta["depth"],
width=meta["width"],
pretrained=True,
pad=False)
elif meta["mode"] == demosaicnet.XTRANS_MODE:
model = demosaicnet.XTransDemosaick(depth=meta["depth"],
width=meta["width"],
pretrained=True,
pad=False)
checkpointer = demosaicnet.utils.Checkpointer(args.checkpoint_dir, model, meta=meta)
checkpointer.load_latest() # Resume from checkpoint, if any.
# No need for gradients
for p in model.parameters():
p.requires_grad = False
mse_fn = th.nn.MSELoss()
psnr_fn = PSNR()
device = "cpu"
if th.cuda.is_available():
device = "cuda"
log.info("Using CUDA")
count = 0
mse = 0.0
psnr = 0.0
for idx, batch in enumerate(dataloader):
mosaic = batch[0].to(device)
target = batch[1].to(device)
output = model(mosaic)
target = demosaicnet.utils.crop_like(target, output)
output = th.clamp(output, 0, 1)
psnr_ = psnr_fn(output, target).item()
mse_ = mse_fn(output, target).item()
psnr += psnr_
mse += mse_
count += 1
log.info("Image %04d, PSNR = %.1f dB, MSE = %.5f", idx, psnr_, mse_)
mse /= count
psnr /= count
log.info("-----------------------------------")
log.info("Average, PSNR = %.1f dB, MSE = %.5f", psnr, mse)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("data", help="root directory for the demosaicnet dataset.")
parser.add_argument("checkpoint_dir", help="directory with the model checkpoints.")
args = parser.parse_args()
main(args)
================================================
FILE: scripts/train.py
================================================
#!/bin/env python
"""Train a demosaicking model."""
import logging
import torch as th
from torch.utils.data import DataLoader
import demosaicnet
log = logging.getLogger(__name__)
class PSNR(th.nn.Module):
def __init__(self):
super(PSNR, self).__init__()
self.mse = th.nn.MSELoss()
def forward(self, out, ref):
mse = self.mse(out, ref)
return -10*th.log10(mse+1e-12)
class DemosaicnetInterface(demosaicnet.utils.ModelInterface):
"""Training and validation interface.
Args:
model(th.nn.Module): model to train.
lr(float): learning rate for the optimizer.
cuda(bool): whether to use CPU or GPU for training.
"""
def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()):
self.model = model
self.device = "cpu"
if cuda:
self.device = "cuda"
self.model.to(self.device)
self.opt = th.optim.Adam(self.model.parameters(), lr=lr)
self.loss = th.nn.MSELoss()
self.psnr = PSNR()
def training_step(self, batch):
fwd_data = self.forward(batch)
bwd_data = self.backward(batch, fwd_data)
return bwd_data
def forward(self, batch):
mosaic = batch[0]
mosaic = mosaic.to(self.device)
output = self.model(mosaic)
return output
def backward(self, batch, fwd_output):
target = batch[1].to(self.device)
# remove boundaries to match output size
target = demosaicnet.utils.crop_like(target, fwd_output)
loss = self.loss(fwd_output, target)
self.opt.zero_grad()
loss.backward()
self.opt.step()
with th.no_grad():
psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)
return {"loss": loss.item(), "psnr": psnr.item()}
def init_validation(self):
return {"count": 0, "psnr": 0}
def update_validation(self, batch, fwd_output, running_data):
target = batch[1].to(self.device)
# remove boundaries to match output size
target = demosaicnet.utils.crop_like(target, fwd_output)
with th.no_grad():
psnr = self.psnr(th.clamp(fwd_output, 0, 1), target)
n = target.shape[0]
return {
"psnr": running_data["psnr"] + psnr.item()*n,
"count": running_data["count"] + n
}
def finalize_validation(self, running_data):
return {
"psnr": running_data["psnr"] / running_data["count"]
}
def main(args):
"""Entrypoint to the training."""
# Load model parameters from checkpoint, if any
meta = demosaicnet.utils.Checkpointer.load_meta(args.checkpoint_dir)
if meta is None:
log.info("No metadata or checkpoint, "
"parsing model parameters from command line.")
meta = {
"depth": args.depth,
"width": args.width,
"mode": args.mode,
}
data = demosaicnet.Dataset(args.data, download=False,
mode=meta["mode"],
subset=demosaicnet.TRAIN_SUBSET)
dataloader = DataLoader(
data, batch_size=args.bs, num_workers=args.num_worker_threads,
pin_memory=True, shuffle=True)
val_dataloader = None
if args.val_data:
val_data = demosaicnet.Dataset(args.data, download=False,
mode=meta["mode"],
subset=demosaicnet.VAL_SUBSET)
val_dataloader = DataLoader(
val_data, batch_size=args.bs, num_workers=1,
pin_memory=True, shuffle=False)
if meta["mode"] == demosaicnet.BAYER_MODE:
model = demosaicnet.BayerDemosaick(depth=meta["depth"],
width=meta["width"],
pretrained=True,
pad=False)
elif meta["mode"] == demosaicnet.XTRANS_MODE:
model = demosaicnet.XTransDemosaick(depth=meta["depth"],
width=meta["width"],
pretrained=True,
pad=False)
checkpointer = demosaicnet.utils.Checkpointer(
args.checkpoint_dir, model, meta=meta)
interface = DemosaicnetInterface(model, lr=args.lr, cuda=args.cuda)
checkpointer.load_latest() # Resume from checkpoint, if any.
trainer = demosaicnet.utils.Trainer(interface)
keys = ["loss", "psnr"]
val_keys = ["psnr"]
trainer.add_callback(demosaicnet.utils.ProgressBarCallback(
keys=keys, val_keys=val_keys))
trainer.add_callback(demosaicnet.utils.CheckpointingCallback(
checkpointer, max_files=8, interval=3600, max_epochs=10))
if args.cuda:
log.info("Training with CUDA enabled")
else:
log.info("Training on CPU")
trainer.train(
dataloader, num_epochs=args.num_epochs,
val_dataloader=val_dataloader)
if __name__ == "__main__":
parser = demosaicnet.utils.BasicArgumentParser()
parser.add_argument("--depth", default=15,
help="number of net layers.")
parser.add_argument("--width", default=64,
help="number of features per layer.")
parser.add_argument("--mode", default=demosaicnet.BAYER_MODE,
choices=[demosaicnet.BAYER_MODE,
demosaicnet.XTRANS_MODE],
help="number of features per layer.")
args = parser.parse_args()
main(args)
================================================
FILE: setup.py
================================================
import re
import setuptools
with open('demosaicnet/version.py') as fid:
try:
__version__, = re.findall( '__version__ = "(.*)"', fid.read() )
except:
raise ValueError("could not find version number")
with open("README.md", "r") as fh:
long_description = fh.read()
setuptools.setup(
name='demosaicnet',
version=__version__,
scripts=["scripts/demosaicnet_demo.py"],
author="Michaël Gharbi",
author_email="gharbi@csail.mit.edu",
description="Minimal implementation of Deep Joint Demosaicking and Denoising [Gharbi2016]",
long_description=long_description,
url="https://github.com/mgharbi/",
packages = setuptools.find_packages(exclude=["tests"]),
include_package_data=True,
install_requires=["wget", "tqdm", "torch", "imageio", "numpy"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: MacOS :: MacOS X",
"Operating System :: POSIX",
],
)
gitextract_p3udv4rd/ ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── demosaicnet/ │ ├── .gitignore │ ├── __init__.py │ ├── dataset.py │ ├── modules.py │ ├── mosaic.py │ ├── utils.py │ └── version.py ├── docs/ │ ├── .gitignore │ ├── Makefile │ └── source/ │ ├── conf.py │ ├── dataset.rst │ ├── helpers.rst │ ├── index.rst │ └── models.rst ├── requirements.txt ├── scripts/ │ ├── demosaicnet_demo.py │ ├── eval.py │ └── train.py └── setup.py
SYMBOL INDEX (108 symbols across 7 files)
FILE: demosaicnet/dataset.py
class Dataset (line 38) | class Dataset(TorchDataset):
method __init__ (line 47) | def __init__(self, root, download=False,
method __len__ (line 81) | def __len__(self):
method __getitem__ (line 84) | def __getitem__(self, idx):
function _download (line 441) | def _download(dst):
function md5sum (line 517) | def md5sum(filename, blocksize=65536):
FILE: demosaicnet/modules.py
class BayerDemosaick (line 18) | class BayerDemosaick(nn.Module):
method __init__ (line 26) | def __init__(self, depth=15, width=64, pretrained=True, pad=False):
method forward (line 67) | def forward(self, mosaic):
class XTransDemosaick (line 92) | class XTransDemosaick(nn.Module):
method __init__ (line 98) | def __init__(self, depth=11, width=64, pretrained=True, pad=False):
method forward (line 134) | def forward(self, mosaic):
function _crop_like (line 151) | def _crop_like(src, tgt):
FILE: demosaicnet/mosaic.py
function bayer (line 9) | def bayer(im, return_mask=False):
function xtrans_cell (line 58) | def xtrans_cell(torch=False):
function xtrans (line 89) | def xtrans(im, return_mask=False):
FILE: demosaicnet/utils.py
function crop_like (line 20) | def crop_like(src, tgt):
class ExponentialMovingAverage (line 48) | class ExponentialMovingAverage(object):
method __init__ (line 56) | def __init__(self, keys, alpha=0.999):
method __getitem__ (line 61) | def __getitem__(self, key):
method update (line 64) | def update(self, key, value):
class BasicArgumentParser (line 75) | class BasicArgumentParser(argparse.ArgumentParser):
method __init__ (line 78) | def __init__(self, *args, **kwargs):
class ModelInterface (line 115) | class ModelInterface(metaclass=ABCMeta):
method __init__ (line 118) | def __init__(self):
method training_step (line 122) | def training_step(self, batch):
method init_validation (line 137) | def init_validation(self):
method validation_step (line 149) | def validation_step(self, batch, running_val_data):
method __repr__ (line 164) | def __repr__(self):
class Checkpointer (line 168) | class Checkpointer(object):
method __init__ (line 188) | def __init__(self, root, model=None, meta=None, optimizers=None,
method __repr__ (line 215) | def __repr__(self):
method __path (line 218) | def __path(self, path, prefix=None):
method save (line 223) | def save(self, path, extras=None):
method try_and_init_from (line 257) | def try_and_init_from(self, path):
method load (line 286) | def load(self, path):
method load_latest (line 331) | def load_latest(self):
method sorted_checkpoints (line 353) | def sorted_checkpoints(self):
method delete (line 374) | def delete(self, path):
method load_meta (line 386) | def load_meta(root, prefix=None):
class Trainer (line 404) | class Trainer(object):
method __init__ (line 416) | def __init__(self, interface):
method interrupt_handler (line 426) | def interrupt_handler(self, signo, frame):
method _stop (line 431) | def _stop(self):
method add_callback (line 436) | def add_callback(self, callback):
method train (line 447) | def train(self, dataloader, starting_epoch=None, num_epochs=None,
method __repr__ (line 500) | def __repr__(self):
method __training_start (line 504) | def __training_start(self, dataloader):
method __training_end (line 508) | def __training_end(self):
method __epoch_start (line 512) | def __epoch_start(self, epoch_idx):
method __epoch_end (line 516) | def __epoch_end(self):
method __batch_start (line 520) | def __batch_start(self, batch_idx, batch):
method __batch_end (line 524) | def __batch_end(self, batch, train_step_data):
method __val_batch_start (line 528) | def __val_batch_start(self, batch_idx, batch):
method __val_batch_end (line 532) | def __val_batch_end(self, batch, running_val_data):
method __validation_start (line 536) | def __validation_start(self, dataloader):
method __validation_end (line 541) | def __validation_end(self, running_val_data):
method __training_step (line 545) | def __training_step(self, batch):
method __validation_step (line 548) | def __validation_step(self, batch, running_val_data):
class Callback (line 552) | class Callback(object):
method __repr__ (line 563) | def __repr__(self):
method __init__ (line 566) | def __init__(self):
method training_start (line 575) | def training_start(self, dataloader):
method training_end (line 584) | def training_end(self):
method epoch_start (line 588) | def epoch_start(self, epoch):
method epoch_end (line 599) | def epoch_end(self):
method validation_start (line 607) | def validation_start(self, dataloader):
method validation_end (line 616) | def validation_end(self, val_data):
method batch_start (line 620) | def batch_start(self, batch_idx, batch_data):
method batch_end (line 629) | def batch_end(self, batch_data, train_step_data):
method val_batch_start (line 639) | def val_batch_start(self, batch_idx, batch_data):
method val_batch_end (line 648) | def val_batch_end(self, batch_data, running_val_data):
class CheckpointingCallback (line 658) | class CheckpointingCallback(Callback):
method __init__ (line 675) | def __init__(self, checkpointer, interval=600,
method epoch_end (line 685) | def epoch_end(self):
method training_end (line 692) | def training_end(self):
method batch_end (line 696) | def batch_end(self, batch_data, train_step_data):
method __purge_old_files (line 720) | def __purge_old_files(self):
class KeyedCallback (line 746) | class KeyedCallback(Callback):
method __init__ (line 758) | def __init__(self, keys=None, val_keys=None, smoothing=0.999):
method batch_end (line 776) | def batch_end(self, batch_data, train_step_data):
class ProgressBarCallback (line 780) | class ProgressBarCallback(KeyedCallback):
method __init__ (line 787) | def __init__(self, keys=None, val_keys=None, smoothing=0.99, label=None):
method training_start (line 796) | def training_start(self, dataloader):
method training_end (line 800) | def training_end(self):
method epoch_start (line 804) | def epoch_start(self, epoch):
method epoch_end (line 812) | def epoch_end(self):
method validation_start (line 817) | def validation_start(self, dataloader):
method val_batch_end (line 823) | def val_batch_end(self, batch, running_val_data):
method validation_end (line 826) | def validation_end(self, val_data):
method batch_end (line 836) | def batch_end(self, batch_data, train_step_data):
FILE: scripts/demosaicnet_demo.py
function main (line 16) | def main(args):
FILE: scripts/eval.py
class PSNR (line 14) | class PSNR(th.nn.Module):
method __init__ (line 15) | def __init__(self):
method forward (line 18) | def forward(self, out, ref):
function main (line 22) | def main(args):
FILE: scripts/train.py
class PSNR (line 14) | class PSNR(th.nn.Module):
method __init__ (line 15) | def __init__(self):
method forward (line 18) | def forward(self, out, ref):
class DemosaicnetInterface (line 23) | class DemosaicnetInterface(demosaicnet.utils.ModelInterface):
method __init__ (line 31) | def __init__(self, model, lr=1e-4, cuda=th.cuda.is_available()):
method training_step (line 41) | def training_step(self, batch):
method forward (line 46) | def forward(self, batch):
method backward (line 52) | def backward(self, batch, fwd_output):
method init_validation (line 69) | def init_validation(self):
method update_validation (line 72) | def update_validation(self, batch, fwd_output, running_data):
method finalize_validation (line 87) | def finalize_validation(self, running_data):
function main (line 93) | def main(args):
Condensed preview — 24 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (89K chars).
[
{
"path": ".gitignore",
"chars": 57,
"preview": "output/\ndist/\ndemosaicnet.egg-info\nbuild/\ndata\n.DS_Store\n"
},
{
"path": "LICENSE",
"chars": 1216,
"preview": "MIT License\r\n\r\nDeep Joint Demosaicking and Denoising\r\nSiggraph Asia 2016\r\nMichael Gharbi, Gaurav Chaurasia, Sylvain Pari"
},
{
"path": "MANIFEST.in",
"chars": 111,
"preview": "include demosaicnet/data/bayer.pth\ninclude demosaicnet/data/xtrans.pth\ninclude demosaicnet/data/test_input.png\n"
},
{
"path": "Makefile",
"chars": 343,
"preview": "test:\n\tpy.test tests\n\n.PHONY: docs\ndocs:\n\t$(MAKE) -C docs html\n\nclean:\n\tpython setup.py clean\n\trm -rf build demosaicnet."
},
{
"path": "README.md",
"chars": 1067,
"preview": "# Deep Joint Demosaicking and Denoising\nSiGGRAPH Asia 2016\n\nMichaël Gharbi gharbi@mit.edu Gaurav Chaurasia Sylvain Paris"
},
{
"path": "demosaicnet/.gitignore",
"chars": 12,
"preview": "__pycache__\n"
},
{
"path": "demosaicnet/__init__.py",
"chars": 201,
"preview": "from .modules import BayerDemosaick\nfrom .modules import XTransDemosaick\nfrom .mosaic import xtrans\nfrom .mosaic import "
},
{
"path": "demosaicnet/dataset.py",
"chars": 24857,
"preview": "\"\"\"Dataset loader for demosaicnet.\"\"\"\nimport os\nimport subprocess\nimport shutil\nimport hashlib\nimport logging\n\n\nimport n"
},
{
"path": "demosaicnet/modules.py",
"chars": 4922,
"preview": "\"\"\"Models for [Gharbi2016] Deep Joint demosaicking and denoising.\"\"\"\nimport os\nfrom collections import OrderedDict\nfrom "
},
{
"path": "demosaicnet/mosaic.py",
"chars": 3052,
"preview": "\"\"\"Utilities to make a mosaic mask and apply it to an image.\"\"\"\nimport numpy as np\nimport torch as th\n\n\n__all__ = [\"baye"
},
{
"path": "demosaicnet/utils.py",
"chars": 29621,
"preview": "\"\"\"Helper functions.\"\"\"\n\nfrom abc import ABCMeta, abstractmethod\nimport argparse\nimport logging\nimport os\nimport re\nimpo"
},
{
"path": "demosaicnet/version.py",
"chars": 23,
"preview": "__version__ = \"0.0.14\"\n"
},
{
"path": "docs/.gitignore",
"chars": 6,
"preview": "build\n"
},
{
"path": "docs/Makefile",
"chars": 584,
"preview": "# Minimal makefile for Sphinx documentation\n#\n\n# You can set these variables from the command line.\nSPHINXOPTS =\nSPHI"
},
{
"path": "docs/source/conf.py",
"chars": 5933,
"preview": "# -*- coding: utf-8 -*-\n#\n# Configuration file for the Sphinx documentation builder.\n#\n# This file does only contain a s"
},
{
"path": "docs/source/dataset.rst",
"chars": 66,
"preview": "Dataset\n=======\n\n.. automodule:: demosaicnet.dataset\n :members:\n"
},
{
"path": "docs/source/helpers.rst",
"chars": 65,
"preview": "Helpers\n=======\n\n.. automodule:: demosaicnet.mosaic\n :members:\n"
},
{
"path": "docs/source/index.rst",
"chars": 523,
"preview": ".. demosaicnet documentation master file, created by\n sphinx-quickstart on Thu Mar 14 13:14:16 2019.\n You can adapt "
},
{
"path": "docs/source/models.rst",
"chars": 64,
"preview": "Models\n======\n\n.. automodule:: demosaicnet.modules\n :members:\n"
},
{
"path": "requirements.txt",
"chars": 80,
"preview": "imageio==2.31.1\nnumpy==1.25.0\nsetuptools==49.2.1\ntorch==2.0.1\ntqdm==4.65.0\nwget\n"
},
{
"path": "scripts/demosaicnet_demo.py",
"chars": 1987,
"preview": "#!/usr/bin/env python\n\"\"\"Demo script on using demosaicnet for inference.\"\"\"\n\nimport os\nfrom pkg_resources import resourc"
},
{
"path": "scripts/eval.py",
"chars": 2959,
"preview": "#!/bin/env python\n\"\"\"Evaluate a demosaicking model.\"\"\"\nimport argparse\nimport logging\n\nimport torch as th\nfrom torch.uti"
},
{
"path": "scripts/train.py",
"chars": 5605,
"preview": "#!/bin/env python\n\"\"\"Train a demosaicking model.\"\"\"\nimport logging\n\nimport torch as th\nfrom torch.utils.data import Data"
},
{
"path": "setup.py",
"chars": 1013,
"preview": "import re\nimport setuptools\n\n\nwith open('demosaicnet/version.py') as fid:\n try:\n __version__, = re.findall( '_"
}
]
About this extraction
This page contains the full source code of the mgharbi/demosaicnet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 24 files (82.4 KB), approximately 24.8k tokens, and a symbol index with 108 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.