Full Code of Nintorac/NeuralDX7 for AI

master 327844cea18a cached
66 files
197.0 KB
58.8k tokens
202 symbols
1 requests
Download .txt
Showing preview only (214K chars total). Download the full file or copy to clipboard to get everything.
Repository: Nintorac/NeuralDX7
Branch: master
Commit: 327844cea18a
Files: 66
Total size: 197.0 KB

Directory structure:
gitextract_l201xfp8/

├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── neuralDX7/
│   ├── __init__.py
│   ├── constants.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   └── dx7_sysex_dataset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── attention/
│   │   │   ├── __init__.py
│   │   │   ├── attention.py
│   │   │   ├── attention_encoder.py
│   │   │   ├── attention_layer.py
│   │   │   └── conditional_attention_encoder.py
│   │   ├── dx7_cnp.py
│   │   ├── dx7_np.py
│   │   ├── dx7_nsp.py
│   │   ├── dx7_vae.py
│   │   ├── general/
│   │   │   ├── __init__.py
│   │   │   └── gelu_ff.py
│   │   ├── stochastic_nodes/
│   │   │   ├── __init__.py
│   │   │   ├── normal.py
│   │   │   └── triangular_sylvester.py
│   │   └── utils.py
│   ├── solvers/
│   │   ├── __init__.py
│   │   ├── dx7_np.py
│   │   ├── dx7_nsp.py
│   │   ├── dx7_patch_process.py
│   │   ├── dx7_vae.py
│   │   └── utils.py
│   └── utils.py
├── projects/
│   ├── dx7_np/
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   ├── dx7_nsp/
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   ├── dx7_patch_neural_process/
│   │   ├── evaluate.py
│   │   ├── features_analysis.py
│   │   └── ray_train.py
│   ├── dx7_vae/
│   │   ├── duplicate_test.py
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   └── mnist_neural_process/
│       └── experiment.py
├── requirements.txt
├── scratch/
│   ├── dx7-sysexformat.md
│   ├── dx7_constants.py
│   ├── dx7_syx.py
│   ├── fm-param-analysis.py
│   ├── fm_param_ae.py
│   ├── fm_param_agoge_vae_rnn.py
│   ├── fm_param_rnn_decoder.py
│   ├── fm_param_vae.py
│   ├── fm_param_vae_rnn.py
│   ├── syx_parser.py
│   └── syx_write.py
├── setup.cfg
├── setup.py
└── version

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.vscode/
__pycache__/
.empty
*.pyc
*.egg-info
dist/

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) [year] [fullname]

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: MANIFEST.in
================================================
include version requirements.txt

================================================
FILE: README.md
================================================
# FM Synth Parameter Generator

Random machine learning experiments related to the classic Yamaha DX7

## Dexed

Dexed is a linux open source DX7 emulator and was used heavily in the testing of this project

## SYX

DX7 and it's similar instruments are programmable bt

format [found here](https://github.com/asb2m10/dexed/tree/master/Documentation
) under `sysexformat.txt`

## Dataset

Big thanks to Bobby Blues for collecting [these](http://bobbyblues.recup.ch/yamaha_dx7/dx7_patches.html) DX7 patches. This was the only data source

# Directory structure
```
neuralDX7/
├── constants.py
├── datasets                    # modules to interface with preprocessed datasets 
│   ├── dx7_sysex_dataset.py
│   └── __init__.py
├── __init__.py
├── models
│   ├── attention               # modules implementing transformer stack based on Attention Is All You Need
│   │   ├── attention_encoder.py
│   │   ├── attention_layer.py
│   │   ├── attention.py
│   │   ├── conditional_attention_encoder.py
│   │   └── __init__.py
│   ├── general
│   │   ├── gelu_ff.py          # two layer non linear layer using GeLU non-linearity 
│   │   └── __init__.py
│   ├── __init__.py
│   ├── stochastic_nodes        # layers implementing stochastic transformations
│   │   ├── __init__.py
│   │   ├── normal.py
│   │   └── triangular_sylvester.py
│   ├── dx7_cnp.py              # experimental modules
│   ├── dx7_np.py               # experimental modules
│   ├── dx7_nsp.py              # experimental modules
│   ├── dx7_vae.py              # working model used in production of thisdx7cartdoesnotexist.com
│   └── utils.py
├── solvers
│   ├── dx7_np.py               # experimental modules
│   ├── dx7_nsp.py              # experimental modules
│   ├── dx7_patch_process.py    # experimental modules
│   ├── dx7_vae.py              # working model used in production of thisdx7cartdoesnotexist.com
│   ├── __init__.py
│   └── utils.py
└── utils.py
```


# thisdx7cartdoesnotexist.com

If you've found your way here through [thisdx7cartdoesnotexist.com](https://www.thisdx7cartdoesnotexist.com/) then this section will give an overview of how that site generates patches.


The model itself is defined under `NeuralDX7/models/dx7_vae.py`, it is a simple VAE with triangular sylvester flows implemented with attention layers over the parameters of the DX7. 

The training code can be found in `NeuralDX7/solvers/dx7_vae.py` and is a fairly standard VAE+Flow optimisation setup 

Finally the training script itself, as well as various other scripts used to perform various functions with the trained model can be found under `projects/dx7_vae`. The scripts in here do the following:

`duplicate_test.py` - samples randomly from the prior and calculates the number of identical patches. This was found to be around 99.9% unique

`evaluate.py` - was a script designed to create a single cartridge and was used during development to ensure the model was outputting valid parameter configurations.

`experiment.py` - contains the code to run the experiment. If you want to train your own version then start here.

`features.py` - simple feature extractor for the dataset, used to calculate the model posterior.

`interpolate.py` - takes two samples from the dataset and produces a cartridge that moves between the two in latent space over 32 steps

`live.py` - uses jackd to provide a realtime interface to the model. This is really fun to play with as it lets you hook up a midi controller to control the latent variables of the model and update the parameters of your FM Synthesizer in real time. 

To use it you will need jackd installed, add both your controller and FM synthsizer to the jack graph, update the names of the controller and fm synth in the `live.py` script and then run. Each of the first 8 midi cc controls recieved will be mapped to a latent dimension of the model.


```
import torch
import mido
from agoge import InferenceWorker
from neuralDX7.utils import dx7_bulk_pack

# load model (weights should download automatically)
model = InferenceWorker('hasty-copper-dogfish', 'dx7-vae', with_data=False).model

# sample latent from prior N(0,1)
z = torch.randn(32, 8)

# decode samples to logits
p_x = model.generate(z)

# sample
sample = p_x.logits.argmax(-1)

# convert pytorch tensors to syx
msg = dx7_bulk_pack(sample.numpy().tolist())

mido.write_syx_file('path/to/save.syx', [msg])
```

================================================
FILE: neuralDX7/__init__.py
================================================



from agoge import DEFAULTS, defaults_f


DEFAULTS = defaults_f({
    'ARTIFACT_ROOT': '~/agoge/artifacts'
}, DEFAULTS)

================================================
FILE: neuralDX7/constants.py
================================================
from pathlib import Path
import bitstruct
import mido


def take(take_from, n):
    for _ in range(n):
        yield next(take_from)

N_OSC = 6
N_VOICES = 32

def checksum(data):
    return (128-sum(data)&127)%128

GLOBAL_VALID_RANGES = {
    'PR1':  range(0, 99+1),
    'PR2':  range(0, 99+1),
    'PR3':  range(0, 99+1),
    'PR4':  range(0, 99+1),
    'PL1':  range(0, 99+1),
    'PL2':  range(0, 99+1),
    'PL3':  range(0, 99+1),
    'PL4':  range(0, 99+1),
    'ALG':  range(0, 31+1),
    'OKS':  range(0, 1+1),
    'FB':   range(0, 7+1),
    'LFS':  range(0, 99+1),
    'LFD':  range(0, 99+1),
    'LPMD':  range(0, 99+1),
    'LAMD':  range(0, 99+1),
    'LPMS': range(0, 7+1),
    'LFW':  range(0, 5+1),
    'LKS':  range(0, 1+1),
    'TRNSP':  range(0, 48+1),
    'NAME CHAR 1': range(128),
    'NAME CHAR 2': range(128),
    'NAME CHAR 3': range(128),
    'NAME CHAR 4': range(128),
    'NAME CHAR 5': range(128),
    'NAME CHAR 6': range(128),
    'NAME CHAR 7': range(128),
    'NAME CHAR 8': range(128),
    'NAME CHAR 9': range(128),
    'NAME CHAR 10': range(128),
 }

OSCILLATOR_VALID_RANGES = {
    'R1':  range(0, 99+1),
    'R2':  range(0, 99+1),
    'R3':  range(0, 99+1),
    'R4':  range(0, 99+1),
    'L1':  range(0, 99+1),
    'L2':  range(0, 99+1),
    'L3':  range(0, 99+1),
    'L4':  range(0, 99+1),
    'BP':  range(0, 99+1),
    'LD':  range(0, 99+1),
    'RD':  range(0, 99+1),
    'RC':  range(0, 3+1),
    'LC':  range(0, 3+1),
    'DET': range(0, 14+1),
    'RS':  range(0, 7+1),
    'KVS': range(0, 7+1),
    'AMS': range(0, 3+1),
    'OL':  range(0, 99+1),
    'FC':  range(0, 31+1),
    'M':   range(0, 1+1),
    'FF':  range(0, 99+1),
}

VOICE_PARAMETER_RANGES = {f'{i}_{key}': value for key, value in OSCILLATOR_VALID_RANGES.items() for i in range(N_OSC)}
VOICE_PARAMETER_RANGES.update(GLOBAL_VALID_RANGES)

def verify(actual, ranges):
    assert set(actual.keys())==set(ranges.keys()), 'Params dont match'
    for key in actual:
        if not actual[key] in ranges[key]:
            return False
    return True


HEADER_KEYS = [
    'ID',
    'Sub-status',
    'format number',
    'byte count',
    'byte count',
]

GENERAL_KEYS = [
    'PR1',
    'PR2',
    'PR3',
    'PR4',
    'PL1',
    'PL2',
    'PL3',
    'PL4',
    'ALG',
    'OKS',
    'FB',
    'LFS',
    'LFD',
    'LPMD',
    'LAMD',
    'LPMS',
    'LFW',
    'LKS',
    'TRNSP',
    'NAME CHAR 1',
    'NAME CHAR 2',
    'NAME CHAR 3',
    'NAME CHAR 4',
    'NAME CHAR 5',
    'NAME CHAR 6',
    'NAME CHAR 7',
    'NAME CHAR 8',
    'NAME CHAR 9',
    'NAME CHAR 10',
]

OSC_KEYS = [
    'R1',
    'R2',
    'R3',
    'R4',
    'L1',
    'L2',
    'L3',
    'L4',
    'BP',
    'LD',
    'RD',
    'RC',
    'LC',
    'DET',
    'RS',
    'KVS',
    'AMS',
    'OL',
    'FC',
    'M',
    'FF',
]

FOOTER_KEYS = ['checksum']


VOICE_KEYS = [f'{i}_{key}' for i in range(6) for key in OSC_KEYS] + \
        GENERAL_KEYS 

KEYS =  HEADER_KEYS + \
        list(VOICE_KEYS * N_VOICES) + \
        FOOTER_KEYS



header_bytes = [
    'p1u7',             # ID # (i=67; Yamaha)
    'p1u7',             # Sub-status (s=0) & channel number (n=0; ch 1)
    'p1u7',             # format number (f=9; 32 voices)
    'p1u7',             # byte count MS byte
    'p1u7',             # byte count LS byte (b=4096; 32 voices)
]




general_parameter_bytes = [ 
    'p1u7',             # PR1
    'p1u7',             # PR2
    'p1u7',             # PR3
    'p1u7',             # PR4
    'p1u7',             # PL1
    'p1u7',             # PL2
    'p1u7',             # PL3
    'p1u7',             # PL4
    'p3u5',             # ALG
    'p4u1u3',           # OKS|    FB
    'p1u7',             # LFS
    'p1u7',             # LFD
    'p1u7',             # LPMD
    'p1u7',             # LAMD
    'p1u3u3u1',         # LPMS |      LFW      |LKS
    'p1u7',             # TRNSP
    'p1u7',             # NAME CHAR 1
    'p1u7',             # NAME CHAR 2
    'p1u7',             # NAME CHAR 3
    'p1u7',             # NAME CHAR 4
    'p1u7',             # NAME CHAR 5
    'p1u7',             # NAME CHAR 6
    'p1u7',             # NAME CHAR 7
    'p1u7',             # NAME CHAR 8
    'p1u7',             # NAME CHAR 9
    'p1u7',             # NAME CHAR 10
]

osc_parameter_bytes = [
    'p1u7',         # R1
    'p1u7',         # R2
    'p1u7',         # R3
    'p1u7',         # R4
    'p1u7',         # L1
    'p1u7',         # L2
    'p1u7',         # L3
    'p1u7',         # L4
    'p1u7',         # BP
    'p1u7',         # LD
    'p1u7',         # RD
    'p4u2u2',       # RC | LC 
    'p1u4u3',       # DET | RS
    'p3u3u2',       # KVS | AMS
    'p1u7',         # OL
    'p2u5u1',       # FC | M
    'p1u7'          # FF
]

voice_bytes = (osc_parameter_bytes * N_OSC) + general_parameter_bytes

tail_bytes = [
    'p1u7',         # checksum
]


full_string = ''.join(header_bytes + osc_parameter_bytes * 6 + general_parameter_bytes)
dx7_struct = bitstruct.compile(full_string)

voice_struct = bitstruct.compile(''.join(voice_bytes), names=VOICE_KEYS)
header_struct = bitstruct.compile(''.join(header_bytes))

N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1


"""
SYSEX Message: Bulk Data for 1 Voice
------------------------------------
       bits    hex  description

     11110000  F0   Status byte - start sysex
     0iiiiiii  43   ID # (i=67; Yamaha)
     0sssnnnn  00   Sub-status (s=0) & channel number (n=0; ch 1)
     0fffffff  00   format number (f=0; 1 voice)
     0bbbbbbb  01   byte count MS byte
     0bbbbbbb  1B   byte count LS byte (b=155; 1 voice)
     0ddddddd  **   data byte 1

        |       |       |

     0ddddddd  **   data byte 155
     0eeeeeee  **   checksum (masked 2's complement of sum of 155 bytes)
     11110111  F7   Status - end sysex



///////////////////////////////////////////////////////////
"""
class DX7Single():
    HEADER = int('0x43', 0), int('0x00', 0), int('0x00', 0), int('0x01', 0), int('0x1B', 0)
    
    GENERAL_KEYS = [
        'PR1',
        'PR2',
        'PR3',
        'PR4',
        'PL1',
        'PL2',
        'PL3',
        'PL4',
        'ALG',
        'FB',
        'OKS',
        'LFS',
        'LFD',
        'LPMD',
        'LAMD',
        'LKS',
        'LFW',
        'LPMS',
        'TRNSP',
        'NAME CHAR 1',
        'NAME CHAR 2',
        'NAME CHAR 3',
        'NAME CHAR 4',
        'NAME CHAR 5',
        'NAME CHAR 6',
        'NAME CHAR 7',
        'NAME CHAR 8',
        'NAME CHAR 9',
        'NAME CHAR 10',
    ]

    OSC_KEYS = [
        'R1',
        'R2',
        'R3',
        'R4',
        'L1',
        'L2',
        'L3',
        'L4',
        'BP',
        'LD',
        'RD',
        'LC',
        'RC',
        'RS',
        'AMS',
        'KVS',
        'OL',
        'M',
        'FC',
        'FF',
        'DET',
    ]

    @staticmethod
    def keys():

        osc_keys = DX7Single.OSC_KEYS
        osc_params = [f'{i}_{param}' for i in range(N_OSC) for param in osc_keys]
        # print(osc_params)
        all = osc_params + DX7Single.GENERAL_KEYS
        return all

    @staticmethod
    def struct():
        return bitstruct.compile('p1u7'*155, names=DX7Single.keys())


    @staticmethod
    def to_syx(voices):


        assert len(voices)==1
        voice = voices[0]
        voices_bytes = bytes()
        voices_bytes = DX7Single.struct().pack(dict(zip(VOICE_KEYS, voice)))    
        
        patch_checksum = [checksum(voices_bytes)]

        data = bytes(DX7Single.HEADER) \
            + voices_bytes \
            + bytes(patch_checksum)


        return mido.Message('sysex', data=data)
        

def consume_syx(path):

    path = Path(path).expanduser()
    try:
        preset = mido.read_syx_file(path.as_posix())[0]
    except IndexError as e:
        return None
    except ValueError as e:
        return None
    if len(preset.data) == 0:
        return None

    def get_voice(data):
        
        unpacked = voice_struct.unpack(data)

        if not verify(unpacked, VOICE_PARAMETER_RANGES):
            return None
        
        return unpacked

    get_header = header_struct.unpack
    sysex_iter = iter(preset.data)
    
    try:
        header = get_header(bytes(take(sysex_iter, len(header_bytes))))
        yield from (get_voice(bytes(take(sysex_iter, len(voice_bytes)))) for _ in range(N_VOICES))
    except RuntimeError:
        return None

if __name__=="__main__":
    print(VOICE_KEYS)

    # print(DX7Single.to_syx(n)

================================================
FILE: neuralDX7/datasets/__init__.py
================================================
from .dx7_sysex_dataset import DX7SysexDataset

================================================
FILE: neuralDX7/datasets/dx7_sysex_dataset.py
================================================
from pathlib import Path
import numpy as np
import torch
from neuralDX7 import DEFAULTS




class DX7SysexDataset():
    """
    Pytorch Dataset module to provide access to precprocessed DX7 patch data
    """
    

    def __init__(self, data_file='dx7.npy', root=DEFAULTS['ARTIFACTS_ROOT'], data_size=1.):
        """
        data_file - the name of the prprocessed data
        root - the root directory for data
        data_size - how much of the data is used. good for development
        """

        assert data_size <= 1
        self.data_size = data_size

        # initialise path handler
        if not isinstance(root, Path):
            root = Path(root).expanduser()

        # load data into memory
        self.data = np.load(root.joinpath(data_file)) 

    def __getitem__(self, index):
        
        # turn the data item into a tensor and return
        item = torch.tensor(self.data[index].item()).long()

        return {'X': item}
    
    def __len__(self):
        return int(len(self.data) * self.data_size)


if __name__ == "__main__":
    

    dataset = DX7SysexDataset()

    print([dataset[i] for i in np.random.randint(0, len(dataset)-1, 20)])

================================================
FILE: neuralDX7/models/__init__.py
================================================
from .dx7_cnp import DX7PatchProcess
from .dx7_np import DX7NeuralProcess
from .dx7_nsp import DX7NeuralSylvesterProcess
from .dx7_vae import DX7VAE

================================================
FILE: neuralDX7/models/attention/__init__.py
================================================
from .attention import Attention
from .attention_layer import AttentionLayer
from .attention_encoder import ResidualAttentionEncoder
from .conditional_attention_encoder import CondtionalResidualAttentionEncoder

================================================
FILE: neuralDX7/models/attention/attention.py
================================================
import torch
from torch import nn

class Attention(nn.Module):



    def __init__(self, n_features, n_hidden, n_heads=8, inf=1e9):
        """
        n_features - number of input features
        n_hidden - hidden dim per head
        n_heads - number of heads
        """

        super().__init__()

        self.QKV = nn.Linear(n_features, n_hidden * 3 * n_heads)
        self.n_heads = n_heads
        self._inf = inf

    @property
    def inf(self):
        # if self.training:
            return self._inf
        # return float('inf')

    def forward(self, X, A):

        *input_shape, _ = X.shape

        # calculate the query key and value vectors for all data points
        QKV = self.QKV(X).reshape(*input_shape, -1, 3, self.n_heads)

        # permute the heads and qkv vectors to the first dimensions
        n_dims = len(QKV.shape)
        permuter = torch.arange(n_dims).roll(2)
        Q, K, V = QKV.permute(*permuter)

        # calculate the attention values
        qk_t = (Q @ K.transpose(-1, -2)) / (self.n_heads**(1/2))
        qk_t_masked =  qk_t.masked_fill(~A, -self.inf)
        
        # apply the attention values to the values
        Y = qk_t_masked.softmax(-1) @ V

        # restore heads to the final dimension and flatten (effectively concatenating them)
        n_dims = len(Y.shape)
        permuter = torch.arange(n_dims).roll(-1)
        Y = Y.permute(*permuter).flatten(-2, -1)
        return Y



if __name__=="__main__":



    model = Attention(100, 20)
    X = torch.randn(3, 25, 100)
    A = torch.rand(3, 25, 25)>0.5
    Y = model(X, A)
    print(Y.shape)


================================================
FILE: neuralDX7/models/attention/attention_encoder.py
================================================
import torch
from os import environ

from torch import nn

from agoge import AbstractModel
from neuralDX7.models.attention import AttentionLayer
from neuralDX7.models.utils import position_encoding_init



class ResidualAttentionEncoder(AbstractModel):
    """
    Residual attention stacks based on the Attention Is All You Need paper
    """

    def __init__(self, features, attention_layer, max_len=200, n_layers=3):
        """
        features - the number of features per parameter
        c_features - the number of side conditioning features per batch item
        attention_layer - a dictionary containing instantiation parameters for the AttentionLayer module
        max_len - the maximum needed size of the positional encodings
        n_layers - number of layers for the module to use
        """
        super().__init__()

        # create the layers
        self.layers = nn.ModuleList(
            map(lambda x: AttentionLayer(**attention_layer), range(n_layers))
        )

        # pre generate the positional encodings
        positional_encoding = position_encoding_init(max_len, features)
        self.register_buffer('positional_encoding', positional_encoding)

        self.p2x = nn.Linear(features, features * 2)


    def forward(self, X, A):
        """
        X - data tensor, torch.FloatTensor(batch_size, num_parameters, features)
        A - connection mask, torch.BoolTensor(batch_size, num_parameters, features)
        """

        # generate FiLM parameters from positional encodings for conditioning
        gamma, beta = self.p2x(self.positional_encoding).chunk(2, -1)
        gamma, beta = torch.sigmoid(gamma), torch.tanh(beta)

        # Apply the data through the layers adding the positioning information in at each layer
        for layer in self.layers:
            X = layer(gamma * X + beta, A)

        return X
        


if __name__=='__main__':

    layer_features = 100
    n_heads = 4

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    
    max_len = 25
    
    
    model = ResidualAttentionEncoder(layer_features, attention_layer, max_len=max_len)
    A = torch.rand(3, 25, 25)>0.5
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, 25))

    model(X, A)



================================================
FILE: neuralDX7/models/attention/attention_layer.py
================================================
import torch
from torch import nn

from neuralDX7.models.attention import Attention



class AttentionLayer(nn.Module):
    """
    Layer based on the original Attention is All You Need paper and is usable in graph network setups

    """

    def __init__(self, features, hidden_dim, attention):
        """
        features - the number of features the layer has at input and output
        hidden_dim - the hidden dimension of the feedforward network

        """

        super().__init__()

        self.attention = Attention(**attention)
        self.feedforward = nn.Sequential(
            nn.Linear(features, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, features),
        )

        self.attention_norm = nn.LayerNorm(features)
        self.feedforward_norm = nn.LayerNorm(features, elementwise_affine=False) # save elementwise affine for film conditioning

    

    def forward(self, X, A):
        """
        X - data tensor, torch.FloatTensor(batch_size, num_parameters, features)
        A - connection mask, torch.BoolTensor(batch_size, num_parameters, features)
        """
        X = self.attention_norm(self.attention(X, A) + X)
        X = self.feedforward_norm(self.feedforward(X) + X)

        return X

if __name__=="__main__":

    attention = {
        'n_features': 100,
        'n_hidden': 25,
        'n_heads': 4
    }
    
    model = AttentionLayer(100, 250, attention)

    X = torch.randn(3, 25, 100)
    A = torch.rand(3, 25, 25)>0.5
    Y = model(X, A)

    print(Y.shape)

================================================
FILE: neuralDX7/models/attention/conditional_attention_encoder.py
================================================
import torch
from os import environ

from torch import nn

from agoge import AbstractModel
from neuralDX7.models.attention import AttentionLayer
from neuralDX7.models.general import FeedForwardGELU
from neuralDX7.models.utils import position_encoding_init



