main cb7aaea0412c cached
46 files
71.9 KB
18.7k tokens
102 symbols
1 requests
Download .txt
Repository: ben-hayes/neural-waveshaping-synthesis
Branch: main
Commit: cb7aaea0412c
Files: 46
Total size: 71.9 KB

Directory structure:
gitextract_i3m08itl/

├── .gitignore
├── LICENSE
├── README.md
├── checkpoints/
│   └── nws/
│       ├── fl/
│       │   ├── data_mean.npy
│       │   ├── data_std.npy
│       │   ├── epoch=4992-step=119831.ckpt
│       │   └── last.ckpt
│       ├── tpt/
│       │   ├── data_mean.npy
│       │   ├── data_std.npy
│       │   ├── epoch=358-step=24052.ckpt
│       │   └── last.ckpt
│       └── vn/
│           ├── data_mean.npy
│           ├── data_std.npy
│           ├── epoch=526-step=102237.ckpt
│           └── last.ckpt
├── gin/
│   ├── data/
│   │   └── urmp_4second_crepe.gin
│   ├── models/
│   │   └── newt.gin
│   └── train/
│       └── train_newt.gin
├── neural_waveshaping_synthesis/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── general.py
│   │   ├── urmp.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── create_dataset.py
│   │       ├── f0_extraction.py
│   │       ├── loudness_extraction.py
│   │       ├── mfcc_extraction.py
│   │       ├── preprocess_audio.py
│   │       └── upsampling.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── dynamic.py
│   │   │   ├── generators.py
│   │   │   └── shaping.py
│   │   └── neural_waveshaping.py
│   └── utils/
│       ├── __init__.py
│       ├── seed_all.py
│       └── utils.py
├── requirements.txt
├── scripts/
│   ├── create_dataset.py
│   ├── create_urmp_dataset.py
│   ├── resynthesise_dataset.py
│   ├── time_buffer_sizes.py
│   ├── time_forward_pass.py
│   └── train.py
└── setup.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================

# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vscode
# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks,vscode

### JupyterNotebooks ###
# gitignore template for Jupyter Notebooks
# website: http://jupyter.org/

