[
  {
    "path": ".gitignore",
    "content": "\n# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vscode\n# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,vscode\n\n### JupyterNotebooks ###\n# gitignore template for Jupyter Notebooks\n# website: http://jupyter.org/\n\n.ipynb_checkpoints\n*/.ipynb_checkpoints/*\n\n# IPython\nprofile_default/\nipython_config.py\n\n# Remove previous ipynb_checkpoints\n#   git rm -r .ipynb_checkpoints/\n\n### Python ###\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\npytestdebug.log\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\ndoc/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n\n# IPython\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\npythonenv*\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# pytype static type analyzer\n.pytype/\n\n# profiling data\n.prof\n\n### vscode ###\n.vscode/*\n\n# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vscode\n\n.DS_Store\ntest_audio\n\nwandb\nwandb/*\nlightning_logs\nlightning_logs/*\n\n*.ipynb\n*.wav\nneural-timbre-shaping\n*.webm"
  },
  {
    "path": "LICENSE",
    "content": "Mozilla Public License Version 2.0\n==================================\n\n1. Definitions\n--------------\n\n1.1. \"Contributor\"\n    means each individual or legal entity that creates, contributes to\n    the creation of, or owns Covered Software.\n\n1.2. \"Contributor Version\"\n    means the combination of the Contributions of others (if any) used\n    by a Contributor and that particular Contributor's Contribution.\n\n1.3. \"Contribution\"\n    means Covered Software of a particular Contributor.\n\n1.4. \"Covered Software\"\n    means Source Code Form to which the initial Contributor has attached\n    the notice in Exhibit A, the Executable Form of such Source Code\n    Form, and Modifications of such Source Code Form, in each case\n    including portions thereof.\n\n1.5. \"Incompatible With Secondary Licenses\"\n    means\n\n    (a) that the initial Contributor has attached the notice described\n        in Exhibit B to the Covered Software; or\n\n    (b) that the Covered Software was made available under the terms of\n        version 1.1 or earlier of the License, but not also under the\n        terms of a Secondary License.\n\n1.6. \"Executable Form\"\n    means any form of the work other than Source Code Form.\n\n1.7. \"Larger Work\"\n    means a work that combines Covered Software with other material, in\n    a separate file or files, that is not Covered Software.\n\n1.8. \"License\"\n    means this document.\n\n1.9. \"Licensable\"\n    means having the right to grant, to the maximum extent possible,\n    whether at the time of the initial grant or subsequently, any and\n    all of the rights conveyed by this License.\n\n1.10. \"Modifications\"\n    means any of the following:\n\n    (a) any file in Source Code Form that results from an addition to,\n        deletion from, or modification of the contents of Covered\n        Software; or\n\n    (b) any new file in Source Code Form that contains any Covered\n        Software.\n\n1.11. \"Patent Claims\" of a Contributor\n    means any patent claim(s), including without limitation, method,\n    process, and apparatus claims, in any patent Licensable by such\n    Contributor that would be infringed, but for the grant of the\n    License, by the making, using, selling, offering for sale, having\n    made, import, or transfer of either its Contributions or its\n    Contributor Version.\n\n1.12. \"Secondary License\"\n    means either the GNU General Public License, Version 2.0, the GNU\n    Lesser General Public License, Version 2.1, the GNU Affero General\n    Public License, Version 3.0, or any later versions of those\n    licenses.\n\n1.13. \"Source Code Form\"\n    means the form of the work preferred for making modifications.\n\n1.14. \"You\" (or \"Your\")\n    means an individual or a legal entity exercising rights under this\n    License. For legal entities, \"You\" includes any entity that\n    controls, is controlled by, or is under common control with You. For\n    purposes of this definition, \"control\" means (a) the power, direct\n    or indirect, to cause the direction or management of such entity,\n    whether by contract or otherwise, or (b) ownership of more than\n    fifty percent (50%) of the outstanding shares or beneficial\n    ownership of such entity.\n\n2. License Grants and Conditions\n--------------------------------\n\n2.1. Grants\n\nEach Contributor hereby grants You a world-wide, royalty-free,\nnon-exclusive license:\n\n(a) under intellectual property rights (other than patent or trademark)\n    Licensable by such Contributor to use, reproduce, make available,\n    modify, display, perform, distribute, and otherwise exploit its\n    Contributions, either on an unmodified basis, with Modifications, or\n    as part of a Larger Work; and\n\n(b) under Patent Claims of such Contributor to make, use, sell, offer\n    for sale, have made, import, and otherwise transfer either its\n    Contributions or its Contributor Version.\n\n2.2. Effective Date\n\nThe licenses granted in Section 2.1 with respect to any Contribution\nbecome effective for each Contribution on the date the Contributor first\ndistributes such Contribution.\n\n2.3. Limitations on Grant Scope\n\nThe licenses granted in this Section 2 are the only rights granted under\nthis License. No additional rights or licenses will be implied from the\ndistribution or licensing of Covered Software under this License.\nNotwithstanding Section 2.1(b) above, no patent license is granted by a\nContributor:\n\n(a) for any code that a Contributor has removed from Covered Software;\n    or\n\n(b) for infringements caused by: (i) Your and any other third party's\n    modifications of Covered Software, or (ii) the combination of its\n    Contributions with other software (except as part of its Contributor\n    Version); or\n\n(c) under Patent Claims infringed by Covered Software in the absence of\n    its Contributions.\n\nThis License does not grant any rights in the trademarks, service marks,\nor logos of any Contributor (except as may be necessary to comply with\nthe notice requirements in Section 3.4).\n\n2.4. Subsequent Licenses\n\nNo Contributor makes additional grants as a result of Your choice to\ndistribute the Covered Software under a subsequent version of this\nLicense (see Section 10.2) or under the terms of a Secondary License (if\npermitted under the terms of Section 3.3).\n\n2.5. Representation\n\nEach Contributor represents that the Contributor believes its\nContributions are its original creation(s) or it has sufficient rights\nto grant the rights to its Contributions conveyed by this License.\n\n2.6. Fair Use\n\nThis License is not intended to limit any rights You have under\napplicable copyright doctrines of fair use, fair dealing, or other\nequivalents.\n\n2.7. Conditions\n\nSections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted\nin Section 2.1.\n\n3. Responsibilities\n-------------------\n\n3.1. Distribution of Source Form\n\nAll distribution of Covered Software in Source Code Form, including any\nModifications that You create or to which You contribute, must be under\nthe terms of this License. You must inform recipients that the Source\nCode Form of the Covered Software is governed by the terms of this\nLicense, and how they can obtain a copy of this License. You may not\nattempt to alter or restrict the recipients' rights in the Source Code\nForm.\n\n3.2. Distribution of Executable Form\n\nIf You distribute Covered Software in Executable Form then:\n\n(a) such Covered Software must also be made available in Source Code\n    Form, as described in Section 3.1, and You must inform recipients of\n    the Executable Form how they can obtain a copy of such Source Code\n    Form by reasonable means in a timely manner, at a charge no more\n    than the cost of distribution to the recipient; and\n\n(b) You may distribute such Executable Form under the terms of this\n    License, or sublicense it under different terms, provided that the\n    license for the Executable Form does not attempt to limit or alter\n    the recipients' rights in the Source Code Form under this License.\n\n3.3. Distribution of a Larger Work\n\nYou may create and distribute a Larger Work under terms of Your choice,\nprovided that You also comply with the requirements of this License for\nthe Covered Software. If the Larger Work is a combination of Covered\nSoftware with a work governed by one or more Secondary Licenses, and the\nCovered Software is not Incompatible With Secondary Licenses, this\nLicense permits You to additionally distribute such Covered Software\nunder the terms of such Secondary License(s), so that the recipient of\nthe Larger Work may, at their option, further distribute the Covered\nSoftware under the terms of either this License or such Secondary\nLicense(s).\n\n3.4. Notices\n\nYou may not remove or alter the substance of any license notices\n(including copyright notices, patent notices, disclaimers of warranty,\nor limitations of liability) contained within the Source Code Form of\nthe Covered Software, except that You may alter any license notices to\nthe extent required to remedy known factual inaccuracies.\n\n3.5. Application of Additional Terms\n\nYou may choose to offer, and to charge a fee for, warranty, support,\nindemnity or liability obligations to one or more recipients of Covered\nSoftware. However, You may do so only on Your own behalf, and not on\nbehalf of any Contributor. You must make it absolutely clear that any\nsuch warranty, support, indemnity, or liability obligation is offered by\nYou alone, and You hereby agree to indemnify every Contributor for any\nliability incurred by such Contributor as a result of warranty, support,\nindemnity or liability terms You offer. You may include additional\ndisclaimers of warranty and limitations of liability specific to any\njurisdiction.\n\n4. Inability to Comply Due to Statute or Regulation\n---------------------------------------------------\n\nIf it is impossible for You to comply with any of the terms of this\nLicense with respect to some or all of the Covered Software due to\nstatute, judicial order, or regulation then You must: (a) comply with\nthe terms of this License to the maximum extent possible; and (b)\ndescribe the limitations and the code they affect. Such description must\nbe placed in a text file included with all distributions of the Covered\nSoftware under this License. Except to the extent prohibited by statute\nor regulation, such description must be sufficiently detailed for a\nrecipient of ordinary skill to be able to understand it.\n\n5. Termination\n--------------\n\n5.1. The rights granted under this License will terminate automatically\nif You fail to comply with any of its terms. However, if You become\ncompliant, then the rights granted under this License from a particular\nContributor are reinstated (a) provisionally, unless and until such\nContributor explicitly and finally terminates Your grants, and (b) on an\nongoing basis, if such Contributor fails to notify You of the\nnon-compliance by some reasonable means prior to 60 days after You have\ncome back into compliance. Moreover, Your grants from a particular\nContributor are reinstated on an ongoing basis if such Contributor\nnotifies You of the non-compliance by some reasonable means, this is the\nfirst time You have received notice of non-compliance with this License\nfrom such Contributor, and You become compliant prior to 30 days after\nYour receipt of the notice.\n\n5.2. If You initiate litigation against any entity by asserting a patent\ninfringement claim (excluding declaratory judgment actions,\ncounter-claims, and cross-claims) alleging that a Contributor Version\ndirectly or indirectly infringes any patent, then the rights granted to\nYou by any and all Contributors for the Covered Software under Section\n2.1 of this License shall terminate.\n\n5.3. In the event of termination under Sections 5.1 or 5.2 above, all\nend user license agreements (excluding distributors and resellers) which\nhave been validly granted by You or Your distributors under this License\nprior to termination shall survive termination.\n\n************************************************************************\n*                                                                      *\n*  6. Disclaimer of Warranty                                           *\n*  -------------------------                                           *\n*                                                                      *\n*  Covered Software is provided under this License on an \"as is\"       *\n*  basis, without warranty of any kind, either expressed, implied, or  *\n*  statutory, including, without limitation, warranties that the       *\n*  Covered Software is free of defects, merchantable, fit for a        *\n*  particular purpose or non-infringing. The entire risk as to the     *\n*  quality and performance of the Covered Software is with You.        *\n*  Should any Covered Software prove defective in any respect, You     *\n*  (not any Contributor) assume the cost of any necessary servicing,   *\n*  repair, or correction. This disclaimer of warranty constitutes an   *\n*  essential part of this License. No use of any Covered Software is   *\n*  authorized under this License except under this disclaimer.         *\n*                                                                      *\n************************************************************************\n\n************************************************************************\n*                                                                      *\n*  7. Limitation of Liability                                          *\n*  --------------------------                                          *\n*                                                                      *\n*  Under no circumstances and under no legal theory, whether tort      *\n*  (including negligence), contract, or otherwise, shall any           *\n*  Contributor, or anyone who distributes Covered Software as          *\n*  permitted above, be liable to You for any direct, indirect,         *\n*  special, incidental, or consequential damages of any character      *\n*  including, without limitation, damages for lost profits, loss of    *\n*  goodwill, work stoppage, computer failure or malfunction, or any    *\n*  and all other commercial damages or losses, even if such party      *\n*  shall have been informed of the possibility of such damages. This   *\n*  limitation of liability shall not apply to liability for death or   *\n*  personal injury resulting from such party's negligence to the       *\n*  extent applicable law prohibits such limitation. Some               *\n*  jurisdictions do not allow the exclusion or limitation of           *\n*  incidental or consequential damages, so this exclusion and          *\n*  limitation may not apply to You.                                    *\n*                                                                      *\n************************************************************************\n\n8. Litigation\n-------------\n\nAny litigation relating to this License may be brought only in the\ncourts of a jurisdiction where the defendant maintains its principal\nplace of business and such litigation shall be governed by laws of that\njurisdiction, without reference to its conflict-of-law provisions.\nNothing in this Section shall prevent a party's ability to bring\ncross-claims or counter-claims.\n\n9. Miscellaneous\n----------------\n\nThis License represents the complete agreement concerning the subject\nmatter hereof. If any provision of this License is held to be\nunenforceable, such provision shall be reformed only to the extent\nnecessary to make it enforceable. Any law or regulation which provides\nthat the language of a contract shall be construed against the drafter\nshall not be used to construe this License against a Contributor.\n\n10. Versions of the License\n---------------------------\n\n10.1. New Versions\n\nMozilla Foundation is the license steward. Except as provided in Section\n10.3, no one other than the license steward has the right to modify or\npublish new versions of this License. Each version will be given a\ndistinguishing version number.\n\n10.2. Effect of New Versions\n\nYou may distribute the Covered Software under the terms of the version\nof the License under which You originally received the Covered Software,\nor under the terms of any subsequent version published by the license\nsteward.\n\n10.3. Modified Versions\n\nIf you create software not governed by this License, and you want to\ncreate a new license for such software, you may create and use a\nmodified version of this License if you rename the license and remove\nany references to the name of the license steward (except to note that\nsuch modified license differs from this License).\n\n10.4. Distributing Source Code Form that is Incompatible With Secondary\nLicenses\n\nIf You choose to distribute Source Code Form that is Incompatible With\nSecondary Licenses under the terms of this version of the License, the\nnotice described in Exhibit B of this License must be attached.\n\nExhibit A - Source Code Form License Notice\n-------------------------------------------\n\n  This Source Code Form is subject to the terms of the Mozilla Public\n  License, v. 2.0. If a copy of the MPL was not distributed with this\n  file, You can obtain one at http://mozilla.org/MPL/2.0/.\n\nIf it is not possible or desirable to put the notice in a particular\nfile, then You may include the notice in a location (such as a LICENSE\nfile in a relevant directory) where a recipient would be likely to look\nfor such a notice.\n\nYou may add additional accurate notices of copyright ownership.\n\nExhibit B - \"Incompatible With Secondary Licenses\" Notice\n---------------------------------------------------------\n\n  This Source Code Form is \"Incompatible With Secondary Licenses\", as\n  defined by the Mozilla Public License, v. 2.0.\n"
  },
  {
    "path": "README.md",
    "content": "<h1 align=\"center\">neural waveshaping synthesis</h1>\n<h4 align=\"center\">real-time neural audio synthesis in the waveform domain</h4>\n<div align=\"center\">\n<h4>\n    <a href=\"https://benhayes.net/assets/pdf/nws_arxiv.pdf\" target=\"_blank\">paper</a> •\n        <a href=\"https://benhayes.net/projects/nws/\" target=\"_blank\">website</a> • \n        <a href=\"https://colab.research.google.com/github/ben-hayes/neural-waveshaping-synthesis/blob/main/colab/NEWT_Timbre_Transfer.ipynb\" target=\"_blank\">colab</a> • \n        <a href=\"https://benhayes.net/projects/nws/#audio-examples\">audio</a>\n    </h4>\n    <p>\n    by <em>Ben Hayes, Charalampos Saitis, György Fazekas</em>\n    </p>\n</div>\n<p align=\"center\"><img src=\"https://benhayes.net/assets/img/newt_shapers.png\" /></p>\n\nThis repository is the official implementation of [Neural Waveshaping Synthesis](https://benhayes.net/projects/nws/).\n\n## Model Architecture\n\n<p align=\"center\"><img src=\"https://benhayes.net/assets/img/nws.png\" /></p>\n\n## Requirements\n\nTo install:\n\n```setup\npip install -r requirements.txt\npip install -e .\n```\n\nWe recommend installing in a virtual environment.\n\n## Data\n\nWe trained our checkpoints on the [URMP](http://www2.ece.rochester.edu/projects/air/projects/URMP.html) dataset.\nOnce downloaded, the dataset can be preprocessed using `scripts/create_urmp_dataset.py`. \nThis will consolidate recordings of each instrument within the dataset and preprocess them according to the pipeline in the paper.\n\n```bash\npython scripts/create_urmp_dataset.py \\\n  --gin-file gin/data/urmp_4second_crepe.gin \\ \n  --data-directory /path/to/urmp \\\n  --output-directory /path/to/output \\\n  --device cuda:0  # torch device string for CREPE model\n```\n\nAlternatively, you can supply your own dataset and use the general `create_dataset.py` script:\n\n```bash\npython scripts/create_dataset.py \\\n  --gin-file gin/data/urmp_4second_crepe.gin \\ \n  --data-directory /path/to/dataset \\\n  --output-directory /path/to/output \\\n  --device cuda:0  # torch device string for CREPE model\n```\n\n## Training\n\nTo train a model on the URMP dataset, use this command:\n\n```bash\npython scripts/train.py \\\n  --gin-file gin/train/train_newt.gin \\\n  --dataset-path /path/to/processed/urmp \\\n  --urmp \\\n  --instrument vn \\  # select URMP instrument with abbreviated string\n  --load-data-to-memory\n```\n\nOr to use a non-URMP dataset:\n```bash\npython scripts/train.py \\\n  --gin-file gin/train/train_newt.gin \\\n  --dataset-path /path/to/processed/data \\\n  --load-data-to-memory\n```\n"
  },
  {
    "path": "gin/data/urmp_4second_crepe.gin",
    "content": "sample_rate = 16000\ninterpolation = None\ncontrol_hop = 128\n\nextract_f0_with_crepe.sample_rate = %sample_rate\nextract_f0_with_crepe.device = %device\nextract_f0_with_crepe.full_model = True\nextract_f0_with_crepe.interpolate_fn = %interpolation\nextract_f0_with_crepe.hop_length = %control_hop\n\nextract_perceptual_loudness.sample_rate = %sample_rate\nextract_perceptual_loudness.interpolate_fn = %interpolation\nextract_perceptual_loudness.n_fft = 1024\nextract_perceptual_loudness.hop_length = %control_hop\n\nextract_mfcc.sample_rate = %sample_rate\nextract_mfcc.n_fft = 1024\nextract_mfcc.hop_length = 128\nextract_mfcc.n_mfcc = 16\n\npreprocess_audio.target_sr = %sample_rate\npreprocess_audio.f0_extractor = @extract_f0_with_crepe\npreprocess_audio.loudness_extractor = @extract_perceptual_loudness\npreprocess_audio.segment_length_in_seconds = 4\npreprocess_audio.hop_length_in_seconds = 4\npreprocess_audio.normalise_audio = True\npreprocess_audio.control_decimation_factor = %control_hop"
  },
  {
    "path": "gin/models/newt.gin",
    "content": "sample_rate = 16000\n\ncontrol_embedding_size = 128\nn_waveshapers = 64\ncontrol_hop = 128\n\nHarmonicOscillator.n_harmonics = 101\nHarmonicOscillator.sample_rate = %sample_rate\n\nNEWT.n_waveshapers = %n_waveshapers\nNEWT.control_embedding_size = %control_embedding_size\nNEWT.shaping_fn_size = 8\nNEWT.out_channels = 1\nTrainableNonlinearity.depth = 4\n\nControlModule.control_size = 2\nControlModule.hidden_size = 128\nControlModule.embedding_size = %control_embedding_size\n\nnoise_synth/TimeDistributedMLP.in_size = %control_embedding_size\nnoise_synth/TimeDistributedMLP.hidden_size = %control_embedding_size\nnoise_synth/TimeDistributedMLP.out_size = 129\nnoise_synth/TimeDistributedMLP.depth = 4\nnoise_synth/FIRNoiseSynth.ir_length = 256\nnoise_synth/FIRNoiseSynth.hop_length = %control_hop\n\nReverb.length_in_seconds = 2\nReverb.sr = %sample_rate\n\n\nNeuralWaveshaping.n_waveshapers = %n_waveshapers\nNeuralWaveshaping.control_hop = %control_hop\nNeuralWaveshaping.sample_rate = %sample_rate"
  },
  {
    "path": "gin/train/train_newt.gin",
    "content": "get_model.model = @NeuralWaveshaping\n\ninclude 'gin/models/newt.gin'\n\nURMPDataModule.batch_size = 8\n\nNeuralWaveshaping.learning_rate = 0.001\nNeuralWaveshaping.lr_decay = 0.9\nNeuralWaveshaping.lr_decay_interval = 10000\n\ntrainer_kwargs.max_steps = 120000\ntrainer_kwargs.gradient_clip_val = 2.0\ntrainer_kwargs.accelerator = 'dp'"
  },
  {
    "path": "neural_waveshaping_synthesis/__init__.py",
    "content": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/__init__.py",
    "content": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/general.py",
    "content": "import os\n\nimport gin\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\n\n\nclass GeneralDataset(torch.utils.data.Dataset):\n    def __init__(self, path: str, split: str = \"train\", load_to_memory: bool = True):\n        super().__init__()\n        # split = \"train\"\n        self.load_to_memory = load_to_memory\n\n        self.split_path = os.path.join(path, split)\n        self.data_list = [\n            f.replace(\"audio_\", \"\")\n            for f in os.listdir(os.path.join(self.split_path, \"audio\"))\n            if f[-4:] == \".npy\"\n        ]\n        if load_to_memory:\n            self.audio = [\n                np.load(os.path.join(self.split_path, \"audio\", \"audio_%s\" % name))\n                for name in self.data_list\n            ]\n            self.control = [\n                np.load(os.path.join(self.split_path, \"control\", \"control_%s\" % name))\n                for name in self.data_list\n            ]\n\n        self.data_mean = np.load(os.path.join(path, \"data_mean.npy\"))\n        self.data_std = np.load(os.path.join(path, \"data_std.npy\"))\n\n    def __len__(self):\n        return len(self.data_list)\n\n    def __getitem__(self, idx):\n        # idx = 10\n        name = self.data_list[idx]\n        if self.load_to_memory:\n            audio = self.audio[idx]\n            control = self.control[idx]\n        else:\n            audio_name = \"audio_%s\" % name\n            control_name = \"control_%s\" % name\n\n            audio = np.load(os.path.join(self.split_path, \"audio\", audio_name))\n            control = np.load(os.path.join(self.split_path, \"control\", control_name))\n        denormalised_control = (control * self.data_std) + self.data_mean\n\n        return {\n            \"audio\": audio,\n            \"f0\": denormalised_control[0:1, :],\n            \"amp\": denormalised_control[1:2, :],\n            \"control\": control,\n            \"name\": os.path.splitext(os.path.basename(name))[0],\n        }\n\n\n@gin.configurable\nclass GeneralDataModule(pl.LightningDataModule):\n    def __init__(\n        self,\n        data_root: str,\n        batch_size: int = 16,\n        load_to_memory: bool = True,\n        **dataloader_args\n    ):\n        super().__init__()\n        self.data_dir = data_root\n        self.batch_size = batch_size\n        self.dataloader_args = dataloader_args\n        self.load_to_memory = load_to_memory\n\n    def prepare_data(self):\n        pass\n\n    def setup(self, stage: str = None):\n        if stage == \"fit\":\n            self.urmp_train = GeneralDataset(self.data_dir, \"train\", self.load_to_memory)\n            self.urmp_val = GeneralDataset(self.data_dir, \"val\", self.load_to_memory)\n        elif stage == \"test\" or stage is None:\n            self.urmp_test = GeneralDataset(self.data_dir, \"test\", self.load_to_memory)\n\n    def _make_dataloader(self, dataset):\n        return torch.utils.data.DataLoader(\n            dataset, self.batch_size, **self.dataloader_args\n        )\n\n    def train_dataloader(self):\n        return self._make_dataloader(self.urmp_train)\n\n    def val_dataloader(self):\n        return self._make_dataloader(self.urmp_val)\n\n    def test_dataloader(self):\n        return self._make_dataloader(self.urmp_test)\n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/urmp.py",
    "content": "import os\n\nimport gin\n\nfrom .general import GeneralDataModule\n\n\n@gin.configurable\nclass URMPDataModule(GeneralDataModule):\n    def __init__(\n        self,\n        urmp_root: str,\n        instrument: str,\n        batch_size: int = 16,\n        load_to_memory: bool = True,\n        **dataloader_args\n    ):\n        super().__init__(\n            os.path.join(urmp_root, instrument),\n            batch_size,\n            load_to_memory,\n            **dataloader_args\n        )\n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/__init__.py",
    "content": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/create_dataset.py",
    "content": "import os\nimport shutil\nfrom typing import Sequence\n\nimport gin\nimport numpy as np\nfrom sklearn.model_selection import train_test_split\n\nfrom .preprocess_audio import preprocess_audio\nfrom ...utils import seed_all\n\n\ndef create_directory(path):\n    if not os.path.isdir(path):\n        try:\n            os.mkdir(path)\n        except OSError:\n            print(\"Failed to create directory %s\" % path)\n        else:\n            print(\"Created directory %s...\" % path)\n    else:\n        print(\"Directory %s already exists. Skipping...\" % path)\n\n\ndef create_directories(target_root, names):\n    create_directory(target_root)\n    for name in names:\n        create_directory(os.path.join(target_root, name))\n\n\ndef make_splits(\n    audio_list: Sequence[str],\n    control_list: Sequence[str],\n    splits: Sequence[str],\n    split_proportions: Sequence[float],\n):\n    assert len(splits) == len(\n        split_proportions\n    ), \"Length of splits and split_proportions must be equal\"\n\n    train_size = split_proportions[0] / np.sum(split_proportions)\n    audio_0, audio_1, control_0, control_1 = train_test_split(\n        audio_list, control_list, train_size=train_size\n    )\n    if len(splits) == 2:\n        return {\n            splits[0]: {\n                \"audio\": audio_0,\n                \"control\": control_0,\n            },\n            splits[1]: {\n                \"audio\": audio_1,\n                \"control\": control_1,\n            },\n        }\n    elif len(splits) > 2:\n        return {\n            splits[0]: {\n                \"audio\": audio_0,\n                \"control\": control_0,\n            },\n            **make_splits(audio_1, control_1, splits[1:], split_proportions[1:]),\n        }\n    elif len(splits) == 1:\n        return {\n            splits[0]: {\n                \"audio\": audio_list,\n                \"control\": control_list,\n            }\n        }\n\n\ndef lazy_create_dataset(\n    files: Sequence[str],\n    output_directory: str,\n    splits: Sequence[str],\n    split_proportions: Sequence[float],\n):\n    audio_files = []\n    control_files = []\n    audio_max = 1e-5\n    means = []\n    stds = []\n    lengths = []\n    control_mean = 0\n    control_std = 1\n\n    for i, (all_audio, all_f0, all_confidence, all_loudness, all_mfcc) in enumerate(\n        preprocess_audio(files)\n    ):\n        file = os.path.split(files[i])[-1].replace(\".wav\", \"\")\n        for j, (audio, f0, confidence, loudness, mfcc) in enumerate(\n            zip(all_audio, all_f0, all_confidence, all_loudness, all_mfcc)\n        ):\n            audio_file_name = \"audio_%s_%d.npy\" % (file, j)\n            control_file_name = \"control_%s_%d.npy\" % (file, j)\n\n            max_sample = np.abs(audio).max()\n            if max_sample > audio_max:\n                audio_max = max_sample\n\n            np.save(\n                os.path.join(output_directory, \"temp\", \"audio\", audio_file_name),\n                audio,\n            )\n            control = np.stack((f0, loudness, confidence), axis=0)\n            control = np.concatenate((control, mfcc), axis=0)\n            np.save(\n                os.path.join(output_directory, \"temp\", \"control\", control_file_name),\n                control,\n            )\n\n            audio_files.append(audio_file_name)\n            control_files.append(control_file_name)\n\n            means.append(control.mean(axis=-1))\n            stds.append(control.std(axis=-1))\n            lengths.append(control.shape[-1])\n\n    if len(audio_files) == 0:\n        print(\"No datapoints to split. Skipping...\")\n        return\n\n    data_mean = np.mean(np.stack(means, axis=-1), axis=-1)[:, np.newaxis]\n    lengths = np.stack(lengths)[np.newaxis, :]\n    stds = np.stack(stds, axis=-1)\n    data_std = np.sqrt(np.sum(lengths * stds ** 2, axis=-1) / np.sum(lengths))[\n        :, np.newaxis\n    ]\n\n    print(\"Saving dataset stats...\")\n    np.save(os.path.join(output_directory, \"data_mean.npy\"), data_mean)\n    np.save(os.path.join(output_directory, \"data_std.npy\"), data_std)\n\n    splits = make_splits(audio_files, control_files, splits, split_proportions)\n    for split in splits:\n        for audio_file in splits[split][\"audio\"]:\n            audio = np.load(os.path.join(output_directory, \"temp\", \"audio\", audio_file))\n            audio = audio / audio_max\n            np.save(os.path.join(output_directory, split, \"audio\", audio_file), audio)\n        for control_file in splits[split][\"control\"]:\n            control = np.load(\n                os.path.join(output_directory, \"temp\", \"control\", control_file)\n            )\n            control = (control - data_mean) / data_std\n            np.save(\n                os.path.join(output_directory, split, \"control\", control_file), control\n            )\n\n\n@gin.configurable\ndef create_dataset(\n    files: Sequence[str],\n    output_directory: str,\n    splits: Sequence[str] = (\"train\", \"val\", \"test\"),\n    split_proportions: Sequence[float] = (0.8, 0.1, 0.1),\n    lazy: bool = True,\n):\n    create_directories(output_directory, (*splits, \"temp\"))\n    for split in (*splits, \"temp\"):\n        create_directories(os.path.join(output_directory, split), (\"audio\", \"control\"))\n\n    if lazy:\n        lazy_create_dataset(files, output_directory, splits, split_proportions)\n\n    shutil.rmtree(os.path.join(output_directory, \"temp\"))"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/f0_extraction.py",
    "content": "from functools import partial\nfrom typing import Callable, Optional, Sequence, Union\n\nimport gin\nimport librosa\nimport numpy as np\nimport torch\nimport torchcrepe\n\nfrom .upsampling import linear_interpolation\nfrom ...utils import apply\n\n\nCREPE_WINDOW_LENGTH = 1024\n\n@gin.configurable\ndef extract_f0_with_crepe(\n    audio: np.ndarray,\n    sample_rate: float,\n    hop_length: int = 128,\n    minimum_frequency: float = 50.0,\n    maximum_frequency: float = 2000.0,\n    full_model: bool = True,\n    batch_size: int = 2048,\n    device: Union[str, torch.device] = \"cpu\",\n    interpolate_fn: Optional[Callable] = linear_interpolation,\n):\n    # convert to torch tensor with channel dimension (necessary for CREPE)\n    audio = torch.tensor(audio).unsqueeze(0)\n    f0, confidence = torchcrepe.predict(\n        audio,\n        sample_rate,\n        hop_length,\n        minimum_frequency,\n        maximum_frequency,\n        \"full\" if full_model else \"tiny\",\n        batch_size=batch_size,\n        device=device,\n        decoder=torchcrepe.decode.viterbi,\n        # decoder=torchcrepe.decode.weighted_argmax,\n        return_harmonicity=True,\n    )\n\n    f0, confidence = f0.squeeze().numpy(), confidence.squeeze().numpy()\n\n    if interpolate_fn:\n        f0 = interpolate_fn(\n            f0, CREPE_WINDOW_LENGTH, hop_length, original_length=audio.shape[-1]\n        )\n        confidence = interpolate_fn(\n            confidence,\n            CREPE_WINDOW_LENGTH,\n            hop_length,\n            original_length=audio.shape[-1],\n        )\n\n    return f0, confidence\n\n\n@gin.configurable\ndef extract_f0_with_pyin(\n    audio: np.ndarray,\n    sample_rate: float,\n    minimum_frequency: float = 65.0,  # recommended minimum freq from librosa docs\n    maximum_frequency: float = 2093.0,  # recommended maximum freq from librosa docs\n    frame_length: int = 1024,\n    hop_length: int = 128,\n    fill_na: Optional[float] = None,\n    interpolate_fn: Optional[Callable] = linear_interpolation,\n):\n    f0, _, voiced_prob = librosa.pyin(\n        audio,\n        sr=sample_rate,\n        fmin=minimum_frequency,\n        fmax=maximum_frequency,\n        frame_length=frame_length,\n        hop_length=hop_length,\n        fill_na=fill_na,\n    )\n\n    if interpolate_fn:\n        f0 = interpolate_fn(\n            f0, frame_length, hop_length, original_length=audio.shape[-1]\n        )\n        voiced_prob = interpolate_fn(\n            voiced_prob,\n            frame_length,\n            hop_length,\n            original_length=audio.shape[-1],\n        )\n\n    return f0, voiced_prob\n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/loudness_extraction.py",
    "content": "from typing import Callable, Optional\nimport warnings\n\nimport gin\nimport librosa\nimport numpy as np\n\nfrom .upsampling import linear_interpolation\n\n\ndef compute_power_spectrogram(\n    audio: np.ndarray,\n    n_fft: int,\n    hop_length: int,\n    window: str,\n    epsilon: float,\n):\n    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, window=window)\n    magnitude_spectrogram = np.abs(spectrogram)\n    power_spectrogram = librosa.amplitude_to_db(\n        magnitude_spectrogram, ref=np.max, amin=epsilon\n    )\n    return power_spectrogram\n\n\ndef perform_perceptual_weighting(\n    power_spectrogram_in_db: np.ndarray, sample_rate: float, n_fft: int\n):\n    centre_frequencies = librosa.fft_frequencies(sample_rate, n_fft)\n\n    # We know that we will get a log(0) warning here due to the DC component -- we can\n    # safely ignore as it is clipped to the default min dB value of -80.0 dB\n    with warnings.catch_warnings():\n        warnings.simplefilter(\"ignore\")\n        weights = librosa.A_weighting(centre_frequencies)\n\n    weights = np.expand_dims(weights, axis=1)\n    weighted_spectrogram = power_spectrogram_in_db  # + weights\n    return weighted_spectrogram\n\n\n@gin.configurable\ndef extract_perceptual_loudness(\n    audio: np.ndarray,\n    sample_rate: float = 16000,\n    n_fft: int = 2048,\n    hop_length: int = 512,\n    window: str = \"hann\",\n    epsilon: float = 1e-5,\n    interpolate_fn: Optional[Callable] = linear_interpolation,\n    normalise: bool = True,\n):\n    power_spectrogram = compute_power_spectrogram(\n        audio, n_fft=n_fft, hop_length=hop_length, window=window, epsilon=epsilon\n    )\n    perceptually_weighted_spectrogram = perform_perceptual_weighting(\n        power_spectrogram, sample_rate=sample_rate, n_fft=n_fft\n    )\n    loudness = np.mean(perceptually_weighted_spectrogram, axis=0)\n    if interpolate_fn:\n        loudness = interpolate_fn(\n            loudness, n_fft, hop_length, original_length=audio.size\n        )\n\n    if normalise:\n        loudness = (loudness + 80) / 80\n\n    return loudness\n\n\n@gin.configurable\ndef extract_rms(\n    audio: np.ndarray,\n    window_size: int = 2048,\n    hop_length: int = 512,\n    sample_rate: Optional[float] = 16000.0,\n    interpolate_fn: Optional[Callable] = linear_interpolation,\n):\n    # pad audio to centre frames\n    padded_audio = np.pad(audio, (window_size // 2, window_size // 2))\n    frames = librosa.util.frame(padded_audio, window_size, hop_length)\n    squared = frames ** 2\n    mean = np.mean(squared, axis=0)\n    root = np.sqrt(mean)\n    if interpolate_fn:\n        assert sample_rate is not None, \"Must provide sample rate if upsampling\"\n        root = interpolate_fn(root, window_size, hop_length, original_length=audio.size)\n\n    return root\n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/mfcc_extraction.py",
    "content": "import gin\nimport librosa\nimport numpy as np\n\n\n@gin.configurable\ndef extract_mfcc(\n    audio: np.ndarray, sample_rate: float, n_fft: int, hop_length: int, n_mfcc: int\n):\n    mfcc = librosa.feature.mfcc(\n        audio, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length\n    )\n    return mfcc"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/preprocess_audio.py",
    "content": "from functools import partial\nfrom typing import Callable, Sequence, Union\n\nimport gin\nimport librosa\nimport numpy as np\nimport resampy\nimport scipy.io.wavfile as wavfile\n\nfrom .f0_extraction import extract_f0_with_crepe, extract_f0_with_pyin\nfrom .loudness_extraction import extract_perceptual_loudness, extract_rms\nfrom .mfcc_extraction import extract_mfcc\nfrom ...utils import apply, apply_unpack, unzip\n\n\ndef read_audio_files(files: list):\n    rates_and_audios = apply(wavfile.read, files)\n    return unzip(rates_and_audios)\n\n\ndef convert_to_float32_audio(audio: np.ndarray):\n    if audio.dtype == np.float32:\n        return audio\n\n    max_sample_value = np.iinfo(audio.dtype).max\n    floating_point_audio = audio / max_sample_value\n    return floating_point_audio.astype(np.float32)\n\n\ndef make_monophonic(audio: np.ndarray, strategy: str = \"keep_left\"):\n    # deal with non stereo array formats\n    if len(audio.shape) == 1:\n        return audio\n    elif len(audio.shape) != 2:\n        raise ValueError(\"Unknown audio array format.\")\n\n    # deal with single audio channel\n    if audio.shape[0] == 1:\n        return audio[0]\n    elif audio.shape[1] == 1:\n        return audio[:, 0]\n    # deal with more than two channels\n    elif audio.shape[0] != 2 and audio.shape[1] != 2:\n        raise ValueError(\"Expected stereo input audio but got too many channels.\")\n\n    # put channel first\n    if audio.shape[1] == 2:\n        audio = audio.T\n\n    # make stereo audio monophonic\n    if strategy == \"keep_left\":\n        return audio[0]\n    elif strategy == \"keep_right\":\n        return audio[1]\n    elif strategy == \"sum\":\n        return np.mean(audio, axis=0)\n    elif strategy == \"diff\":\n        return audio[0] - audio[1]\n\n\ndef normalise_signal(audio: np.ndarray, factor: float):\n    return audio / factor\n\n\ndef resample_audio(audio: np.ndarray, original_sr: float, target_sr: float):\n    return resampy.resample(audio, original_sr, target_sr)\n\n\ndef segment_signal(\n    signal: np.ndarray,\n    sample_rate: float,\n    segment_length_in_seconds: float,\n    hop_length_in_seconds: float,\n):\n    segment_length_in_samples = int(sample_rate * segment_length_in_seconds)\n    hop_length_in_samples = int(sample_rate * hop_length_in_seconds)\n    segments = librosa.util.frame(\n        signal, segment_length_in_samples, hop_length_in_samples\n    )\n    return segments\n\n\ndef filter_segments(\n    threshold: float,\n    key_segments: np.ndarray,\n    segments: Sequence[np.ndarray],\n):\n    mean_keys = key_segments.mean(axis=0)\n    mask = mean_keys > threshold\n    filtered_segments = apply(\n        lambda x: x[:, mask] if len(x.shape) == 2 else x[:, :, mask], segments\n    )\n    return filtered_segments\n\n\ndef preprocess_single_audio_file(\n    file: str,\n    control_decimation_factor: float,\n    target_sr: float = 16000.0,\n    segment_length_in_seconds: float = 4.0,\n    hop_length_in_seconds: float = 2.0,\n    confidence_threshold: float = 0.85,\n    f0_extractor: Callable = extract_f0_with_crepe,\n    loudness_extractor: Callable = extract_perceptual_loudness,\n    mfcc_extractor: Callable = extract_mfcc,\n    normalisation_factor: Union[float, None] = None,\n):\n    print(\"Loading audio file: %s...\" % file)\n    original_sr, audio = wavfile.read(file)\n    audio = convert_to_float32_audio(audio)\n    audio = make_monophonic(audio)\n\n    if normalisation_factor:\n        audio = normalise_signal(audio, normalisation_factor)\n\n    print(\"Resampling audio file: %s...\" % file)\n    audio = resample_audio(audio, original_sr, target_sr)\n\n    print(\"Extracting f0 with extractor '%s': %s...\" % (f0_extractor.__name__, file))\n    f0, confidence = f0_extractor(audio)\n\n    print(\n        \"Extracting loudness with extractor '%s': %s...\"\n        % (loudness_extractor.__name__, file)\n    )\n    loudness = loudness_extractor(audio)\n\n    print(\n        \"Extracting MFCC with extractor '%s': %s...\" % (mfcc_extractor.__name__, file)\n    )\n    mfcc = mfcc_extractor(audio)\n\n    print(\"Segmenting audio file: %s...\" % file)\n    segmented_audio = segment_signal(\n        audio, target_sr, segment_length_in_seconds, hop_length_in_seconds\n    )\n\n    print(\"Segmenting control signals: %s...\" % file)\n    segmented_f0 = segment_signal(\n        f0,\n        target_sr / (control_decimation_factor or 1),\n        segment_length_in_seconds,\n        hop_length_in_seconds,\n    )\n    segmented_confidence = segment_signal(\n        confidence,\n        target_sr / (control_decimation_factor or 1),\n        segment_length_in_seconds,\n        hop_length_in_seconds,\n    )\n    segmented_loudness = segment_signal(\n        loudness,\n        target_sr / (control_decimation_factor or 1),\n        segment_length_in_seconds,\n        hop_length_in_seconds,\n    )\n    segmented_mfcc = segment_signal(\n        mfcc,\n        target_sr / (control_decimation_factor or 1),\n        segment_length_in_seconds,\n        hop_length_in_seconds,\n    )\n\n    (\n        filtered_audio,\n        filtered_f0,\n        filtered_confidence,\n        filtered_loudness,\n        filtered_mfcc,\n    ) = filter_segments(\n        confidence_threshold,\n        segmented_confidence,\n        (\n            segmented_audio,\n            segmented_f0,\n            segmented_confidence,\n            segmented_loudness,\n            segmented_mfcc,\n        ),\n    )\n\n    if filtered_audio.shape[-1] == 0:\n        print(\"No segments exceeding confidence threshold...\")\n        audio_split, f0_split, confidence_split, loudness_split, mfcc_split = (\n            [],\n            [],\n            [],\n            [],\n            [],\n        )\n    else:\n        split = lambda x: [e.squeeze() for e in np.split(x, x.shape[-1], -1)]\n        audio_split = split(filtered_audio)\n        f0_split = split(filtered_f0)\n        confidence_split = split(filtered_confidence)\n        loudness_split = split(filtered_loudness)\n        mfcc_split = split(filtered_mfcc)\n\n    return audio_split, f0_split, confidence_split, loudness_split, mfcc_split\n\n\n@gin.configurable\ndef preprocess_audio(\n    files: list,\n    control_decimation_factor: float,\n    target_sr: float = 16000,\n    segment_length_in_seconds: float = 4.0,\n    hop_length_in_seconds: float = 2.0,\n    confidence_threshold: float = 0.85,\n    f0_extractor: Callable = extract_f0_with_crepe,\n    loudness_extractor: Callable = extract_perceptual_loudness,\n    normalise_audio: bool = False,\n):\n    if normalise_audio:\n        print(\"Finding normalisation factor...\")\n        normalisation_factor = 0\n        for file in files:\n            _, audio = wavfile.read(file)\n            audio = convert_to_float32_audio(audio)\n            audio = make_monophonic(audio)\n            max_value = np.abs(audio).max()\n            normalisation_factor = (\n                max_value if max_value > normalisation_factor else normalisation_factor\n            )\n\n    processor = partial(\n        preprocess_single_audio_file,\n        control_decimation_factor=control_decimation_factor,\n        target_sr=target_sr,\n        segment_length_in_seconds=segment_length_in_seconds,\n        hop_length_in_seconds=hop_length_in_seconds,\n        f0_extractor=f0_extractor,\n        loudness_extractor=loudness_extractor,\n        normalisation_factor=None if not normalise_audio else normalisation_factor,\n    )\n    for file in files:\n        yield processor(file)\n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/upsampling.py",
    "content": "from typing import Optional\n\nimport gin\nimport numpy as np\nimport scipy.interpolate\nimport scipy.signal.windows\n\n\ndef get_padded_length(frames: int, window_length: int, hop_length: int):\n    return frames * hop_length + window_length - hop_length\n\n\ndef get_source_target_axes(frames: int, window_length: int, hop_length: int):\n    padded_length = get_padded_length(frames, window_length, hop_length)\n    source_x = np.linspace(0, frames - 1, frames)\n    target_x = np.linspace(0, frames - 1, padded_length)\n    return source_x, target_x\n\n\n@gin.configurable\ndef linear_interpolation(\n    signal: np.ndarray,\n    window_length: int,\n    hop_length: int,\n    original_length: Optional[int] = None,\n):\n    source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)\n\n    interpolated = np.interp(target_x, source_x, signal)\n    if original_length:\n        interpolated = interpolated[window_length // 2 :]\n        interpolated = interpolated[:original_length]\n\n    return interpolated\n\n\n@gin.configurable\ndef cubic_spline_interpolation(\n    signal: np.ndarray,\n    window_length: int,\n    hop_length: int,\n    original_length: Optional[int] = None,\n):\n    source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)\n\n    interpolant = scipy.interpolate.interp1d(source_x, signal, kind=\"cubic\")\n    interpolated = interpolant(target_x)\n    if original_length:\n        interpolated = interpolated[window_length // 2 :]\n        interpolated = interpolated[:original_length]\n\n    return interpolated\n\n\n@gin.configurable\ndef overlap_add_upsample(\n    signal: np.ndarray,\n    window_length: int,\n    hop_length: int,\n    window_fn: str = \"hann\",\n    window_scale: int = 2,\n    original_length: Optional[int] = None,\n):\n    window = scipy.signal.windows.get_window(window_fn, hop_length * window_scale)\n    padded_length = get_padded_length(signal.size, window_length, hop_length)\n    padded_output = np.zeros(padded_length)\n\n    for i, value in enumerate(signal):\n        window_start = i * hop_length\n        window_end = window_start + hop_length * window_scale\n        padded_output[window_start:window_end] += window * value\n\n    if original_length:\n        output = padded_output[(padded_length - original_length) // 2:]\n        output = output[:original_length]\n    else:\n        output = padded_output\n\n    return output\n"
  },
  {
    "path": "neural_waveshaping_synthesis/models/__init__.py",
    "content": ""
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/__init__.py",
    "content": ""
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/dynamic.py",
    "content": "import gin\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FiLM(nn.Module):\n    def forward(self, x, gamma, beta):\n        return gamma * x + beta\n\n\nclass TimeDistributedLayerNorm(nn.Module):\n    def __init__(self, size: int):\n        super().__init__()\n        self.layer_norm = nn.LayerNorm(size)\n\n    def forward(self, x):\n        return self.layer_norm(x.transpose(1, 2)).transpose(1, 2)\n\n\n@gin.configurable\nclass TimeDistributedMLP(nn.Module):\n    def __init__(self, in_size: int, hidden_size: int, out_size: int, depth: int = 3):\n        super().__init__()\n        assert depth >= 3, \"Depth must be at least 3\"\n        layers = []\n        for i in range(depth):\n            layers.append(\n                nn.Conv1d(\n                    in_size if i == 0 else hidden_size,\n                    hidden_size if i < depth - 1 else out_size,\n                    1,\n                )\n            )\n            if i < depth - 1:\n                layers.append(TimeDistributedLayerNorm(hidden_size))\n                layers.append(nn.LeakyReLU())\n        self.net = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.net(x)\n"
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/generators.py",
    "content": "import math\nfrom typing import Callable\n\nimport gin\nimport torch\nimport torch.fft\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\n@gin.configurable\nclass FIRNoiseSynth(nn.Module):\n    def __init__(\n        self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window\n    ):\n        super().__init__()\n        self.ir_length = ir_length\n        self.hop_length = hop_length\n        self.register_buffer(\"window\", window_fn(ir_length))\n\n    def forward(self, H_re):\n        H_im = torch.zeros_like(H_re)\n        H_z = torch.complex(H_re, H_im)\n\n        h = torch.fft.irfft(H_z.transpose(1, 2))\n        h = h.roll(self.ir_length // 2, -1)\n        h = h * self.window.view(1, 1, -1)\n        H = torch.fft.rfft(h)\n\n        noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)\n        X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)\n        X = X.unsqueeze(0)\n        Y = X * H.transpose(1, 2)\n        y = torch.istft(Y, self.ir_length, self.hop_length, center=False)\n        return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]\n\n\n@gin.configurable\nclass HarmonicOscillator(nn.Module):\n    def __init__(self, n_harmonics, sample_rate):\n        super().__init__()\n        self.sample_rate = sample_rate\n        self.n_harmonics = n_harmonics\n        self.register_buffer(\"harmonic_axis\", self._create_harmonic_axis(n_harmonics))\n        self.register_buffer(\"rand_phase\", torch.ones(1, n_harmonics, 1) * math.tau)\n\n    def _create_harmonic_axis(self, n_harmonics):\n        return torch.arange(1, n_harmonics + 1).view(1, -1, 1)\n\n    def _create_antialias_mask(self, f0):\n        freqs = f0.unsqueeze(1) * self.harmonic_axis\n        return freqs < (self.sample_rate / 2)\n\n    def _create_phase_shift(self, n_harmonics):\n        shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi\n        return shift\n\n    def forward(self, f0):\n        phase = math.tau * f0.cumsum(-1) / self.sample_rate\n        harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)\n        harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)\n        antialias_mask = self._create_antialias_mask(f0)\n\n        output = torch.sin(harmonic_phase) * antialias_mask\n\n        return output\n"
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/shaping.py",
    "content": "import gin\nimport torch\nimport torch.fft\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .dynamic import FiLM, TimeDistributedMLP\n\n\nclass Sine(nn.Module):\n    def forward(self, x: torch.Tensor):\n        return torch.sin(x)\n\n\n@gin.configurable\nclass TrainableNonlinearity(nn.Module):\n    def __init__(\n        self, channels, width, nonlinearity=nn.ReLU, final_nonlinearity=Sine, depth=3\n    ):\n        super().__init__()\n        self.input_scale = nn.Parameter(torch.randn(1, channels, 1) * 10)\n        layers = []\n        for i in range(depth):\n            layers.append(\n                nn.Conv1d(\n                    channels if i == 0 else channels * width,\n                    channels * width if i < depth - 1 else channels,\n                    1,\n                    groups=channels,\n                )\n            )\n            layers.append(nonlinearity() if i < depth - 1 else final_nonlinearity())\n\n        self.net = nn.Sequential(*layers)\n\n    def forward(self, x):\n        return self.net(self.input_scale * x)\n\n\n@gin.configurable\nclass NEWT(nn.Module):\n    def __init__(\n        self,\n        n_waveshapers: int,\n        control_embedding_size: int,\n        shaping_fn_size: int = 16,\n        out_channels: int = 1,\n    ):\n        super().__init__()\n\n        self.n_waveshapers = n_waveshapers\n\n        self.mlp = TimeDistributedMLP(\n            control_embedding_size, control_embedding_size, n_waveshapers * 4, depth=4\n        )\n\n        self.waveshaping_index = FiLM()\n        self.shaping_fn = TrainableNonlinearity(\n            n_waveshapers, shaping_fn_size, nonlinearity=Sine\n        )\n        self.normalising_coeff = FiLM()\n\n        self.mixer = nn.Sequential(\n            nn.Conv1d(n_waveshapers, out_channels, 1),\n        )\n\n    def forward(self, exciter, control_embedding):\n        film_params = self.mlp(control_embedding)\n        film_params = F.upsample(film_params, exciter.shape[-1], mode=\"linear\")\n        gamma_index, beta_index, gamma_norm, beta_norm = torch.split(\n            film_params, self.n_waveshapers, 1\n        )\n\n        x = self.waveshaping_index(exciter, gamma_index, beta_index)\n        x = self.shaping_fn(x)\n        x = self.normalising_coeff(x, gamma_norm, beta_norm)\n\n        # return x\n        return self.mixer(x)\n\n\nclass FastNEWT(NEWT):\n    def __init__(\n        self,\n        newt: NEWT,\n        table_size: int = 4096,\n        table_min: float = -3.0,\n        table_max: float = 3.0,\n    ):\n        super().__init__()\n        self.table_size = table_size\n        self.table_min = table_min\n        self.table_max = table_max\n\n        self.n_waveshapers = newt.n_waveshapers\n        self.mlp = newt.mlp\n\n        self.waveshaping_index = newt.waveshaping_index\n        self.normalising_coeff = newt.normalising_coeff\n        self.mixer = newt.mixer\n\n        self.lookup_table = self._init_lookup_table(\n            newt, table_size, self.n_waveshapers, table_min, table_max\n        )\n        self.to(next(iter(newt.parameters())).device)\n\n    def _init_lookup_table(\n        self,\n        newt: NEWT,\n        table_size: int,\n        n_waveshapers: int,\n        table_min: float,\n        table_max: float,\n    ):\n        sample_values = torch.linspace(table_min, table_max, table_size, device=next(iter(newt.parameters())).device).expand(\n            1, n_waveshapers, table_size\n        )\n        lookup_table = newt.shaping_fn(sample_values)[0]\n        return nn.Parameter(lookup_table)\n\n    def _lookup(self, idx):\n        return torch.stack(\n            [\n                torch.stack(\n                    [\n                        self.lookup_table[shaper, idx[batch, shaper]]\n                        for shaper in range(idx.shape[1])\n                    ],\n                    dim=0,\n                )\n                for batch in range(idx.shape[0])\n            ],\n            dim=0,\n        )\n\n    def shaping_fn(self, x):\n        idx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min)\n\n        lower = torch.floor(idx).long()\n        lower[lower < 0] = 0\n        lower[lower >= self.table_size] = self.table_size - 1\n\n        upper = lower + 1\n        upper[upper >= self.table_size] = self.table_size - 1\n\n        fract = idx - lower\n        lower_v = self._lookup(lower)\n        upper_v = self._lookup(upper)\n\n        output = (upper_v - lower_v) * fract + lower_v\n        return output\n\n\n@gin.configurable\nclass Reverb(nn.Module):\n    def __init__(self, length_in_seconds, sr):\n        super().__init__()\n        self.ir = nn.Parameter(torch.randn(1, sr * length_in_seconds - 1) * 1e-6)\n        self.register_buffer(\"initial_zero\", torch.zeros(1, 1))\n\n    def forward(self, x):\n        ir_ = torch.cat((self.initial_zero, self.ir), dim=-1)\n        if x.shape[-1] > ir_.shape[-1]:\n            ir_ = F.pad(ir_, (0, x.shape[-1] - ir_.shape[-1]))\n            x_ = x\n        else:\n            x_ = F.pad(x, (0, ir_.shape[-1] - x.shape[-1]))\n        return (\n            x\n            + torch.fft.irfft(torch.fft.rfft(x_) * torch.fft.rfft(ir_))[\n                ..., : x.shape[-1]\n            ]\n        )\n"
  },
  {
    "path": "neural_waveshaping_synthesis/models/neural_waveshaping.py",
    "content": "import auraloss\nimport gin\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport wandb\n\nfrom .modules.dynamic import TimeDistributedMLP\nfrom .modules.generators import FIRNoiseSynth, HarmonicOscillator\nfrom .modules.shaping import NEWT, Reverb\n\ngin.external_configurable(nn.GRU, module=\"torch.nn\")\ngin.external_configurable(nn.Conv1d, module=\"torch.nn\")\n\n\n@gin.configurable\nclass ControlModule(nn.Module):\n    def __init__(self, control_size: int, hidden_size: int, embedding_size: int):\n        super().__init__()\n        self.gru = nn.GRU(control_size, hidden_size, batch_first=True)\n        self.proj = nn.Conv1d(hidden_size, embedding_size, 1)\n\n    def forward(self, x):\n        x, _ = self.gru(x.transpose(1, 2))\n        return self.proj(x.transpose(1, 2))\n\n\n@gin.configurable\nclass NeuralWaveshaping(pl.LightningModule):\n    def __init__(\n        self,\n        n_waveshapers: int,\n        control_hop: int,\n        sample_rate: float = 16000,\n        learning_rate: float = 1e-3,\n        lr_decay: float = 0.9,\n        lr_decay_interval: int = 10000,\n        log_audio: bool = False,\n    ):\n        super().__init__()\n        self.save_hyperparameters()\n        self.learning_rate = learning_rate\n        self.lr_decay = lr_decay\n        self.lr_decay_interval = lr_decay_interval\n        self.control_hop = control_hop\n        self.log_audio = log_audio\n\n        self.sample_rate = sample_rate\n\n        self.embedding = ControlModule()\n\n        self.osc = HarmonicOscillator()\n        self.harmonic_mixer = nn.Conv1d(self.osc.n_harmonics, n_waveshapers, 1)\n\n        self.newt = NEWT()\n\n        with gin.config_scope(\"noise_synth\"):\n            self.h_generator = TimeDistributedMLP()\n            self.noise_synth = FIRNoiseSynth()\n\n        self.reverb = Reverb()\n\n    def render_exciter(self, f0):\n        sig = self.osc(f0[:, 0])\n        sig = self.harmonic_mixer(sig)\n        return sig\n\n    def get_embedding(self, control):\n        f0, other = control[:, 0:1], control[:, 1:2]\n        control = torch.cat((f0, other), dim=1)\n        return self.embedding(control)\n\n    def forward(self, f0, control):\n        f0_upsampled = F.upsample(f0, f0.shape[-1] * self.control_hop, mode=\"linear\")\n        x = self.render_exciter(f0_upsampled)\n\n        control_embedding = self.get_embedding(control)\n\n        x = self.newt(x, control_embedding)\n\n        H = self.h_generator(control_embedding)\n        noise = self.noise_synth(H)\n\n        x = torch.cat((x, noise), dim=1)\n        x = x.sum(1)\n\n        x = self.reverb(x)\n\n        return x\n\n    def configure_optimizers(self):\n        self.stft_loss = auraloss.freq.MultiResolutionSTFTLoss()\n\n        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n        scheduler = torch.optim.lr_scheduler.StepLR(\n            optimizer, self.lr_decay_interval, self.lr_decay\n        )\n        return {\n            \"optimizer\": optimizer,\n            \"lr_scheduler\": {\"scheduler\": scheduler, \"interval\": \"step\"},\n        }\n\n    def _run_step(self, batch):\n        audio = batch[\"audio\"].float()\n        f0 = batch[\"f0\"].float()\n        control = batch[\"control\"].float()\n\n        recon = self(f0, control)\n\n        loss = self.stft_loss(recon, audio)\n        return loss, recon, audio\n\n    def _log_audio(self, name, audio):\n        wandb.log(\n            {\n                \"audio/%s\"\n                % name: wandb.Audio(audio, sample_rate=self.sample_rate, caption=name)\n            },\n            commit=False,\n        )\n\n    def training_step(self, batch, batch_idx):\n        loss, _, _ = self._run_step(batch)\n        self.log(\n            \"train/loss\",\n            loss.item(),\n            on_step=False,\n            on_epoch=True,\n            prog_bar=True,\n            logger=True,\n            sync_dist=True,\n        )\n        return loss\n\n    def validation_step(self, batch, batch_idx):\n        loss, recon, audio = self._run_step(batch)\n        self.log(\n            \"val/loss\",\n            loss.item(),\n            on_step=False,\n            on_epoch=True,\n            prog_bar=True,\n            logger=True,\n            sync_dist=True,\n        )\n        if batch_idx == 0 and self.log_audio:\n            self._log_audio(\"original\", audio[0].detach().cpu().squeeze())\n            self._log_audio(\"recon\", recon[0].detach().cpu().squeeze())\n        return loss\n\n    def test_step(self, batch, batch_idx):\n        loss, recon, audio = self._run_step(batch)\n        self.log(\n            \"test/loss\",\n            loss.item(),\n            on_step=False,\n            on_epoch=True,\n            prog_bar=True,\n            logger=True,\n            sync_dist=True,\n        )\n        if batch_idx == 0:\n            self._log_audio(\"original\", audio[0].detach().cpu().squeeze())\n            self._log_audio(\"recon\", recon[0].detach().cpu().squeeze())\n"
  },
  {
    "path": "neural_waveshaping_synthesis/utils/__init__.py",
    "content": "from .utils import *\nfrom .seed_all import *"
  },
  {
    "path": "neural_waveshaping_synthesis/utils/seed_all.py",
    "content": "import numpy as np\nimport os\nimport random\nimport torch\n\ndef seed_all(seed):\n    np.random.seed(seed)\n    os.environ['PYTHONHASHSEED'] = str(seed)\n    random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed(seed)\n    torch.backends.cudnn.deterministic = True"
  },
  {
    "path": "neural_waveshaping_synthesis/utils/utils.py",
    "content": "import os\nfrom typing import Callable, Sequence\n\n\ndef apply(fn: Callable[[any], any], x: Sequence[any]):\n    if type(x) not in (tuple, list):\n        raise TypeError(\"x must be a tuple or list.\")\n    return type(x)([fn(element) for element in x])\n\n\ndef apply_unpack(fn: Callable[[any], any], x: Sequence[Sequence[any]]):\n    if type(x) not in (tuple, list):\n        raise TypeError(\"x must be a tuple or list.\")\n    return type(x)([fn(*element) for element in x])\n\n\ndef unzip(x: Sequence[any]):\n    return list(zip(*x))\n\n\ndef make_dir_if_not_exists(path):\n    if not os.path.exists(path):\n        os.makedirs(path, exist_ok=True)\n"
  },
  {
    "path": "requirements.txt",
    "content": "auraloss==0.2.1\nblack==20.8b1\nclick==7.1.2\ngin-config==0.4.0\nlibrosa==0.8.0\nnumpy==1.20.1\npytorch_lightning==1.1.2\nresampy==0.2.2\nscipy==1.6.1\ntorch==1.7.1\ntorchcrepe==0.0.12\n"
  },
  {
    "path": "scripts/create_dataset.py",
    "content": "import os\n\nimport click\nimport gin\n\nfrom neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset\nfrom neural_waveshaping_synthesis.utils import seed_all\n\n\ndef get_filenames(directory):\n    return [os.path.join(directory, f) for f in os.listdir(directory) if \".wav\" in f]\n\n\n@click.command()\n@click.option(\"--gin-file\", prompt=\"Gin config file\")\n@click.option(\"--data-directory\", prompt=\"Data directory\")\n@click.option(\"--output-directory\", prompt=\"Output directory\")\n@click.option(\"--seed\", default=0)\n@click.option(\"--device\", default=\"cpu\")\ndef main(gin_file, data_directory, output_directory, seed=0, device=\"cpu\"):\n    gin.constant(\"device\", device)\n    gin.parse_config_file(gin_file)\n\n    seed_all(seed)\n\n    files = get_filenames(data_directory)\n    create_dataset(files, output_directory)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "scripts/create_urmp_dataset.py",
    "content": "import os\nfrom pathlib import Path\n\nimport click\nimport gin\n\nfrom neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset\nfrom neural_waveshaping_synthesis.utils import seed_all\n\nINSTRUMENTS = (\n    \"vn\",\n    \"vc\",\n    \"fl\",\n    \"cl\",\n    \"tpt\",\n    \"sax\",\n    \"tbn\",\n    \"ob\",\n    \"va\",\n    \"bn\",\n    \"hn\",\n    \"db\",\n)\n\n\ndef get_instrument_file_list(instrument_string, directory):\n    return [\n        str(f)\n        for f in Path(directory).glob(\n            \"**/*_%s_*/AuSep*_%s_*.wav\" % (instrument_string, instrument_string)\n        )\n    ]\n\n\n@click.command()\n@click.option(\"--gin-file\", prompt=\"Gin config file\")\n@click.option(\"--data-directory\", prompt=\"Data directory\")\n@click.option(\"--output-directory\", prompt=\"Output directory\")\n@click.option(\"--seed\", default=0)\n@click.option(\"--device\", default=\"cpu\")\ndef main(gin_file, data_directory, output_directory, seed=0, device=\"cpu\"):\n    gin.constant(\"device\", device)\n    gin.parse_config_file(gin_file)\n\n    seed_all(seed)\n\n    file_lists = {\n        instrument: get_instrument_file_list(instrument, data_directory)\n        for instrument in INSTRUMENTS\n    }\n    for instrument in file_lists:\n        create_dataset(\n            file_lists[instrument], os.path.join(output_directory, instrument)\n        )\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "scripts/resynthesise_dataset.py",
    "content": "import os\n\nimport click\nimport gin\nfrom scipy.io import wavfile\nfrom tqdm import tqdm\nimport torch\n\nfrom neural_waveshaping_synthesis.data.urmp import URMPDataset\nfrom neural_waveshaping_synthesis.models.modules.shaping import FastNEWT\nfrom neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping\nfrom neural_waveshaping_synthesis.utils import make_dir_if_not_exists\n\n\n@click.command()\n@click.option(\"--model-gin\", prompt=\"Model .gin file\")\n@click.option(\"--model-checkpoint\", prompt=\"Model checkpoint\")\n@click.option(\"--dataset-root\", prompt=\"Dataset root directory\")\n@click.option(\"--dataset-split\", default=\"test\")\n@click.option(\"--output-path\", default=\"audio_output\")\n@click.option(\"--load-data-to-memory\", default=False)\n@click.option(\"--device\", default=\"cuda:0\")\n@click.option(\"--batch-size\", default=8)\n@click.option(\"--num_workers\", default=16)\n@click.option(\"--use-fastnewt\", is_flag=True)\ndef main(\n    model_gin,\n    model_checkpoint,\n    dataset_root,\n    dataset_split,\n    output_path,\n    load_data_to_memory,\n    device,\n    batch_size,\n    num_workers,\n    use_fastnewt\n):\n    gin.parse_config_file(model_gin)\n    make_dir_if_not_exists(output_path)\n\n    data = URMPDataset(dataset_root, dataset_split, load_data_to_memory)\n    data_loader = torch.utils.data.DataLoader(\n        data, batch_size=batch_size, num_workers=num_workers\n    )\n\n    device = torch.device(device)\n    model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint)\n    model.eval()\n\n    if use_fastnewt:\n        model.newt = FastNEWT(model.newt)\n    \n    model = model.to(device)\n\n    for i, batch in enumerate(tqdm(data_loader)):\n        with torch.no_grad():\n            f0 = batch[\"f0\"].float().to(device)\n            control = batch[\"control\"].float().to(device)\n            output = model(f0, control)\n\n        target_audio = batch[\"audio\"].float().numpy()\n        output_audio = output.cpu().numpy()\n        for j in range(output_audio.shape[0]):\n            name = batch[\"name\"][j]\n            target_name = \"%s.target.wav\" % name\n            output_name = \"%s.output.wav\" % name\n            wavfile.write(\n                os.path.join(output_path, target_name),\n                model.sample_rate,\n                target_audio[j],\n            )\n            wavfile.write(\n                os.path.join(output_path, output_name),\n                model.sample_rate,\n                output_audio[j],\n            )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "scripts/time_buffer_sizes.py",
    "content": "import time\n\nimport click\nimport gin\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom tqdm import trange\n\nfrom neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping\nfrom neural_waveshaping_synthesis.models.modules.shaping import FastNEWT\n\nBUFFER_SIZES = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]\n\n@click.command()\n@click.option(\"--gin-file\", prompt=\"Model config gin file\")\n@click.option(\"--output-file\", prompt=\"output file\")\n@click.option(\"--num-iters\", default=100)\n@click.option(\"--batch-size\", default=1)\n@click.option(\"--device\", default=\"cpu\")\n@click.option(\"--length-in-seconds\", default=4)\n@click.option(\"--use-fast-newt\", is_flag=True)\n@click.option(\"--model-name\", default=\"ours\")\ndef main(\n    gin_file,\n    output_file,\n    num_iters,\n    batch_size,\n    device,\n    length_in_seconds,\n    use_fast_newt,\n    model_name,\n):\n    gin.parse_config_file(gin_file)\n    model = NeuralWaveshaping()\n    if use_fast_newt:\n        model.newt = FastNEWT(model.newt)\n    model.eval()\n    model = model.to(device)\n\n    # eliminate any lazy init costs\n    with torch.no_grad():\n        for i in range(10):\n            model(\n                torch.rand(4, 1, 250, device=device),\n                torch.rand(4, 2, 250, device=device),\n            )\n\n    times = []\n    with torch.no_grad():\n        for bs in BUFFER_SIZES:\n            dummy_control = torch.rand(\n                batch_size,\n                2,\n                bs // 128,\n                device=device,\n                requires_grad=False,\n            )\n            dummy_f0 = torch.rand(\n                batch_size,\n                1,\n                bs // 128,\n                device=device,\n                requires_grad=False,\n            )\n            for i in trange(num_iters):\n                start_time = time.time()\n                model(dummy_f0, dummy_control)\n                time_elapsed = time.time() - start_time\n                times.append(\n                    [model_name, device if device == \"cpu\" else \"gpu\", bs, time_elapsed]\n                )\n\n    df = pd.DataFrame(times)\n    df.to_csv(output_file)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "scripts/time_forward_pass.py",
    "content": "import time\n\nimport click\nimport gin\nimport numpy as np\nfrom scipy.stats import describe\nimport torch\nfrom tqdm import trange\n\nfrom neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping\nfrom neural_waveshaping_synthesis.models.modules.shaping import FastNEWT\n\n\n@click.command()\n@click.option(\"--gin-file\", prompt=\"Model config gin file\")\n@click.option(\"--num-iters\", default=100)\n@click.option(\"--batch-size\", default=1)\n@click.option(\"--device\", default=\"cpu\")\n@click.option(\"--length-in-seconds\", default=4)\n@click.option(\"--sample-rate\", default=16000)\n@click.option(\"--control-hop\", default=128)\n@click.option(\"--use-fast-newt\", is_flag=True)\ndef main(\n    gin_file, num_iters, batch_size, device, length_in_seconds, sample_rate, control_hop, use_fast_newt\n):\n    gin.parse_config_file(gin_file)\n    dummy_control = torch.rand(\n        batch_size,\n        2,\n        sample_rate * length_in_seconds // control_hop,\n        device=device,\n        requires_grad=False,\n    )\n    dummy_f0 = torch.rand(\n        batch_size,\n        1,\n        sample_rate * length_in_seconds // control_hop,\n        device=device,\n        requires_grad=False,\n    )\n    model = NeuralWaveshaping()\n    if use_fast_newt:\n        model.newt = FastNEWT(model.newt)\n    model.eval()\n    model = model.to(device)\n\n    times = []\n    with torch.no_grad():\n        for i in trange(num_iters):\n            start_time = time.time()\n            model(dummy_f0, dummy_control)\n            time_elapsed = time.time() - start_time\n            times.append(time_elapsed)\n\n    print(describe(times))\n    rtfs = np.array(times) / length_in_seconds\n    print(\"Mean RTF: %.4f\" % np.mean(rtfs))\n    print(\"90th percentile RTF: %.4f\" % np.percentile(rtfs, 90))\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "scripts/train.py",
    "content": "import click\nimport gin\nimport pytorch_lightning as pl\n\nfrom neural_waveshaping_synthesis.data.general import GeneralDataModule\nfrom neural_waveshaping_synthesis.data.urmp import URMPDataModule\nfrom neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping\n\n\n@gin.configurable\ndef get_model(model, with_wandb):\n    return model(log_audio=with_wandb)\n\n\n@gin.configurable\ndef trainer_kwargs(**kwargs):\n    return kwargs\n\n\n@click.command()\n@click.option(\"--gin-file\", prompt=\"Gin config file\")\n@click.option(\"--dataset-path\", prompt=\"Dataset root\")\n@click.option(\"--urmp\", is_flag=True)\n@click.option(\"--device\", default=\"0\")\n@click.option(\"--instrument\", default=\"vn\")\n@click.option(\"--load-data-to-memory\", is_flag=True)\n@click.option(\"--with-wandb\", is_flag=True)\n@click.option(\"--restore-checkpoint\", default=\"\")\ndef main(\n    gin_file,\n    dataset_path,\n    urmp,\n    device,\n    instrument,\n    load_data_to_memory,\n    with_wandb,\n    restore_checkpoint,\n):\n    gin.parse_config_file(gin_file)\n    model = get_model(with_wandb=with_wandb)\n\n    if urmp:\n        data = URMPDataModule(\n            dataset_path,\n            instrument,\n            load_to_memory=load_data_to_memory,\n            num_workers=16,\n            shuffle=True,\n        )\n    else:\n        data = GeneralDataModule(\n            dataset_path,\n            load_to_memory=load_data_to_memory,\n            num_workers=16,\n            shuffle=True,\n        )\n\n    checkpointing = pl.callbacks.ModelCheckpoint(\n        monitor=\"val/loss\", save_top_k=1, save_last=True\n    )\n    callbacks = [checkpointing]\n    if with_wandb:\n        lr_logger = pl.callbacks.LearningRateMonitor(logging_interval=\"epoch\")\n        callbacks.append(lr_logger)\n        logger = pl.loggers.WandbLogger(project=\"neural-waveshaping-synthesis\")\n        logger.watch(model, log=\"parameters\")\n\n\n    kwargs = trainer_kwargs()\n    trainer = pl.Trainer(\n        logger=logger if with_wandb else None,\n        callbacks=callbacks,\n        gpus=device,\n        resume_from_checkpoint=restore_checkpoint if restore_checkpoint != \"\" else None,\n        **kwargs\n    )\n    trainer.fit(model, data)\n\n\nif __name__ == \"__main__\":\n    main()"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(name=\"neural_waveshaping_synthesis\", version=\"0.0.1\", packages=find_packages())\n"
  }
]