class CondtionalResidualAttentionEncoder(AbstractModel):
    """
    Very similar to attention encoder but also allows custom side conditioning capacity

    """
    def __init__(self, features, c_features, attention_layer, max_len=200, n_layers=3):
        """
        features - the number of features per parameter
        c_features - the number of side conditioning features per batch item
        attention_layer - a dictionary containing instantiation parameters for the AttentionLayer module
        max_len - the maximum needed size of the positional encodings
        n_layers - number of layers for the module to use
        """
        super().__init__()


        self.layers = nn.ModuleList(
            map(lambda x: AttentionLayer(**attention_layer), range(n_layers))
        )

        positional_encoding = position_encoding_init(max_len, features)
        self.c_layers = nn.ModuleList(
            map(lambda x: FeedForwardGELU(c_features, features*2), range(n_layers))
        )

        self.p2x = nn.Linear(features, features * 2)
        self.register_buffer('positional_encoding', positional_encoding)


    def forward(self, X, A, c):
        """
        X - data tensor, torch.FloatTensor(batch_size, num_parameters, features)
        A - connection mask, torch.BoolTensor(batch_size, num_parameters, features)
        """

        # generate FiLM parameters from positional encodings for conditioning
        gamma_p, beta_p = self.p2x(self.positional_encoding).chunk(2, -1)
        gamma_p, beta_p = torch.sigmoid(gamma_p), torch.tanh(beta_p)

        X = gamma_p * X + beta_p

        for layer, c_layer in zip(self.layers, self.c_layers):

            gamma_c, beta_c = c_layer(c).chunk(2, -1)
            gamma_c, beta_c = torch.sigmoid(gamma_c), torch.tanh(beta_c)

            X = layer(gamma_c * X + beta_c, A)

        return X
        


if __name__=='__main__':

    layer_features = 100
    n_heads = 4

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    
    max_len = 25
    
    
    model = ResidualAttentionEncoder(layer_features, attention_layer, max_len=max_len)
    A = torch.rand(3, 25, 25)>0.5
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, 25))

    model(X, A)



================================================
FILE: neuralDX7/models/dx7_cnp.py
================================================
import torch
from torch import nn

from agoge import AbstractModel
from neuralDX7.models.attention import ResidualAttentionEncoder
from neuralDX7.constants import MAX_VALUE, N_PARAMS
from neuralDX7.utils import mask_parameters


class DX7PatchProcess(AbstractModel):
    """
    EXPERIMENTAL AND UNTESTED

    
    """

    def __init__(self, features, encoder):
        
        super().__init__()

        self.embedder = nn.Embedding(MAX_VALUE, features)
        self.encoder = ResidualAttentionEncoder(**encoder)

        self.logits = nn.Linear(features, MAX_VALUE)

    def forward(self, X):
        # print(X.shape, )

        # generate random masks
        batch_p = torch.rand(X.shape[0]) # decide p value for each item in batch
        item_logits = torch.rand(X.shape) # random value for each param
        X_a = batch_p.unsqueeze(-1) <= item_logits # active params in X
        X_a = X_a.to(self.device)

        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye
        # 1/0

        X = self.embedder(X) * X_a.unsqueeze(-1).float()
        # X = self.embedder(X) 

        X = self.encoder(X, A)
        # X_hat = mask_parameters(self.logits(X))
        X_hat = self.logits(X)
        # print(X_hat.max(), X_hat.min())

        return X_hat, X_a

    @torch.no_grad()
    def features(self, X):

        X_a = torch.ones_like(X).bool()
        A = X_a.unsqueeze(-1) & X_a.unsqueeze(-2)
        X = self.embedder(X)
        # X = self.embedder(X) 

        X = self.encoder(X, A)

        return X

    @torch.no_grad()
    def generate(self, X, X_a):
        
    
        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye

        X = self.embedder(X) * X_a.unsqueeze(-1).float()
        X = self.encoder(X, A)
        X_hat = mask_parameters(self.logits(X))

        X_hat = torch.distributions.Categorical(logits=X_hat)

        return X_hat



if __name__=='__main__':

    layer_features = 100
    n_heads = 4
    N_PARAMS = 8

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS
    }
        
    
    model = DX7PatchProcess(layer_features, encoder=encoder)
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, N_PARAMS))

    logits = model(X)
    print(logits.shape)
    print(logits[0])





================================================
FILE: neuralDX7/models/dx7_np.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

from agoge import AbstractModel
from neuralDX7.models.attention import ResidualAttentionEncoder, CondtionalResidualAttentionEncoder
from neuralDX7.models.general import FeedForwardGELU
from neuralDX7.models.stochastic_nodes import NormalNode
from neuralDX7.constants import MAX_VALUE, N_PARAMS
from neuralDX7.utils import mask_parameters


class DX7NeuralProcess(AbstractModel):
    """
    EXPERIMENTAL AND UNTESTED

    """

    def __init__(self, features, latent_dim, encoder, decoder, deterministic_path_drop_rate=0.5):
        
        super().__init__()

        self.embedder = nn.Embedding(MAX_VALUE, features)
        self.encoder = ResidualAttentionEncoder(**encoder)
        self._latent_encoder = nn.ModuleList([
            ResidualAttentionEncoder(**encoder),
            NormalNode(features, latent_dim)]
        )
        self.z_to_c = nn.Linear(latent_dim, latent_dim*155)
        self.decoder = CondtionalResidualAttentionEncoder(**decoder)
        self.logits = FeedForwardGELU(features, MAX_VALUE)
        self.drop = nn.Dropout(deterministic_path_drop_rate)

    def latent_encoder(self,  X, A, mean=False):

        encoder, q_x = self._latent_encoder

        return q_x(encoder(X, A).mean(-2))


    def forward(self, X):

        # generate random masks
        batch_p = torch.rand(X.shape[0]) # decide p value for each item in batch
        item_logits = torch.rand(X.shape) # random value for each param
        X_a = batch_p.unsqueeze(-1) <= item_logits # active params in X
        X_a = X_a.to(self.device)

        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye
        # A = A | True

        X_target = self.embedder(X)
        
        X_context = X_target * X_a.unsqueeze(-1).float()

        q_context = self.latent_encoder(X_context, A)
        q_target = self.latent_encoder(X_target, A | (~X_a.unsqueeze(-1)))

        # r = self.drop(self.encoder(X_context, A))
        # X_encoded = F.drop out
        z = q_target.rsample()
        # z_context = q_context.rsample()
        # mask = (torch.rand_like(z_target[...,[0]]) > 0.5).float()
        # z = (z_target * mask) + (z_context * (1-mask))

        c = self.z_to_c(z).view(z.shape[0], 155, -1)
        # c = z.unsqueeze(-2)

        X_dec = self.decoder(X_context, A, c)
        X_hat = self.logits(X_dec)

        return X_hat, X_a, q_context, q_target, z

    @torch.no_grad()
    def features(self, X, X_a):

        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye

        X = self.embedder(X) * X_a.unsqueeze(-1).float()
        q = self.latent_encoder(X, A)

        return q

    @torch.no_grad()
    def generate_z(self, X, X_a, z, t=1.):


        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye

        X = self.embedder(X)
        X = X * X_a.unsqueeze(-1).float()       

        c = self.z_to_c(z).view(z.shape[0], 155, -1)
        X_dec = self.decoder(X, A, c)
        X_hat = mask_parameters(self.logits(X_dec))
        X_hat = torch.distributions.Categorical(logits=X_hat/t)

        return X_hat

    @torch.no_grad()
    def generate(self, X, X_a, sample=True, t=1.):

        q = self.features(X, X_a)

        

        z = q.sample()

        c_gamma, c_beta = self.z_to_c(z).chunk(2, -1)


        X_hat = mask_parameters(self.logits(c_gamma))

        X_hat = torch.distributions.Categorical(logits=X_hat/t)

        return X_hat



if __name__=='__main__':

    layer_features = 100
    n_heads = 4

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS
    }
        
    
    model = DX7PatchProcess(layer_features, encoder=encoder)
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, N_PARAMS))

    logits = model(X)
    print(logits.shape)
    print(logits[0])


================================================
FILE: neuralDX7/models/dx7_nsp.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

from agoge import AbstractModel
from neuralDX7.models.attention import ResidualAttentionEncoder, CondtionalResidualAttentionEncoder
from neuralDX7.models.general import FeedForwardGELU
from neuralDX7.models.stochastic_nodes import TriangularSylvesterFlow
from neuralDX7.constants import MAX_VALUE, N_PARAMS
from neuralDX7.utils import mask_parameters


class DX7NeuralSylvesterProcess(AbstractModel):
    """
    EXPERIMENTAL AND UNTESTED

    """

    def __init__(self, features, latent_dim, encoder, decoder, deterministic_path_drop_rate=0.5,  num_flows=3):
        
        super().__init__()

        self.embedder = nn.Embedding(MAX_VALUE, features)
        self.encoder = ResidualAttentionEncoder(**encoder)
        self._latent_encoder = nn.ModuleList([
            ResidualAttentionEncoder(**encoder),
            TriangularSylvesterFlow(features, latent_dim, num_flows)]
        )
        self.z_to_c = nn.Linear(latent_dim, latent_dim*155)
        self.decoder = CondtionalResidualAttentionEncoder(**decoder)
        self.logits = FeedForwardGELU(features, MAX_VALUE)
        self.drop = nn.Dropout(deterministic_path_drop_rate)

    def latent_encoder(self,  X, A, z=None, flow=True):

        encoder, q_x = self._latent_encoder

        return q_x(encoder(X, A).mean(-2), z, flow)


    def forward(self, X):

        batch_size = X.shape[0]

        # generate random masks
        batch_p = torch.rand(batch_size) # decide p value for each item in batch
        item_logits = torch.rand(X.shape) # random value for each param
        X_a = batch_p.unsqueeze(-1) <= item_logits # active params in X
        X_a = X_a.to(self.device)

        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye
        # A = A | True

        X_target = self.embedder(X)
        
        X_context = X_target * X_a.unsqueeze(-1).float()

        flow_target = self.latent_encoder(X_target, A | (~X_a.unsqueeze(-1)), flow=False)
        flow_context = self.latent_encoder(X_context, A)
        
        c = self.z_to_c(flow_target.z_0).view(batch_size, 155, -1)
        # c = z.unsqueeze(-2)

        X_dec = self.decoder(X_context, A, c)
        X_hat = self.logits(X_dec)

        return {
            'X_hat': X_hat,
            'X_a': X_a,
            'flow_context': flow_context,
            'flow_target': flow_target,
        }


    @torch.no_grad()
    def features(self, X, X_a):

        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye

        X = self.embedder(X) * X_a.unsqueeze(-1).float()
        q = self.latent_encoder(X, A)

        return q

    @torch.no_grad()
    def generate_z(self, X, X_a, z, t=1.):


        A = (~X_a.unsqueeze(-1)) & (X_a.unsqueeze(-2))
        eye = torch.eye(A.shape[-1]).bool().to(self.device) & (~X_a.unsqueeze(-2))
        A = A | eye

        X = self.embedder(X)
        X = X * X_a.unsqueeze(-1).float()       

        c = self.z_to_c(z).view(z.shape[0], 155, -1)
        X_dec = self.decoder(X, A, c)
        X_hat = mask_parameters(self.logits(X_dec))
        X_hat = torch.distributions.Categorical(logits=X_hat/t)

        return X_hat

    @torch.no_grad()
    def generate(self, X, X_a, sample=True, t=1.):

        q = self.features(X, X_a)

        z = q.sample()

        c_gamma, c_beta = self.z_to_c(z).chunk(2, -1)


        X_hat = mask_parameters(self.logits(c_gamma))

        X_hat = torch.distributions.Categorical(logits=X_hat/t)

        return X_hat



if __name__=='__main__':

    layer_features = 100
    n_heads = 4

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS
    }
        
    
    model = DX7PatchProcess(layer_features, encoder=encoder)
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, N_PARAMS))

    logits = model(X)
    print(logits.shape)
    print(logits[0])


================================================
FILE: neuralDX7/models/dx7_vae.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

from agoge import AbstractModel
from neuralDX7.models.attention import ResidualAttentionEncoder, CondtionalResidualAttentionEncoder
from neuralDX7.models.general import FeedForwardGELU
from neuralDX7.models.stochastic_nodes import TriangularSylvesterFlow
from neuralDX7.constants import MAX_VALUE, N_PARAMS
from neuralDX7.utils import mask_parameters


class DX7VAE(AbstractModel):
    """
    Variational Auto Encoder for a single DX7 patch. 
    
    Uses a Triangular sylvester flow to transform the encoder output to decoder input
    """

    def __init__(self, features, latent_dim, encoder, decoder, num_flows=3):
        """
        features - number of features in the model
        latent_dim - the latent dimension of the model
        encoder - dictionary containing instantiation parameters for ResidualAttentionEncoder module
        decoder - dictionary containing instantiation parameters for CondtionalResidualAttentionEncoder module
        num_flows - the number of flows for the TriangularSylvesterFlow module
        """
        
        super().__init__()

        self.embedder = nn.Embedding(MAX_VALUE, features)
        self.encoder = ResidualAttentionEncoder(**encoder)
        self._latent_encoder = nn.ModuleList([
            ResidualAttentionEncoder(**encoder),
            TriangularSylvesterFlow(features, latent_dim, num_flows)]
        )
        self.z_to_c = nn.Linear(latent_dim, latent_dim*155)
        self.decoder = CondtionalResidualAttentionEncoder(**decoder)
        self.logits = FeedForwardGELU(features, MAX_VALUE)

        self.n_features = features

    def latent_encoder(self,  X, A, z=None, mean=False):
        """
        Calculate the latent distribution

        X - data tensor, torch.FloatTensor(batch_size, 155, features)
        A - connection mask, torch.BoolTensor(batch_size, 155, features)
        z - a presampled latent, if none then the z is sampled using reparameterization technique
        mean - use the mean rather than sampling from the latent
        """
        
        encoder, q_x = self._latent_encoder

        return q_x(encoder(X, A).mean(-2), z)


    def forward(self, X):
        """
        Auto encodes the inputs variational latent layer

        X - the array of dx7 voices, torch.LongTensor(batch_size, 155)
        """

        batch_size = X.shape[0]

        A = torch.ones_like(X).bool()
        A = A[...,None] | A[...,None,:]

        X_emb = self.embedder(X)
        
        flow = self.latent_encoder(X_emb, A)
        
        c = self.z_to_c(flow.z_k).view(batch_size, 155, -1)

        X_dec = self.decoder(torch.ones_like(X_emb), A, c)
        X_hat = self.logits(X_dec)

        return {
            'X_hat': X_hat,
            'flow': flow,
        }


    @torch.no_grad()
    def features(self, X):
        """
        Get the latent distributions for a set of voices

        X - the array of dx7 voices, torch.LongTensor(batch_size, 155)

        """

        A = torch.ones_like(X).bool()
        A = A[...,None] | A[...,None,:]

        X = self.embedder(X)
        q = self.latent_encoder(X, A)

        return q.q_z

    @torch.no_grad()
    def generate(self, z, t=1.):
        """
        Given a sample from the latent distribution, reporojects it back to data space
        
        z - the array of dx7 voices, torch.FloatTensor(batch_size, latent_dim)
        t - the temperature of the output distribution. approaches determenistic as t->0 and approach uniforms as t->infty, requires t>0
        """
        A = z.new(z.size(0), 155, 155).bool() | 1
        X = z.new(z.size(0), 155, self.n_features)
        X = X * 0 + 1

        c = self.z_to_c(z).view(z.shape[0], 155, -1)
        X_dec = self.decoder(X, A, c)
        X_hat = mask_parameters(self.logits(X_dec))
        X_hat = torch.distributions.Categorical(logits=X_hat/t)

        return X_hat

if __name__=='__main__':

    layer_features = 100
    n_heads = 4

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': 555
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS
    }
        
    
    model = DX7PatchProcess(layer_features, encoder=encoder)
    X = torch.distributions.Categorical(torch.ones(128)).sample((3, N_PARAMS))

    logits = model(X)
    print(logits.shape)
    print(logits[0])


================================================
FILE: neuralDX7/models/general/__init__.py
================================================
from .gelu_ff import FeedForwardGELU

================================================
FILE: neuralDX7/models/general/gelu_ff.py
================================================
import torch
from torch import nn

class FeedForwardGELU(nn.Module):
    """
    Simple wrapper for two layer projection with GeLU non linearity

    """

    def __init__(self, features, out_features=None, exapnsion_factor=3):
        """
        features - the number of input features
        out_features - the number of output features, if None copies the input dimension
        expansion_factor - the size of the hidden dimension as a factor of the input features
        """

        super().__init__()
        out_features = features if out_features is None else out_features

        self.net = nn.Sequential(
            nn.Linear(features, features*exapnsion_factor),
            nn.GELU(),
            nn.Linear(features*exapnsion_factor, out_features)
        )

    def forward(self, x):

        return self.net(x)

================================================
FILE: neuralDX7/models/stochastic_nodes/__init__.py
================================================
from .normal import NormalNode
from .triangular_sylvester import TriangularSylvesterFlow

================================================
FILE: neuralDX7/models/stochastic_nodes/normal.py
================================================
from torch import nn
from torch.distributions import Normal



class NormalNode(nn.Module):
    """
    Simple module to create a normally distributed node in a ala VAE's. 

    this node computes the function
    ```
        p(x) = N(mu(x), sigma(x)I)
    ```
    """

    def __init__(self, in_features, latent_dim, hidden_dim=None):
        """
        in_features - number of input features
        latent_dim - number of normals in the output
        hidden_dim - the inner dimension of the nonlinear feedforward network, 2x the input dimension if None
        """
        super().__init__()

        if hidden_dim is None:

            hidden_dim = in_features * 2

        self.net = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, latent_dim * 2)
        )

    def forward(self, x, *args, **kwargs):
        """
        x - the inpute vector, torch.FloatTensor(..., f)
        """

        # calculate the parameters of the distribution
        mu, log_sigma = self.net(x).chunk(2, -1)

        # sqrt and ensure numerical stability in sigma
        sigma = (log_sigma*0.5).clamp(-5, 4).exp()

        return Normal(mu, sigma)





================================================
FILE: neuralDX7/models/stochastic_nodes/triangular_sylvester.py
================================================
#%%
from collections import namedtuple
from itertools import count
import torch
from torch import nn
from neuralDX7.models.stochastic_nodes import NormalNode

"""
This code modified from the reference implementation provided by the authors
https://github.com/riannevdberg/sylvester-flows
"""



class TriangularSylvester(nn.Module):
    """
    Sylvester normalizing flow with Q=P or Q=I.
    """

    def __init__(self, z_size):

        super(TriangularSylvester, self).__init__()

        self.z_size = z_size
        self.h = nn.Tanh()

        # diag_idx = torch.arange(0, z_size).long()
        # self.register_buffer('diag_idx', diag_idx)

    def der_h(self, x):
        return self.der_tanh(x)

    def der_tanh(self, x):
        return 1 - self.h(x) ** 2

    def forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True):
        """
        All flow parameters are amortized. conditions on diagonals of R1 and R2 need to be satisfied
        outside of this function.
        Computes the following transformation:
        z' = z + QR1 h( R2Q^T z + b)
        or actually
        z'^T = z^T + h(z^T Q R2^T + b^T)R1^T Q^T
        with Q = P a permutation matrix (equal to identity matrix if permute_z=None)
        :param zk: shape: (batch_size, z_size)
        :param r1: shape: (batch_size, num_ortho_vecs, num_ortho_vecs).
        :param r2: shape: (batch_size, num_ortho_vecs, num_ortho_vecs).
        :param b: shape: (batch_size, 1, self.z_size)
        :return: z, log_det_j
        """
        # Amortized flow parameters
        zk = zk.unsqueeze(1)

        # Save diagonals for log_det_j
        # diag_r1 = r1[:, self.diag_idx, self.diag_idx]
        diag_r1 = torch.diagonal(r1, 0, -1, -2)
        # diag_r2 = r2[:, self.diag_idx, self.diag_idx]
        diag_r2 = torch.diagonal(r2, 0, -1, -2)

        if permute_z is not None:
            # permute order of z
            z_per = zk[:, :, permute_z]
        else:
            z_per = zk

        r2qzb = z_per @ r2.transpose(2, 1) + b
        z = self.h(r2qzb) @ r1.transpose(2, 1)

        if permute_z is not None:
            # permute order of z again back again
            z = z[:, :, permute_z]

        z += zk
        z = z.squeeze(1)

        # Compute log|det J|
        # Output log_det_j in shape (batch_size) instead of (batch_size,1)
        diag_j = diag_r1 * diag_r2
        diag_j = self.der_h(r2qzb).squeeze(1) * diag_j
        diag_j += 1.
        log_diag_j = (diag_j.abs()+1e-8).log()

        if sum_ldj:
            log_det_j = log_diag_j.sum(-1)
        else:
            log_det_j = log_diag_j

        return z, log_det_j

class TriangularSylvesterFlow(nn.Module):
    """
    Variational auto-encoder with triangular Sylvester flows in the encoder. Alternates between setting
    the orthogonal matrix equal to permutation and identity matrix for each flow.
    """

    def __init__(self, in_features, latent_dim, num_flows):

        super().__init__()
        # Initialize log-det-jacobian to zero
        self.log_det_j = 0.

        # Flow parameters
        self.num_flows = num_flows
        self.latent_dim = latent_dim

        # permuting indices corresponding to Q=P (permutation matrix) for every other flow
        flip_idx = torch.arange(latent_dim - 1, -1, -1).long()
        self.register_buffer('flip_idx', flip_idx)

        # self.amor_b = nn.Linear(self.q_z_nn_output_dim, self.num_flows * latent_dim)
        self.q_z = NormalNode(in_features, latent_dim)
        self._flow_params = nn.Linear(in_features,
                self.num_flows * latent_dim * latent_dim + \
                self.num_flows * latent_dim + \
                self.num_flows * latent_dim + \
                self.num_flows * latent_dim
        )
        self.flows = nn.ModuleList([
            TriangularSylvester(latent_dim) for k in range(self.num_flows)
        ])

    def flow_params(self, h):
        """
        Parameterise the base distribution, sample and flow
        """

        batch_size = h.size(0)

        params = self._flow_params(h)
        params = params.reshape(batch_size, self.num_flows, self.latent_dim, -1)
        params = params.transpose(0,1) # batch x flows x z x z  -> flows x batch x z x z
        
        diag1 = torch.tanh(params[...,0])
        diag2 = torch.tanh(params[...,1])
        b = params[...,2].unsqueeze(-2)
        full_d = params[...,3:]

        r1 = torch.triu(full_d, diagonal=1)
        r2 = torch.triu(full_d.transpose(-1, -2), diagonal=1)
        r1 = diag1.diag_embed(0) + r1
        r2 = diag2.diag_embed(0) + r2

        return r1, r2, b

    def forward(self, h, z=None, flow=True):
        """
        Forward pass with orthogonal flows for the transformation z_0 -> z_1 -> ... -> z_k.
        Log determinant is computed as log_det_j = N E_q_z0[\sum_k log |det dz_k/dz_k-1| ].
        """
        Flow = namedtuple('Flow', ('q_z', 'log_det', 'z_0', 'z_k', 'flow'))

        q_z = self.q_z(h)
        z_0 = z_k = q_z.rsample() if z is None else z

        if not flow:
            return Flow(q_z, None, z_0, None)

        r1, r2, b = self.flow_params(h)

        # Sample z_0
        def flow_f(z_k):
            log_det_j = 0.

            # Normalizing flows
            for k, flow_k, r1_k, r2_k, b_k in zip(count(), self.flows, r1, r2, b):

                if k % 2 == 1:
                    # Alternate with reorderering z for triangular flow
                    permute_z = self.flip_idx
                else:
                    permute_z = None

                z_k, log_det_jacobian = flow_k(z_k, r1_k, r2_k, b_k, permute_z, sum_ldj=True)

                log_det_j += log_det_jacobian
            
            return z_k, log_det_j
        z_k, log_det_j = flow_f(z_0)
        return Flow(q_z, log_det_j, z_0, z_k, flow_f)


if __name__=="__main__":
    
    num_ortho_vecs = z_size = 6
    batch_size = 12
    in_features = 64
    
    h = torch.randn(batch_size, in_features)
    # zk = torch.randn(batch_size, z_size)
    # r1 = torch.randn(batch_size, num_ortho_vecs, num_ortho_vecs)
    # r2 = torch.randn(batch_size, num_ortho_vecs, num_ortho_vecs)
    # b = torch.randn(batch_size, 1, z_size)

    f = TriangularSylvesterFlow(in_features, z_size, 3)

    f(h)

# %%


================================================
FILE: neuralDX7/models/utils.py
================================================

import torch
import numpy as np