.ipynb_checkpoints
*/.ipynb_checkpoints/*

# IPython
profile_default/
ipython_config.py

# Remove previous ipynb_checkpoints
#   git rm -r .ipynb_checkpoints/

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
pytestdebug.log

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/
doc/_build/

# PyBuilder
target/

# Jupyter Notebook

# IPython

# pyenv
.python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pythonenv*

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# profiling data
.prof

### vscode ###
.vscode/*

# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vscode

.DS_Store
test_audio

wandb
wandb/*
lightning_logs
lightning_logs/*

*.ipynb
*.wav
neural-timbre-shaping
*.webm

================================================
FILE: LICENSE
================================================
Mozilla Public License Version 2.0
==================================

1. Definitions
--------------

1.1. "Contributor"
    means each individual or legal entity that creates, contributes to
    the creation of, or owns Covered Software.

1.2. "Contributor Version"
    means the combination of the Contributions of others (if any) used
    by a Contributor and that particular Contributor's Contribution.

1.3. "Contribution"
    means Covered Software of a particular Contributor.

1.4. "Covered Software"
    means Source Code Form to which the initial Contributor has attached
    the notice in Exhibit A, the Executable Form of such Source Code
    Form, and Modifications of such Source Code Form, in each case
    including portions thereof.

1.5. "Incompatible With Secondary Licenses"
    means

    (a) that the initial Contributor has attached the notice described
        in Exhibit B to the Covered Software; or

    (b) that the Covered Software was made available under the terms of
        version 1.1 or earlier of the License, but not also under the
        terms of a Secondary License.

1.6. "Executable Form"
    means any form of the work other than Source Code Form.

1.7. "Larger Work"
    means a work that combines Covered Software with other material, in
    a separate file or files, that is not Covered Software.

1.8. "License"
    means this document.

1.9. "Licensable"
    means having the right to grant, to the maximum extent possible,
    whether at the time of the initial grant or subsequently, any and
    all of the rights conveyed by this License.

1.10. "Modifications"
    means any of the following:

    (a) any file in Source Code Form that results from an addition to,
        deletion from, or modification of the contents of Covered
        Software; or

    (b) any new file in Source Code Form that contains any Covered
        Software.

1.11. "Patent Claims" of a Contributor
    means any patent claim(s), including without limitation, method,
    process, and apparatus claims, in any patent Licensable by such
    Contributor that would be infringed, but for the grant of the
    License, by the making, using, selling, offering for sale, having
    made, import, or transfer of either its Contributions or its
    Contributor Version.

1.12. "Secondary License"
    means either the GNU General Public License, Version 2.0, the GNU
    Lesser General Public License, Version 2.1, the GNU Affero General
    Public License, Version 3.0, or any later versions of those
    licenses.

1.13. "Source Code Form"
    means the form of the work preferred for making modifications.

1.14. "You" (or "Your")
    means an individual or a legal entity exercising rights under this
    License. For legal entities, "You" includes any entity that
    controls, is controlled by, or is under common control with You. For
    purposes of this definition, "control" means (a) the power, direct
    or indirect, to cause the direction or management of such entity,
    whether by contract or otherwise, or (b) ownership of more than
    fifty percent (50%) of the outstanding shares or beneficial
    ownership of such entity.

2. License Grants and Conditions
--------------------------------

2.1. Grants

Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:

(a) under intellectual property rights (other than patent or trademark)
    Licensable by such Contributor to use, reproduce, make available,
    modify, display, perform, distribute, and otherwise exploit its
    Contributions, either on an unmodified basis, with Modifications, or
    as part of a Larger Work; and

(b) under Patent Claims of such Contributor to make, use, sell, offer
    for sale, have made, import, and otherwise transfer either its
    Contributions or its Contributor Version.

2.2. Effective Date

The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.

2.3. Limitations on Grant Scope

The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:

(a) for any code that a Contributor has removed from Covered Software;
    or

(b) for infringements caused by: (i) Your and any other third party's
    modifications of Covered Software, or (ii) the combination of its
    Contributions with other software (except as part of its Contributor
    Version); or

(c) under Patent Claims infringed by Covered Software in the absence of
    its Contributions.

This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).

2.4. Subsequent Licenses

No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).

2.5. Representation

Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.

2.6. Fair Use

This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.

2.7. Conditions

Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.

3. Responsibilities
-------------------

3.1. Distribution of Source Form

All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.

3.2. Distribution of Executable Form

If You distribute Covered Software in Executable Form then:

(a) such Covered Software must also be made available in Source Code
    Form, as described in Section 3.1, and You must inform recipients of
    the Executable Form how they can obtain a copy of such Source Code
    Form by reasonable means in a timely manner, at a charge no more
    than the cost of distribution to the recipient; and

(b) You may distribute such Executable Form under the terms of this
    License, or sublicense it under different terms, provided that the
    license for the Executable Form does not attempt to limit or alter
    the recipients' rights in the Source Code Form under this License.

3.3. Distribution of a Larger Work

You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).

3.4. Notices

You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.

3.5. Application of Additional Terms

You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.

4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------

If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.

5. Termination
--------------

5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.

5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.

5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.

************************************************************************
*                                                                      *
*  6. Disclaimer of Warranty                                           *
*  -------------------------                                           *
*                                                                      *
*  Covered Software is provided under this License on an "as is"       *
*  basis, without warranty of any kind, either expressed, implied, or  *
*  statutory, including, without limitation, warranties that the       *
*  Covered Software is free of defects, merchantable, fit for a        *
*  particular purpose or non-infringing. The entire risk as to the     *
*  quality and performance of the Covered Software is with You.        *
*  Should any Covered Software prove defective in any respect, You     *
*  (not any Contributor) assume the cost of any necessary servicing,   *
*  repair, or correction. This disclaimer of warranty constitutes an   *
*  essential part of this License. No use of any Covered Software is   *
*  authorized under this License except under this disclaimer.         *
*                                                                      *
************************************************************************

************************************************************************
*                                                                      *
*  7. Limitation of Liability                                          *
*  --------------------------                                          *
*                                                                      *
*  Under no circumstances and under no legal theory, whether tort      *
*  (including negligence), contract, or otherwise, shall any           *
*  Contributor, or anyone who distributes Covered Software as          *
*  permitted above, be liable to You for any direct, indirect,         *
*  special, incidental, or consequential damages of any character      *
*  including, without limitation, damages for lost profits, loss of    *
*  goodwill, work stoppage, computer failure or malfunction, or any    *
*  and all other commercial damages or losses, even if such party      *
*  shall have been informed of the possibility of such damages. This   *
*  limitation of liability shall not apply to liability for death or   *
*  personal injury resulting from such party's negligence to the       *
*  extent applicable law prohibits such limitation. Some               *
*  jurisdictions do not allow the exclusion or limitation of           *
*  incidental or consequential damages, so this exclusion and          *
*  limitation may not apply to You.                                    *
*                                                                      *
************************************************************************

8. Litigation
-------------

Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.

9. Miscellaneous
----------------

This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.

10. Versions of the License
---------------------------

10.1. New Versions

Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.

10.2. Effect of New Versions

You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.

10.3. Modified Versions

If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).

10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses

If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.

Exhibit A - Source Code Form License Notice
-------------------------------------------

  This Source Code Form is subject to the terms of the Mozilla Public
  License, v. 2.0. If a copy of the MPL was not distributed with this
  file, You can obtain one at http://mozilla.org/MPL/2.0/.

If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.

You may add additional accurate notices of copyright ownership.

Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------

  This Source Code Form is "Incompatible With Secondary Licenses", as
  defined by the Mozilla Public License, v. 2.0.


================================================
FILE: README.md
================================================
<h1 align="center">neural waveshaping synthesis</h1>
<h4 align="center">real-time neural audio synthesis in the waveform domain</h4>
<div align="center">
<h4>
    <a href="https://benhayes.net/assets/pdf/nws_arxiv.pdf" target="_blank">paper</a> •
        <a href="https://benhayes.net/projects/nws/" target="_blank">website</a> • 
        <a href="https://colab.research.google.com/github/ben-hayes/neural-waveshaping-synthesis/blob/main/colab/NEWT_Timbre_Transfer.ipynb" target="_blank">colab</a> • 
        <a href="https://benhayes.net/projects/nws/#audio-examples">audio</a>
    </h4>
    <p>
    by <em>Ben Hayes, Charalampos Saitis, György Fazekas</em>
    </p>
</div>
<p align="center"><img src="https://benhayes.net/assets/img/newt_shapers.png" /></p>

This repository is the official implementation of [Neural Waveshaping Synthesis](https://benhayes.net/projects/nws/).

## Model Architecture

<p align="center"><img src="https://benhayes.net/assets/img/nws.png" /></p>

## Requirements

To install:

```setup
pip install -r requirements.txt
pip install -e .
```

We recommend installing in a virtual environment.

## Data

We trained our checkpoints on the [URMP](http://www2.ece.rochester.edu/projects/air/projects/URMP.html) dataset.
Once downloaded, the dataset can be preprocessed using `scripts/create_urmp_dataset.py`. 
This will consolidate recordings of each instrument within the dataset and preprocess them according to the pipeline in the paper.

```bash
python scripts/create_urmp_dataset.py \
  --gin-file gin/data/urmp_4second_crepe.gin \ 
  --data-directory /path/to/urmp \
  --output-directory /path/to/output \
  --device cuda:0  # torch device string for CREPE model
```

Alternatively, you can supply your own dataset and use the general `create_dataset.py` script:

```bash
python scripts/create_dataset.py \
  --gin-file gin/data/urmp_4second_crepe.gin \ 
  --data-directory /path/to/dataset \
  --output-directory /path/to/output \
  --device cuda:0  # torch device string for CREPE model
```

## Training

To train a model on the URMP dataset, use this command:

```bash
python scripts/train.py \
  --gin-file gin/train/train_newt.gin \
  --dataset-path /path/to/processed/urmp \
  --urmp \
  --instrument vn \  # select URMP instrument with abbreviated string
  --load-data-to-memory
```

Or to use a non-URMP dataset:
```bash
python scripts/train.py \
  --gin-file gin/train/train_newt.gin \
  --dataset-path /path/to/processed/data \
  --load-data-to-memory
```


================================================
FILE: gin/data/urmp_4second_crepe.gin
================================================
sample_rate = 16000
interpolation = None
control_hop = 128

extract_f0_with_crepe.sample_rate = %sample_rate
extract_f0_with_crepe.device = %device
extract_f0_with_crepe.full_model = True
extract_f0_with_crepe.interpolate_fn = %interpolation
extract_f0_with_crepe.hop_length = %control_hop

extract_perceptual_loudness.sample_rate = %sample_rate
extract_perceptual_loudness.interpolate_fn = %interpolation
extract_perceptual_loudness.n_fft = 1024
extract_perceptual_loudness.hop_length = %control_hop

extract_mfcc.sample_rate = %sample_rate
extract_mfcc.n_fft = 1024
extract_mfcc.hop_length = 128
extract_mfcc.n_mfcc = 16

preprocess_audio.target_sr = %sample_rate
preprocess_audio.f0_extractor = @extract_f0_with_crepe
preprocess_audio.loudness_extractor = @extract_perceptual_loudness
preprocess_audio.segment_length_in_seconds = 4
preprocess_audio.hop_length_in_seconds = 4
preprocess_audio.normalise_audio = True
preprocess_audio.control_decimation_factor = %control_hop

================================================
FILE: gin/models/newt.gin
================================================
sample_rate = 16000

control_embedding_size = 128
n_waveshapers = 64
control_hop = 128

HarmonicOscillator.n_harmonics = 101
HarmonicOscillator.sample_rate = %sample_rate

NEWT.n_waveshapers = %n_waveshapers
NEWT.control_embedding_size = %control_embedding_size
NEWT.shaping_fn_size = 8
NEWT.out_channels = 1
TrainableNonlinearity.depth = 4

ControlModule.control_size = 2
ControlModule.hidden_size = 128
ControlModule.embedding_size = %control_embedding_size

noise_synth/TimeDistributedMLP.in_size = %control_embedding_size
noise_synth/TimeDistributedMLP.hidden_size = %control_embedding_size
noise_synth/TimeDistributedMLP.out_size = 129
noise_synth/TimeDistributedMLP.depth = 4
noise_synth/FIRNoiseSynth.ir_length = 256
noise_synth/FIRNoiseSynth.hop_length = %control_hop

Reverb.length_in_seconds = 2
Reverb.sr = %sample_rate


NeuralWaveshaping.n_waveshapers = %n_waveshapers
NeuralWaveshaping.control_hop = %control_hop
NeuralWaveshaping.sample_rate = %sample_rate

================================================
FILE: gin/train/train_newt.gin
================================================
get_model.model = @NeuralWaveshaping

include 'gin/models/newt.gin'

URMPDataModule.batch_size = 8

NeuralWaveshaping.learning_rate = 0.001
NeuralWaveshaping.lr_decay = 0.9
NeuralWaveshaping.lr_decay_interval = 10000

trainer_kwargs.max_steps = 120000
trainer_kwargs.gradient_clip_val = 2.0
trainer_kwargs.accelerator = 'dp'

================================================
FILE: neural_waveshaping_synthesis/__init__.py
================================================


================================================
FILE: neural_waveshaping_synthesis/data/__init__.py
================================================


================================================
FILE: neural_waveshaping_synthesis/data/general.py
================================================
import os

import gin
import numpy as np
import pytorch_lightning as pl
import torch


class GeneralDataset(torch.utils.data.Dataset):
    def __init__(self, path: str, split: str = "train", load_to_memory: bool = True):
        super().__init__()
        # split = "train"
        self.load_to_memory = load_to_memory

        self.split_path = os.path.join(path, split)
        self.data_list = [
            f.replace("audio_", "")
            for f in os.listdir(os.path.join(self.split_path, "audio"))
            if f[-4:] == ".npy"
        ]
        if load_to_memory:
            self.audio = [
                np.load(os.path.join(self.split_path, "audio", "audio_%s" % name))
                for name in self.data_list
            ]
            self.control = [
                np.load(os.path.join(self.split_path, "control", "control_%s" % name))
                for name in self.data_list
            ]

        self.data_mean = np.load(os.path.join(path, "data_mean.npy"))
        self.data_std = np.load(os.path.join(path, "data_std.npy"))

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        # idx = 10
        name = self.data_list[idx]
        if self.load_to_memory:
            audio = self.audio[idx]
            control = self.control[idx]
        else:
            audio_name = "audio_%s" % name
            control_name = "control_%s" % name

            audio = np.load(os.path.join(self.split_path, "audio", audio_name))
            control = np.load(os.path.join(self.split_path, "control", control_name))
        denormalised_control = (control * self.data_std) + self.data_mean

        return {
            "audio": audio,
            "f0": denormalised_control[0:1, :],
            "amp": denormalised_control[1:2, :],
            "control": control,
            "name": os.path.splitext(os.path.basename(name))[0],
        }


@gin.configurable
class GeneralDataModule(pl.LightningDataModule):
    def __init__(
        self,
        data_root: str,
        batch_size: int = 16,
        load_to_memory: bool = True,
        **dataloader_args
    ):
        super().__init__()
        self.data_dir = data_root
        self.batch_size = batch_size
        self.dataloader_args = dataloader_args
        self.load_to_memory = load_to_memory

    def prepare_data(self):
        pass

    def setup(self, stage: str = None):
        if stage == "fit":
            self.urmp_train = GeneralDataset(self.data_dir, "train", self.load_to_memory)
            self.urmp_val = GeneralDataset(self.data_dir, "val", self.load_to_memory)
        elif stage == "test" or stage is None:
            self.urmp_test = GeneralDataset(self.data_dir, "test", self.load_to_memory)

    def _make_dataloader(self, dataset):
        return torch.utils.data.DataLoader(
            dataset, self.batch_size, **self.dataloader_args
        )

    def train_dataloader(self):
        return self._make_dataloader(self.urmp_train)

    def val_dataloader(self):
        return self._make_dataloader(self.urmp_val)

    def test_dataloader(self):
        return self._make_dataloader(self.urmp_test)


================================================
FILE: neural_waveshaping_synthesis/data/urmp.py
================================================
import os

import gin

from .general import GeneralDataModule


@gin.configurable
class URMPDataModule(GeneralDataModule):
    def __init__(
        self,
        urmp_root: str,
        instrument: str,
        batch_size: int = 16,
        load_to_memory: bool = True,
        **dataloader_args
    ):
        super().__init__(
            os.path.join(urmp_root, instrument),
            batch_size,
            load_to_memory,
            **dataloader_args
        )


================================================
FILE: neural_waveshaping_synthesis/data/utils/__init__.py
================================================


================================================
FILE: neural_waveshaping_synthesis/data/utils/create_dataset.py
================================================
import os
import shutil
from typing import Sequence

import gin
import numpy as np
from sklearn.model_selection import train_test_split

from .preprocess_audio import preprocess_audio
from ...utils import seed_all


def create_directory(path):
    if not os.path.isdir(path):
        try:
            os.mkdir(path)
        except OSError:
            print("Failed to create directory %s" % path)
        else:
            print("Created directory %s..." % path)
    else:
        print("Directory %s already exists. Skipping..." % path)


def create_directories(target_root, names):
    create_directory(target_root)
    for name in names:
        create_directory(os.path.join(target_root, name))


def make_splits(
    audio_list: Sequence[str],
    control_list: Sequence[str],
    splits: Sequence[str],
    split_proportions: Sequence[float],
):
    assert len(splits) == len(
        split_proportions
    ), "Length of splits and split_proportions must be equal"

    train_size = split_proportions[0] / np.sum(split_proportions)
    audio_0, audio_1, control_0, control_1 = train_test_split(
        audio_list, control_list, train_size=train_size
    )
    if len(splits) == 2:
        return {
            splits[0]: {
                "audio": audio_0,
                "control": control_0,
            },
            splits[1]: {
                "audio": audio_1,
                "control": control_1,
            },
        }
    elif len(splits) > 2:
        return {
            splits[0]: {
                "audio": audio_0,
                "control": control_0,
            },
            **make_splits(audio_1, control_1, splits[1:], split_proportions[1:]),
        }
    elif len(splits) == 1:
        return {
            splits[0]: {
                "audio": audio_list,
                "control": control_list,
            }
        }


def lazy_create_dataset(
    files: Sequence[str],
    output_directory: str,
    splits: Sequence[str],
    split_proportions: Sequence[float],
):
    audio_files = []
    control_files = []
    audio_max = 1e-5
    means = []
    stds = []
    lengths = []
    control_mean = 0
    control_std = 1

    for i, (all_audio, all_f0, all_confidence, all_loudness, all_mfcc) in enumerate(
        preprocess_audio(files)
    ):
        file = os.path.split(files[i])[-1].replace(".wav", "")
        for j, (audio, f0, confidence, loudness, mfcc) in enumerate(
            zip(all_audio, all_f0, all_confidence, all_loudness, all_mfcc)
        ):
            audio_file_name = "audio_%s_%d.npy" % (file, j)
            control_file_name = "control_%s_%d.npy" % (file, j)

            max_sample = np.abs(audio).max()
            if max_sample > audio_max:
                audio_max = max_sample

            np.save(
                os.path.join(output_directory, "temp", "audio", audio_file_name),
                audio,
            )
            control = np.stack((f0, loudness, confidence), axis=0)
            control = np.concatenate((control, mfcc), axis=0)
            np.save(
                os.path.join(output_directory, "temp", "control", control_file_name),
                control,
            )

            audio_files.append(audio_file_name)
            control_files.append(control_file_name)

            means.append(control.mean(axis=-1))
            stds.append(control.std(axis=-1))
            lengths.append(control.shape[-1])

    if len(audio_files) == 0:
        print("No datapoints to split. Skipping...")
        return

    data_mean = np.mean(np.stack(means, axis=-1), axis=-1)[:, np.newaxis]
    lengths = np.stack(lengths)[np.newaxis, :]
    stds = np.stack(stds, axis=-1)
    data_std = np.sqrt(np.sum(lengths * stds ** 2, axis=-1) / np.sum(lengths))[
        :, np.newaxis
    ]

    print("Saving dataset stats...")
    np.save(os.path.join(output_directory, "data_mean.npy"), data_mean)
    np.save(os.path.join(output_directory, "data_std.npy"), data_std)

    splits = make_splits(audio_files, control_files, splits, split_proportions)
    for split in splits:
        for audio_file in splits[split]["audio"]:
            audio = np.load(os.path.join(output_directory, "temp", "audio", audio_file))
            audio = audio / audio_max
            np.save(os.path.join(output_directory, split, "audio", audio_file), audio)
        for control_file in splits[split]["control"]:
            control = np.load(
                os.path.join(output_directory, "temp", "control", control_file)
            )
            control = (control - data_mean) / data_std
            np.save(
                os.path.join(output_directory, split, "control", control_file), control
            )


@gin.configurable
def create_dataset(
    files: Sequence[str],
    output_directory: str,
    splits: Sequence[str] = ("train", "val", "test"),
    split_proportions: Sequence[float] = (0.8, 0.1, 0.1),
    lazy: bool = True,
):
    create_directories(output_directory, (*splits, "temp"))
    for split in (*splits, "temp"):
        create_directories(os.path.join(output_directory, split), ("audio", "control"))

    if lazy:
        lazy_create_dataset(files, output_directory, splits, split_proportions)

    shutil.rmtree(os.path.join(output_directory, "temp"))

================================================
FILE: neural_waveshaping_synthesis/data/utils/f0_extraction.py
================================================
from functools import partial
from typing import Callable, Optional, Sequence, Union

import gin
import librosa
import numpy as np
import torch
import torchcrepe

from .upsampling import linear_interpolation
from ...utils import apply


CREPE_WINDOW_LENGTH = 1024

@gin.configurable
def extract_f0_with_crepe(
    audio: np.ndarray,
    sample_rate: float,
    hop_length: int = 128,
    minimum_frequency: float = 50.0,
    maximum_frequency: float = 2000.0,
    full_model: bool = True,
    batch_size: int = 2048,
    device: Union[str, torch.device] = "cpu",
    interpolate_fn: Optional[Callable] = linear_interpolation,
):
    # convert to torch tensor with channel dimension (necessary for CREPE)
    audio = torch.tensor(audio).unsqueeze(0)
    f0, confidence = torchcrepe.predict(
        audio,
        sample_rate,
        hop_length,
        minimum_frequency,
        maximum_frequency,
        "full" if full_model else "tiny",
        batch_size=batch_size,
        device=device,
        decoder=torchcrepe.decode.viterbi,
        # decoder=torchcrepe.decode.weighted_argmax,
        return_harmonicity=True,
    )

    f0, confidence = f0.squeeze().numpy(), confidence.squeeze().numpy()

    if interpolate_fn:
        f0 = interpolate_fn(
            f0, CREPE_WINDOW_LENGTH, hop_length, original_length=audio.shape[-1]
        )
        confidence = interpolate_fn(
            confidence,
            CREPE_WINDOW_LENGTH,
            hop_length,
            original_length=audio.shape[-1],
        )

    return f0, confidence


@gin.configurable
def extract_f0_with_pyin(
    audio: np.ndarray,
    sample_rate: float,
    minimum_frequency: float = 65.0,  # recommended minimum freq from librosa docs
    maximum_frequency: float = 2093.0,  # recommended maximum freq from librosa docs
    frame_length: int = 1024,
    hop_length: int = 128,
    fill_na: Optional[float] = None,
    interpolate_fn: Optional[Callable] = linear_interpolation,
):
    f0, _, voiced_prob = librosa.pyin(
        audio,
        sr=sample_rate,
        fmin=minimum_frequency,
        fmax=maximum_frequency,
        frame_length=frame_length,
        hop_length=hop_length,
        fill_na=fill_na,
    )

    if interpolate_fn:
        f0 = interpolate_fn(
            f0, frame_length, hop_length, original_length=audio.shape[-1]
        )
        voiced_prob = interpolate_fn(
            voiced_prob,
            frame_length,
            hop_length,
            original_length=audio.shape[-1],
        )

    return f0, voiced_prob


================================================
FILE: neural_waveshaping_synthesis/data/utils/loudness_extraction.py
================================================
from typing import Callable, Optional
import warnings

import gin
import librosa
import numpy as np

from .upsampling import linear_interpolation


def compute_power_spectrogram(
    audio: np.ndarray,
    n_fft: int,
    hop_length: int,
    window: str,
    epsilon: float,
):
    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, window=window)
    magnitude_spectrogram = np.abs(spectrogram)
    power_spectrogram = librosa.amplitude_to_db(
        magnitude_spectrogram, ref=np.max, amin=epsilon
    )
    return power_spectrogram


def perform_perceptual_weighting(
    power_spectrogram_in_db: np.ndarray, sample_rate: float, n_fft: int
):
    centre_frequencies = librosa.fft_frequencies(sample_rate, n_fft)

    # We know that we will get a log(0) warning here due to the DC component -- we can
    # safely ignore as it is clipped to the default min dB value of -80.0 dB
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        weights = librosa.A_weighting(centre_frequencies)

    weights = np.expand_dims(weights, axis=1)
    weighted_spectrogram = power_spectrogram_in_db  # + weights
    return weighted_spectrogram


@gin.configurable
def extract_perceptual_loudness(
    audio: np.ndarray,
    sample_rate: float = 16000,
    n_fft: int = 2048,
    hop_length: int = 512,
    window: str = "hann",
    epsilon: float = 1e-5,
    interpolate_fn: Optional[Callable] = linear_interpolation,
    normalise: bool = True,
):
    power_spectrogram = compute_power_spectrogram(
        audio, n_fft=n_fft, hop_length=hop_length, window=window, epsilon=epsilon
    )
    perceptually_weighted_spectrogram = perform_perceptual_weighting(
        power_spectrogram, sample_rate=sample_rate, n_fft=n_fft
    )
    loudness = np.mean(perceptually_weighted_spectrogram, axis=0)
    if interpolate_fn:
        loudness = interpolate_fn(
            loudness, n_fft, hop_length, original_length=audio.size
        )

    if normalise:
        loudness = (loudness + 80) / 80

    return loudness


@gin.configurable
def extract_rms(
    audio: np.ndarray,
    window_size: int = 2048,
    hop_length: int = 512,
    sample_rate: Optional[float] = 16000.0,
    interpolate_fn: Optional[Callable] = linear_interpolation,
):
    # pad audio to centre frames
    padded_audio = np.pad(audio, (window_size // 2, window_size // 2))
    frames = librosa.util.frame(padded_audio, window_size, hop_length)
    squared = frames ** 2
    mean = np.mean(squared, axis=0)
    root = np.sqrt(mean)
    if interpolate_fn:
        assert sample_rate is not None, "Must provide sample rate if upsampling"
        root = interpolate_fn(root, window_size, hop_length, original_length=audio.size)

    return root


================================================
FILE: neural_waveshaping_synthesis/data/utils/mfcc_extraction.py
================================================
import gin
import librosa
import numpy as np


@gin.configurable
def extract_mfcc(
    audio: np.ndarray, sample_rate: float, n_fft: int, hop_length: int, n_mfcc: int
):
    mfcc = librosa.feature.mfcc(
        audio, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length
    )
    return mfcc

================================================
FILE: neural_waveshaping_synthesis/data/utils/preprocess_audio.py
================================================
from functools import partial
from typing import Callable, Sequence, Union

import gin
import librosa
import numpy as np
import resampy
import scipy.io.wavfile as wavfile

from .f0_extraction import extract_f0_with_crepe, extract_f0_with_pyin
from .loudness_extraction import extract_perceptual_loudness, extract_rms
from .mfcc_extraction import extract_mfcc
from ...utils import apply, apply_unpack, unzip


def read_audio_files(files: list):
    rates_and_audios = apply(wavfile.read, files)
    return unzip(rates_and_audios)


def convert_to_float32_audio(audio: np.ndarray):
    if audio.dtype == np.float32:
        return audio

    max_sample_value = np.iinfo(audio.dtype).max
    floating_point_audio = audio / max_sample_value
    return floating_point_audio.astype(np.float32)


def make_monophonic(audio: np.ndarray, strategy: str = "keep_left"):
    # deal with non stereo array formats
    if len(audio.shape) == 1:
        return audio
    elif len(audio.shape) != 2:
        raise ValueError("Unknown audio array format.")

    # deal with single audio channel
    if audio.shape[0] == 1:
        return audio[0]
    elif audio.shape[1] == 1:
        return audio[:, 0]
    # deal with more than two channels
    elif audio.shape[0] != 2 and audio.shape[1] != 2:
        raise ValueError("Expected stereo input audio but got too many channels.")

    # put channel first
    if audio.shape[1] == 2:
        audio = audio.T

    # make stereo audio monophonic
    if strategy == "keep_left":
        return audio[0]
    elif strategy == "keep_right":
        return audio[1]
    elif strategy == "sum":
        return np.mean(audio, axis=0)
    elif strategy == "diff":
        return audio[0] - audio[1]


def normalise_signal(audio: np.ndarray, factor: float):
    return audio / factor


def resample_audio(audio: np.ndarray, original_sr: float, target_sr: float):
    return resampy.resample(audio, original_sr, target_sr)


def segment_signal(
    signal: np.ndarray,
    sample_rate: float,
    segment_length_in_seconds: float,
    hop_length_in_seconds: float,
):
    segment_length_in_samples = int(sample_rate * segment_length_in_seconds)
    hop_length_in_samples = int(sample_rate * hop_length_in_seconds)
    segments = librosa.util.frame(
        signal, segment_length_in_samples, hop_length_in_samples
    )
    return segments


def filter_segments(
    threshold: float,
    key_segments: np.ndarray,
    segments: Sequence[np.ndarray],
):
    mean_keys = key_segments.mean(axis=0)
    mask = mean_keys > threshold
    filtered_segments = apply(
        lambda x: x[:, mask] if len(x.shape) == 2 else x[:, :, mask], segments
    )
    return filtered_segments


def preprocess_single_audio_file(
    file: str,
    control_decimation_factor: float,
    target_sr: float = 16000.0,
    segment_length_in_seconds: float = 4.0,
    hop_length_in_seconds: float = 2.0,
    confidence_threshold: float = 0.85,
    f0_extractor: Callable = extract_f0_with_crepe,
    loudness_extractor: Callable = extract_perceptual_loudness,
    mfcc_extractor: Callable = extract_mfcc,
    normalisation_factor: Union[float, None] = None,
):
    print("Loading audio file: %s..." % file)
    original_sr, audio = wavfile.read(file)
    audio = convert_to_float32_audio(audio)
    audio = make_monophonic(audio)

    if normalisation_factor:
        audio = normalise_signal(audio, normalisation_factor)

    print("Resampling audio file: %s..." % file)
    audio = resample_audio(audio, original_sr, target_sr)

    print("Extracting f0 with extractor '%s': %s..." % (f0_extractor.__name__, file))
    f0, confidence = f0_extractor(audio)

    print(
        "Extracting loudness with extractor '%s': %s..."
        % (loudness_extractor.__name__, file)
    )
    loudness = loudness_extractor(audio)

    print(
        "Extracting MFCC with extractor '%s': %s..." % (mfcc_extractor.__name__, file)
    )
    mfcc = mfcc_extractor(audio)

    print("Segmenting audio file: %s..." % file)
    segmented_audio = segment_signal(
        audio, target_sr, segment_length_in_seconds, hop_length_in_seconds
    )

    print("Segmenting control signals: %s..." % file)
    segmented_f0 = segment_signal(
        f0,
        target_sr / (control_decimation_factor or 1),
        segment_length_in_seconds,
        hop_length_in_seconds,
    )
    segmented_confidence = segment_signal(
        confidence,
        target_sr / (control_decimation_factor or 1),
        segment_length_in_seconds,
        hop_length_in_seconds,
    )
    segmented_loudness = segment_signal(
        loudness,
        target_sr / (control_decimation_factor or 1),
        segment_length_in_seconds,
        hop_length_in_seconds,
    )
    segmented_mfcc = segment_signal(
        mfcc,
        target_sr / (control_decimation_factor or 1),
        segment_length_in_seconds,
        hop_length_in_seconds,
    )

    (
        filtered_audio,
        filtered_f0,
        filtered_confidence,
        filtered_loudness,
        filtered_mfcc,
    ) = filter_segments(
        confidence_threshold,
        segmented_confidence,
        (
            segmented_audio,
            segmented_f0,
            segmented_confidence,
            segmented_loudness,
            segmented_mfcc,
        ),
    )

    if filtered_audio.shape[-1] == 0:
        print("No segments exceeding confidence threshold...")
        audio_split, f0_split, confidence_split, loudness_split, mfcc_split = (
            [],
            [],
            [],
            [],
            [],
        )
    else:
        split = lambda x: [e.squeeze() for e in np.split(x, x.shape[-1], -1)]
        audio_split = split(filtered_audio)
        f0_split = split(filtered_f0)
        confidence_split = split(filtered_confidence)
        loudness_split = split(filtered_loudness)
        mfcc_split = split(filtered_mfcc)

    return audio_split, f0_split, confidence_split, loudness_split, mfcc_split


@gin.configurable
def preprocess_audio(
    files: list,
    control_decimation_factor: float,
    target_sr: float = 16000,
    segment_length_in_seconds: float = 4.0,
    hop_length_in_seconds: float = 2.0,
    confidence_threshold: float = 0.85,
    f0_extractor: Callable = extract_f0_with_crepe,
    loudness_extractor: Callable = extract_perceptual_loudness,
    normalise_audio: bool = False,
):
    if normalise_audio:
        print("Finding normalisation factor...")
        normalisation_factor = 0
        for file in files:
            _, audio = wavfile.read(file)
            audio = convert_to_float32_audio(audio)
            audio = make_monophonic(audio)
            max_value = np.abs(audio).max()
            normalisation_factor = (
                max_value if max_value > normalisation_factor else normalisation_factor
            )

    processor = partial(
        preprocess_single_audio_file,
        control_decimation_factor=control_decimation_factor,
        target_sr=target_sr,
        segment_length_in_seconds=segment_length_in_seconds,
        hop_length_in_seconds=hop_length_in_seconds,
        f0_extractor=f0_extractor,
        loudness_extractor=loudness_extractor,
        normalisation_factor=None if not normalise_audio else normalisation_factor,
    )
    for file in files:
        yield processor(file)


================================================
FILE: neural_waveshaping_synthesis/data/utils/upsampling.py
================================================
from typing import Optional

import gin
import numpy as np
import scipy.interpolate
import scipy.signal.windows


def get_padded_length(frames: int, window_length: int, hop_length: int):
    return frames * hop_length + window_length - hop_length


def get_source_target_axes(frames: int, window_length: int, hop_length: int):
    padded_length = get_padded_length(frames, window_length, hop_length)
    source_x = np.linspace(0, frames - 1, frames)
    target_x = np.linspace(0, frames - 1, padded_length)
    return source_x, target_x


@gin.configurable
def linear_interpolation(
    signal: np.ndarray,
    window_length: int,
    hop_length: int,
    original_length: Optional[int] = None,
):
    source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)

    interpolated = np.interp(target_x, source_x, signal)
    if original_length:
        interpolated = interpolated[window_length // 2 :]
        interpolated = interpolated[:original_length]

    return interpolated


@gin.configurable
def cubic_spline_interpolation(
    signal: np.ndarray,
    window_length: int,
    hop_length: int,
    original_length: Optional[int] = None,
):
    source_x, target_x = get_source_target_axes(signal.size, window_length, hop_length)

    interpolant = scipy.interpolate.interp1d(source_x, signal, kind="cubic")
    interpolated = interpolant(target_x)
    if original_length:
        interpolated = interpolated[window_length // 2 :]
        interpolated = interpolated[:original_length]

    return interpolated


@gin.configurable
def overlap_add_upsample(
    signal: np.ndarray,
    window_length: int,
    hop_length: int,
    window_fn: str = "hann",
    window_scale: int = 2,
    original_length: Optional[int] = None,
):
    window = scipy.signal.windows.get_window(window_fn, hop_length * window_scale)
    padded_length = get_padded_length(signal.size, window_length, hop_length)
    padded_output = np.zeros(padded_length)

    for i, value in enumerate(signal):
        window_start = i * hop_length
        window_end = window_start + hop_length * window_scale
        padded_output[window_start:window_end] += window * value

    if original_length:
        output = padded_output[(padded_length - original_length) // 2:]
        output = output[:original_length]
    else:
        output = padded_output

    return output


================================================
FILE: neural_waveshaping_synthesis/models/__init__.py
================================================


================================================
FILE: neural_waveshaping_synthesis/models/modules/__init__.py
================================================


================================================
FILE: neural_waveshaping_synthesis/models/modules/dynamic.py
================================================
import gin
import torch.nn as nn
import torch.nn.functional as F


class FiLM(nn.Module):
    def forward(self, x, gamma, beta):
        return gamma * x + beta


class TimeDistributedLayerNorm(nn.Module):
    def __init__(self, size: int):
        super().__init__()
        self.layer_norm = nn.LayerNorm(size)

    def forward(self, x):
        return self.layer_norm(x.transpose(1, 2)).transpose(1, 2)


@gin.configurable
class TimeDistributedMLP(nn.Module):
    def __init__(self, in_size: int, hidden_size: int, out_size: int, depth: int = 3):
        super().__init__()
        assert depth >= 3, "Depth must be at least 3"
        layers = []
        for i in range(depth):
            layers.append(
                nn.Conv1d(
                    in_size if i == 0 else hidden_size,
                    hidden_size if i < depth - 1 else out_size,
                    1,
                )
            )
            if i < depth - 1:
                layers.append(TimeDistributedLayerNorm(hidden_size))
                layers.append(nn.LeakyReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


================================================
FILE: neural_waveshaping_synthesis/models/modules/generators.py
================================================
import math
from typing import Callable

import gin
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F


@gin.configurable
class FIRNoiseSynth(nn.Module):
    def __init__(
        self, ir_length: int, hop_length: int, window_fn: Callable = torch.hann_window
    ):
        super().__init__()
        self.ir_length = ir_length
        self.hop_length = hop_length
        self.register_buffer("window", window_fn(ir_length))

    def forward(self, H_re):
        H_im = torch.zeros_like(H_re)
        H_z = torch.complex(H_re, H_im)

        h = torch.fft.irfft(H_z.transpose(1, 2))
        h = h.roll(self.ir_length // 2, -1)
        h = h * self.window.view(1, 1, -1)
        H = torch.fft.rfft(h)

        noise = torch.rand(self.hop_length * H_re.shape[-1] - 1, device=H_re.device)
        X = torch.stft(noise, self.ir_length, self.hop_length, return_complex=True)
        X = X.unsqueeze(0)
        Y = X * H.transpose(1, 2)
        y = torch.istft(Y, self.ir_length, self.hop_length, center=False)
        return y.unsqueeze(1)[:, :, : H_re.shape[-1] * self.hop_length]


@gin.configurable
class HarmonicOscillator(nn.Module):
    def __init__(self, n_harmonics, sample_rate):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_harmonics = n_harmonics
        self.register_buffer("harmonic_axis", self._create_harmonic_axis(n_harmonics))
        self.register_buffer("rand_phase", torch.ones(1, n_harmonics, 1) * math.tau)

    def _create_harmonic_axis(self, n_harmonics):
        return torch.arange(1, n_harmonics + 1).view(1, -1, 1)

    def _create_antialias_mask(self, f0):
        freqs = f0.unsqueeze(1) * self.harmonic_axis
        return freqs < (self.sample_rate / 2)

    def _create_phase_shift(self, n_harmonics):
        shift = torch.rand_like(self.rand_phase) * self.rand_phase - math.pi
        return shift

    def forward(self, f0):
        phase = math.tau * f0.cumsum(-1) / self.sample_rate
        harmonic_phase = self.harmonic_axis * phase.unsqueeze(1)
        harmonic_phase = harmonic_phase + self._create_phase_shift(self.n_harmonics)
        antialias_mask = self._create_antialias_mask(f0)

        output = torch.sin(harmonic_phase) * antialias_mask

        return output


================================================
FILE: neural_waveshaping_synthesis/models/modules/shaping.py
================================================
import gin
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F

from .dynamic import FiLM, TimeDistributedMLP


class Sine(nn.Module):
    def forward(self, x: torch.Tensor):
        return torch.sin(x)


@gin.configurable
class TrainableNonlinearity(nn.Module):
    def __init__(
        self, channels, width, nonlinearity=nn.ReLU, final_nonlinearity=Sine, depth=3
    ):
        super().__init__()
        self.input_scale = nn.Parameter(torch.randn(1, channels, 1) * 10)
        layers = []
        for i in range(depth):
            layers.append(
                nn.Conv1d(
                    channels if i == 0 else channels * width,
                    channels * width if i < depth - 1 else channels,
                    1,
                    groups=channels,
                )
            )
            layers.append(nonlinearity() if i < depth - 1 else final_nonlinearity())

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(self.input_scale * x)


@gin.configurable
class NEWT(nn.Module):
    def __init__(
        self,
        n_waveshapers: int,
        control_embedding_size: int,
        shaping_fn_size: int = 16,
        out_channels: int = 1,
    ):
        super().__init__()

        self.n_waveshapers = n_waveshapers

        self.mlp = TimeDistributedMLP(
            control_embedding_size, control_embedding_size, n_waveshapers * 4, depth=4
        )

        self.waveshaping_index = FiLM()
        self.shaping_fn = TrainableNonlinearity(
            n_waveshapers, shaping_fn_size, nonlinearity=Sine
        )
        self.normalising_coeff = FiLM()

        self.mixer = nn.Sequential(
            nn.Conv1d(n_waveshapers, out_channels, 1),
        )

    def forward(self, exciter, control_embedding):
        film_params = self.mlp(control_embedding)
        film_params = F.upsample(film_params, exciter.shape[-1], mode="linear")
        gamma_index, beta_index, gamma_norm, beta_norm = torch.split(
            film_params, self.n_waveshapers, 1
        )

        x = self.waveshaping_index(exciter, gamma_index, beta_index)
        x = self.shaping_fn(x)
        x = self.normalising_coeff(x, gamma_norm, beta_norm)

        # return x
        return self.mixer(x)


class FastNEWT(NEWT):
    def __init__(
        self,
        newt: NEWT,
        table_size: int = 4096,
        table_min: float = -3.0,
        table_max: float = 3.0,
    ):
        super().__init__()
        self.table_size = table_size
        self.table_min = table_min
        self.table_max = table_max

        self.n_waveshapers = newt.n_waveshapers
        self.mlp = newt.mlp

        self.waveshaping_index = newt.waveshaping_index
        self.normalising_coeff = newt.normalising_coeff
        self.mixer = newt.mixer

        self.lookup_table = self._init_lookup_table(
            newt, table_size, self.n_waveshapers, table_min, table_max
        )
        self.to(next(iter(newt.parameters())).device)

    def _init_lookup_table(
        self,
        newt: NEWT,
        table_size: int,
        n_waveshapers: int,
        table_min: float,
        table_max: float,
    ):
        sample_values = torch.linspace(table_min, table_max, table_size, device=next(iter(newt.parameters())).device).expand(
            1, n_waveshapers, table_size
        )
        lookup_table = newt.shaping_fn(sample_values)[0]
        return nn.Parameter(lookup_table)

    def _lookup(self, idx):
        return torch.stack(
            [
                torch.stack(
                    [
                        self.lookup_table[shaper, idx[batch, shaper]]
                        for shaper in range(idx.shape[1])
                    ],
                    dim=0,
                )
                for batch in range(idx.shape[0])
            ],
            dim=0,
        )

    def shaping_fn(self, x):
        idx = self.table_size * (x - self.table_min) / (self.table_max - self.table_min)

        lower = torch.floor(idx).long()
        lower[lower < 0] = 0
        lower[lower >= self.table_size] = self.table_size - 1

        upper = lower + 1
        upper[upper >= self.table_size] = self.table_size - 1

        fract = idx - lower
        lower_v = self._lookup(lower)
        upper_v = self._lookup(upper)

        output = (upper_v - lower_v) * fract + lower_v
        return output


@gin.configurable
class Reverb(nn.Module):
    def __init__(self, length_in_seconds, sr):
        super().__init__()
        self.ir = nn.Parameter(torch.randn(1, sr * length_in_seconds - 1) * 1e-6)
        self.register_buffer("initial_zero", torch.zeros(1, 1))

    def forward(self, x):
        ir_ = torch.cat((self.initial_zero, self.ir), dim=-1)
        if x.shape[-1] > ir_.shape[-1]:
            ir_ = F.pad(ir_, (0, x.shape[-1] - ir_.shape[-1]))
            x_ = x
        else:
            x_ = F.pad(x, (0, ir_.shape[-1] - x.shape[-1]))
        return (
            x
            + torch.fft.irfft(torch.fft.rfft(x_) * torch.fft.rfft(ir_))[
                ..., : x.shape[-1]
            ]
        )


================================================
FILE: neural_waveshaping_synthesis/models/neural_waveshaping.py
================================================
import auraloss
import gin
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb

from .modules.dynamic import TimeDistributedMLP
from .modules.generators import FIRNoiseSynth, HarmonicOscillator
from .modules.shaping import NEWT, Reverb

gin.external_configurable(nn.GRU, module="torch.nn")
gin.external_configurable(nn.Conv1d, module="torch.nn")


@gin.configurable
class ControlModule(nn.Module):
    def __init__(self, control_size: int, hidden_size: int, embedding_size: int):
        super().__init__()
        self.gru = nn.GRU(control_size, hidden_size, batch_first=True)
        self.proj = nn.Conv1d(hidden_size, embedding_size, 1)

    def forward(self, x):
        x, _ = self.gru(x.transpose(1, 2))
        return self.proj(x.transpose(1, 2))


@gin.configurable
class NeuralWaveshaping(pl.LightningModule):
    def __init__(
        self,
        n_waveshapers: int,
        control_hop: int,
        sample_rate: float = 16000,
        learning_rate: float = 1e-3,
        lr_decay: float = 0.9,
        lr_decay_interval: int = 10000,
        log_audio: bool = False,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.lr_decay = lr_decay
        self.lr_decay_interval = lr_decay_interval
        self.control_hop = control_hop
        self.log_audio = log_audio

        self.sample_rate = sample_rate

        self.embedding = ControlModule()

        self.osc = HarmonicOscillator()
        self.harmonic_mixer = nn.Conv1d(self.osc.n_harmonics, n_waveshapers, 1)

        self.newt = NEWT()

        with gin.config_scope("noise_synth"):
            self.h_generator = TimeDistributedMLP()
            self.noise_synth = FIRNoiseSynth()

        self.reverb = Reverb()

    def render_exciter(self, f0):
        sig = self.osc(f0[:, 0])
        sig = self.harmonic_mixer(sig)
        return sig

    def get_embedding(self, control):
        f0, other = control[:, 0:1], control[:, 1:2]
        control = torch.cat((f0, other), dim=1)
        return self.embedding(control)

    def forward(self, f0, control):
        f0_upsampled = F.upsample(f0, f0.shape[-1] * self.control_hop, mode="linear")
        x = self.render_exciter(f0_upsampled)

        control_embedding = self.get_embedding(control)

        x = self.newt(x, control_embedding)

        H = self.h_generator(control_embedding)
        noise = self.noise_synth(H)

        x = torch.cat((x, noise), dim=1)
        x = x.sum(1)

        x = self.reverb(x)

        return x

    def configure_optimizers(self):
        self.stft_loss = auraloss.freq.MultiResolutionSTFTLoss()

        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, self.lr_decay_interval, self.lr_decay
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "interval": "step"},
        }

    def _run_step(self, batch):
        audio = batch["audio"].float()
        f0 = batch["f0"].float()
        control = batch["control"].float()

        recon = self(f0, control)

        loss = self.stft_loss(recon, audio)
        return loss, recon, audio

    def _log_audio(self, name, audio):
        wandb.log(
            {
                "audio/%s"
                % name: wandb.Audio(audio, sample_rate=self.sample_rate, caption=name)
            },
            commit=False,
        )

    def training_step(self, batch, batch_idx):
        loss, _, _ = self._run_step(batch)
        self.log(
            "train/loss",
            loss.item(),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, recon, audio = self._run_step(batch)
        self.log(
            "val/loss",
            loss.item(),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        if batch_idx == 0 and self.log_audio:
            self._log_audio("original", audio[0].detach().cpu().squeeze())
            self._log_audio("recon", recon[0].detach().cpu().squeeze())
        return loss

    def test_step(self, batch, batch_idx):
        loss, recon, audio = self._run_step(batch)
        self.log(
            "test/loss",
            loss.item(),
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        if batch_idx == 0:
            self._log_audio("original", audio[0].detach().cpu().squeeze())
            self._log_audio("recon", recon[0].detach().cpu().squeeze())


================================================
FILE: neural_waveshaping_synthesis/utils/__init__.py
================================================
from .utils import *
from .seed_all import *

================================================
FILE: neural_waveshaping_synthesis/utils/seed_all.py
================================================
import numpy as np
import os
import random
import torch

def seed_all(seed):
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

================================================
FILE: neural_waveshaping_synthesis/utils/utils.py
================================================
import os
from typing import Callable, Sequence


def apply(fn: Callable[[any], any], x: Sequence[any]):
    if type(x) not in (tuple, list):
        raise TypeError("x must be a tuple or list.")
    return type(x)([fn(element) for element in x])


def apply_unpack(fn: Callable[[any], any], x: Sequence[Sequence[any]]):
    if type(x) not in (tuple, list):
        raise TypeError("x must be a tuple or list.")
    return type(x)([fn(*element) for element in x])


def unzip(x: Sequence[any]):
    return list(zip(*x))


def make_dir_if_not_exists(path):
    if not os.path.exists(path):
        os.makedirs(path, exist_ok=True)


================================================
FILE: requirements.txt
================================================
auraloss==0.2.1
black==20.8b1
click==7.1.2
gin-config==0.4.0
librosa==0.8.0
numpy==1.20.1
pytorch_lightning==1.1.2
resampy==0.2.2
scipy==1.6.1
torch==1.7.1
torchcrepe==0.0.12


================================================
FILE: scripts/create_dataset.py
================================================
import os

import click
import gin

from neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset
from neural_waveshaping_synthesis.utils import seed_all


def get_filenames(directory):
    return [os.path.join(directory, f) for f in os.listdir(directory) if ".wav" in f]


@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--data-directory", prompt="Data directory")
@click.option("--output-directory", prompt="Output directory")
@click.option("--seed", default=0)
@click.option("--device", default="cpu")
def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):
    gin.constant("device", device)
    gin.parse_config_file(gin_file)

    seed_all(seed)

    files = get_filenames(data_directory)
    create_dataset(files, output_directory)


if __name__ == "__main__":
    main()

================================================
FILE: scripts/create_urmp_dataset.py
================================================
import os
from pathlib import Path

import click
import gin

from neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset
from neural_waveshaping_synthesis.utils import seed_all

INSTRUMENTS = (
    "vn",
    "vc",
    "fl",
    "cl",
    "tpt",
    "sax",
    "tbn",
    "ob",
    "va",
    "bn",
    "hn",
    "db",
)


def get_instrument_file_list(instrument_string, directory):
    return [
        str(f)
        for f in Path(directory).glob(
            "**/*_%s_*/AuSep*_%s_*.wav" % (instrument_string, instrument_string)
        )
    ]