def position_encoding_init(n_position, emb_dim):
    ''' Init the sinusoid position encoding table '''

    # keep dim 0 for padding token position encoding zero vector
    position_enc = np.array([
        [pos / np.power(10000, 2 * (j // 2) / emb_dim) for j in range(emb_dim)]
        if pos != 0 else np.zeros(emb_dim) for pos in range(n_position)])
    

    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # apply sin on 0th,2nd,4th...emb_dim
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # apply cos on 1st,3rd,5th...emb_dim
    return torch.from_numpy(position_enc).type(torch.FloatTensor)


================================================
FILE: neuralDX7/solvers/__init__.py
================================================
from .dx7_patch_process import DX7PatchProcess
from .dx7_np import DX7NeuralProcess
from .dx7_nsp import DX7NeuralSylvesterProcess
from .dx7_vae import DX7VAE

================================================
FILE: neuralDX7/solvers/dx7_np.py
================================================
import torch
from torch.nn import functional as F
from importlib import import_module
from torch.optim import AdamW
from torch.distributions.kl import kl_divergence

from agoge import AbstractSolver

from .utils import sigmoidal_annealing

class DX7NeuralProcess(AbstractSolver):
    """
    EXPERIMENTAL AND UNTESTED
    """

    def __init__(self, model,
        Optim=AdamW, optim_opts=dict(lr= 1e-4),
        max_beta=0.5,
        beta_temp=1e-4,
        **kwargs):

        if isinstance(Optim, str):
            Optim = import_module(Optim)


        self.optim = Optim(params=model.parameters(), **optim_opts)
        self.max_beta = max_beta
        self.model = model

        self.iter = 0
        self.beta_temp = beta_temp

    def loss(self, x, x_hat, x_a, q_context, q_target, z):

        valid_predictions = (~x_a).nonzero().t()

        valid_x_hat = x_hat[(*valid_predictions,)]
        valid_x = x[(*valid_predictions,)]

 
        # kl = kl_divergence(q_target, q_context)#[(*valid_predictions,)]
        # kl = kl_divergence(Normal(torch.zeros_like()), q_context)#[(*valid_predictions,)]
        kl = q_target.log_prob(z) - q_context.log_prob(z)
        kl = kl.sum(-1).mean()
        entropy = q_target.entropy().mean()
        beta = sigmoidal_annealing(self.iter, self.beta_temp).item()

        reconstruction_loss = F.cross_entropy(valid_x_hat, valid_x)
        accuracy = (valid_x_hat.argmax(-1)==valid_x).float().mean()

        loss = reconstruction_loss + 0.25 * beta * kl

        return loss, {
            'accuracy': accuracy,
            'reconstruction_loss': reconstruction_loss,
            'kl': kl,
            'entropy': entropy,
            'beta': beta
        }
        

    def solve(self, x, **kwargs):
        
        x_hat, x_a, q_context, q_target, z  = self.model(x)
        loss, L = self.loss(x, x_hat, x_a, q_context, q_target, z)

        if loss != loss:
            raise ValueError('Nan Values detected')

        if self.model.training:
            self.iter += 1
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        return L

    
    def step(self):

        pass


    def state_dict(self):
        
        state_dict = {
            'optim': self.optim.state_dict(),
            'iter': self.iter
        }

        return state_dict

    def load_state_dict(self, state_dict):
        
        self.optim.load_state_dict(state_dict['optim'])
        self.iter = state_dict['iter']

================================================
FILE: neuralDX7/solvers/dx7_nsp.py
================================================
import torch
from torch.nn import functional as F
from importlib import import_module
from torch.optim import AdamW
from torch.distributions.kl import kl_divergence

from agoge import AbstractSolver

from .utils import sigmoidal_annealing

class DX7NeuralSylvesterProcess(AbstractSolver):
    """
    EXPERIMENTAL AND UNTESTED
    """
 
    def __init__(self, model,
        Optim=AdamW, optim_opts=dict(lr= 1e-4),
        max_beta=0.5,
        beta_temp=1e-4,
        **kwargs):

        if isinstance(Optim, str):
            Optim = import_module(Optim)

        self.optim = Optim(params=model.parameters(), **optim_opts)
        self.max_beta = max_beta
        self.model = model

        self.iter = 0
        self.beta_temp = beta_temp

    def loss(self, X, X_hat, X_a, flow_context, flow_target):

        valid_predictions = (~X_a).nonzero().t()

        valid_x_hat = X_hat[(*valid_predictions,)]
        valid_x = X[(*valid_predictions,)]

        p_z = flow_target.q_z.log_prob(flow_context.z_k).sum(-1)
        q_z = flow_context.q_z.log_prob(flow_context.z_0).sum(-1)
        kl = (q_z-p_z-flow_context.log_det).mean() / flow_context.z_k.shape[-1]
        beta = sigmoidal_annealing(self.iter, self.beta_temp).item()

        reconstruction_loss = F.cross_entropy(valid_x_hat, valid_x)
        accuracy = (valid_x_hat.argmax(-1)==valid_x).float().mean()

        loss = reconstruction_loss + self.max_beta * beta * kl

        return loss, {
            'accuracy': accuracy,
            'reconstruction_loss': reconstruction_loss,
            'kl': kl,
            'beta': beta,
            'q_log_det': flow_context.log_det.mean(),
            # 'p_log_det': flow_target.log_det.mean(),
            'q_z': q_z.mean(),
            'p_z': p_z.mean()
        }
        

    def solve(self, X, **kwargs):
        
        Y = self.model(**X)
        loss, L = self.loss(**X, **Y)

        if loss != loss:
            raise ValueError('Nan Values detected')

        if self.model.training:
            self.iter += 1
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        return L

    
    def step(self):

        pass


    def state_dict(self):
        
        state_dict = {
            'optim': self.optim.state_dict(),
            'iter': self.iter
        }

        return state_dict

    def load_state_dict(self, state_dict):
        
        self.optim.load_state_dict(state_dict['optim'])
        self.iter = state_dict['iter']

================================================
FILE: neuralDX7/solvers/dx7_patch_process.py
================================================
import torch
from torch.nn import functional as F
from importlib import import_module
from torch.optim import AdamW

from agoge import AbstractSolver



class DX7PatchProcess(AbstractSolver):
    """
    EXPERIMENTAL AND UNTESTED
    """

    def __init__(self, model,
        Optim=AdamW, optim_opts=dict(lr= 1e-4),
        max_beta=0.5,
        **kwargs):

        if isinstance(Optim, str):
            Optim = import_module(Optim)


        self.optim = Optim(params=model.parameters(), **optim_opts)
        self.max_beta = max_beta
        self.model = model

    def loss(self, x, x_hat, x_a):

        valid_predictions = (~x_a).nonzero().t()

        valid_x_hat = x_hat[(*valid_predictions,)]
        valid_x = x[(*valid_predictions,)]

        reconstruction_loss = F.cross_entropy(valid_x_hat, valid_x)
        accuracy = (valid_x_hat.argmax(-1)==valid_x).float().mean()

        return reconstruction_loss, {
            'accuracy': accuracy,
            'reconstruction_loss': reconstruction_loss,
        }
        

    def solve(self, x, **kwargs):
        
        x_hat, x_a  = self.model(x)
        loss, L = self.loss(x, x_hat, x_a)

        if loss != loss:
            raise ValueError('Nan Values detected')

        if self.model.training:

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
        
        return L

    
    def step(self):

        pass


    def state_dict(self):
        
        state_dict = {
            'optim': self.optim.state_dict()
        }

        return state_dict

    def load_state_dict(self, state_dict):
        
        self.optim.load_state_dict(state_dict['optim'])

================================================
FILE: neuralDX7/solvers/dx7_vae.py
================================================
import torch
from torch.nn import functional as F
from importlib import import_module
from torch.optim import AdamW
from torch.distributions.kl import kl_divergence
from torch.distributions import Normal
from agoge import AbstractSolver

from .utils import sigmoidal_annealing

class DX7VAE(AbstractSolver):
    """
    Solver used to train DX7VAE model
    """

    def __init__(self, model,
        Optim=AdamW, optim_opts=dict(lr= 1e-4),
        max_beta=0.5,
        beta_temp=1e-4,
        **kwargs):

        if isinstance(Optim, str):
            Optim = import_module(Optim)

        self.optim = Optim(params=model.parameters(), **optim_opts)
        self.max_beta = max_beta
        self.model = model

        self.iter = 0
        self.beta_temp = beta_temp

    def loss(self, X, X_hat, flow):
        """
        Computes the VAE loss objective and collects some training statistics

        X - data tensor, torch.LongTensor(batch_size, num_parameters=155)
        X_hat - data tensor, torch.FloatTensor(batch_size, num_parameters=155, max_value=128)
        flow - the namedtuple returned by TriangularSylvesterFlow
        
        for reference, the namedtuple is ('Flow', ('q_z', 'log_det', 'z_0', 'z_k', 'flow'))
        """
    
        p_z_k = Normal(0,1).log_prob(flow.z_k).sum(-1)
        q_z_0 = flow.q_z.log_prob(flow.z_0).sum(-1)
        kl = (q_z_0-p_z_k-flow.log_det).mean() / flow.z_k.shape[-1]

        beta = sigmoidal_annealing(self.iter, self.beta_temp).item()

        reconstruction_loss = F.cross_entropy(X_hat.transpose(-1, -2), X)
        accuracy = (X_hat.argmax(-1)==X).float().mean()

        loss = reconstruction_loss + self.max_beta * beta * kl

        return loss, {
            'accuracy': accuracy,
            'reconstruction_loss': reconstruction_loss,
            'kl': kl,
            'beta': beta,
            'log_det': flow.log_det.mean(),
            'p_z_k': p_z_k.mean(),
            'q_z_0': q_z_0.mean(),
            # 'iter': self.iter // self.
        }

    def solve(self, X, **kwargs):
        """
        Take a gradient step given an input X

        X - data tensor, torch.LongTensor(batch_size, num_parameters=155)
        """
        
        Y = self.model(**X)
        loss, L = self.loss(**X, **Y)

        if loss != loss:
            raise ValueError('Nan Values detected')

        if self.model.training:
            self.iter += 1
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

        return L

    def step(self):

        pass

    def state_dict(self):
        
        state_dict = {
            'optim': self.optim.state_dict(),
            'iter': self.iter
        }

        return state_dict

    def load_state_dict(self, state_dict):
        
        self.optim.load_state_dict(state_dict['optim'])
        self.iter = state_dict['iter']

================================================
FILE: neuralDX7/solvers/utils.py
================================================
import torch


def sigmoidal_annealing(iter_nb, t=1e-4, s=-6):
    """

    iter_nb - number of parameter updates completed
    t - step size
    s - slope of the sigmoid
    """
    
    t, s = torch.tensor(t), torch.tensor(s).float()
    x0 = torch.sigmoid(s)
    value = (torch.sigmoid(iter_nb*t + s) - x0)/(1-x0) 

    return value

================================================
FILE: neuralDX7/utils.py
================================================

import mido
import torch
import numpy as np
from pathlib import Path
from itertools import chain
from neuralDX7.constants import VOICE_KEYS, VOICE_PARAMETER_RANGES, MAX_VALUE, checksum
from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
import bitstruct


def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
    device = x.device
    mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
    mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

    mask = torch.stack(list(mapper))
    
    return torch.masked_fill(x, mask, -inf)


# %%


def consume_syx(path):

    path = Path(path).expanduser()
    try:
        preset = mido.read_syx_file(path.as_posix())[0]
    except IndexError as e:
        return None
    except ValueError as e:
        return None
    if len(preset.data) == 0:
        return None

    def get_voice(data):
        
        unpacked = voice_struct.unpack(data)

        if not verify(unpacked, VOICE_PARAMETER_RANGES):
            return None
        
        return unpacked

    get_header = header_struct.unpack
    sysex_iter = iter(preset.data)
    
    try:
        header = get_header(bytes(take(sysex_iter, len(header_bytes))))
        yield from (get_voice(bytes(take(sysex_iter, len(voice_bytes)))) for _ in range(N_VOICES))
    except RuntimeError:
        return None

def dx7_bulk_pack(voices):

    HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
    assert len(voices)==32
    voices_bytes = bytes()
    for voice in voices:
        voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
        voices_bytes += voice_bytes
    
    
    patch_checksum = [checksum(voices_bytes)]

    data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

    return mido.Message('sysex', data=data)



def generate_syx(patch_list):

    dx7_struct

================================================
FILE: projects/dx7_np/evaluate.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32

loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
for n in range(10):
    X = next(iter_X)['x']
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(32, 155, 128)))

    X_a = torch.rand_like(X.float()) < 0.3
    X_a = torch.ones_like(X).bool()
    X_a[:,:-10] = 0

    X = X[[0]*32]
    X[~X_a] = X_d.sample()[~X_a]

    max_to_sample = max((~X_a).sum(-1))

    # for i in tqdm(range(max_to_sample)):

    logits = model.generate(X, X_a)
    samples = logits.sample()

    has_unsampled = ~X_a.all(-1)

    batch_idxs, sample_idx = (~X_a).nonzero().t()

    X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
    X_a[batch_idxs, sample_idx] = 1
   

    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_np/experiment.py
================================================
#%%
from os import environ
environ['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'

from neuralDX7.constants import N_PARAMS, MAX_VALUE
from agoge.utils import trial_name_creator
from neuralDX7 import DEFAULTS
from agoge import TrainWorker as Worker
from ray import tune
from neuralDX7.models import DX7NeuralProcess as Model
from neuralDX7.solvers import DX7NeuralProcess as Solver
from neuralDX7.datasets import DX7SysexDataset as Dataset

def config(experiment_name, trial_name, 
        n_heads=8, n_features=32, 
        batch_size=16, data_size=1.,
        latent_dim=8,
        **kwargs):
    


    data_handler = {
        'Dataset': Dataset,
        'dataset_opts': {
            'data_size': data_size
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    ### MODEL FEATURES
    layer_features = n_heads * n_features

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': layer_features * 3
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS,
        'n_layers': 1
    }
    

    model = {
        'Model': Model,
        'features': layer_features,
        'latent_dim': latent_dim,
        'encoder': encoder,
        'decoder': {
            'c_features': latent_dim,
            'features': layer_features,
            'attention_layer': attention_layer,
            'max_len': N_PARAMS,
            'n_layers': 1
        },
        'deterministic_path_drop_rate': 0.8
    }

    solver = {
        'Solver': Solver,
        'beta_temp': 1e-4
    }

    tracker = {
        'metrics': ['reconstruction_loss', 'accuracy', 'kl', 'beta', 'entropy'],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

if __name__=='__main__':
    # from ray import ray
    import sys
    postfix = sys.argv[1] if len(sys.argv)==2 else ''
    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    # client = MlflowClient(tracking_uri='localhost:5000')
    experiment_name = f'dx7-np-{postfix}'#+experiment_name_creator()
    # experiment_id = client.create_experiment(experiment_name)


    experiment_metrics = dict(metric="loss/accuracy", mode="max")

    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 10
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        'gpu': 1
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    # search_alg=bohb_search, 
    # scheduler=bohb_hyperband,
    num_samples=1,
    verbose=0,
    local_dir=DEFAULTS['ARTIFACTS_ROOT']
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch

================================================
FILE: projects/dx7_np/features.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
worker = InferenceWorker('/home/nintorac/agoge/artifacts/bluesy-chestnut-forest_0_2020-04-30_11-11-00x02v59be/checkpoint_220/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
features_all = []
features_half = []
for x in map(lambda x: x['x'], tqdm(loader)):
    q = model.features(x, torch.ones_like(x.float()).bool())
    features_all += [(q.mean.numpy(), q.stddev.numpy())]
    # features_half += [model.features(x, torch.rand_like(x.float())>torch.linspace(0, 1, 32).unsqueeze(-1)).mean.numpy()]

# for item in loader:
# %%


================================================
FILE: projects/dx7_np/interpoalte.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
X_og = next(iter_X)['x']
for n in range(n_latents):
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_l = X_og[[0]].clone()

    X_a = torch.ones(1,155).bool()
    X_a[:,0] = 0
    q_l = model.features(X_l, X_a)

    z = q_l.mean[:,[0]*155][[0]*32]
    z[:,:,n] = torch.linspace(-4, 4, 32).unsqueeze(-1)

    X = model.generate_z(z).sample()
    X[...,-1] = 48 + torch.arange(32)
    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_interp_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_np/live.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import mido
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
worker = InferenceWorker('/home/nintorac/agoge/artifacts/squeaky-green-mist_0_2020-04-30_12-50-391i661hyj/checkpoint_100/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples


from uuid import uuid4 as uuid
uuid = lambda: hex(uuid) 
#     self._event.set()


client = jack.Client('DX7Parameteriser')
port = client.midi_outports.register('output')
inport = client.midi_inports.register('input')
event = threading.Event()
fs = None  # sampling rate
offset = 0
from neuralDX7.constants import DX7Single, consume_syx
import torch

name = torch.tensor([i for i in "horns     ".encode('ascii')])

X = torch.zeros(1, 155).long()
X[:,-(len(name)):] = name
X_a = torch.zeros(1, 155).bool()
X_a[:,-(len(name)):] = 1
q = model.features(X, X_a)

X_a[:,-29:] = 1
# X_a = X_a & 0
iter_X = iter(loader)
X_a
syx = list(consume_syx('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/SynprezFM/SynprezFM_01.syx'))
syx = torch.from_numpy(np.array([list(i.values()) for i in syx]))
# m1, m2 = q.mean[:,0]
# syx_iter = cycle(syx)
def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high
x_iter = cycle([*torch.linspace(0, 1, 7)[1:],  *torch.linspace(1, 0, 7)[1:]])
#%%
i=0

mu, std = \
(np.array([ 4.1554513,  4.1125965, -1.9699959,  2.8919716, -6.056072 ,
        -2.407577 , -5.1152377,  2.811712 ], dtype=np.float32),
 np.array([0.3197436 , 0.23426053, 0.25944906, 0.17878139, 0.3019972 ,
        0.34080952, 0.32115632, 0.34698236], dtype=np.float32))

vals = torch.from_numpy(mu + np.linspace(-3, 3, 128)[:,None] * std).float()

controller_map = {}

latent = torch.full((1, 8), 64).long()
patch_no = 0

from neuralDX7.utils import mask_parameters
@client.set_process_callback
def process(frames):
    global offset, i
    global msg
    global syx_iter
    global controller_map, patch_no, vals, latent
    port.clear_buffer()
    needs_update = False
    X = syx[[patch_no]]
    a = X_a

    for offset, data in inport.incoming_midi_events():
        msg = mido.parse(bytes(data))

        if msg.type=='note_on':
            # print(msg.__dir__())
            patch_no = msg.note%32
            print(f"patch set to {patch_no}")
            needs_update = True

            a = X_a[[0]]
            q = model.features(X, a|1)
            vals = q.mean + torch.linspace(-4, 4, 128)[:,None] * q.stddev 
        
        if msg.type!='control_change':
            continue


        if msg.control not in controller_map:
            if len(controller_map) == 8:
                continue
            print(f"latent {len(controller_map)} set to encoder {msg.control}")
            controller_map[msg.control] = len(controller_map)
        l_i = list(controller_map).index(msg.control)
        print(f'Latent: {latent}')
        latent[:, controller_map[msg.control]] =  msg.value
        needs_update = True
        
        # print("{0}: 0x{1}".format(client.last_frame_time + offset,
        #                           binascii.hexlify(data).decode()))
    # print(time.time()-offset)
    inport.clear_buffer()
    if (needs_update):
        offset = time.time()

        # X = next(iter_X)['x'][[0]]
        # X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(1, 155, 128)))

        # X_a = torch.rand_like(X.float()) < 0.3
        # X_a = torch.ones_like(X).bool()
        # X_a[:,:-10] = 0
        


        # X[~X_a] = X_d.sample()[~X_a]

        # max_to_sample = max((~X_a).sum(-1))
        # # X = X[[0]*1]

        # # for i in tqdm(range(max_to_sample)):

        # logits = model.generate(X, X_a)
        # samples = logits.sample()
            
        # batch_idxs, sample_idx = (~X_a).nonzero().t()

        # X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
        # X_a[batch_idxs, sample_idx] = 1

        # z = slerp(next(x_iter), m1, m2).unsqueeze(-2)[...,[0]*155,:].unsqueeze(0)
        # z = q.mean
        # z = torch.from_numpy(latent).float()
        # print(q.stddev)
        # z = torch.randn_like(z)
        # z = z.mean().unsqueeze
        # z = q.mean# + torch.randn(q.mean.shape) * 0.1 + q.stddev
        # val = torch.from_numpy(vals).float()
        # print(latent)
        # print(time.time())
        z = vals.gather(0, latent)
        # print(time.time())
        # print(a)
        msg = model.generate_z(X, a, z, t=0.001).sample()
        # print(time.time())
        msg[a] = X[a]
        # msg = X if i%2 else msg
        # msg[:,-(len(name)):] = name

        # msg = DX7Single.to_syx([list(next(syx_iter).values())])
        msg = DX7Single.to_syx(msg.numpy().tolist())
        # import mido
        # print([int(i, 2)mido.Message('control_change', control=123).bytes()])
        port.write_midi_event(0, msg.bytes())
        port.write_midi_event(1, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(2, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(3, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(4, mido.Message('control_change', control=123).bytes())

    


@client.set_samplerate_callback
def samplerate(samplerate):
    global fs
    fs = samplerate


@client.set_shutdown_callback
def shutdown(status, reason):
    print('JACK shutdown:', reason, status)
    event.set()

capture_port = 'a2j:Arturia BeatStep [24] (capture): Arturia BeatStep MIDI 1'
playback_port = 'Carla:Dexed:events-in' 

with client:
    # print(client.get_ports())
    offset = time.time()
    # if connect_to:
    port.connect(playback_port)
    inport.connect(capture_port)

    # print('Playing', repr(filename), '... press Ctrl+C to stop')
    try:
        event.wait()
    except KeyboardInterrupt:
        print('\nInterrupted by user')


# %%


================================================
FILE: projects/dx7_nsp/evaluate.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32

loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
for n in range(10):
    X = next(iter_X)['x']
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(32, 155, 128)))

    X_a = torch.rand_like(X.float()) < 0.3
    X_a = torch.ones_like(X).bool()
    X_a[:,:-10] = 0

    X = X[[0]*32]
    X[~X_a] = X_d.sample()[~X_a]

    max_to_sample = max((~X_a).sum(-1))

    # for i in tqdm(range(max_to_sample)):

    logits = model.generate(X, X_a)
    samples = logits.sample()

    has_unsampled = ~X_a.all(-1)

    batch_idxs, sample_idx = (~X_a).nonzero().t()

    X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
    X_a[batch_idxs, sample_idx] = 1
   

    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_nsp/experiment.py
================================================
#%%
from os import environ
environ['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'

from neuralDX7.constants import N_PARAMS, MAX_VALUE
from agoge.utils import trial_name_creator
from neuralDX7 import DEFAULTS
from agoge import TrainWorker as Worker
from ray import tune
from neuralDX7.models import DX7NeuralSylvesterProcess as Model
from neuralDX7.solvers import DX7NeuralSylvesterProcess as Solver
from neuralDX7.datasets import DX7SysexDataset as Dataset

def config(experiment_name, trial_name, 
        n_heads=8, n_features=32, 
        batch_size=16, data_size=1.,
        latent_dim=8, num_flows=1,
        **kwargs):
    


    data_handler = {
        'Dataset': Dataset,
        'dataset_opts': {
            'data_size': data_size
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    ### MODEL FEATURES
    layer_features = n_heads * n_features

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': layer_features * 3
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS,
        'n_layers': 1
    }
    

    model = {
        'Model': Model,
        'features': layer_features,
        'latent_dim': latent_dim,
        'encoder': encoder,
        'decoder': {
            'c_features': latent_dim,
            'features': layer_features,
            'attention_layer': attention_layer,
            'max_len': N_PARAMS,
            'n_layers': 1
        },
        'num_flows': num_flows,
        'deterministic_path_drop_rate': 0.8
    }

    solver = {
        'Solver': Solver,
        'beta_temp': 1e-3,
        'max_beta': 1
    }

    tracker = {
        'metrics': [
            'reconstruction_loss', 
            'accuracy', 
            'kl', 
            'beta', 
            'q_log_det', 
            'q_z',
            'p_z',
        ],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

if __name__=='__main__':
    # from ray import ray
    import sys
    postfix = sys.argv[1] if len(sys.argv)==2 else ''
    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    # client = MlflowClient(tracking_uri='localhost:5000')
    experiment_name = f'dx7-nsp-00'#+experiment_name_creator()
    # experiment_id = client.create_experiment(experiment_name)


    experiment_metrics = dict(metric="loss/accuracy", mode="max")

    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 10
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        'gpu': 1
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    # search_alg=bohb_search, 
    # scheduler=bohb_hyperband,
    num_samples=1,
    verbose=0,
    local_dir=DEFAULTS['ARTIFACTS_ROOT']
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch

================================================
FILE: projects/dx7_nsp/features.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
worker = InferenceWorker('~/agoge/artifacts/dx7-nsp/leaky-burgundy-coati.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
features_all = []
features_half = []
for x in map(lambda x: x['X'], tqdm(loader)):
    q = model.features(x, torch.ones_like(x.float()).bool()).q_z
    features_all += [(q.mean.numpy(), q.stddev.numpy())]
    # features_half += [model.features(x, torch.rand_like(x.float())>torch.linspace(0, 1, 32).unsqueeze(-1)).mean.numpy()]

mus, vars = map(np.concatenate, zip(*features_all))

# for item in loader:
# %%


================================================
FILE: projects/dx7_nsp/interpoalte.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
X_og = next(iter_X)['x']
for n in range(n_latents):
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_l = X_og[[0]].clone()

    X_a = torch.ones(1,155).bool()
    X_a[:,0] = 0
    q_l = model.features(X_l, X_a)

    z = q_l.mean[:,[0]*155][[0]*32]
    z[:,:,n] = torch.linspace(-4, 4, 32).unsqueeze(-1)

    X = model.generate_z(z).sample()
    X[...,-1] = 48 + torch.arange(32)
    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_interp_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_nsp/live.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import mido
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
worker = InferenceWorker('~/agoge/artifacts/dx7-nsp/leaky-burgundy-coati.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples


from uuid import uuid4 as uuid
uuid = lambda: hex(uuid) 
#     self._event.set()


client = jack.Client('DX7Parameteriser')
port = client.midi_outports.register('output')
inport = client.midi_inports.register('input')
event = threading.Event()
fs = None  # sampling rate
offset = 0
from neuralDX7.constants import DX7Single, consume_syx
import torch

name = torch.tensor([i for i in "horns     ".encode('ascii')])

X = torch.zeros(1, 155).long()
X[:,-(len(name)):] = name
X_a = torch.zeros(1, 155).bool()
X_a[:,-(len(name)):] = 1
q = model.features(X, X_a)

X_a[:,-29:] = 1
# X_a = X_a & 0
iter_X = iter(loader)
X_a
syx = list(consume_syx('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/SynprezFM/SynprezFM_01.syx'))
syx = torch.from_numpy(np.array([list(i.values()) for i in syx]))
# m1, m2 = q.mean[:,0]
# syx_iter = cycle(syx)
def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high
x_iter = cycle([*torch.linspace(0, 1, 7)[1:],  *torch.linspace(1, 0, 7)[1:]])
#%%
i=0

mu, std = \
(np.array([ 4.1554513,  4.1125965, -1.9699959,  2.8919716, -6.056072 ,
        -2.407577 , -5.1152377,  2.811712 ], dtype=np.float32),
 np.array([0.3197436 , 0.23426053, 0.25944906, 0.17878139, 0.3019972 ,
        0.34080952, 0.32115632, 0.34698236], dtype=np.float32))

vals = torch.from_numpy(mu + np.linspace(-3, 3, 128)[:,None] * std).float()

controller_map = {}

latent = torch.full((1, 8), 64).long()
patch_no = 0
flow=None
from neuralDX7.utils import mask_parameters
@client.set_process_callback
def process(frames):
    global offset, i
    global msg
    global syx_iter
    global controller_map, patch_no, vals, latent, flow
    port.clear_buffer()
    needs_update = False
    X = syx[[patch_no]]
    a = X_a

    for offset, data in inport.incoming_midi_events():
        msg = mido.parse(bytes(data))

        if msg.type=='note_on':
            # print(msg.__dir__())
            patch_no = msg.note%32
            print(f"patch set to {patch_no}")
            needs_update = True

            a = X_a[[0]]
            flow = model.features(X, a|1)
            q = flow.q_z
            vals = q.mean + torch.linspace(-4, 4, 128)[:,None] * q.stddev 
        
        if msg.type!='control_change':
            continue


        if msg.control not in controller_map:
            if len(controller_map) == 8:
                continue
            print(f"latent {len(controller_map)} set to encoder {msg.control}")
            controller_map[msg.control] = len(controller_map)
        l_i = list(controller_map).index(msg.control)
        print(f'Latent: {latent}')
        latent[:, controller_map[msg.control]] =  msg.value
        needs_update = True
        
        # print("{0}: 0x{1}".format(client.last_frame_time + offset,
        #                           binascii.hexlify(data).decode()))
    # print(time.time()-offset)
    inport.clear_buffer()
    if (needs_update):
        offset = time.time()

        # X = next(iter_X)['x'][[0]]
        # X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(1, 155, 128)))

        # X_a = torch.rand_like(X.float()) < 0.3
        # X_a = torch.ones_like(X).bool()
        # X_a[:,:-10] = 0
        


        # X[~X_a] = X_d.sample()[~X_a]

        # max_to_sample = max((~X_a).sum(-1))
        # # X = X[[0]*1]

        # # for i in tqdm(range(max_to_sample)):

        # logits = model.generate(X, X_a)
        # samples = logits.sample()
            
        # batch_idxs, sample_idx = (~X_a).nonzero().t()

        # X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
        # X_a[batch_idxs, sample_idx] = 1

        # z = slerp(next(x_iter), m1, m2).unsqueeze(-2)[...,[0]*155,:].unsqueeze(0)
        # z = q.mean
        # z = torch.from_numpy(latent).float()
        # print(q.stddev)
        # z = torch.randn_like(z)
        # z = z.mean().unsqueeze
        # z = q.mean# + torch.randn(q.mean.shape) * 0.1 + q.stddev
        # val = torch.from_numpy(vals).float()
        # print(latent)
        # print(time.time())
        z, _ = flow.flow(vals.gather(0, latent))
        # print(time.time())
        # print(a)
        msg = model.generate_z(X, a, z, t=0.001).sample()
        # print(time.time())
        msg[a] = X[a]
        # msg = X if i%2 else msg
        # msg[:,-(len(name)):] = name

        # msg = DX7Single.to_syx([list(next(syx_iter).values())])
        msg = DX7Single.to_syx(msg.numpy().tolist())
        # import mido
        # print([int(i, 2)mido.Message('control_change', control=123).bytes()])
        port.write_midi_event(0, msg.bytes())
        port.write_midi_event(1, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(2, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(3, mido.Message('control_change', control=123).bytes())
        port.write_midi_event(4, mido.Message('control_change', control=123).bytes())

    


@client.set_samplerate_callback
def samplerate(samplerate):
    global fs
    fs = samplerate


@client.set_shutdown_callback
def shutdown(status, reason):
    print('JACK shutdown:', reason, status)
    event.set()

capture_port = 'a2j:Arturia BeatStep [24] (capture): Arturia BeatStep MIDI 1'
playback_port = 'Carla:Dexed:events-in' 

with client:
    # print(client.get_ports())
    offset = time.time()
    # if connect_to:
    port.connect(playback_port)
    inport.connect(capture_port)

    # print('Playing', repr(filename), '... press Ctrl+C to stop')
    try:
        event.wait()
    except KeyboardInterrupt:
        print('\nInterrupted by user')


# %%


================================================
FILE: projects/dx7_patch_neural_process/evaluate.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/Worker/messy-firebrick-barracuda_0_2020-04-14_08-59-40ryp93n44/checkpoint_4/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.train

n_samples = 32

loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
for n in range(10):
    X = next(iter_X)['x']
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(32, 155, 128)))

    X_a = torch.rand_like(X.float()) < 0.3
    X_a = torch.ones_like(X).bool()
    X_a[:,:-10] = 0

    X = X[[0]*32]
    X[~X_a] = X_d.sample()[~X_a]

    max_to_sample = max((~X_a).sum(-1))

    for i in tqdm(range(max_to_sample)):

        logits = model.generate(X, X_a)
        samples = logits.sample()

        has_unsampled = ~X_a.all(-1)

        sample_idx = (torch.rand_like(X.float()) * (~X_a).float()).argmax(-1)[has_unsampled]
        batch_idxs = torch.arange(X.shape[0])[has_unsampled]


        X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
        X_a[batch_idxs, sample_idx] = 1
        # X_a = X_a | new_mask

        if X_a.all():
            break

    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/gen_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_patch_neural_process/features_analysis.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/Worker/messy-firebrick-barracuda_0_2020-04-14_08-59-40ryp93n44/checkpoint_4/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test


Xs = []
features = []
for X in tqdm(loader):
    Xs += [X['x']]
    features += [model.features(X['x'])]
features = torch.cat(features).flatten(-2, -1)
X = torch.cat(Xs)
    
#%%
### --------TSNE-----------

import numpy as np
from sklearn.manifold import TSNE
X_embedded = TSNE(n_components=2).fit_transform(features)
X_embedded.shape
#%%
plt.figure(figsize=(20,30))
plt.scatter(*zip(*X_embedded), linewidths=0.1)

# %%
### ---------K-means---------------

from sklearn.cluster import KMeans
from neuralDX7.utils import dx7_bulk_pack
import mido
n_clusters = 8
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(features)
labels = kmeans.labels_
tasting = []

out_template = '/home/nintorac/.local/share/DigitalSuburban/Dexed/cartridges/neuralDX7/group_{}.syx'

for i in range(n_clusters):
    in_cluster, = (labels == i).nonzero()
    if len(in_cluster) < 32:
        continue

    choices = np.random.choice(in_cluster, 32)
    voices = X[choices]

    patch_message = dx7_bulk_pack(voices)

    mido.write_syx_file(out_template.format(i), [patch_message])



# %%


================================================
FILE: projects/dx7_patch_neural_process/ray_train.py
================================================
#%%
from os import environ
environ['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'
# environ['MLFLOW_TRACKING_URI'] = 'http://localhost:9001/'
#environ['ARTIFACTS_ROOT'] = '/content/gdrive/My Drive/audio/artifacts'
# ARTIFACTS_ROOT='/content/gdrive/My Drive/audio/artifacts'
from neuralDX7.constants import N_PARAMS, MAX_VALUE
from agoge.utils import trial_name_creator
from neuralDX7 import DEFAULTS
from agoge import TrainWorker as Worker
from ray import tune
from neuralDX7.models import DX7PatchProcess as Model
from neuralDX7.solvers import DX7PatchProcess as Solver
from neuralDX7.datasets import DX7SysexDataset as Dataset

def config(experiment_name, trial_name, 
        n_heads=8, n_features=32, 
        batch_size=16, data_size=0.05,
        **kwargs):
    


    data_handler = {
        'Dataset': Dataset,
        'dataset_opts': {
            'data_size': data_size
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    ### MODEL FEATURES
    layer_features = n_heads * n_features

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': layer_features * 2
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS,
        'n_layers': 12
    }
    

    model = {
        'Model': Model,
        'features': layer_features,
        'encoder': encoder
    }

    solver = {
        'Solver': Solver,
        'lr': 1e-3,
    }

    tracker = {
        'metrics': ['reconstruction_loss', 'accuracy'],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

if __name__=='__main__':
    # from ray import ray
    import sys
    import mlflow
    from mlflow.tracking import MlflowClient
    postfix = sys.argv[1] if len(sys.argv)==2 else ''

    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    client = MlflowClient()
    experiment_name = f'dx7-vae-{postfix}'#+experiment_name_creator()
    resume=False
    try:
        experiment_id = client.create_experiment(experiment_name)
    except mlflow.exceptions.RestException:
        resume = True

    experiment_metrics = dict(metric="loss/accuracy", mode="max")
    import torch
    gpus = 0.5 if torch.cuda.is_available() else 0
    gpus = 1
    
    # import ray

    # ray.init()
    # ray.tune.utils.validate_save_restore(Worker)

    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 10
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        # 'gpu': gpus,
        'cpu': 6
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    # search_alg=bohb_search, 
    # scheduler=bohb_hyperband,
    num_samples=1,
    verbose=1,
    local_dir=DEFAULTS['ARTIFACTS_ROOT'],
    resume=resume
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch


================================================
FILE: projects/dx7_vae/duplicate_test.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import mido
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
from numpy import array
worker = InferenceWorker('hasty-copper-dogfish', 'dx7-vae', with_data=True)
float32='float32'
model = worker.model
# data = worker.dataset
# loader = data.loaders.test

n_samples = 32
n_latents = 8
# loader.batch_sampler.batch_size = n_samples


# randoms = torch.cat([model.generate(torch.randn(2**11, 8)).logits.argmax(-1) for _ in tqdm(range(2**5))])

# %%
from matplotlib import pyplot as plt
import torch

rand = torch.rand(100)
randn = torch.randn(100)
plt.scatter(rand, torch.sigmoid(rand)+0.5)
plt.scatter(randn, torch.sigmoid(randn)-0.5)

# %%


================================================
FILE: projects/dx7_vae/evaluate.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32

loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
for n in range(10):
    X = next(iter_X)['x']
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_d = torch.distributions.Categorical(logits=mask_parameters(torch.zeros(32, 155, 128)))

    X_a = torch.rand_like(X.float()) < 0.3
    X_a = torch.ones_like(X).bool()
    X_a[:,:-10] = 0

    X = X[[0]*32]
    X[~X_a] = X_d.sample()[~X_a]

    max_to_sample = max((~X_a).sum(-1))

    # for i in tqdm(range(max_to_sample)):

    logits = model.generate(X, X_a)
    samples = logits.sample()

    has_unsampled = ~X_a.all(-1)

    batch_idxs, sample_idx = (~X_a).nonzero().t()

    X[batch_idxs, sample_idx] = samples[batch_idxs, sample_idx]
    X_a[batch_idxs, sample_idx] = 1
   

    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_vae/experiment.py
================================================
#%%
from os import environ
environ['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'

from neuralDX7.constants import N_PARAMS, MAX_VALUE
from agoge.utils import trial_name_creator
from neuralDX7 import DEFAULTS
from agoge import TrainWorker as Worker
from ray import tune
from neuralDX7.models import DX7VAE as Model
from neuralDX7.solvers import DX7VAE as Solver
from neuralDX7.datasets import DX7SysexDataset as Dataset

def config(experiment_name, trial_name, 
        n_heads=8, n_features=64, 
        batch_size=16, data_size=1.,
        latent_dim=8, num_flows=16,
        **kwargs):
    


    data_handler = {
        'Dataset': Dataset,
        'dataset_opts': {
            'data_size': data_size
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    ### MODEL FEATURES
    layer_features = n_heads * n_features

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': layer_features * 3
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': N_PARAMS,
        'n_layers': 12
    }
    

    model = {
        'Model': Model,
        'features': layer_features,
        'latent_dim': latent_dim,
        'encoder': encoder,
        'decoder': {
            'c_features': latent_dim,
            'features': layer_features,
            'attention_layer': attention_layer,
            'max_len': N_PARAMS,
            'n_layers': 12
        },
        'num_flows': num_flows,
        'deterministic_path_drop_rate': 0.8
    }

    solver = {
        'Solver': Solver,
        'beta_temp': 6e-5,
        'max_beta': 0.5
    }

    tracker = {
        'metrics': [
            'reconstruction_loss', 
            'accuracy', 
            'kl', 
            'beta', 
            'log_det', 
            'q_z_0',
            'p_z_k',
        ],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

if __name__=='__main__':
    # from ray import ray
    import sys
    postfix = sys.argv[1] if len(sys.argv)==2 else ''
    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    # client = MlflowClient(tracking_uri='localhost:5000')
    experiment_name = f'dx7-vae-dev'#+experiment_name_creator()
    # experiment_id = client.create_experiment(experiment_name)


    experiment_metrics = dict(metric="loss/accuracy", mode="max")

    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 10
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        # 'gpu': 1
        # 'cpu': 5
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    # search_alg=bohb_search, 
    # scheduler=bohb_hyperband,
    num_samples=1,
    verbose=0,
    local_dir=DEFAULTS['ARTIFACTS_ROOT']
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch

# %%


================================================
FILE: projects/dx7_vae/features.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
worker = InferenceWorker('~/agoge/artifacts/dx7-vae/hasty-copper-dogfish_0_2020-05-06_10-46-27o654hmde/checkpoint_204/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
features_all = []
features_half = []
for x in map(lambda x: x['X'], tqdm(loader)):
    q = model.features(x)
    features_all += [(q.mean.numpy(), q.stddev.numpy())]
    # features_half += [model.features(x, torch.rand_like(x.float())>torch.linspace(0, 1, 32).unsqueeze(-1)).mean.numpy()]

mus, vars = map(np.concatenate, zip(*features_all))
print(mus.mean(0), vars.mean(0))
# for item in loader:
# %%
from timeit import timeit

with timeit():
    model.generate(q.sample()[[0]])

# %%
from time import time
s_time = time()

for i in range(100):
    model.generate(q.sample()[[0]])

print((time()-s_time)/100)
# %%


================================================
FILE: projects/dx7_vae/interpoalte.py
================================================
# %%
from agoge import InferenceWorker
import torch
from tqdm import tqdm
from matplotlib import pyplot as plt
worker = InferenceWorker('/home/nintorac/agoge/artifacts/craggy-goldenrod-catfish_0_2020-04-28_02-22-57m8eftq1b/checkpoint_410/model.box', with_data=True)

model = worker.model
data = worker.dataset
loader = data.loaders.test

n_samples = 32
n_latents = 8
loader.batch_sampler.batch_size = n_samples
# %%
# batch = next(iter(loader))['x']

# X_a = torch.rand_like(batch.float()) > torch.linspace(0, 1, n_samples).unsqueeze(-1)

# logits = model.generate(batch, X_a)

# # %%
# from matplotlib import pyplot as plt
# plt.imshow(X_a)

# # %%
# plt.scatter(torch.arange(n_samples), logits.log_prob(batch).mean(-1))

#     # %%
# plt.imshow(logits.log_prob(batch))

# %%

from itertools import count 
from neuralDX7.utils import dx7_bulk_pack, mask_parameters
import mido
iter_X = iter(loader)
X_og = next(iter_X)['x']
for n in range(n_latents):
    # syx = dx7_bulk_pack(X.numpy().tolist())
    # mido.write_syx_file('/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/OG.syx', [syx])

    X_l = X_og[[0]].clone()

    X_a = torch.ones(1,155).bool()
    X_a[:,0] = 0
    q_l = model.features(X_l, X_a)

    z = q_l.mean[:,[0]*155][[0]*32]
    z[:,:,n] = torch.linspace(-4, 4, 32).unsqueeze(-1)

    X = model.generate_z(z).sample()
    X[...,-1] = 48 + torch.arange(32)
    syx = dx7_bulk_pack(X.numpy().tolist())
    mido.write_syx_file(f'/home/nintorac/.local/share/DigitalSuburban/Dexed/Cartridges/neuralDX7/np_interp_{n}.syx', [syx])

# # %%
# from neuralDX7.constants import voice_struct, VOICE_KEYS, checksum
# def dx7_bulk_pack(voices):

#     HEADER = int('0x43', 0), int('0x00', 0), int('0x09', 0), int('0x20', 0), int('0x00', 0)
#     assert len(voices)==32
#     voices_bytes = bytes()
#     for voice in voices:
#         voice_bytes = voice_struct.pack(dict(zip(VOICE_KEYS, voice)))
#         voices_bytes += voice_bytes
    
    
#     patch_checksum = [checksum(voices_bytes)]

# #     data = bytes(HEADER) + voices_bytes + bytes(patch_checksum)

# #     return mido.Message('sysex', data=data)

# # %%

# from neuralDX7.constants import VOICE_KEYS, MAX_VALUE, VOICE_PARAMETER_RANGES
# def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
#     device = x.device
#     mask_item_f = lambda x: torch.arange(MAX_VALUE).to(device) > max(x) 
#     mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, voice_keys))

#     mask = torch.stack(list(mapper))
    
#     return torch.masked_fill(x, mask, -inf)

# plt.imshow(mask_parameters(torch.randn(10, 155, 128))[0])
# # %%


# %%


================================================
FILE: projects/dx7_vae/live.py
================================================
# %%
from agoge import InferenceWorker
import threading
import torch
import mido
import time
import numpy as np
from tqdm import tqdm
import jack
from matplotlib import pyplot as plt
from itertools import cycle
from numpy import array
worker = InferenceWorker('~/agoge/artifacts/dx7-vae/hasty-copper-dogfish_0_2020-05-06_10-46-27o654hmde/checkpoint_204/model.box', with_data=False)
float32='float32'
model = worker.model
# data = worker.dataset
# loader = data.loaders.test

n_samples = 32
n_latents = 8
# loader.batch_sampler.batch_size = n_samples


from uuid import uuid4 as uuid
uuid = lambda: hex(uuid) 
#     self._event.set()


client = jack.Client('DX7Parameteriser')
port = client.midi_outports.register('output')
inport = client.midi_inports.register('input')
event = threading.Event()
fs = None  # sampling rate
offset = 0
from neuralDX7.constants import DX7Single, consume_syx
import torch

name = torch.tensor([i for i in "horns     ".encode('ascii')])


def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low/np.linalg.norm(low), high/np.linalg.norm(high)), -1, 1))
    so = np.sin(omega)
    if so == 0:
        return (1.0-val) * low + val * high # L'Hopital's rule/LERP
    return np.sin((1.0-val)*omega) / so * low + np.sin(val*omega) / so * high
x_iter = cycle([*torch.linspace(0, 1, 7)[1:],  *torch.linspace(1, 0, 7)[1:]])
#%%
i=0

mu, std = \
(array([ 5.5626068e-02,  7.9248362e-04, -8.0890575e-04,  1.6684370e-01,
         1.6537485e-01, -6.2455550e-02,  9.4467170e-05, -7.5367272e-02],
       dtype=float32),
 array([0.35453376, 0.3556142 , 0.35896832, 0.341505  , 0.3299536 ,
        0.33990443, 0.3350083 , 0.339214  ], dtype=float32))
vals = torch.from_numpy(mu + np.linspace(-3, 3, 128)[:,None] * std).float()

controller_map = {}

latent = torch.full((1, 8), 64).long()
patch_no = 0

from neuralDX7.utils import mask_parameters
@client.set_process_callback
def process(frames):
    global offset, i
    global msg
    global syx_iter
    global controller_map, patch_no, vals, latent
    port.clear_buffer()
    needs_update = False


    for offset, data in inport.incoming_midi_events():
        msg = mido.parse(bytes(data))

        if msg.type=='note_on':
            port.write_midi_event(0, mido.Message('control_change', control=123).bytes())   
        if msg.type!='control_change':
            continue


        if msg.control not in controller_map:
            if len(controller_map) == 8:
                continue
            print(f"latent {len(controller_map)} set to encoder {msg.control}")
            controller_map[msg.control] = len(controller_map)
        l_i = list(controller_map).index(msg.control)
        print(f'Latent: {latent}')
        latent[:, controller_map[msg.control]] =  msg.value
        needs_update = True
        
        # print("{0}: 0x{1}".format(client.last_frame_time + offset,
        #                           binascii.hexlify(data).decode()))
    # print(time.time()-offset)
    inport.clear_buffer()
    if (needs_update):
        offset = time.time()

        z = vals.gather(0, latent)
        msg = model.generate(z, t=0.001).sample()

        msg = DX7Single.to_syx(msg.numpy().tolist())

        port.write_midi_event(1, msg.bytes())
        mido.write_syx_file('example_single_voice.mid', [msg])
        # port.write_midi_event(1, mido.Message('control_change', control=123).bytes())
        # port.write_midi_event(2, mido.Message('control_change', control=123).bytes())
        # port.write_midi_event(3, mido.Message('control_change', control=123).bytes())
        # port.write_midi_event(4, mido.Message('control_change', control=123).bytes())

    


@client.set_samplerate_callback
def samplerate(samplerate):
    global fs
    fs = samplerate


@client.set_shutdown_callback
def shutdown(status, reason):
    print('JACK shutdown:', reason, status)
    event.set()

capture_port = 'a2j:Arturia BeatStep [24] (capture): Arturia BeatStep MIDI 1'
playback_port = 'Carla:Dexed:events-in' 

with client:
    # print(client.get_ports())
    offset = time.time()
    # if connect_to:
    port.connect(playback_port)
    inport.connect(capture_port)

    # print('Playing', repr(filename), '... press Ctrl+C to stop')
    try:
        event.wait()
    except KeyboardInterrupt:
        print('\nInterrupted by user')


# %%


================================================
FILE: projects/mnist_neural_process/experiment.py
================================================
#%%
from os import environ
environ['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'
# environ['MLFLOW_TRACKING_URI'] = 'http://localhost:9001/'
#environ['ARTIFACTS_ROOT'] = '/content/gdrive/My Drive/audio/artifacts'
# ARTIFACTS_ROOT='/content/gdrive/My Drive/audio/artifacts'
from neuralDX7.constants import N_PARAMS, MAX_VALUE
from agoge.utils import trial_name_creator
from neuralDX7 import DEFAULTS
from agoge import Worker
from ray import tune
from neuralDX7.models import DX7PatchProcess as Model
from neuralDX7.solvers import DX7PatchProcess as Solver
from neuralDX7.datasets import DX7SysexDataset as Dataset

from pathlib import Path
from torchvision import datasets, transforms
from torch.utils.data import ConcatDataset
from agoge import DEFAULTS

import numpy as np 
import torch

K=2

class MNISTDataset():

    def __init__(self, data_path=DEFAULTS['ARTIFACTS_ROOT'], transform=None):
        
        transform = transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])

        if not isinstance(data_path, Path):
            data_path = Path(data_path).expanduser()

        train_dataset = datasets.MNIST(data_path.as_posix(), train=True, download=True,transform=transform)
        test_dataset = datasets.MNIST(data_path.as_posix(), train=False, download=True, transform=transform)

        self.dataset = ConcatDataset((train_dataset, test_dataset))

    def __getitem__(self, i):
        
        item = self.dataset[i][0]
        sub_item = item[:,14-K :14+K ,14-K :14+K ].flatten()
        digitzed = np.digitize(sub_item, np.linspace(-4, 4, 16))
        
        return {'x': torch.from_numpy(digitzed)}

    def __len__(self):

        return len(self.dataset)

MNISTDataset()[0]

#%%
Dataset = MNISTDataset


def config(experiment_name, trial_name, 
        n_heads=8, n_features=32, 
        batch_size=16, data_size=0.9999,
        **kwargs):
    


    data_handler = {
        'Dataset': Dataset,
        'dataset_opts': {
            # 'data_size': data_size
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    ### MODEL FEATURES
    layer_features = n_heads * n_features

    head_features = layer_features // n_heads

    attention = {
        'n_features': layer_features,
        'n_hidden': head_features,
        'n_heads': n_heads
    }
    
    attention_layer = {
        'attention': attention,
        'features': layer_features,
        'hidden_dim': layer_features * 2
    }

    encoder = {
        'features': layer_features,
        'attention_layer': attention_layer,
        'max_len': (K*2)**2,
        'n_layers': 24
    }
    

    model = {
        'Model': Model,
        'features': layer_features,
        'encoder': encoder
    }

    solver = {
        'Solver': Solver,
        'lr': 1e-3,
    }

    tracker = {
        'metrics': ['reconstruction_loss', 'accuracy'],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

if __name__=='__main__':
    # from ray import ray
    import sys
    import mlflow
    from mlflow.tracking import MlflowClient
    postfix = sys.argv[1] if len(sys.argv)==2 else ''

    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    client = MlflowClient()
    experiment_name = f'dx7-vae-{postfix}'#+experiment_name_creator()
    resume=False
    try:
        experiment_id = client.create_experiment(experiment_name)
    except mlflow.exceptions.RestException:
        resume = True

    experiment_metrics = dict(metric="loss/accuracy", mode="max")
    import torch
    gpus = 0.5 if torch.cuda.is_available() else 0
    gpus = 1


    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 10
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        'gpu': gpus,
        'cpu': 1
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    # search_alg=bohb_search, 
    # scheduler=bohb_hyperband,
    num_samples=1,
    verbose=1,
    local_dir=DEFAULTS['ARTIFACTS_ROOT'],
    resume=resume
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch


================================================
FILE: requirements.txt
================================================
bitstruct==8.9.0
agoge==0.0.6
mido==1.2.9

================================================
FILE: scratch/dx7-sysexformat.md
================================================
Sysex Documentation 
===================

(Message GUS:472)
Received: from mailhub.iastate.edu by po-3.iastate.edu 
	id AA06806; Sat, 25 Sep 93 16:13:53 -0500
Received: from Waisman.Wisc.EDU (don.waisman.wisc.edu) by mailhub.iastate.edu
	id AA23002; Sat, 25 Sep 1993 16:14:09 -0500
Received: from Waisman.Wisc.EDU by Waisman.Wisc.EDU (PMDF V4.2-10 #2484) id
 <01H3DDLUXLDSBMA3H1@Waisman.Wisc.EDU>; Sat, 25 Sep 1993 16:13:40 CDT
Date: Sat, 25 Sep 1993 16:13:40 -0500 (CDT)
From: "Ewan A. Macpherson" <MACPHERSON@waisman.wisc.edu>
Subject: DX7 Data Format
To: xeno@iastate.edu
Message-Id: <01H3DDLUY4O2BMA3H1@Waisman.Wisc.EDU>
Organization: Waisman Center, University of Wisconsin-Madison
X-Vms-To: IN::"xeno@iastate.edu"
Mime-Version: 1.0
Content-Type: TEXT/PLAIN; CHARSET=US-ASCII
Content-Transfer-Encoding: 7BIT

Gary:

I don't know anything about the differences between the DX7 and DX7s, but this
DX7 info may be useful.  I posted this to r.m.s. before xmas.

I've seen many requests for public domain / shareware DX editors, but I've
never seen a definitive reply.  They're usually along the lines of "I was
roaching around on CompuServe last month, and I think I remember seeing one..."

Anyway, hope this helps ... 

=========================================================================

For those interested in unpacking the uscd.edu DX7 patch data, here is
DX7 data format information.

     compiled from - the DX7 MIDI Data Format Sheet
                   - article by Steve DeFuria (Keyboard Jan 87)
                   - looking at what my DX7 spits out

I have kept the kinda weird notation used in the DX7 Data Sheet to reduce
typing errors. Where it doesn't quite make sense to me I've added comments.
(And I will not be liable for errors etc ....)

Contents: A: SYSEX Message: Bulk Data for 1 Voice
          B: SYSEX Message: Bulk Data for 32 Voices
          C: SYSEX Message: Parameter Change
          D: Data Structure: Single Voice Dump & Voice Parameter #'s
          E: Function Parameter #'s
          F: Data Structure: Bulk Dump Packed Format

////////////////////////////////////////////////////////////
A:
SYSEX Message: Bulk Data for 1 Voice
------------------------------------
       bits    hex  description

     11110000  F0   Status byte - start sysex
     0iiiiiii  43   ID # (i=67; Yamaha)
     0sssnnnn  00   Sub-status (s=0) & channel number (n=0; ch 1)
     0fffffff  00   format number (f=0; 1 voice)
     0bbbbbbb  01   byte count MS byte
     0bbbbbbb  1B   byte count LS byte (b=155; 1 voice)
     0ddddddd  **   data byte 1

        |       |       |

     0ddddddd  **   data byte 155
     0eeeeeee  **   checksum (masked 2's complement of sum of 155 bytes)
     11110111  F7   Status - end sysex



///////////////////////////////////////////////////////////
B:
SYSEX Message: Bulk Data for 32 Voices
--------------------------------------
       bits    hex  description

     11110000  F0   Status byte - start sysex
     0iiiiiii  43   ID # (i=67; Yamaha)
     0sssnnnn  00   Sub-status (s=0) & channel number (n=0; ch 1)
     0fffffff  09   format number (f=9; 32 voices)
     0bbbbbbb  20   byte count MS byte
     0bbbbbbb  00   byte count LS byte (b=4096; 32 voices)
     0ddddddd  **   data byte 1

        |       |       |

     0ddddddd  **   data byte 4096  (there are 128 bytes / voice)
     0eeeeeee  **   checksum (masked 2's comp. of sum of 4096 bytes)
     11110111  F7   Status - end sysex


/////////////////////////////////////////////////////////////
C:
SYSEX MESSAGE: Parameter Change
-------------------------------
       bits    hex  description

     11110000  F0   Status byte - start sysex
     0iiiiiii  43   ID # (i=67; Yamaha)
     0sssnnnn  10   Sub-status (s=1) & channel number (n=0; ch 1)
     0gggggpp  **   parameter group # (g=0; voice, g=2; function)
     0ppppppp  **   parameter # (these are listed in next section)
                     Note that voice parameter #'s can go over 128 so
                     the pp bits in the group byte are either 00 for
                     par# 0-127 or 01 for par# 128-155. In the latter case
                     you add 128 to the 0ppppppp byte to compute par#. 
     0ddddddd  **   data byte
     11110111  F7   Status - end sysex


//////////////////////////////////////////////////////////////

D:
Data Structure: Single Voice Dump & Parameter #'s (single voice format, g=0)
-------------------------------------------------------------------------

Parameter
 Number    Parameter                  Value Range
---------  ---------                  -----------
  0        OP6 EG rate 1              0-99
  1         "  "  rate 2               "
  2         "  "  rate 3               "
  3         "  "  rate 4               "
  4         "  " level 1               "
  5         "  " level 2               "
  6         "  " level 3               "
  7         "  " level 4               "
  8        OP6 KBD LEV SCL BRK PT      "        C3= $27
  9         "   "   "   "  LFT DEPTH   "
 10         "   "   "   "  RHT DEPTH   "
 11         "   "   "   "  LFT CURVE  0-3       0=-LIN, -EXP, +EXP, +LIN
 12         "   "   "   "  RHT CURVE   "            "    "    "    "  
 13        OP6 KBD RATE SCALING       0-7
 14        OP6 AMP MOD SENSITIVITY    0-3
 15        OP6 KEY VEL SENSITIVITY    0-7
 16        OP6 OPERATOR OUTPUT LEVEL  0-99
 17        OP6 OSC MODE (fixed/ratio) 0-1        0=ratio
 18        OP6 OSC FREQ COARSE        0-31
 19        OP6 OSC FREQ FINE          0-99
 20        OP6 OSC DETUNE             0-14       0: det=-7
 21 \
  |  > repeat above for OSC 5, OSC 4,  ... OSC 1
125 /
126        PITCH EG RATE 1            0-99
127          "    " RATE 2              "
128          "    " RATE 3              "
129          "    " RATE 4              "
130          "    " LEVEL 1             "
131          "    " LEVEL 2             "
132          "    " LEVEL 3             "
133          "    " LEVEL 4             "
134        ALGORITHM #                 0-31
135        FEEDBACK                    0-7
136        OSCILLATOR SYNC             0-1
137        LFO SPEED                   0-99
138         "  DELAY                    "
139         "  PITCH MOD DEPTH          "
140         "  AMP   MOD DEPTH          "
141        LFO SYNC                    0-1
142         "  WAVEFORM                0-5, (data sheet claims 9-4 ?!?)
                                       0:TR, 1:SD, 2:SU, 3:SQ, 4:SI, 5:SH
143        PITCH MOD SENSITIVITY       0-7
144        TRANSPOSE                   0-48   12 = C2
145        VOICE NAME CHAR 1           ASCII
146        VOICE NAME CHAR 2           ASCII
147        VOICE NAME CHAR 3           ASCII
148        VOICE NAME CHAR 4           ASCII
149        VOICE NAME CHAR 5           ASCII
150        VOICE NAME CHAR 6           ASCII
151        VOICE NAME CHAR 7           ASCII
152        VOICE NAME CHAR 8           ASCII
153        VOICE NAME CHAR 9           ASCII
154        VOICE NAME CHAR 10          ASCII
155        OPERATOR ON/OFF
              bit6 = 0 / bit 5: OP1 / ... / bit 0: OP6

Note that there are actually 156 parameters listed here, one more than in 
a single voice dump. The OPERATOR ON/OFF parameter is not stored with the
voice, and is only transmitted or received while editing a voice. So it
only shows up in parameter change SYS-EX's.


////////////////////////////////////////////////////////

E:
Function Parameters: (g=2)
-------------------------

Parameter
Number        Parameter           Range
---------     ----------          ------
64         MONO/POLY MODE CHANGE  0-1      O=POLY
65         PITCH BEND RANGE       0-12
66           "    "   STEP        0-12
67         PORTAMENTO MODE        0-1      0=RETAIN 1=FOLLOW
68              "     GLISS       0-1
69              "     TIME        0-99
70         MOD WHEEL RANGE        0-99
71          "    "   ASSIGN       0-7     b0: pitch,  b1:amp, b2: EG bias
72         FOOT CONTROL RANGE     0-99
73          "     "     ASSIGN    0-7           "
74         BREATH CONT RANGE      0-99
75           "     "   ASSIGN     0-7           "
76         AFTERTOUCH RANGE       0-99
77             "      ASSIGN      0-7           "

///////////////////////////////////////////////////////////////

F:
Data Structure: Bulk Dump Packed Format
---------------------------------------

OK, now the tricky bit. For a bulk dump the 155 voice parameters for each
 voice are packed into 32 consecutive 128 byte chunks as follows ...

byte             bit #
 #     6   5   4   3   2   1   0   param A       range  param B       range
----  --- --- --- --- --- --- ---  ------------  -----  ------------  -----
  0                R1              OP6 EG R1      0-99
  1                R2              OP6 EG R2      0-99
  2                R3              OP6 EG R3      0-99
  3                R4              OP6 EG R4      0-99
  4                L1              OP6 EG L1      0-99
  5                L2              OP6 EG L2      0-99
  6                L3              OP6 EG L3      0-99
  7                L4              OP6 EG L4      0-99
  8                BP              LEV SCL BRK PT 0-99
  9                LD              SCL LEFT DEPTH 0-99
 10                RD              SCL RGHT DEPTH 0-99
 11    0   0   0 |  RC   |   LC  | SCL LEFT CURVE 0-3   SCL RGHT CURVE 0-3
 12  |      DET      |     RS    | OSC DETUNE     0-14  OSC RATE SCALE 0-7
 13    0   0 |    KVS    |  AMS  | KEY VEL SENS   0-7   AMP MOD SENS   0-3
 14                OL              OP6 OUTPUT LEV 0-99
 15    0 |         FC        | M | FREQ COARSE    0-31  OSC MODE       0-1
 16                FF              FREQ FINE      0-99
 17 \
  |  > these 17 bytes for OSC 5
 33 /
 34 \
  |  > these 17 bytes for OSC 4
 50 /
 51 \
  |  > these 17 bytes for OSC 3
 67 /
 68 \
  |  > these 17 bytes for OSC 2
 84 /
 85 \
  |  > these 17 bytes for OSC 1
101 /

byte             bit #
 #     6   5   4   3   2   1   0   param A       range  param B       range
----  --- --- --- --- --- --- ---  ------------  -----  ------------  -----
102               PR1              PITCH EG R1   0-99
103               PR2              PITCH EG R2   0-99
104               PR3              PITCH EG R3   0-99
105               PR4              PITCH EG R4   0-99
106               PL1              PITCH EG L1   0-99
107               PL2              PITCH EG L2   0-99
108               PL3              PITCH EG L3   0-99
109               PL4              PITCH EG L4   0-99
110    0   0 |        ALG        | ALGORITHM     0-31
111    0   0   0 |OKS|    FB     | OSC KEY SYNC  0-1    FEEDBACK      0-7
112               LFS              LFO SPEED     0-99
113               LFD              LFO DELAY     0-99
114               LPMD             LF PT MOD DEP 0-99
115               LAMD             LF AM MOD DEP 0-99
116  |  LPMS |      LFW      |LKS| LF PT MOD SNS 0-7   WAVE 0-5,  SYNC 0-1
117              TRNSP             TRANSPOSE     0-48
118          NAME CHAR 1           VOICE NAME 1  ASCII
119          NAME CHAR 2           VOICE NAME 2  ASCII
120          NAME CHAR 3           VOICE NAME 3  ASCII
121          NAME CHAR 4           VOICE NAME 4  ASCII
122          NAME CHAR 5           VOICE NAME 5  ASCII
123          NAME CHAR 6           VOICE NAME 6  ASCII
124          NAME CHAR 7           VOICE NAME 7  ASCII
125          NAME CHAR 8           VOICE NAME 8  ASCII
126          NAME CHAR 9           VOICE NAME 9  ASCII
127          NAME CHAR 10          VOICE NAME 10 ASCII

/////////////////////////////////////////////////////////////////////

And that's it.

Hope this is useful.

ewan.



  0                R1              OP6 EG R1      0-99
  1                R2              OP6 EG R2      0-99
  2                R3              OP6 EG R3      0-99
  3                R4              OP6 EG R4      0-99
  4                L1              OP6 EG L1      0-99
  5                L2              OP6 EG L2      0-99
  6                L3              OP6 EG L3      0-99
  7                L4              OP6 EG L4      0-99
  8                BP              LEV SCL BRK PT 0-99
  9                LD              SCL LEFT DEPTH 0-99
 10                RD              SCL RGHT DEPTH 0-99
 11    0   0   0 |  RC   |   LC  | SCL LEFT CURVE 0-3   SCL RGHT CURVE 0-3
 12  |      DET      |     RS    | OSC DETUNE     0-14  OSC RATE SCALE 0-7
 13    0   0 |    KVS    |  AMS  | KEY VEL SENS   0-7   AMP MOD SENS   0-3
 14                OL              OP6 OUTPUT LEV 0-99
 15    0 |         FC        | M | FREQ COARSE    0-31  OSC MODE       0-1
 16                FF              FREQ FINE      0-99
 
102               PR1              PITCH EG R1   0-99
103               PR2              PITCH EG R2   0-99
104               PR3              PITCH EG R3   0-99
105               PR4              PITCH EG R4   0-99
106               PL1              PITCH EG L1   0-99
107               PL2              PITCH EG L2   0-99
108               PL3              PITCH EG L3   0-99
109               PL4              PITCH EG L4   0-99
110    0   0 |        ALG        | ALGORITHM     0-31
111    0   0   0 |OKS|    FB     | OSC KEY SYNC  0-1    FEEDBACK      0-7
112               LFS              LFO SPEED     0-99
113               LFD              LFO DELAY     0-99
114               LPMD             LF PT MOD DEP 0-99
115               LAMD             LF AM MOD DEP 0-99
116  |  LPMS |      LFW      |LKS| LF PT MOD SNS 0-7   WAVE 0-5,  SYNC 0-1
117              TRNSP             TRANSPOSE     0-48
118          NAME CHAR 1           VOICE NAME 1  ASCII
119          NAME CHAR 2           VOICE NAME 2  ASCII
120          NAME CHAR 3           VOICE NAME 3  ASCII
121          NAME CHAR 4           VOICE NAME 4  ASCII
122          NAME CHAR 5           VOICE NAME 5  ASCII
123          NAME CHAR 6           VOICE NAME 6  ASCII
124          NAME CHAR 7           VOICE NAME 7  ASCII
125          NAME CHAR 8           VOICE NAME 8  ASCII
126          NAME CHAR 9           VOICE NAME 9  ASCII
127          NAME CHAR 10          VOICE NAME 10 ASCII

================================================
FILE: scratch/dx7_constants.py
================================================
from pathlib import Path
import bitstruct

ARTIFACTS_ROOT = Path('/content/gdrive/My Drive/audio/artifacts').expanduser()

def take(take_from, n):
    for _ in range(n):
        yield next(take_from)

N_OSC = 6
N_VOICES = 32

def checksum(data):
    return (128-sum(data)&127)%128

GLOBAL_VALID_RANGES = {
    'PR1':  range(0, 99+1),
    'PR2':  range(0, 99+1),
    'PR3':  range(0, 99+1),
    'PR4':  range(0, 99+1),
    'PL1':  range(0, 99+1),
    'PL2':  range(0, 99+1),
    'PL3':  range(0, 99+1),
    'PL4':  range(0, 99+1),
    'ALG':  range(0, 31+1),
    'OKS':  range(0, 1+1),
    'FB':   range(0, 7+1),
    'LFS':  range(0, 99+1),
    'LFD':  range(0, 99+1),
    'LPMD':  range(0, 99+1),
    'LAMD':  range(0, 99+1),
    'LPMS': range(0, 7+1),
    'LFW':  range(0, 5+1),
    'LKS':  range(0, 1+1),
    'TRNSP':  range(0, 48+1),
    'NAME CHAR 1': range(128),
    'NAME CHAR 2': range(128),
    'NAME CHAR 3': range(128),
    'NAME CHAR 4': range(128),
    'NAME CHAR 5': range(128),
    'NAME CHAR 6': range(128),
    'NAME CHAR 7': range(128),
    'NAME CHAR 8': range(128),
    'NAME CHAR 9': range(128),
    'NAME CHAR 10': range(128),
 }

OSCILLATOR_VALID_RANGES = {
    'R1':  range(0, 99+1),
    'R2':  range(0, 99+1),
    'R3':  range(0, 99+1),
    'R4':  range(0, 99+1),
    'L1':  range(0, 99+1),
    'L2':  range(0, 99+1),
    'L3':  range(0, 99+1),
    'L4':  range(0, 99+1),
    'BP':  range(0, 99+1),
    'LD':  range(0, 99+1),
    'RD':  range(0, 99+1),
    'RC':  range(0, 3+1),
    'LC':  range(0, 3+1),
    'DET': range(0, 14+1),
    'RS':  range(0, 7+1),
    'KVS': range(0, 7+1),
    'AMS': range(0, 3+1),
    'OL':  range(0, 99+1),
    'FC':  range(0, 31+1),
    'M':   range(0, 1+1),
    'FF':  range(0, 99+1),
}

VOICE_PARAMETER_RANGES = {f'{i}_{key}': value for key, value in OSCILLATOR_VALID_RANGES.items() for i in range(N_OSC)}
VOICE_PARAMETER_RANGES.update(GLOBAL_VALID_RANGES)

def verify(actual, ranges):
    assert set(actual.keys())==set(ranges.keys()), 'Params dont match'
    for key in actual:
        if not actual[key] in ranges[key]:
            return False
    return True


HEADER_KEYS = [
    'ID',
    'Sub-status',
    'format number',
    'byte count',
    'byte count',
]

GENERAL_KEYS = [
    'PR1',
    'PR2',
    'PR3',
    'PR4',
    'PL1',
    'PL2',
    'PL3',
    'PL4',
    'ALG',
    'OKS',
    'FB',
    'LFS',
    'LFD',
    'LPMD',
    'LAMD',
    'LPMS',
    'LFW',
    'LKS',
    'TRNSP',
    'NAME CHAR 1',
    'NAME CHAR 2',
    'NAME CHAR 3',
    'NAME CHAR 4',
    'NAME CHAR 5',
    'NAME CHAR 6',
    'NAME CHAR 7',
    'NAME CHAR 8',
    'NAME CHAR 9',
    'NAME CHAR 10',
]

OSC_KEYS = [
    'R1',
    'R2',
    'R3',
    'R4',
    'L1',
    'L2',
    'L3',
    'L4',
    'BP',
    'LD',
    'RD',
    'RC',
    'LC',
    'DET',
    'RS',
    'KVS',
    'AMS',
    'OL',
    'FC',
    'M',
    'FF',
]

FOOTER_KEYS = ['checksum']


VOICE_KEYS = [f'{i}_{key}' for i in range(6) for key in OSC_KEYS] + \
        GENERAL_KEYS 

KEYS =  HEADER_KEYS + \
        list(VOICE_KEYS * N_VOICES) + \
        FOOTER_KEYS



header_bytes = [
    'p1u7',             # ID # (i=67; Yamaha)
    'p1u7',             # Sub-status (s=0) & channel number (n=0; ch 1)
    'p1u7',             # format number (f=9; 32 voices)
    'p1u7',             # byte count MS byte
    'p1u7',             # byte count LS byte (b=4096; 32 voices)
]




general_parameter_bytes = [ 
    'p1u7',             # PR1
    'p1u7',             # PR2
    'p1u7',             # PR3
    'p1u7',             # PR4
    'p1u7',             # PL1
    'p1u7',             # PL2
    'p1u7',             # PL3
    'p1u7',             # PL4
    'p3u5',             # ALG
    'p4u1u3',           # OKS|    FB
    'p1u7',             # LFS
    'p1u7',             # LFD
    'p1u7',             # LPMD
    'p1u7',             # LAMD
    'p1u3u3u1',         # LPMS |      LFW      |LKS
    'p1u7',             # TRNSP
    'p1u7',             # NAME CHAR 1
    'p1u7',             # NAME CHAR 2
    'p1u7',             # NAME CHAR 3
    'p1u7',             # NAME CHAR 4
    'p1u7',             # NAME CHAR 5
    'p1u7',             # NAME CHAR 6
    'p1u7',             # NAME CHAR 7
    'p1u7',             # NAME CHAR 8
    'p1u7',             # NAME CHAR 9
    'p1u7',             # NAME CHAR 10
]

osc_parameter_bytes = [
    'p1u7',         # R1
    'p1u7',         # R2
    'p1u7',         # R3
    'p1u7',         # R4
    'p1u7',         # L1
    'p1u7',         # L2
    'p1u7',         # L3
    'p1u7',         # L4
    'p1u7',         # BP
    'p1u7',         # LD
    'p1u7',         # RD
    'p4u2u2',       # RC | LC 
    'p1u4u3',       # DET | RS
    'p3u3u2',       # KVS | AMS
    'p1u7',         # OL
    'p2u5u1',       # FC | M
    'p1u7'          # FF
]

voice_bytes = (osc_parameter_bytes * N_OSC) + general_parameter_bytes

tail_bytes = [
    'p1u7',         # checksum
]


full_string = ''.join(header_bytes + osc_parameter_bytes * 6 + general_parameter_bytes)
dx7_struct = bitstruct.compile(full_string)

voice_struct = bitstruct.compile(''.join(voice_bytes), names=VOICE_KEYS)
header_struct = bitstruct.compile(''.join(header_bytes))

if __name__=="__main__":
    print(VOICE_KEYS)

================================================
FILE: scratch/dx7_syx.py
================================================
#%%
import bitstruct
import mido
from pathlib import Path
from itertools import chain

from dx7_constants import voice_struct, verify, VOICE_PARAMETER_RANGES, header_struct,\
    header_bytes, voice_bytes, take, VOICE_KEYS, ARTIFACTS_ROOT, N_VOICES, N_OSC

# %%


def consume_syx(path):

    path = Path(path).expanduser()
    try:
        preset = mido.read_syx_file(path.as_posix())[0]
    except IndexError as e:
        return None
    except ValueError as e:
        return None
    if len(preset.data) == 0:
        return None

    def get_voice(data):
        
        unpacked = voice_struct.unpack(data)

        if not verify(unpacked, VOICE_PARAMETER_RANGES):
            return None
        
        return unpacked

    get_header = header_struct.unpack
    sysex_iter = iter(preset.data)
    
    try:
        header = get_header(bytes(take(sysex_iter, len(header_bytes))))
        yield from (get_voice(bytes(take(sysex_iter, len(voice_bytes)))) for _ in range(N_VOICES))
    except RuntimeError:
        return None
#%%
from tqdm import tqdm as tqdm
from functools import reduce
import numpy as np
if __name__=='__main__':
    DEV = False
    dataset_file = 'dx7.npy'
# PRESETS_ROOT = '~/audio/artifacts/dx7-patches'
# PRESETS_ROOT = '~/audio/artifacts/dx7-patches-dev'
    preset_root = ARTIFACTS_ROOT.joinpath('dx7-patches').expanduser()

    preset_paths = iter(tqdm(sorted(preset_root.glob('**/*.syx'))))
    if DEV:
        dataset_file = f'dev-{dataset_file}'
        preset_paths = take(preset_paths, 10)
    preset_paths = iter(filter(lambda preset_path: preset_path.is_file(), preset_paths))

    consume_chain = chain.from_iterable(map(consume_syx, preset_paths))
    consume_chain = filter(lambda x: x is not None, consume_chain)

    arr_dtype = list(zip(VOICE_KEYS, ['u8']*len(VOICE_KEYS)))
    to_arr = lambda voice_dict: np.array([tuple(voice_dict.values())], dtype=arr_dtype)
    arr = np.concatenate(list(map(to_arr, consume_chain)))
    arr = np.unique(arr)
    np.save(ARTIFACTS_ROOT.joinpath(dataset_file).as_posix(), arr)
    # print(to_arr(next(iter(consume_chain))))
    # print(type(consume_chain))
    # 1/0
    # # for i in list(map(list, outputs)):
    # #     print(len(i))
    # paths, ns, data = zip(*outputs)
    # data = np.array(list(set(data)))
    # paths = np.array([path.as_posix() for path in paths])
    # ns = np.array(ns)
    # np.savez(ARTIFACTS_ROOT.joinpath('dx7.npy'), outputs, paths, ns)
    # # patch_map = dict(tqdm((iter(consume_chain))))
    # # random_key, *_ = patch_map.keys()
    # # # print(global_config['NAME'])

    # # name = items.pop('')
    # # print(global_config)
    # # print(json.dumps(oscillator_configs, indent=4))
    # # print(list(sysex_iter))





# %%


# %%


================================================
FILE: scratch/fm-param-analysis.py
================================================
#%%
import torch
from torch.utils.data import Subset as DataSubset
from sklearn.model_selection import train_test_split
from fm_param_vae import Net, DX7Dataset, ARTIFACTS_ROOT, VOICE_KEYS
from dx7_constants import VOICE_PARAMETER_RANGES

import numpy as np

dataset = DX7Dataset()
train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
train_dataset = DataSubset(dataset, train_idxs)
test_dataset = DataSubset(dataset, test_idxs)

model = Net()
model.load_state_dict(torch.load(ARTIFACTS_ROOT.joinpath('fm-param-vae-8.pt')))


# %%
with torch.no_grad():
    data = torch.stack([test_dataset[i] for i in range(len(test_dataset))])
    results, *_ = model(data)
# %%

correct = (results.argmax(-1) == data)
# %%

p_z = torch.distributions.Normal(0, 1)
z = p_z.sample((32, 8))

x_hat = model.generate(z, 0.3)


# %%


================================================
FILE: scratch/fm_param_ae.py
================================================
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
from torch.utils.data import Subset as DataSubset
from sklearn.model_selection import train_test_split

from dx7_constants import VOICE_PARAMETER_RANGES, ARTIFACTS_ROOT, VOICE_KEYS
import numpy as np
N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1
#%%

# class DataHandler()
#     def __init__(self, data_file, root=ARTIFACTS_ROOT):

#         if not isinstance(root, Path):
#             root = Path(root).expanduser()

#         data = np.load(ARTIFACTS_ROOT.joinpath(patch_file))






class DX7Dataset():
    

    def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):

        if not isinstance(root, Path):
            root = Path(root).expanduser()

        self.data = np.load(root.joinpath(data_file)) 
        

    def __getitem__(self, index):

        item = torch.tensor(self.data[index].item()).long()

        return item
    def __len__(self):
        return len(self.data)


#%%
class Net(nn.Module):
    def __init__(self, latent_dim=16, n_params=N_PARAMS, max_value=MAX_VALUE):
        super(Net, self).__init__()

        self.n_params = n_params
        self.max_value = max_value

        self.embedder = nn.Embedding(max_value, 8)

        self.enc = nn.Sequential(
            nn.Linear(8*n_params, 512),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(512, latent_dim),
        )

        self.dec = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(512, max_value*n_params),
        )

        self.register_buffer('mask', self.generate_mask())

    @staticmethod
    def generate_mask():
        
        mask_item_f = lambda x: torch.arange(MAX_VALUE) <= max(x) 
        mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, VOICE_KEYS))

        return torch.stack(list(mapper))

    def forward(self, x):
        
        x = self.embedder(x)
        x = x.flatten(-2, -1)
        z = self.enc(x)

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)

        x_hat = torch.masked_fill(x_hat, ~self.mask, -1e9)
        return x_hat

#%%
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output.transpose(-1,-2), data)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output.transpose(-1,-2), data, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=-1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(data.view_as(pred)).sum().item() / 155

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__=="__main__":
    # Training settings
    use_cuda = False
    batch_size = 32
    lr = 1
    gamma = 0.7
    epochs = 100

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    
    dataset = DX7Dataset()
    train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
    train_dataset = DataSubset(dataset, train_idxs)
    test_dataset = DataSubset(dataset, test_idxs)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    # if args.save_model:
    #     torch.save(model.state_dict(), "mnist_cnn.pt")


# if __name__ == '__main__':
#     main()

# %%


# %%


# %%


# %%


================================================
FILE: scratch/fm_param_agoge_vae_rnn.py
================================================
#%%
import os
os.environ['MLFLOW_TRACKING_URI'] = 'http://localhost:9001'
import torch
torch.randn(10,10).cuda()
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
from torch.utils.data import Subset as DataSubset
from ray import tune
# from sklearn.model_selection import train_test_split
import numpy as np
from dx7_constants import VOICE_PARAMETER_RANGES, ARTIFACTS_ROOT, VOICE_KEYS

from ray.tune.schedulers import HyperBandForBOHB
from ray.tune.suggest.bohb import TuneBOHB
import ConfigSpace as CS

from agoge import AbstractModel, AbstractSolver, Worker
from agoge.utils import uuid, trial_name_creator, experiment_name_creator, get_logger
from itertools import starmap

import numpy as np
N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1
#%%

# class DataHandler()
#     def __init__(self, data_file, root=ARTIFACTS_ROOT):

#         if not isinstance(root, Path):
#             root = Path(root).expanduser()

#         data = np.load(ARTIFACTS_ROOT.joinpath(patch_file))



class DX7Dataset():
    

    def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT, data_size=1.):

        assert data_size < 1

        self.data_size = data_size

        if not isinstance(root, Path):
            root = Path(root).expanduser()

        self.data = np.load(root.joinpath(data_file)) 

    def __getitem__(self, index):

        item = torch.tensor(self.data[index].item()).long()

        return {'x': item}
    
    def __len__(self):
        return int(len(self.data) * self.data_size)


#%%

#%%
class DX7RecurrentVAE(AbstractModel):
    def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE, hidden_dim=128, params_ordering=None):
        super().__init__()

        self.n_params = n_params
        self.max_value = max_value

        self.embedder = nn.Embedding(max_value, hidden_dim)

        self.enc = nn.ModuleList(
            [nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True),]
        )

        self.q_z = nn.Linear(hidden_dim, 2*latent_dim)
        self.z2x = nn.Linear(hidden_dim+latent_dim, hidden_dim)
        self.logits = nn.Linear(hidden_dim, max_value)

        self.dec = nn.ModuleList(
            [nn.LSTM(hidden_dim, hidden_dim, batch_first=True),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.LSTM(hidden_dim, hidden_dim, batch_first=True),
            ]
        )

        self.ordering = params_ordering
        if params_ordering is not None:
            assert len(params_ordering) == len(VOICE_KEYS), 'more or less params than expected'
            self.ordering = np.argsort(params_ordering)
            self.reverse_ordering = np.argsort(self.ordering)


        self.register_buffer('mask', self.generate_mask(self.ordering))


    def network(self, x, network):

        lstm, gelu, drop, lstm2 = network

        x_1, (h_1, _) = lstm(x)
        if lstm.bidirectional == True:
            x_1 = h_1.mean(0)
        x_1 = drop(gelu(x_1))

        x_2, (h_2, _) = lstm2(x)

        if lstm2.bidirectional == True:
            x_2 = h_2.mean(0)
            x = torch.ones_like(x_2)

        x_2 = drop(x_2)

        x = x_1 * x + x_2

        return x

    @staticmethod
    def generate_mask(ordering=None):
        """
        ordering the index ordering of the parameters based on the index in the dx7_constants.VOICE_KEYS
        """
        
        mask_item_f = lambda x: torch.arange(MAX_VALUE) <= max(x) 
        mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, VOICE_KEYS))

        mask = torch.stack(list(mapper))

        if ordering is not None:
            return mask[ordering]
        return mask

    def forward(self, x):
        if self.ordering is not None:
            x = x[:, self.ordering]

        x = self.embedder(x)
        theta_z = self.network(x, self.enc)

        q_z_mu, q_z_std = self.q_z(theta_z).chunk(2, -1)

        q_z = torch.distributions.Normal(q_z_mu, (0.5*q_z_std.clamp(-5, 3)).exp())

        z = q_z.sample()
        z_in = z.unsqueeze(-2) + torch.zeros_like(x[...,0]).unsqueeze(-1)

        # x_endcut = x
        x_prepad = torch.cat([torch.zeros_like(x[:,[0]]), x], dim=-2)
        x_endcut = x_prepad[:,:-1]
        
        x_dec_in = torch.cat([x_endcut, z_in], dim=-1)
        x_dec_in = self.z2x(x_dec_in)
        x_hat = self.network(x_dec_in, self.dec)

        x_hat = self.logits(x_hat)

        x_hat = torch.masked_fill(x_hat, ~self.mask, -1e9)

        if self.ordering is not None:
            x_hat = x_hat[:, self.reverse_ordering]

        return x_hat, q_z, z

    def generate(self, z, t=1.):

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)
        x_hat = torch.masked_fill(x_hat, ~self.mask, -float('inf'))

        x_hat = torch.distributions.Categorical(logits=x_hat / t)

        return x_hat
#%%

class DX7RecurrentVAESolver(AbstractSolver):

    def __init__(self, model,
        Optim=AdamW, optim_opts=dict(lr= 1e-4),
        max_beta=0.5,
        **kwargs):

        if isinstance(Optim, str):
            Optim = import_module(Optim)


        self.optim = Optim(params=model.parameters(), **optim_opts)
        self.schedule = self.scheduler()
        self.max_beta = max_beta
        self.model = model
        self._beta = self.schedule()

    @property
    def beta(self):

        return self._beta * self.max_beta

    @staticmethod
    def scheduler():
        n_steps  = 0
        beta_steps = 37400

        def schedule():
            nonlocal n_steps
            n_steps += 1

            step = (n_steps)/beta_steps
            step = min(1, step)

            return 0.5 * (1 + np.sin((step*np.pi)-(np.pi/2)))
        return schedule
      
    def loss(self, x, x_hat, q_z, z):

        reconstruction_loss = F.cross_entropy(x_hat.transpose(-1,-2), x)

        p_z = torch.distributions.Normal(0, 1)

        log_q_z = q_z.log_prob(z)
        log_p_z = p_z.log_prob(z)

        kl = (log_q_z - log_p_z).mean()

        kl_tempered = kl * self.beta
        
        loss = reconstruction_loss + kl_tempered


        accuracy = (x_hat.argmax(-1)==x).float().mean()

        return loss, {
            'log_q_z': log_q_z.mean(),
            'log_p_z': log_p_z.mean(),
            'kl': kl,
            'reconstruction_loss': reconstruction_loss,
            'beta': self.beta,
            'accuracy': accuracy
        }
        

    def solve(self, x, **kwargs):
        
        x_hat, q_z, z  = self.model(x)
        loss, L = self.loss(x, x_hat, q_z, z)

        if loss != loss:
            raise ValueError('Nan Values detected')

        if self.model.training:

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()

            self._beta = self.schedule()
        
        return L

    
    def step(self):

        pass


    def state_dict(self):
        
        state_dict = {
            'optim': self.optim.state_dict()
        }

        return state_dict

    def load_state_dict(self, state_dict):
        
        load_component = lambda component, state: getattr(self, component).load_state_dict(state)
        list(starmap(load_component, state_dict.items()))



def config(experiment_name, trial_name, batch_size=16, **kwargs):
    
    voice_params = {key.split('..')[-1]: value for key, value in kwargs.items() if 'VOICE..' in key}
    params_ordering = list(map(voice_params.get, VOICE_KEYS))


    data_handler = {
        'Dataset': DX7Dataset,
        'dataset_opts': {
            'data_size': 0.2
        },
        'loader_opts': {
            'batch_size': batch_size,
        },
    }

    model = {
        'Model': DX7RecurrentVAE,
        'params_ordering': params_ordering
        # 'conv1': (1, 32, 3, 1)
    }

    solver = {
        'Solver': DX7RecurrentVAESolver
    }

    tracker = {
        'metrics': ['reconstruction_loss', 'log_q_z', 'log_p_z', 'kl', 'beta', 'accuracy'],
        'experiment_name': experiment_name,
        'trial_name': trial_name
    }

    return {
        'data_handler': data_handler,
        'model': model,
        'solver': solver,
        'tracker': tracker,
    }

from mlflow.tracking import MlflowClient
if __name__=='__main__':
    # from ray import ray
    import sys
    postfix = sys.argv[1]
    # ray.init()
    # from ray.tune.utils import validate_save_restore
    # validate_save_restore(Worker)
    # client = MlflowClient(tracking_uri='localhost:5000')
    experiment_name = f'dx7-vae-{postfix}'#+experiment_name_creator()
    # experiment_id = client.create_experiment(experiment_name)


    experiment_metrics = dict(metric="loss/accuracy", mode="max")

    config_space = CS.ConfigurationSpace()
    [config_space.add_hyperparameter(
        CS.UniformFloatHyperparameter(f'VOICE..{key}', lower=0., upper=1)
    ) for key in VOICE_KEYS]
    bohb_hyperband = HyperBandForBOHB(
        time_attr="training_iteration", max_t=16, **experiment_metrics)
    bohb_search = TuneBOHB(
        config_space, max_concurrent=1, **experiment_metrics)


    tune.run(Worker, 
    config={
        'config_generator': config,
        'experiment_name': experiment_name,
        'points_per_epoch': 2
    },
    trial_name_creator=trial_name_creator,
    resources_per_trial={
        'gpu': 1
    },
    checkpoint_freq=2,
    checkpoint_at_end=True,
    keep_checkpoints_num=1,
    search_alg=bohb_search, 
    scheduler=bohb_hyperband,
    num_samples=4096,
    verbose=0,
    local_dir='~/ray_results'
    # webui_host='127.0.0.1' ## supresses an error
        # stop={'loss/loss': 0}
    )
# points_per_epoch
# %%

# #################################################################################
  

# if __name__=="__main__":
#     # Training settings
#     use_cuda = True
#     batch_size = 32
#     lr = 1e-4
#     gamma = 1.
#     epochs = 100
#     beta = 0.5
#     beta_steps = 37400
    
#     schedule = scheduler()
#     device = torch.device("cuda" if use_cuda else "cpu")

#     kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    
#     dataset = DX7Dataset()
#     train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
#     train_dataset = DataSubset(dataset, train_idxs)
#     test_dataset = DataSubset(dataset, test_idxs)

#     train_loader = torch.utils.data.DataLoader(
#         train_dataset,
#         batch_size=batch_size, shuffle=True, **kwargs)
#     test_loader = torch.utils.data.DataLoader(
#         test_dataset,
#         batch_size=batch_size, shuffle=True, **kwargs)

#     model = Net().to(device)
#     optimizer = optim.AdamW(model.parameters(), lr=lr)

#     scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
#     for epoch in range(1, epochs + 1):
#         train(model, device, train_loader, optimizer, epoch)
#         test(model, device, test_loader)
#         # scheduler.step()

#     # if args.save_model:
#     #     torch.save(model.state_dict(), "mnist_cnn.pt")

#     torch.save(model.state_dict(), ARTIFACTS_ROOT.joinpath('fm-param-vae-8.pt'))
# # if __name__ == '__main__':
# #     main()

# # %%


# # %%


# # %%


# # %%


================================================
FILE: scratch/fm_param_rnn_decoder.py
================================================
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
from torch.utils.data import Subset as DataSubset
from sklearn.model_selection import train_test_split
import numpy as np
from dx7_constants import VOICE_PARAMETER_RANGES, ARTIFACTS_ROOT, VOICE_KEYS
import numpy as np
N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1
#%%

# class DataHandler()
#     def __init__(self, data_file, root=ARTIFACTS_ROOT):

#         if not isinstance(root, Path):
#             root = Path(root).expanduser()

#         data = np.load(ARTIFACTS_ROOT.joinpath(patch_file))



class DX7Dataset():
    

    def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):

        if not isinstance(root, Path):
            root = Path(root).expanduser()

        self.data = np.load(root.joinpath(data_file)) 
        

    def __getitem__(self, index):

        item = torch.tensor(self.data[index].item()).long()

        return item
    def __len__(self):
        return len(self.data)


#%%

#%%
class Net(nn.Module):
    def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE):
        super(Net, self).__init__()

        self.n_params = n_params
        self.max_value = max_value

        self.embedder = nn.Embedding(max_value, 128)

        self.dec = nn.ModuleList(
            [nn.LSTM(128, 128, batch_first=True),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.LSTM(128, 128, batch_first=True),]
        )

        self.register_buffer('mask', self.generate_mask())

    def network(self, x, network):

        lstm, gelu, drop, lstm2 = network

        x_1, _ = lstm(x)
        if lstm.bidirectional == True:
            x_1 = torch.stack(x_1.chunk(2, -1)).sum(0)

        x_in = gelu(drop(x_1)) + x
        x_2, _ = lstm2(x)

        if lstm.bidirectional == True:
            x_2 = torch.stack(x_2.chunk(2, -1)).sum(0)

        x = gelu(drop(x_2)) * x_in

        return x

    @staticmethod
    def generate_mask():
        
        mask_item_f = lambda x: torch.arange(MAX_VALUE) <= max(x) 
        mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, VOICE_KEYS))

        return torch.stack(list(mapper))

    def forward(self, x):
        
        x = self.embedder(x)
        x_z_sub_1 = torch.cat([x[:,[0]]*0, x[:,:-1]], dim=-2)


        x_hat = self.network(x_z_sub_1, self.dec)

        x_hat = torch.masked_fill(x_hat, ~self.mask, -1e9)
        return x_hat

    def generate(self, z, t=1.):

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)
        x_hat = torch.masked_fill(x_hat, ~self.mask, -float('inf'))

        x_hat = torch.distributions.Categorical(logits=x_hat / t)

        return x_hat
#%%
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output.transpose(-1,-2), data)

        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tKL: {:.3f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), 0))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output = model(data)
            loss = F.cross_entropy(output.transpose(-1,-2), data)
            
            test_loss += loss
            pred = output.argmax(dim=-1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(data.view_as(pred)).sum().item() / 155
    
    test_loss /= len(test_loader)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__=="__main__":
    # Training settings
    use_cuda = True
    batch_size = 32
    lr = 1e-4
    gamma = 1.
    epochs = 100
    beta = 0.5
    beta_steps = 37400
    def scheduler():
        n_steps  = 0
        def schedule():
            nonlocal n_steps
            n_steps += 1

            if n_steps < 1000:

                return 0

            step = (n_steps-1000)/beta_steps
            step = min(1, step)

            return 0.5 * (1 + np.sin((step*np.pi)-(np.pi/2)))
        return schedule
    schedule = scheduler()
    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    
    dataset = DX7Dataset()
    train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
    train_dataset = DataSubset(dataset, train_idxs)
    test_dataset = DataSubset(dataset, test_idxs)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        # scheduler.step()

    # if args.save_model:
    #     torch.save(model.state_dict(), "mnist_cnn.pt")

    torch.save(model.state_dict(), ARTIFACTS_ROOT.joinpath('fm-param-vae-8.pt'))
# if __name__ == '__main__':
#     main()

# %%


# %%


# %%


# %%


================================================
FILE: scratch/fm_param_vae.py
================================================
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
from torch.utils.data import Subset as DataSubset
from sklearn.model_selection import train_test_split

from dx7_constants import VOICE_PARAMETER_RANGES, ARTIFACTS_ROOT, VOICE_KEYS
import numpy as np
N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1
#%%

# class DataHandler()
#     def __init__(self, data_file, root=ARTIFACTS_ROOT):

#         if not isinstance(root, Path):
#             root = Path(root).expanduser()

#         data = np.load(ARTIFACTS_ROOT.joinpath(patch_file))



class DX7Dataset():
    

    def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):

        if not isinstance(root, Path):
            root = Path(root).expanduser()

        self.data = np.load(root.joinpath(data_file)) 
        

    def __getitem__(self, index):

        item = torch.tensor(self.data[index].item()).long()

        return item
    def __len__(self):
        return len(self.data)


#%%

#%%
class Net(nn.Module):
    def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE):
        super(Net, self).__init__()

        self.n_params = n_params
        self.max_value = max_value

        self.embedder = nn.Embedding(max_value, 8)

        self.enc = nn.Sequential(
            nn.Linear(8*n_params, 512),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(512, latent_dim*2),
        )

        self.dec = nn.Sequential(
            nn.Linear(latent_dim, 512),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(512, max_value*n_params),
        )

        self.register_buffer('mask', self.generate_mask())

    @staticmethod
    def generate_mask():
        
        mask_item_f = lambda x: torch.arange(MAX_VALUE) <= max(x) 
        mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, VOICE_KEYS))

        return torch.stack(list(mapper))

    def forward(self, x):
        
        x = self.embedder(x)
        x = x.flatten(-2, -1)
        q_z_mu, q_z_std = self.enc(x).chunk(2, -1)

        q_z = torch.distributions.Normal(q_z_mu, q_z_std.clamp(-3, 2).exp())

        z = q_z.sample()

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)

        x_hat = torch.masked_fill(x_hat, ~self.mask, -1e9)
        return x_hat, q_z, z

    def generate(self, z, t=1.):

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)
        x_hat = torch.masked_fill(x_hat, ~self.mask, -float('inf'))

        x_hat = torch.distributions.Categorical(logits=x_hat / t)

        return x_hat