@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--data-directory", prompt="Data directory")
@click.option("--output-directory", prompt="Output directory")
@click.option("--seed", default=0)
@click.option("--device", default="cpu")
def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):
    gin.constant("device", device)
    gin.parse_config_file(gin_file)

    seed_all(seed)

    file_lists = {
        instrument: get_instrument_file_list(instrument, data_directory)
        for instrument in INSTRUMENTS
    }
    for instrument in file_lists:
        create_dataset(
            file_lists[instrument], os.path.join(output_directory, instrument)
        )


if __name__ == "__main__":
    main()

================================================
FILE: scripts/resynthesise_dataset.py
================================================
import os

import click
import gin
from scipy.io import wavfile
from tqdm import tqdm
import torch

from neural_waveshaping_synthesis.data.urmp import URMPDataset
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
from neural_waveshaping_synthesis.utils import make_dir_if_not_exists


@click.command()
@click.option("--model-gin", prompt="Model .gin file")
@click.option("--model-checkpoint", prompt="Model checkpoint")
@click.option("--dataset-root", prompt="Dataset root directory")
@click.option("--dataset-split", default="test")
@click.option("--output-path", default="audio_output")
@click.option("--load-data-to-memory", default=False)
@click.option("--device", default="cuda:0")
@click.option("--batch-size", default=8)
@click.option("--num_workers", default=16)
@click.option("--use-fastnewt", is_flag=True)
def main(
    model_gin,
    model_checkpoint,
    dataset_root,
    dataset_split,
    output_path,
    load_data_to_memory,
    device,
    batch_size,
    num_workers,
    use_fastnewt
):
    gin.parse_config_file(model_gin)
    make_dir_if_not_exists(output_path)

    data = URMPDataset(dataset_root, dataset_split, load_data_to_memory)
    data_loader = torch.utils.data.DataLoader(
        data, batch_size=batch_size, num_workers=num_workers
    )

    device = torch.device(device)
    model = NeuralWaveshaping.load_from_checkpoint(model_checkpoint)
    model.eval()

    if use_fastnewt:
        model.newt = FastNEWT(model.newt)
    
    model = model.to(device)

    for i, batch in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            f0 = batch["f0"].float().to(device)
            control = batch["control"].float().to(device)
            output = model(f0, control)

        target_audio = batch["audio"].float().numpy()
        output_audio = output.cpu().numpy()
        for j in range(output_audio.shape[0]):
            name = batch["name"][j]
            target_name = "%s.target.wav" % name
            output_name = "%s.output.wav" % name
            wavfile.write(
                os.path.join(output_path, target_name),
                model.sample_rate,
                target_audio[j],
            )
            wavfile.write(
                os.path.join(output_path, output_name),
                model.sample_rate,
                output_audio[j],
            )


if __name__ == "__main__":
    main()


================================================
FILE: scripts/time_buffer_sizes.py
================================================
import time

import click
import gin
import numpy as np
import pandas as pd
import torch
from tqdm import trange

from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT

BUFFER_SIZES = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768]

@click.command()
@click.option("--gin-file", prompt="Model config gin file")
@click.option("--output-file", prompt="output file")
@click.option("--num-iters", default=100)
@click.option("--batch-size", default=1)
@click.option("--device", default="cpu")
@click.option("--length-in-seconds", default=4)
@click.option("--use-fast-newt", is_flag=True)
@click.option("--model-name", default="ours")
def main(
    gin_file,
    output_file,
    num_iters,
    batch_size,
    device,
    length_in_seconds,
    use_fast_newt,
    model_name,
):
    gin.parse_config_file(gin_file)
    model = NeuralWaveshaping()
    if use_fast_newt:
        model.newt = FastNEWT(model.newt)
    model.eval()
    model = model.to(device)

    # eliminate any lazy init costs
    with torch.no_grad():
        for i in range(10):
            model(
                torch.rand(4, 1, 250, device=device),
                torch.rand(4, 2, 250, device=device),
            )

    times = []
    with torch.no_grad():
        for bs in BUFFER_SIZES:
            dummy_control = torch.rand(
                batch_size,
                2,
                bs // 128,
                device=device,
                requires_grad=False,
            )
            dummy_f0 = torch.rand(
                batch_size,
                1,
                bs // 128,
                device=device,
                requires_grad=False,
            )
            for i in trange(num_iters):
                start_time = time.time()
                model(dummy_f0, dummy_control)
                time_elapsed = time.time() - start_time
                times.append(
                    [model_name, device if device == "cpu" else "gpu", bs, time_elapsed]
                )

    df = pd.DataFrame(times)
    df.to_csv(output_file)