#%%
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output, q_z, z = model(data)
        loss = F.cross_entropy(output.transpose(-1,-2), data)

        p_z = torch.distributions.Normal(0, 1)

        loss = loss + (q_z.log_prob(z) - p_z.log_prob(z)).mean() * beta

        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output, q_z, z = model(data)
            loss = F.cross_entropy(output.transpose(-1,-2), data)
            p_z = torch.distributions.Normal(0, 1)
            loss = loss + (q_z.log_prob(z) - p_z.log_prob(z)).mean() * beta
            
            test_loss += loss
            pred = output.argmax(dim=-1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(data.view_as(pred)).sum().item() / 155
    print(test_loss)
    test_loss /= len(test_loader)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__=="__main__":
    # Training settings
    use_cuda = True
    batch_size = 32
    lr = 0.01
    gamma = 0.7
    epochs = 100
    beta = 0.5

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    
    dataset = DX7Dataset()
    train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
    train_dataset = DataSubset(dataset, train_idxs)
    test_dataset = DataSubset(dataset, test_idxs)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        # train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        # scheduler.step()

    # if args.save_model:
    #     torch.save(model.state_dict(), "mnist_cnn.pt")

    torch.save(model.state_dict(), ARTIFACTS_ROOT.joinpath('fm-param-vae-8.pt'))
# if __name__ == '__main__':
#     main()

# %%


# %%


# %%


# %%


================================================
FILE: scratch/fm_param_vae_rnn.py
================================================
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from pathlib import Path
from torch.utils.data import Subset as DataSubset
from sklearn.model_selection import train_test_split
import numpy as np
from dx7_constants import VOICE_PARAMETER_RANGES, ARTIFACTS_ROOT, VOICE_KEYS
import numpy as np
N_PARAMS = len(VOICE_PARAMETER_RANGES)
MAX_VALUE = max([max(i) for i in VOICE_PARAMETER_RANGES.values()]) + 1
#%%

# class DataHandler()
#     def __init__(self, data_file, root=ARTIFACTS_ROOT):

#         if not isinstance(root, Path):
#             root = Path(root).expanduser()

#         data = np.load(ARTIFACTS_ROOT.joinpath(patch_file))



class DX7Dataset():
    

    def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):

        if not isinstance(root, Path):
            root = Path(root).expanduser()

        self.data = np.load(root.joinpath(data_file)) 
        

    def __getitem__(self, index):

        item = torch.tensor(self.data[index].item()).long()

        return item
    def __len__(self):
        return len(self.data)


#%%

#%%
class Net(nn.Module):
    def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE, hidden_dim=128):
        super(Net, self).__init__()

        self.n_params = n_params
        self.max_value = max_value

        self.embedder = nn.Embedding(max_value, hidden_dim)

        self.enc = nn.ModuleList(
            [nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.LSTM(hidden_dim, hidden_dim, batch_first=True, bidirectional=True),]
        )

        self.q_z = nn.Linear(hidden_dim, 2*latent_dim)
        self.z2x = nn.Linear(hidden_dim+latent_dim, hidden_dim)
        self.logits = nn.Linear(hidden_dim, max_value)

        self.dec = nn.ModuleList(
            [nn.LSTM(hidden_dim, hidden_dim, batch_first=True),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.LSTM(hidden_dim, hidden_dim, batch_first=True),
            ]
        )

        self.register_buffer('mask', self.generate_mask())

    def network(self, x, network):

        lstm, gelu, drop, lstm2 = network

        x_1, (h_1, _) = lstm(x)
        if lstm.bidirectional == True:
            x_1 = h_1.mean(0)
        x_1 = drop(gelu(x_1))

        x_2, (h_2, _) = lstm2(x)

        if lstm2.bidirectional == True:
            x_2 = h_2.mean(0)
            x = torch.ones_like(x_2)

        x_2 = drop(x_2)

        x = x_1 * x + x_2

        return x

    @staticmethod
    def generate_mask():
        
        mask_item_f = lambda x: torch.arange(MAX_VALUE) <= max(x) 
        mapper = map(mask_item_f, map(VOICE_PARAMETER_RANGES.get, VOICE_KEYS))

        return torch.stack(list(mapper))

    def forward(self, x):
        
        x = self.embedder(x)
        theta_z = self.network(x, self.enc)

        q_z_mu, q_z_std = self.q_z(theta_z).chunk(2, -1)

        q_z = torch.distributions.Normal(q_z_mu, (0.5*q_z_std.clamp(-5, 3)).exp())

        z = q_z.sample()
        z_in = z.unsqueeze(-2) + torch.zeros_like(x[...,0]).unsqueeze(-1)

        # x_endcut = x
        x_prepad = torch.cat([torch.zeros_like(x[:,[0]]), x], dim=-2)
        x_endcut = x_prepad[:,:-1]
        
        x_dec_in = torch.cat([x_endcut, z_in], dim=-1)
        x_dec_in = self.z2x(x_dec_in)
        x_hat = self.network(x_dec_in, self.dec)

        x_hat = self.logits(x_hat)

        x_hat = torch.masked_fill(x_hat, ~self.mask, -1e9)
        return x_hat, q_z, z

    def generate(self, z, t=1.):

        x_hat = self.dec(z)
        x_hat = x_hat.reshape(-1, self.n_params, self.max_value)
        x_hat = torch.masked_fill(x_hat, ~self.mask, -float('inf'))

        x_hat = torch.distributions.Categorical(logits=x_hat / t)

        return x_hat