if __name__ == "__main__":
    main()

================================================
FILE: scripts/time_forward_pass.py
================================================
import time

import click
import gin
import numpy as np
from scipy.stats import describe
import torch
from tqdm import trange

from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping
from neural_waveshaping_synthesis.models.modules.shaping import FastNEWT


@click.command()
@click.option("--gin-file", prompt="Model config gin file")
@click.option("--num-iters", default=100)
@click.option("--batch-size", default=1)
@click.option("--device", default="cpu")
@click.option("--length-in-seconds", default=4)
@click.option("--sample-rate", default=16000)
@click.option("--control-hop", default=128)
@click.option("--use-fast-newt", is_flag=True)
def main(
    gin_file, num_iters, batch_size, device, length_in_seconds, sample_rate, control_hop, use_fast_newt
):
    gin.parse_config_file(gin_file)
    dummy_control = torch.rand(
        batch_size,
        2,
        sample_rate * length_in_seconds // control_hop,
        device=device,
        requires_grad=False,
    )
    dummy_f0 = torch.rand(
        batch_size,
        1,
        sample_rate * length_in_seconds // control_hop,
        device=device,
        requires_grad=False,
    )
    model = NeuralWaveshaping()
    if use_fast_newt:
        model.newt = FastNEWT(model.newt)
    model.eval()
    model = model.to(device)

    times = []
    with torch.no_grad():
        for i in trange(num_iters):
            start_time = time.time()
            model(dummy_f0, dummy_control)
            time_elapsed = time.time() - start_time
            times.append(time_elapsed)

    print(describe(times))
    rtfs = np.array(times) / length_in_seconds
    print("Mean RTF: %.4f" % np.mean(rtfs))
    print("90th percentile RTF: %.4f" % np.percentile(rtfs, 90))