#%%
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        output, q_z, z = model(data)
        loss = F.cross_entropy(output.transpose(-1,-2), data)

        p_z = torch.distributions.Normal(0, 1)

        kl = (q_z.log_prob(z) - p_z.log_prob(z)).mean()    
        kl_tempered = kl * beta * schedule()
        loss = loss + kl_tempered

        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tKL: {:.3f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item(), kl.item()))


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            output, q_z, z = model(data)
            loss = F.cross_entropy(output.transpose(-1,-2), data)
            p_z = torch.distributions.Normal(0, 1)
            kl = (q_z.log_prob(z) - p_z.log_prob(z)).mean()
            loss = loss + kl * beta
            
            test_loss += loss
            pred = output.argmax(dim=-1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(data.view_as(pred)).sum().item() / 155
    
    test_loss /= len(test_loader)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

if __name__=="__main__":
    # Training settings
    use_cuda = True
    batch_size = 32
    lr = 1e-4
    gamma = 1.
    epochs = 100
    beta = 0.5
    beta_steps = 37400
    def scheduler():
        n_steps  = 0
        def schedule():
            nonlocal n_steps
            n_steps += 1

            if n_steps < 1000:

                return 0

            step = (n_steps-1000)/beta_steps
            step = min(1, step)

            return 0.5 * (1 + np.sin((step*np.pi)-(np.pi/2)))
        return schedule
    schedule = scheduler()
    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

    
    dataset = DX7Dataset()
    train_idxs, test_idxs = train_test_split(range(len(dataset)), random_state=42)
    train_dataset = DataSubset(dataset, train_idxs)
    test_dataset = DataSubset(dataset, test_idxs)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size, shuffle=True, **kwargs)

    model = Net().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        # scheduler.step()

    # if args.save_model:
    #     torch.save(model.state_dict(), "mnist_cnn.pt")

    torch.save(model.state_dict(), ARTIFACTS_ROOT.joinpath('fm-param-vae-8.pt'))