if __name__ == "__main__":
    main()

================================================
FILE: scripts/train.py
================================================
import click
import gin
import pytorch_lightning as pl

from neural_waveshaping_synthesis.data.general import GeneralDataModule
from neural_waveshaping_synthesis.data.urmp import URMPDataModule
from neural_waveshaping_synthesis.models.neural_waveshaping import NeuralWaveshaping


@gin.configurable
def get_model(model, with_wandb):
    return model(log_audio=with_wandb)


@gin.configurable
def trainer_kwargs(**kwargs):
    return kwargs


@click.command()
@click.option("--gin-file", prompt="Gin config file")
@click.option("--dataset-path", prompt="Dataset root")
@click.option("--urmp", is_flag=True)
@click.option("--device", default="0")
@click.option("--instrument", default="vn")
@click.option("--load-data-to-memory", is_flag=True)
@click.option("--with-wandb", is_flag=True)
@click.option("--restore-checkpoint", default="")
def main(
    gin_file,
    dataset_path,
    urmp,
    device,
    instrument,
    load_data_to_memory,
    with_wandb,
    restore_checkpoint,
):
    gin.parse_config_file(gin_file)
    model = get_model(with_wandb=with_wandb)

    if urmp:
        data = URMPDataModule(
            dataset_path,
            instrument,
            load_to_memory=load_data_to_memory,
            num_workers=16,
            shuffle=True,
        )
    else:
        data = GeneralDataModule(
            dataset_path,
            load_to_memory=load_data_to_memory,
            num_workers=16,
            shuffle=True,
        )

    checkpointing = pl.callbacks.ModelCheckpoint(
        monitor="val/loss", save_top_k=1, save_last=True
    )
    callbacks = [checkpointing]
    if with_wandb:
        lr_logger = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
        callbacks.append(lr_logger)
        logger = pl.loggers.WandbLogger(project="neural-waveshaping-synthesis")
        logger.watch(model, log="parameters")


    kwargs = trainer_kwargs()
    trainer = pl.Trainer(
        logger=logger if with_wandb else None,
        callbacks=callbacks,
        gpus=device,
        resume_from_checkpoint=restore_checkpoint if restore_checkpoint != "" else None,
        **kwargs
    )
    trainer.fit(model, data)


if __name__ == "__main__":
    main()

================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages

setup(name="neural_waveshaping_synthesis", version="0.0.1", packages=find_packages())
Download .txt
gitextract_i3m08itl/

├── .gitignore
├── LICENSE
├── README.md
├── checkpoints/
│   └── nws/
│       ├── fl/
│       │   ├── data_mean.npy
│       │   ├── data_std.npy
│       │   ├── epoch=4992-step=119831.ckpt
│       │   └── last.ckpt
│       ├── tpt/
│       │   ├── data_mean.npy
│       │   ├── data_std.npy
│       │   ├── epoch=358-step=24052.ckpt
│       │   └── last.ckpt
│       └── vn/
│           ├── data_mean.npy
│           ├── data_std.npy
│           ├── epoch=526-step=102237.ckpt
│           └── last.ckpt
├── gin/
│   ├── data/
│   │   └── urmp_4second_crepe.gin
│   ├── models/
│   │   └── newt.gin
│   └── train/
│       └── train_newt.gin
├── neural_waveshaping_synthesis/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── general.py
│   │   ├── urmp.py
│   │   └── utils/
│   │       ├── __init__.py
│   │       ├── create_dataset.py
│   │       ├── f0_extraction.py
│   │       ├── loudness_extraction.py
│   │       ├── mfcc_extraction.py
│   │       ├── preprocess_audio.py
│   │       └── upsampling.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── modules/
│   │   │   ├── __init__.py
│   │   │   ├── dynamic.py
│   │   │   ├── generators.py
│   │   │   └── shaping.py
│   │   └── neural_waveshaping.py
│   └── utils/
│       ├── __init__.py
│       ├── seed_all.py
│       └── utils.py
├── requirements.txt
├── scripts/
│   ├── create_dataset.py
│   ├── create_urmp_dataset.py
│   ├── resynthesise_dataset.py
│   ├── time_buffer_sizes.py
│   ├── time_forward_pass.py
│   └── train.py
└── setup.py
Download .txt
SYMBOL INDEX (102 symbols across 20 files)