# if __name__ == '__main__':
#     main()

# %%


# %%


# %%


# %%


================================================
FILE: scratch/syx_parser.py
================================================
#%%
import mido
from pathlib import Path
import json
from uuid import uuid4
from itertools import chain
from tqdm import tqdm as tqdm
import numpy as np

PARAMETER_ORDER = ['PR1', 'PR2', 'PR3', 'PR4', 'PL1', 'PL2', 'PL3', 'PL4', 'ALG', 'OKS', 'FB', 'LFS', 'LFD', 'LPMD', 'LAMD', 'LPMS', 'LFW', 'LKS', 'TRNSP', 'NAME_0', 'NAME_1', 'NAME_2', 'NAME_3', 'NAME_4', 'NAME_5', 'NAME_6', 'NAME_7', 'NAME_8', 'NAME_9', '0_R1', '0_R2', '0_R3', '0_R4', '0_L1', '0_L2', '0_L3', '0_L4', '0_BP', '0_LD', '0_RD', '0_RC', '0_LC', '0_DET', '0_RS', '0_KVS', '0_AMS', '0_OL', '0_FC', '0_M', '0_FF', '1_R1', '1_R2', '1_R3', '1_R4', '1_L1', '1_L2', '1_L3', '1_L4', '1_BP', '1_LD', '1_RD', '1_RC', '1_LC', '1_DET', '1_RS', '1_KVS', '1_AMS', '1_OL', '1_FC', '1_M', '1_FF', '2_R1', '2_R2', '2_R3', '2_R4', '2_L1', '2_L2', '2_L3', '2_L4', '2_BP', '2_LD', '2_RD', '2_RC', '2_LC', '2_DET', '2_RS', '2_KVS', '2_AMS', '2_OL', '2_FC', '2_M', '2_FF', '3_R1', '3_R2', '3_R3', '3_R4', '3_L1', '3_L2', '3_L3', '3_L4', '3_BP', '3_LD', '3_RD', '3_RC', '3_LC', '3_DET', '3_RS', '3_KVS', '3_AMS', '3_OL', '3_FC', '3_M', '3_FF', '4_R1', '4_R2', '4_R3', '4_R4', '4_L1', '4_L2', '4_L3', '4_L4', '4_BP', '4_LD', '4_RD', '4_RC', '4_LC', '4_DET', '4_RS', '4_KVS', '4_AMS', '4_OL', '4_FC', '4_M', '4_FF', '5_R1', '5_R2', '5_R3', '5_R4', '5_L1', '5_L2', '5_L3', '5_L4', '5_BP', '5_LD', '5_RD', '5_RC', '5_LC', '5_DET', '5_RS', '5_KVS', '5_AMS', '5_OL', '5_FC', '5_M', '5_FF']

def uuid():

    return uuid4().hex

ARTIFACTS_ROOT = Path('~/audio/artifacts').expanduser()



GLOBAL_VALID_RANGES = {
    'PR1':  range(0, 99+1),
    'PR2':  range(0, 99+1),
    'PR3':  range(0, 99+1),
    'PR4':  range(0, 99+1),
    'PL1':  range(0, 99+1),
    'PL2':  range(0, 99+1),
    'PL3':  range(0, 99+1),
    'PL4':  range(0, 99+1),
    'ALG':  range(0, 31+1),
    'OKS':  range(0, 1+1),
    'FB':   range(0, 7+1),
    'LFS':  range(0, 99+1),
    'LFD':  range(0, 99+1),
    'LPMD':  range(0, 99+1),
    'LAMD':  range(0, 99+1),
    'LPMS': range(0, 7+1),
    'LFW':  range(0, 5+1),
    'LKS':  range(0, 1+1),
    'TRNSP':  range(0, 48+1),
    'NAME_0': range(128),
    'NAME_1': range(128),
    'NAME_2': range(128),
    'NAME_3': range(128),
    'NAME_4': range(128),
    'NAME_5': range(128),
    'NAME_6': range(128),
    'NAME_7': range(128),
    'NAME_8': range(128),
    'NAME_9': range(128),
 }

OSCILLATOR_VALID_RANGES = {
    'R1':  range(0, 99+1),
    'R2':  range(0, 99+1),
    'R3':  range(0, 99+1),
    'R4':  range(0, 99+1),
    'L1':  range(0, 99+1),
    'L2':  range(0, 99+1),
    'L3':  range(0, 99+1),
    'L4':  range(0, 99+1),
    'BP':  range(0, 99+1),
    'LD':  range(0, 99+1),
    'RD':  range(0, 99+1),
    'RC':  range(0, 3+1),
    'LC':  range(0, 3+1),
    'DET': range(0, 14+1),
    'RS':  range(0, 7+1),
    'KVS': range(0, 7+1),
    'AMS': range(0, 3+1),
    'OL':  range(0, 99+1),
    'FC':  range(0, 31+1),
    'M':   range(0, 1+1),
    'FF':  range(0, 99+1),
}

def verify(actual, ranges, prefix=None):
        
    assert set(actual.keys())==set(ranges.keys()), 'Params dont match'

    for key in actual:
        if not actual[key] in ranges[key]:
            # print(f'{key} value {actual[key]} should be in {ranges[key]}')
            return False
    return True

# # %%
# presets = [mido.read_syx_file(patch.as_posix()) for patch in iter(dexed_presets)]

N_OSC = 6
N_VOICE = 32
# # %%
# preset = map(lambda x: x[0], presets[3][0]
# %%
def consume_head(sysex_iter):
    """
    ///////////////////////////////////////////////////////////
    B:
    SYSEX Message: Bulk Data for 32 Voices
    --------------------------------------
        bits    hex  description

        11110000  F0   Status byte - start sysex
        0iiiiiii  43   ID # (i=67; Yamaha)
        0sssnnnn  00   Sub-status (s=0) & channel number (n=0; ch 1)
        0fffffff  09   format number (f=9; 32 voices)
        0bbbbbbb  20   byte count MS byte
        0bbbbbbb  00   byte count LS byte (b=4096; 32 voices)
        0ddddddd  **   data byte 1

            |       |       |

        0ddddddd  **   data byte 4096  (there are 128 bytes / voice)
        0eeeeeee  **   checksum (masked 2's comp. of sum of 4096 bytes)
        11110111  F7   Status - end sysex


    /////////////////////////////////////////////////////////////

    """

    expected = ['0x43',
                '0x00',
                '0x09',
                '0x20',
                '0x00',]

    for i in expected:
        assert int(i, 0) == next(sysex_iter), 'unexpected header'

# consume_head(sysex_iter)
#%%
def consume_osc(sysex_iter):

    """
    byte             bit #
    #     6   5   4   3   2   1   0   param A       range  param B       range
    ----  --- --- --- --- --- --- ---  ------------  -----  ------------  -----
    0                R1              OP6 EG R1      0-99
    1                R2              OP6 EG R2      0-99
    2                R3              OP6 EG R3      0-99
    3                R4              OP6 EG R4      0-99
    4                L1              OP6 EG L1      0-99
    5                L2              OP6 EG L2      0-99
    6                L3              OP6 EG L3      0-99
    7                L4              OP6 EG L4      0-99
    8                BP              LEV SCL BRK PT 0-99
    9                LD              SCL LEFT DEPTH 0-99
    10                RD              SCL RGHT DEPTH 0-99
    11    0   0   0 |  RC   |   LC  | SCL LEFT CURVE 0-3   SCL RGHT CURVE 0-3
    12  |      DET      |     RS    | OSC DETUNE     0-14  OSC RATE SCALE 0-7
    13    0   0 |    KVS    |  AMS  | KEY VEL SENS   0-7   AMP MOD SENS   0-3
    14                OL              OP6 OUTPUT LEV 0-99
    15    0 |         FC        | M | FREQ COARSE    0-31  OSC MODE       0-1
    16                FF              FREQ FINE      0-99
    """

    def process_byte(this_byte):

        this_byte = this_byte & int('0b1111111', 0)

        return int(this_byte)

    int_sysex_iter = iter(map(process_byte, sysex_iter))

    R1 = next(int_sysex_iter)
    R2 = next(int_sysex_iter)
    R3 = next(int_sysex_iter)
    R4 = next(int_sysex_iter)
    L1 = next(int_sysex_iter)
    L2 = next(int_sysex_iter)
    L3 = next(int_sysex_iter)
    L4 = next(int_sysex_iter)
    BP = next(int_sysex_iter)
    LD = next(int_sysex_iter)
    RD = next(int_sysex_iter)

    _RC_LC = next(int_sysex_iter) & int('0b1111', 0)
    RC = _RC_LC >> 2
    LC = _RC_LC & int('0b11', 0)



    _DET_RS = next(int_sysex_iter) & int('0b11111', 0)
    DET = _DET_RS >> 4
    RS = _DET_RS & int('0b111', 0)

    _KVS_AMS = next(int_sysex_iter) & int('0b11111', 0)
    KVS = _KVS_AMS >> 2
    AMS = _KVS_AMS & int('0b11', 0)

    OL = next(int_sysex_iter)

    _FC_M = next(int_sysex_iter) & int('0b111111', 0)
    FC = _FC_M >> 1
    M = _FC_M & int('0b1', 0)

    FF = next(int_sysex_iter)


    oscilattor_config = {
        'R1': R1,
        'R2': R2,
        'R3': R3,
        'R4': R4,
        'L1': L1,
        'L2': L2,
        'L3': L3,
        'L4': L4,
        'BP': BP,
        'LD': LD,
        'RD': RD,
        'RC': RC,
        'LC': LC,
        'DET': DET,
        'RS': RS,
        'KVS': KVS,
        'AMS': AMS,
        'OL': OL,
        'FC': FC,
        'M': M,
        'FF': FF,
    }

    return oscilattor_config
#%%



def consume_global(sysex_iter):
    """
        byte             bit #
        #     6   5   4   3   2   1   0   param A       range  param B       range
        ----  --- --- --- --- --- --- ---  ------------  -----  ------------  -----
        102               PR1              PITCH EG R1   0-99
        103               PR2              PITCH EG R2   0-99
        104               PR3              PITCH EG R3   0-99
        105               PR4              PITCH EG R4   0-99
        106               PL1              PITCH EG L1   0-99
        107               PL2              PITCH EG L2   0-99
        108               PL3              PITCH EG L3   0-99
        109               PL4              PITCH EG L4   0-99
        110    0   0 |        ALG        | ALGORITHM     0-31
        111    0   0   0 |OKS|    FB     | OSC KEY SYNC  0-1    FEEDBACK      0-7
        112               LFS              LFO SPEED     0-99
        113               LFD              LFO DELAY     0-99
        114               LPMD             LF PT MOD DEP 0-99
        115               LAMD             LF AM MOD DEP 0-99
        116  |  LPMS |      LFW      |LKS| LF PT MOD SNS 0-7   WAVE 0-5,  SYNC 0-1
        117              TRNSP             TRANSPOSE     0-48
        118          NAME CHAR 1           VOICE NAME 1  ASCII
        119          NAME CHAR 2           VOICE NAME 2  ASCII
        120          NAME CHAR 3           VOICE NAME 3  ASCII
        121          NAME CHAR 4           VOICE NAME 4  ASCII
        122          NAME CHAR 5           VOICE NAME 5  ASCII
        123          NAME CHAR 6           VOICE NAME 6  ASCII
        124          NAME CHAR 7           VOICE NAME 7  ASCII
        125          NAME CHAR 8           VOICE NAME 8  ASCII
        126          NAME CHAR 9           VOICE NAME 9  ASCII
        127          NAME CHAR 10          VOICE NAME 10 ASCII
    """

    def process_byte(this_byte):

        this_byte = this_byte & int('0b111111', 0)

        return this_byte

    sysex_iter = iter(map(process_byte, sysex_iter))


    PR1 = int(next(sysex_iter))
    PR2 = int(next(sysex_iter))
    PR3 = int(next(sysex_iter))
    PR4 = int(next(sysex_iter))
    PL1 = int(next(sysex_iter))
    PL2 = int(next(sysex_iter))
    PL3 = int(next(sysex_iter))
    PL4 = int(next(sysex_iter))

    ALG = int(next(sysex_iter)) & int('0b11111', 0)

    OKS_FB = int(next(sysex_iter))
    OKS = (OKS_FB & int('0b1000', 0)) >> 3
    FB = (OKS_FB & int('0b111', 0))

    LFS = int(next(sysex_iter))
    LFD = int(next(sysex_iter))
    LPMD = int(next(sysex_iter))
    LAMD = int(next(sysex_iter))

    LPMS_LFW_LKS = int(next(sysex_iter))
    LPMS = LPMS_LFW_LKS >> 4
    LFW = (LPMS_LFW_LKS >> 1) & int('0b111', 0)
    LKS = LPMS_LFW_LKS & int('0b1', 0)

    TRNSP = int(next(sysex_iter))

    to_ascii = lambda byte: bytes.fromhex(byte.decode("ascii")).decode('ascii')
    to_ascii = lambda byte: ascii(byte.decode('ascii'))
    NAME = [(f'NAME_{i}', int(next(sysex_iter))) for i in range(10)]

    global_config = {
        'PR1': PR1,
        'PR2': PR2,
        'PR3': PR3,
        'PR4': PR4,
        'PL1': PL1,
        'PL2': PL2,
        'PL3': PL3,
        'PL4': PL4,
        'ALG': ALG,
        'OKS': OKS,
        'FB': FB,
        'LFS': LFS,
        'LFD': LFD,
        'LPMD': LPMD,
        'LAMD': LAMD,
        'LPMS': LPMS,
        'LFW': LFW,
        'LKS': LKS,
        'TRNSP': TRNSP,
    }
    global_config.update(NAME)

    return uuid(), global_config

def consume_syx(path):

    path = Path(path).expanduser()

    try:
        preset = mido.read_syx_file(path.as_posix())[0]
    except IndexError as e:
        return None
    except ValueError as e:
        return None

    sysex_iter = iter(preset.data)
    # print(len(list(preset.bytes())))
    try:
        consume_head(sysex_iter)
    except AssertionError as e:
        return None

    for i in range(N_VOICE):
        def consume_oscillator():

            oscilattor_config = consume_osc(sysex_iter)

            if not verify(oscilattor_config, OSCILLATOR_VALID_RANGES):
                raise ValueError('Oscillator has values outside range')

            return oscilattor_config.items()

        prefix_oscillator = lambda n: [(f'{n}_{key}', value) for key,value in consume_oscillator()]
        oscilattor_mapper = chain.from_iterable(map(prefix_oscillator, range(N_OSC)))
        
        # oscillator_config = {i: consume_osc(sysex_iter) for i in range(N_OSC)}
        patch_config = {}
        has_error = False
        # oscilattor
        try:
            patch_config.update(oscilattor_mapper)
        except ValueError:
            # invalid range in oscillators
            has_error = True

        name, global_config = consume_global(sysex_iter)

        if not verify(global_config, GLOBAL_VALID_RANGES) or has_error:
            # print('eror')
            yield
            continue

        patch_config.update(global_config)




        yield ((path, i), patch_config)
#%%

if __name__=='__
Download .txt
gitextract_l201xfp8/

├── .gitignore
├── LICENSE
├── MANIFEST.in
├── README.md
├── neuralDX7/
│   ├── __init__.py
│   ├── constants.py
│   ├── datasets/
│   │   ├── __init__.py
│   │   └── dx7_sysex_dataset.py
│   ├── models/
│   │   ├── __init__.py
│   │   ├── attention/
│   │   │   ├── __init__.py
│   │   │   ├── attention.py
│   │   │   ├── attention_encoder.py
│   │   │   ├── attention_layer.py
│   │   │   └── conditional_attention_encoder.py
│   │   ├── dx7_cnp.py
│   │   ├── dx7_np.py
│   │   ├── dx7_nsp.py
│   │   ├── dx7_vae.py
│   │   ├── general/
│   │   │   ├── __init__.py
│   │   │   └── gelu_ff.py
│   │   ├── stochastic_nodes/
│   │   │   ├── __init__.py
│   │   │   ├── normal.py
│   │   │   └── triangular_sylvester.py
│   │   └── utils.py
│   ├── solvers/
│   │   ├── __init__.py
│   │   ├── dx7_np.py
│   │   ├── dx7_nsp.py
│   │   ├── dx7_patch_process.py
│   │   ├── dx7_vae.py
│   │   └── utils.py
│   └── utils.py
├── projects/
│   ├── dx7_np/
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   ├── dx7_nsp/
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   ├── dx7_patch_neural_process/
│   │   ├── evaluate.py
│   │   ├── features_analysis.py
│   │   └── ray_train.py
│   ├── dx7_vae/
│   │   ├── duplicate_test.py
│   │   ├── evaluate.py
│   │   ├── experiment.py
│   │   ├── features.py
│   │   ├── interpoalte.py
│   │   └── live.py
│   └── mnist_neural_process/
│       └── experiment.py
├── requirements.txt
├── scratch/
│   ├── dx7-sysexformat.md
│   ├── dx7_constants.py
│   ├── dx7_syx.py
│   ├── fm-param-analysis.py
│   ├── fm_param_ae.py
│   ├── fm_param_agoge_vae_rnn.py
│   ├── fm_param_rnn_decoder.py
│   ├── fm_param_vae.py
│   ├── fm_param_vae_rnn.py
│   ├── syx_parser.py
│   └── syx_write.py
├── setup.cfg
├── setup.py
└── version
Download .txt
SYMBOL INDEX (202 symbols across 37 files)

FILE: neuralDX7/constants.py
  function take (line 6) | def take(take_from, n):
  function checksum (line 13) | def checksum(data):
  function verify (line 75) | def verify(actual, ranges):
  class DX7Single (line 259) | class DX7Single():
    method keys (line 319) | def keys():
    method struct (line 328) | def struct():
    method to_syx (line 333) | def to_syx(voices):
  function consume_syx (line 351) | def consume_syx(path):