FILE: neural_waveshaping_synthesis/data/general.py
  class GeneralDataset (line 9) | class GeneralDataset(torch.utils.data.Dataset):
    method __init__ (line 10) | def __init__(self, path: str, split: str = "train", load_to_memory: bo...
    method __len__ (line 34) | def __len__(self):
    method __getitem__ (line 37) | def __getitem__(self, idx):
  class GeneralDataModule (line 61) | class GeneralDataModule(pl.LightningDataModule):
    method __init__ (line 62) | def __init__(
    method prepare_data (line 75) | def prepare_data(self):
    method setup (line 78) | def setup(self, stage: str = None):
    method _make_dataloader (line 85) | def _make_dataloader(self, dataset):
    method train_dataloader (line 90) | def train_dataloader(self):
    method val_dataloader (line 93) | def val_dataloader(self):
    method test_dataloader (line 96) | def test_dataloader(self):

FILE: neural_waveshaping_synthesis/data/urmp.py
  class URMPDataModule (line 9) | class URMPDataModule(GeneralDataModule):
    method __init__ (line 10) | def __init__(

FILE: neural_waveshaping_synthesis/data/utils/create_dataset.py
  function create_directory (line 13) | def create_directory(path):
  function create_directories (line 25) | def create_directories(target_root, names):
  function make_splits (line 31) | def make_splits(
  function lazy_create_dataset (line 73) | def lazy_create_dataset(
  function create_dataset (line 152) | def create_dataset(

FILE: neural_waveshaping_synthesis/data/utils/f0_extraction.py
  function extract_f0_with_crepe (line 17) | def extract_f0_with_crepe(
  function extract_f0_with_pyin (line 61) | def extract_f0_with_pyin(

FILE: neural_waveshaping_synthesis/data/utils/loudness_extraction.py
  function compute_power_spectrogram (line 11) | def compute_power_spectrogram(
  function perform_perceptual_weighting (line 26) | def perform_perceptual_weighting(
  function extract_perceptual_loudness (line 43) | def extract_perceptual_loudness(
  function extract_rms (line 72) | def extract_rms(

FILE: neural_waveshaping_synthesis/data/utils/mfcc_extraction.py
  function extract_mfcc (line 7) | def extract_mfcc(

FILE: neural_waveshaping_synthesis/data/utils/preprocess_audio.py
  function read_audio_files (line 16) | def read_audio_files(files: list):
  function convert_to_float32_audio (line 21) | def convert_to_float32_audio(audio: np.ndarray):
  function make_monophonic (line 30) | def make_monophonic(audio: np.ndarray, strategy: str = "keep_left"):
  function normalise_signal (line 61) | def normalise_signal(audio: np.ndarray, factor: float):
  function resample_audio (line 65) | def resample_audio(audio: np.ndarray, original_sr: float, target_sr: flo...
  function segment_signal (line 69) | def segment_signal(
  function filter_segments (line 83) | def filter_segments(
  function preprocess_single_audio_file (line 96) | def preprocess_single_audio_file(
  function preprocess_audio (line 203) | def preprocess_audio(

FILE: neural_waveshaping_synthesis/data/utils/upsampling.py
  function get_padded_length (line 9) | def get_padded_length(frames: int, window_length: int, hop_length: int):
  function get_source_target_axes (line 13) | def get_source_target_axes(frames: int, window_length: int, hop_length: ...
  function linear_interpolation (line 21) | def linear_interpolation(
  function cubic_spline_interpolation (line 38) | def cubic_spline_interpolation(
  function overlap_add_upsample (line 56) | def overlap_add_upsample(

FILE: neural_waveshaping_synthesis/models/modules/dynamic.py
  class FiLM (line 6) | class FiLM(nn.Module):
    method forward (line 7) | def forward(self, x, gamma, beta):
  class TimeDistributedLayerNorm (line 11) | class TimeDistributedLayerNorm(nn.Module):
    method __init__ (line 12) | def __init__(self, size: int):
    method forward (line 16) | def forward(self, x):
  class TimeDistributedMLP (line 21) | class TimeDistributedMLP(nn.Module):
    method __init__ (line 22) | def __init__(self, in_size: int, hidden_size: int, out_size: int, dept...
    method forward (line 39) | def forward(self, x):

FILE: neural_waveshaping_synthesis/models/modules/generators.py
  class FIRNoiseSynth (line 12) | class FIRNoiseSynth(nn.Module):
    method __init__ (line 13) | def __init__(
    method forward (line 21) | def forward(self, H_re):
  class HarmonicOscillator (line 39) | class HarmonicOscillator(nn.Module):
    method __init__ (line 40) | def __init__(self, n_harmonics, sample_rate):
    method _create_harmonic_axis (line 47) | def _create_harmonic_axis(self, n_harmonics):
    method _create_antialias_mask (line 50) | def _create_antialias_mask(self, f0):
    method _create_phase_shift (line 54) | def _create_phase_shift(self, n_harmonics):
    method forward (line 58) | def forward(self, f0):

FILE: neural_waveshaping_synthesis/models/modules/shaping.py
  class Sine (line 10) | class Sine(nn.Module):
    method forward (line 11) | def forward(self, x: torch.Tensor):
  class TrainableNonlinearity (line 16) | class TrainableNonlinearity(nn.Module):
    method __init__ (line 17) | def __init__(
    method forward (line 36) | def forward(self, x):
  class NEWT (line 41) | class NEWT(nn.Module):
    method __init__ (line 42) | def __init__(
    method forward (line 67) | def forward(self, exciter, control_embedding):
  class FastNEWT (line 82) | class FastNEWT(NEWT):
    method __init__ (line 83) | def __init__(
    method _init_lookup_table (line 107) | def _init_lookup_table(
    method _lookup (line 121) | def _lookup(self, idx):
    method shaping_fn (line 136) | def shaping_fn(self, x):
  class Reverb (line 155) | class Reverb(nn.Module):
    method __init__ (line 156) | def __init__(self, length_in_seconds, sr):
    method forward (line 161) | def forward(self, x):

FILE: neural_waveshaping_synthesis/models/neural_waveshaping.py
  class ControlModule (line 18) | class ControlModule(nn.Module):
    method __init__ (line 19) | def __init__(self, control_size: int, hidden_size: int, embedding_size...
    method forward (line 24) | def forward(self, x):
  class NeuralWaveshaping (line 30) | class NeuralWaveshaping(pl.LightningModule):
    method __init__ (line 31) | def __init__(
    method render_exciter (line 64) | def render_exciter(self, f0):
    method get_embedding (line 69) | def get_embedding(self, control):
    method forward (line 74) | def forward(self, f0, control):
    method configure_optimizers (line 92) | def configure_optimizers(self):
    method _run_step (line 104) | def _run_step(self, batch):
    method _log_audio (line 114) | def _log_audio(self, name, audio):
    method training_step (line 123) | def training_step(self, batch, batch_idx):
    method validation_step (line 136) | def validation_step(self, batch, batch_idx):
    method test_step (line 152) | def test_step(self, batch, batch_idx):

FILE: neural_waveshaping_synthesis/utils/seed_all.py
  function seed_all (line 6) | def seed_all(seed):

FILE: neural_waveshaping_synthesis/utils/utils.py
  function apply (line 5) | def apply(fn: Callable[[any], any], x: Sequence[any]):
  function apply_unpack (line 11) | def apply_unpack(fn: Callable[[any], any], x: Sequence[Sequence[any]]):
  function unzip (line 17) | def unzip(x: Sequence[any]):
  function make_dir_if_not_exists (line 21) | def make_dir_if_not_exists(path):

FILE: scripts/create_dataset.py
  function get_filenames (line 10) | def get_filenames(directory):
  function main (line 20) | def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):

FILE: scripts/create_urmp_dataset.py
  function get_instrument_file_list (line 26) | def get_instrument_file_list(instrument_string, directory):
  function main (line 41) | def main(gin_file, data_directory, output_directory, seed=0, device="cpu"):

FILE: scripts/resynthesise_dataset.py
  function main (line 26) | def main(

FILE: scripts/time_buffer_sizes.py
  function main (line 24) | def main(

FILE: scripts/time_forward_pass.py
  function main (line 23) | def main(

FILE: scripts/train.py
  function get_model (line 11) | def get_model(model, with_wandb):
  function trainer_kwargs (line 16) | def trainer_kwargs(**kwargs):
  function main (line 29) | def main(
Condensed preview — 46 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (79K chars).
[
  {
    "path": ".gitignore",
    "chars": 2539,
    "preview": "\n# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks,vscode\n# Edit at https://www.topta"
  },
  {
    "path": "LICENSE",
    "chars": 16725,
    "preview": "Mozilla Public License Version 2.0\n==================================\n\n1. Definitions\n--------------\n\n1.1. \"Contributor\""
  },
  {
    "path": "README.md",
    "chars": 2498,
    "preview": "<h1 align=\"center\">neural waveshaping synthesis</h1>\n<h4 align=\"center\">real-time neural audio synthesis in the waveform"
  },
  {
    "path": "gin/data/urmp_4second_crepe.gin",
    "chars": 975,
    "preview": "sample_rate = 16000\ninterpolation = None\ncontrol_hop = 128\n\nextract_f0_with_crepe.sample_rate = %sample_rate\nextract_f0_"
  },
  {
    "path": "gin/models/newt.gin",
    "chars": 971,
    "preview": "sample_rate = 16000\n\ncontrol_embedding_size = 128\nn_waveshapers = 64\ncontrol_hop = 128\n\nHarmonicOscillator.n_harmonics ="
  },
  {
    "path": "gin/train/train_newt.gin",
    "chars": 324,
    "preview": "get_model.model = @NeuralWaveshaping\n\ninclude 'gin/models/newt.gin'\n\nURMPDataModule.batch_size = 8\n\nNeuralWaveshaping.le"
  },
  {
    "path": "neural_waveshaping_synthesis/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/general.py",
    "chars": 3152,
    "preview": "import os\n\nimport gin\nimport numpy as np\nimport pytorch_lightning as pl\nimport torch\n\n\nclass GeneralDataset(torch.utils."
  },
  {
    "path": "neural_waveshaping_synthesis/data/urmp.py",
    "chars": 471,
    "preview": "import os\n\nimport gin\n\nfrom .general import GeneralDataModule\n\n\n@gin.configurable\nclass URMPDataModule(GeneralDataModule"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/create_dataset.py",
    "chars": 5247,
    "preview": "import os\nimport shutil\nfrom typing import Sequence\n\nimport gin\nimport numpy as np\nfrom sklearn.model_selection import t"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/f0_extraction.py",
    "chars": 2541,
    "preview": "from functools import partial\nfrom typing import Callable, Optional, Sequence, Union\n\nimport gin\nimport librosa\nimport n"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/loudness_extraction.py",
    "chars": 2742,
    "preview": "from typing import Callable, Optional\nimport warnings\n\nimport gin\nimport librosa\nimport numpy as np\n\nfrom .upsampling im"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/mfcc_extraction.py",
    "chars": 305,
    "preview": "import gin\nimport librosa\nimport numpy as np\n\n\n@gin.configurable\ndef extract_mfcc(\n    audio: np.ndarray, sample_rate: f"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/preprocess_audio.py",
    "chars": 7311,
    "preview": "from functools import partial\nfrom typing import Callable, Sequence, Union\n\nimport gin\nimport librosa\nimport numpy as np"
  },
  {
    "path": "neural_waveshaping_synthesis/data/utils/upsampling.py",
    "chars": 2366,
    "preview": "from typing import Optional\n\nimport gin\nimport numpy as np\nimport scipy.interpolate\nimport scipy.signal.windows\n\n\ndef ge"
  },
  {
    "path": "neural_waveshaping_synthesis/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/dynamic.py",
    "chars": 1152,
    "preview": "import gin\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FiLM(nn.Module):\n    def forward(self, x, gamma"
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/generators.py",
    "chars": 2278,
    "preview": "import math\nfrom typing import Callable\n\nimport gin\nimport torch\nimport torch.fft\nimport torch.nn as nn\nimport torch.nn."
  },
  {
    "path": "neural_waveshaping_synthesis/models/modules/shaping.py",
    "chars": 5104,
    "preview": "import gin\nimport torch\nimport torch.fft\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .dynamic import FiL"
  },
  {
    "path": "neural_waveshaping_synthesis/models/neural_waveshaping.py",
    "chars": 4864,
    "preview": "import auraloss\nimport gin\nimport pytorch_lightning as pl\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional "
  },
  {
    "path": "neural_waveshaping_synthesis/utils/__init__.py",
    "chars": 44,
    "preview": "from .utils import *\nfrom .seed_all import *"
  },
  {
    "path": "neural_waveshaping_synthesis/utils/seed_all.py",
    "chars": 275,
    "preview": "import numpy as np\nimport os\nimport random\nimport torch\n\ndef seed_all(seed):\n    np.random.seed(seed)\n    os.environ['PY"
  },
  {
    "path": "neural_waveshaping_synthesis/utils/utils.py",
    "chars": 630,
    "preview": "import os\nfrom typing import Callable, Sequence\n\n\ndef apply(fn: Callable[[any], any], x: Sequence[any]):\n    if type(x) "
  },
  {
    "path": "requirements.txt",
    "chars": 175,
    "preview": "auraloss==0.2.1\nblack==20.8b1\nclick==7.1.2\ngin-config==0.4.0\nlibrosa==0.8.0\nnumpy==1.20.1\npytorch_lightning==1.1.2\nresam"
  },
  {
    "path": "scripts/create_dataset.py",
    "chars": 856,
    "preview": "import os\n\nimport click\nimport gin\n\nfrom neural_waveshaping_synthesis.data.utils.create_dataset import create_dataset\nfr"
  },
  {
    "path": "scripts/create_urmp_dataset.py",
    "chars": 1328,
    "preview": "import os\nfrom pathlib import Path\n\nimport click\nimport gin\n\nfrom neural_waveshaping_synthesis.data.utils.create_dataset"
  },
  {
    "path": "scripts/resynthesise_dataset.py",
    "chars": 2477,
    "preview": "import os\n\nimport click\nimport gin\nfrom scipy.io import wavfile\nfrom tqdm import tqdm\nimport torch\n\nfrom neural_waveshap"
  },
  {
    "path": "scripts/time_buffer_sizes.py",
    "chars": 2166,
    "preview": "import time\n\nimport click\nimport gin\nimport numpy as np\nimport pandas as pd\nimport torch\nfrom tqdm import trange\n\nfrom n"
  },
  {
    "path": "scripts/time_forward_pass.py",
    "chars": 1789,
    "preview": "import time\n\nimport click\nimport gin\nimport numpy as np\nfrom scipy.stats import describe\nimport torch\nfrom tqdm import t"
  },
  {
    "path": "scripts/train.py",
    "chars": 2195,
    "preview": "import click\nimport gin\nimport pytorch_lightning as pl\n\nfrom neural_waveshaping_synthesis.data.general import GeneralDat"
  },
  {
    "path": "setup.py",
    "chars": 131,
    "preview": "from setuptools import setup, find_packages\n\nsetup(name=\"neural_waveshaping_synthesis\", version=\"0.0.1\", packages=find_p"
  }
]

// ... and 12 more files (download for full content)

About this extraction

This page contains the full source code of the ben-hayes/neural-waveshaping-synthesis GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 46 files (71.9 KB), approximately 18.7k tokens, and a symbol index with 102 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!