FILE: neuralDX7/datasets/dx7_sysex_dataset.py
  class DX7SysexDataset (line 9) | class DX7SysexDataset():
    method __init__ (line 15) | def __init__(self, data_file='dx7.npy', root=DEFAULTS['ARTIFACTS_ROOT'...
    method __getitem__ (line 32) | def __getitem__(self, index):
    method __len__ (line 39) | def __len__(self):

FILE: neuralDX7/models/attention/attention.py
  class Attention (line 4) | class Attention(nn.Module):
    method __init__ (line 8) | def __init__(self, n_features, n_hidden, n_heads=8, inf=1e9):
    method inf (line 22) | def inf(self):
    method forward (line 27) | def forward(self, X, A):

FILE: neuralDX7/models/attention/attention_encoder.py
  class ResidualAttentionEncoder (line 12) | class ResidualAttentionEncoder(AbstractModel):
    method __init__ (line 17) | def __init__(self, features, attention_layer, max_len=200, n_layers=3):
    method forward (line 39) | def forward(self, X, A):

FILE: neuralDX7/models/attention/attention_layer.py
  class AttentionLayer (line 8) | class AttentionLayer(nn.Module):
    method __init__ (line 14) | def __init__(self, features, hidden_dim, attention):
    method forward (line 35) | def forward(self, X, A):

FILE: neuralDX7/models/attention/conditional_attention_encoder.py
  class CondtionalResidualAttentionEncoder (line 13) | class CondtionalResidualAttentionEncoder(AbstractModel):
    method __init__ (line 18) | def __init__(self, features, c_features, attention_layer, max_len=200,...
    method forward (line 42) | def forward(self, X, A, c):

FILE: neuralDX7/models/dx7_cnp.py
  class DX7PatchProcess (line 10) | class DX7PatchProcess(AbstractModel):
    method __init__ (line 17) | def __init__(self, features, encoder):
    method forward (line 26) | def forward(self, X):
    method features (line 51) | def features(self, X):
    method generate (line 63) | def generate(self, X, X_a):

FILE: neuralDX7/models/dx7_np.py
  class DX7NeuralProcess (line 13) | class DX7NeuralProcess(AbstractModel):
    method __init__ (line 19) | def __init__(self, features, latent_dim, encoder, decoder, determinist...
    method latent_encoder (line 34) | def latent_encoder(self,  X, A, mean=False):
    method forward (line 41) | def forward(self, X):
    method features (line 77) | def features(self, X, X_a):
    method generate_z (line 89) | def generate_z(self, X, X_a, z, t=1.):
    method generate (line 107) | def generate(self, X, X_a, sample=True, t=1.):

FILE: neuralDX7/models/dx7_nsp.py
  class DX7NeuralSylvesterProcess (line 13) | class DX7NeuralSylvesterProcess(AbstractModel):
    method __init__ (line 19) | def __init__(self, features, latent_dim, encoder, decoder, determinist...
    method latent_encoder (line 34) | def latent_encoder(self,  X, A, z=None, flow=True):
    method forward (line 41) | def forward(self, X):
    method features (line 78) | def features(self, X, X_a):
    method generate_z (line 90) | def generate_z(self, X, X_a, z, t=1.):
    method generate (line 108) | def generate(self, X, X_a, sample=True, t=1.):

FILE: neuralDX7/models/dx7_vae.py
  class DX7VAE (line 13) | class DX7VAE(AbstractModel):
    method __init__ (line 20) | def __init__(self, features, latent_dim, encoder, decoder, num_flows=3):
    method latent_encoder (line 43) | def latent_encoder(self,  X, A, z=None, mean=False):
    method forward (line 58) | def forward(self, X):
    method features (line 86) | def features(self, X):
    method generate (line 103) | def generate(self, z, t=1.):

FILE: neuralDX7/models/general/gelu_ff.py
  class FeedForwardGELU (line 4) | class FeedForwardGELU(nn.Module):
    method __init__ (line 10) | def __init__(self, features, out_features=None, exapnsion_factor=3):
    method forward (line 26) | def forward(self, x):

FILE: neuralDX7/models/stochastic_nodes/normal.py
  class NormalNode (line 6) | class NormalNode(nn.Module):
    method __init__ (line 16) | def __init__(self, in_features, latent_dim, hidden_dim=None):
    method forward (line 34) | def forward(self, x, *args, **kwargs):

FILE: neuralDX7/models/stochastic_nodes/triangular_sylvester.py
  class TriangularSylvester (line 15) | class TriangularSylvester(nn.Module):
    method __init__ (line 20) | def __init__(self, z_size):
    method der_h (line 30) | def der_h(self, x):
    method der_tanh (line 33) | def der_tanh(self, x):
    method forward (line 36) | def forward(self, zk, r1, r2, b, permute_z=None, sum_ldj=True):
  class TriangularSylvesterFlow (line 90) | class TriangularSylvesterFlow(nn.Module):
    method __init__ (line 96) | def __init__(self, in_features, latent_dim, num_flows):
    method flow_params (line 122) | def flow_params(self, h):
    method forward (line 145) | def forward(self, h, z=None, flow=True):

FILE: neuralDX7/models/utils.py
  function position_encoding_init (line 5) | def position_encoding_init(n_position, emb_dim):

FILE: neuralDX7/solvers/dx7_np.py
  class DX7NeuralProcess (line 11) | class DX7NeuralProcess(AbstractSolver):
    method __init__ (line 16) | def __init__(self, model,
    method loss (line 33) | def loss(self, x, x_hat, x_a, q_context, q_target, z):
    method solve (line 62) | def solve(self, x, **kwargs):
    method step (line 79) | def step(self):
    method state_dict (line 84) | def state_dict(self):
    method load_state_dict (line 93) | def load_state_dict(self, state_dict):

FILE: neuralDX7/solvers/dx7_nsp.py
  class DX7NeuralSylvesterProcess (line 11) | class DX7NeuralSylvesterProcess(AbstractSolver):
    method __init__ (line 16) | def __init__(self, model,
    method loss (line 32) | def loss(self, X, X_hat, X_a, flow_context, flow_target):
    method solve (line 61) | def solve(self, X, **kwargs):
    method step (line 78) | def step(self):
    method state_dict (line 83) | def state_dict(self):
    method load_state_dict (line 92) | def load_state_dict(self, state_dict):

FILE: neuralDX7/solvers/dx7_patch_process.py
  class DX7PatchProcess (line 10) | class DX7PatchProcess(AbstractSolver):
    method __init__ (line 15) | def __init__(self, model,
    method loss (line 28) | def loss(self, x, x_hat, x_a):
    method solve (line 44) | def solve(self, x, **kwargs):
    method step (line 61) | def step(self):
    method state_dict (line 66) | def state_dict(self):
    method load_state_dict (line 74) | def load_state_dict(self, state_dict):

FILE: neuralDX7/solvers/dx7_vae.py
  class DX7VAE (line 11) | class DX7VAE(AbstractSolver):
    method __init__ (line 16) | def __init__(self, model,
    method loss (line 32) | def loss(self, X, X_hat, flow):
    method solve (line 65) | def solve(self, X, **kwargs):
    method step (line 86) | def step(self):
    method state_dict (line 90) | def state_dict(self):
    method load_state_dict (line 99) | def load_state_dict(self, state_dict):

FILE: neuralDX7/solvers/utils.py
  function sigmoidal_annealing (line 4) | def sigmoidal_annealing(iter_nb, t=1e-4, s=-6):

FILE: neuralDX7/utils.py
  function mask_parameters (line 12) | def mask_parameters(x, voice_keys=VOICE_KEYS, inf=1e9):
  function consume_syx (line 25) | def consume_syx(path):
  function dx7_bulk_pack (line 55) | def dx7_bulk_pack(voices):
  function generate_syx (line 73) | def generate_syx(patch_list):

FILE: projects/dx7_np/experiment.py
  function config (line 14) | def config(experiment_name, trial_name,

FILE: projects/dx7_np/live.py
  function slerp (line 53) | def slerp(val, low, high):
  function process (line 78) | def process(frames):
  function samplerate (line 179) | def samplerate(samplerate):
  function shutdown (line 185) | def shutdown(status, reason):

FILE: projects/dx7_nsp/experiment.py
  function config (line 14) | def config(experiment_name, trial_name,

FILE: projects/dx7_nsp/live.py
  function slerp (line 53) | def slerp(val, low, high):
  function process (line 78) | def process(frames):
  function samplerate (line 180) | def samplerate(samplerate):
  function shutdown (line 186) | def shutdown(status, reason):

FILE: projects/dx7_patch_neural_process/ray_train.py
  function config (line 16) | def config(experiment_name, trial_name,

FILE: projects/dx7_vae/experiment.py
  function config (line 14) | def config(experiment_name, trial_name,

FILE: projects/dx7_vae/live.py
  function slerp (line 41) | def slerp(val, low, high):
  function process (line 66) | def process(frames):
  function samplerate (line 117) | def samplerate(samplerate):
  function shutdown (line 123) | def shutdown(status, reason):

FILE: projects/mnist_neural_process/experiment.py
  class MNISTDataset (line 26) | class MNISTDataset():
    method __init__ (line 28) | def __init__(self, data_path=DEFAULTS['ARTIFACTS_ROOT'], transform=None):
    method __getitem__ (line 43) | def __getitem__(self, i):
    method __len__ (line 51) | def __len__(self):
  function config (line 61) | def config(experiment_name, trial_name,

FILE: scratch/dx7_constants.py
  function take (line 6) | def take(take_from, n):
  function checksum (line 13) | def checksum(data):
  function verify (line 75) | def verify(actual, ranges):

FILE: scratch/dx7_syx.py
  function consume_syx (line 13) | def consume_syx(path):

FILE: scratch/fm_param_ae.py
  class DX7Dataset (line 30) | class DX7Dataset():
    method __init__ (line 33) | def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):
    method __getitem__ (line 41) | def __getitem__(self, index):
    method __len__ (line 46) | def __len__(self):
  class Net (line 51) | class Net(nn.Module):
    method __init__ (line 52) | def __init__(self, latent_dim=16, n_params=N_PARAMS, max_value=MAX_VAL...
    method generate_mask (line 77) | def generate_mask():
    method forward (line 84) | def forward(self, x):
  function train (line 97) | def train(model, device, train_loader, optimizer, epoch):
  function test (line 112) | def test(model, device, test_loader):

FILE: scratch/fm_param_agoge_vae_rnn.py
  class DX7Dataset (line 40) | class DX7Dataset():
    method __init__ (line 43) | def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT, data_size...
    method __getitem__ (line 54) | def __getitem__(self, index):
    method __len__ (line 60) | def __len__(self):
  class DX7RecurrentVAE (line 67) | class DX7RecurrentVAE(AbstractModel):
    method __init__ (line 68) | def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALU...
    method network (line 105) | def network(self, x, network):
    method generate_mask (line 127) | def generate_mask(ordering=None):
    method forward (line 141) | def forward(self, x):
    method generate (line 172) | def generate(self, z, t=1.):
  class DX7RecurrentVAESolver (line 183) | class DX7RecurrentVAESolver(AbstractSolver):
    method __init__ (line 185) | def __init__(self, model,
    method beta (line 201) | def beta(self):
    method scheduler (line 206) | def scheduler():
    method loss (line 220) | def loss(self, x, x_hat, q_z, z):
    method solve (line 248) | def solve(self, x, **kwargs):
    method step (line 267) | def step(self):
    method state_dict (line 272) | def state_dict(self):
    method load_state_dict (line 280) | def load_state_dict(self, state_dict):
  function config (line 287) | def config(experiment_name, trial_name, batch_size=16, **kwargs):

FILE: scratch/fm_param_rnn_decoder.py
  class DX7Dataset (line 27) | class DX7Dataset():
    method __init__ (line 30) | def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):
    method __getitem__ (line 38) | def __getitem__(self, index):
    method __len__ (line 43) | def __len__(self):
  class Net (line 50) | class Net(nn.Module):
    method __init__ (line 51) | def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE):
    method network (line 68) | def network(self, x, network):
    method generate_mask (line 87) | def generate_mask():
    method forward (line 94) | def forward(self, x):
    method generate (line 105) | def generate(self, z, t=1.):
  function train (line 115) | def train(model, device, train_loader, optimizer, epoch):
  function test (line 131) | def test(model, device, test_loader):
  function scheduler (line 160) | def scheduler():

FILE: scratch/fm_param_vae.py
  class DX7Dataset (line 27) | class DX7Dataset():
    method __init__ (line 30) | def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):
    method __getitem__ (line 38) | def __getitem__(self, index):
    method __len__ (line 43) | def __len__(self):
  class Net (line 50) | class Net(nn.Module):
    method __init__ (line 51) | def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALUE):
    method generate_mask (line 76) | def generate_mask():
    method forward (line 83) | def forward(self, x):
    method generate (line 99) | def generate(self, z, t=1.):
  function train (line 109) | def train(model, device, train_loader, optimizer, epoch):
  function test (line 129) | def test(model, device, test_loader):

FILE: scratch/fm_param_vae_rnn.py
  class DX7Dataset (line 27) | class DX7Dataset():
    method __init__ (line 30) | def __init__(self, data_file='dx7.npy', root=ARTIFACTS_ROOT):
    method __getitem__ (line 38) | def __getitem__(self, index):
    method __len__ (line 43) | def __len__(self):
  class Net (line 50) | class Net(nn.Module):
    method __init__ (line 51) | def __init__(self, latent_dim=8, n_params=N_PARAMS, max_value=MAX_VALU...
    method network (line 80) | def network(self, x, network):
    method generate_mask (line 102) | def generate_mask():
    method forward (line 109) | def forward(self, x):
    method generate (line 134) | def generate(self, z, t=1.):
  function train (line 144) | def train(model, device, train_loader, optimizer, epoch):
  function test (line 166) | def test(model, device, test_loader):
  function scheduler (line 198) | def scheduler():

FILE: scratch/syx_parser.py
  function uuid (line 12) | def uuid():
  function verify (line 76) | def verify(actual, ranges, prefix=None):
  function consume_head (line 94) | def consume_head(sysex_iter):
  function consume_osc (line 132) | def consume_osc(sysex_iter):
  function consume_global (line 229) | def consume_global(sysex_iter):
  function consume_syx (line 327) | def consume_syx(path):

FILE: scratch/syx_write.py
  function checksum (line 19) | def checksum(data):
  function encode_head (line 24) | def encode_head():
  function encode_osc (line 61) | def encode_osc(params, n):
  function encode_global (line 127) | def encode_global(params):
  function encode_syx (line 201) | def encode_syx(params_list):
Condensed preview — 66 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (213K chars).
[
  {
    "path": ".gitignore",
    "chars": 51,
    "preview": ".vscode/\n__pycache__/\n.empty\n*.pyc\n*.egg-info\ndist/"
  },
  {
    "path": "LICENSE",
    "chars": 1069,
    "preview": "MIT License\n\nCopyright (c) [year] [fullname]\n\nPermission is hereby granted, free of charge, to any person obtaining a co"
  },
  {
    "path": "MANIFEST.in",
    "chars": 32,
    "preview": "include version requirements.txt"
  },
  {
    "path": "README.md",
    "chars": 4387,
    "preview": "# FM Synth Parameter Generator\n\nRandom machine learning experiments related to the classic Yamaha DX7\n\n## Dexed\n\nDexed i"
  },
  {
    "path": "neuralDX7/__init__.py",
    "chars": 121,
    "preview": "\n\n\nfrom agoge import DEFAULTS, defaults_f\n\n\nDEFAULTS = defaults_f({\n    'ARTIFACT_ROOT': '~/agoge/artifacts'\n}, DEFAULTS"
  },
  {
    "path": "neuralDX7/constants.py",
    "chars": 8538,
    "preview": "from pathlib import Path\nimport bitstruct\nimport mido\n\n\ndef take(take_from, n):\n    for _ in range(n):\n        yield nex"
  },
  {
    "path": "neuralDX7/datasets/__init__.py",
    "chars": 46,
    "preview": "from .dx7_sysex_dataset import DX7SysexDataset"
  },
  {
    "path": "neuralDX7/datasets/dx7_sysex_dataset.py",
    "chars": 1177,
    "preview": "from pathlib import Path\nimport numpy as np\nimport torch\nfrom neuralDX7 import DEFAULTS\n\n\n\n\nclass DX7SysexDataset():\n   "
  },
  {
    "path": "neuralDX7/models/__init__.py",
    "chars": 148,
    "preview": "from .dx7_cnp import DX7PatchProcess\nfrom .dx7_np import DX7NeuralProcess\nfrom .dx7_nsp import DX7NeuralSylvesterProcess"
  },
  {
    "path": "neuralDX7/models/attention/__init__.py",
    "chars": 210,
    "preview": "from .attention import Attention\nfrom .attention_layer import AttentionLayer\nfrom .attention_encoder import ResidualAtte"
  },
  {
    "path": "neuralDX7/models/attention/attention.py",
    "chars": 1609,
    "preview": "import torch\nfrom torch import nn\n\nclass Attention(nn.Module):\n\n\n\n    def __init__(self, n_features, n_hidden, n_heads=8"
  },
  {
    "path": "neuralDX7/models/attention/attention_encoder.py",
    "chars": 2486,
    "preview": "import torch\nfrom os import environ\n\nfrom torch import nn\n\nfrom agoge import AbstractModel\nfrom neuralDX7.models.attenti"
  },
  {
    "path": "neuralDX7/models/attention/attention_layer.py",
    "chars": 1534,
    "preview": "import torch\nfrom torch import nn\n\nfrom neuralDX7.models.attention import Attention\n\n\n\nclass AttentionLayer(nn.Module):\n"
  },
  {
    "path": "neuralDX7/models/attention/conditional_attention_encoder.py",
    "chars": 2747,
    "preview": "import torch\nfrom os import environ\n\nfrom torch import nn\n\nfrom agoge import AbstractModel\nfrom neuralDX7.models.attenti"
  },
  {
    "path": "neuralDX7/models/dx7_cnp.py",
    "chars": 2805,
    "preview": "import torch\nfrom torch import nn\n\nfrom agoge import AbstractModel\nfrom neuralDX7.models.attention import ResidualAttent"
  },
  {
    "path": "neuralDX7/models/dx7_np.py",
    "chars": 4452,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom agoge import AbstractModel\nfrom neuralDX7.m"
  },
  {
    "path": "neuralDX7/models/dx7_nsp.py",
    "chars": 4419,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom agoge import AbstractModel\nfrom neuralDX7.m"
  },
  {
    "path": "neuralDX7/models/dx7_vae.py",
    "chars": 4665,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nfrom agoge import AbstractModel\nfrom neuralDX7.m"
  },
  {
    "path": "neuralDX7/models/general/__init__.py",
    "chars": 36,
    "preview": "from .gelu_ff import FeedForwardGELU"
  },
  {
    "path": "neuralDX7/models/general/gelu_ff.py",
    "chars": 830,
    "preview": "import torch\nfrom torch import nn\n\nclass FeedForwardGELU(nn.Module):\n    \"\"\"\n    Simple wrapper for two layer projection"
  },
  {
    "path": "neuralDX7/models/stochastic_nodes/__init__.py",
    "chars": 88,
    "preview": "from .normal import NormalNode\nfrom .triangular_sylvester import TriangularSylvesterFlow"
  },
  {
    "path": "neuralDX7/models/stochastic_nodes/normal.py",
    "chars": 1211,
    "preview": "from torch import nn\nfrom torch.distributions import Normal\n\n\n\nclass NormalNode(nn.Module):\n    \"\"\"\n    Simple module to"
  },
  {
    "path": "neuralDX7/models/stochastic_nodes/triangular_sylvester.py",
    "chars": 6246,
    "preview": "#%%\nfrom collections import namedtuple\nfrom itertools import count\nimport torch\nfrom torch import nn\nfrom neuralDX7.mode"
  },
  {
    "path": "neuralDX7/models/utils.py",
    "chars": 653,
    "preview": "\nimport torch\nimport numpy as np\n\ndef position_encoding_init(n_position, emb_dim):\n    ''' Init the sinusoid position en"
  },
  {
    "path": "neuralDX7/solvers/__init__.py",
    "chars": 158,
    "preview": "from .dx7_patch_process import DX7PatchProcess\nfrom .dx7_np import DX7NeuralProcess\nfrom .dx7_nsp import DX7NeuralSylves"
  },
  {
    "path": "neuralDX7/solvers/dx7_np.py",
    "chars": 2495,
    "preview": "import torch\nfrom torch.nn import functional as F\nfrom importlib import import_module\nfrom torch.optim import AdamW\nfrom"
  },
  {
    "path": "neuralDX7/solvers/dx7_nsp.py",
    "chars": 2504,
    "preview": "import torch\nfrom torch.nn import functional as F\nfrom importlib import import_module\nfrom torch.optim import AdamW\nfrom"
  },
  {
    "path": "neuralDX7/solvers/dx7_patch_process.py",
    "chars": 1675,
    "preview": "import torch\nfrom torch.nn import functional as F\nfrom importlib import import_module\nfrom torch.optim import AdamW\n\nfro"
  },
  {
    "path": "neuralDX7/solvers/dx7_vae.py",
    "chars": 2875,
    "preview": "import torch\nfrom torch.nn import functional as F\nfrom importlib import import_module\nfrom torch.optim import AdamW\nfrom"
  },
  {
    "path": "neuralDX7/solvers/utils.py",
    "chars": 335,
    "preview": "import torch\n\n\ndef sigmoidal_annealing(iter_nb, t=1e-4, s=-6):\n    \"\"\"\n\n    iter_nb - number of parameter updates comple"
  },
  {
    "path": "neuralDX7/utils.py",
    "chars": 1883,
    "preview": "\nimport mido\nimport torch\nimport numpy as np\nfrom pathlib import Path\nfrom itertools import chain\nfrom neuralDX7.constan"
  },
  {
    "path": "projects/dx7_np/evaluate.py",
    "chars": 2904,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_np/experiment.py",
    "chars": 3222,
    "preview": "#%%\nfrom os import environ\nenviron['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'\n\nfrom neuralDX7"
  },
  {
    "path": "projects/dx7_np/features.py",
    "chars": 849,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import tqd"
  },
  {
    "path": "projects/dx7_np/interpoalte.py",
    "chars": 2626,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_np/live.py",
    "chars": 6293,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport mido\nimport time\nimport numpy as np\nfrom tqd"
  },
  {
    "path": "projects/dx7_nsp/evaluate.py",
    "chars": 2904,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_nsp/experiment.py",
    "chars": 3418,
    "preview": "#%%\nfrom os import environ\nenviron['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'\n\nfrom neuralDX7"
  },
  {
    "path": "projects/dx7_nsp/features.py",
    "chars": 848,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import tqd"
  },
  {
    "path": "projects/dx7_nsp/interpoalte.py",
    "chars": 2626,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_nsp/live.py",
    "chars": 6296,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport mido\nimport time\nimport numpy as np\nfrom tqd"
  },
  {
    "path": "projects/dx7_patch_neural_process/evaluate.py",
    "chars": 3103,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_patch_neural_process/features_analysis.py",
    "chars": 1421,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_patch_neural_process/ray_train.py",
    "chars": 3407,
    "preview": "#%%\nfrom os import environ\nenviron['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'\n# environ['MLFL"
  },
  {
    "path": "projects/dx7_vae/duplicate_test.py",
    "chars": 788,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport mido\nimport time\nimport numpy as np\nfrom tqd"
  },
  {
    "path": "projects/dx7_vae/evaluate.py",
    "chars": 2904,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_vae/experiment.py",
    "chars": 3416,
    "preview": "#%%\nfrom os import environ\nenviron['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'\n\nfrom neuralDX7"
  },
  {
    "path": "projects/dx7_vae/features.py",
    "chars": 1105,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport time\nimport numpy as np\nfrom tqdm import tqd"
  },
  {
    "path": "projects/dx7_vae/interpoalte.py",
    "chars": 2626,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport torch\nfrom tqdm import tqdm\nfrom matplotlib import pyplot as plt\nworker = "
  },
  {
    "path": "projects/dx7_vae/live.py",
    "chars": 4323,
    "preview": "# %%\nfrom agoge import InferenceWorker\nimport threading\nimport torch\nimport mido\nimport time\nimport numpy as np\nfrom tqd"
  },
  {
    "path": "projects/mnist_neural_process/experiment.py",
    "chars": 4531,
    "preview": "#%%\nfrom os import environ\nenviron['MLFLOW_TRACKING_URI'] = 'http://tracking.olympus.nintorac.dev:9001/'\n# environ['MLFL"
  },
  {
    "path": "requirements.txt",
    "chars": 41,
    "preview": "bitstruct==8.9.0\nagoge==0.0.6\nmido==1.2.9"
  },
  {
    "path": "scratch/dx7-sysexformat.md",
    "chars": 14138,
    "preview": "Sysex Documentation \n===================\n\n(Message GUS:472)\nReceived: from mailhub.iastate.edu by po-3.iastate.edu \n\tid "
  },
  {
    "path": "scratch/dx7_constants.py",
    "chars": 5228,
    "preview": "from pathlib import Path\nimport bitstruct\n\nARTIFACTS_ROOT = Path('/content/gdrive/My Drive/audio/artifacts').expanduser("
  },
  {
    "path": "scratch/dx7_syx.py",
    "chars": 2750,
    "preview": "#%%\nimport bitstruct\nimport mido\nfrom pathlib import Path\nfrom itertools import chain\n\nfrom dx7_constants import voice_s"
  },
  {
    "path": "scratch/fm-param-analysis.py",
    "chars": 838,
    "preview": "#%%\nimport torch\nfrom torch.utils.data import Subset as DataSubset\nfrom sklearn.model_selection import train_test_split\n"
  },
  {
    "path": "scratch/fm_param_ae.py",
    "chars": 4877,
    "preview": "#%%\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.optim.lr_s"
  },
  {
    "path": "scratch/fm_param_agoge_vae_rnn.py",
    "chars": 11402,
    "preview": "#%%\nimport os\nos.environ['MLFLOW_TRACKING_URI'] = 'http://localhost:9001'\nimport torch\ntorch.randn(10,10).cuda()\nimport "
  },
  {
    "path": "scratch/fm_param_rnn_decoder.py",
    "chars": 5887,
    "preview": "#%%\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.optim.lr_s"
  },
  {
    "path": "scratch/fm_param_vae.py",
    "chars": 5673,
    "preview": "#%%\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.optim.lr_s"
  },
  {
    "path": "scratch/fm_param_vae_rnn.py",
    "chars": 7256,
    "preview": "#%%\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.optim.lr_s"
  },
  {
    "path": "scratch/syx_parser.py",
    "chars": 13708,
    "preview": "#%%\nimport mido\nfrom pathlib import Path\nimport json\nfrom uuid import uuid4\nfrom itertools import chain\nfrom tqdm import"
  },
  {
    "path": "scratch/syx_write.py",
    "chars": 7447,
    "preview": "#%%\nimport mido\nfrom pathlib import Path\nimport json\nfrom uuid import uuid4\nfrom itertools import chain\nfrom tqdm import"
  },
  {
    "path": "setup.cfg",
    "chars": 39,
    "preview": "[metadata]\ndescription-file = README.md"
  },
  {
    "path": "setup.py",
    "chars": 1314,
    "preview": "#!/usr/bin/env python\nfrom setuptools import setup, find_packages\nfrom pathlib import Path\n\nroot_path = Path(__file__).p"
  },
  {
    "path": "version",
    "chars": 5,
    "preview": "0.0.8"
  }
]

About this extraction

This page contains the full source code of the Nintorac/NeuralDX7 GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 66 files (197.0 KB), approximately 58.8k tokens, and a symbol index with 202 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!