Full Code of yoavz/music_rnn for AI

master 54f0386664ff cached
20 files
82.0 KB
20.2k tokens
54 symbols
1 requests
Download .txt
Repository: yoavz/music_rnn
Branch: master
Commit: 54f0386664ff
Files: 20
Total size: 82.0 KB

Directory structure:
gitextract_o25hts11/

├── .gitignore
├── README.md
├── css/
│   └── style.css
├── data_samples/
│   ├── alb_esp1.mid
│   ├── ashover_simple_chords_1.mid
│   ├── bach_chorale.mid
│   ├── koopa_troopa_beach.mid
│   └── reels_simple_chords_157.mid
├── index.html
├── install.sh
├── midi_util.py
├── model.py
├── nottingham_util.py
├── requirements.txt
├── rnn.py
├── rnn_sample.py
├── rnn_separate.py
├── rnn_test.py
├── sampling.py
└── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data
2012code
models
python-midi
research
tensorflow

*.midi
*.pyc
img/.DS_Store


================================================
FILE: README.md
================================================
Overview
============
A project that trains a LSTM recurrent neural network over a dataset of MIDI files. More information can be found on the [writeup about this project](http://yoavz.com/music_rnn/) or the [final report](http://yoavz.com/music_rnn_paper.pdf) written. *Warning: Some parts of this codebase are unfinished.*

Dependencies
============

* Python 2.7
* Anaconda
* Numpy (http://www.numpy.org/)
* Tensorflow (https://github.com/tensorflow/tensorflow) - 0.8
* Python Midi (https://github.com/vishnubob/python-midi.git)
* Mingus (https://github.com/bspaans/python-mingus)
* Matplotlib (http://matplotlib.org/)

Basic Usage
===========

1. Run `./install.sh` to create conda env, install dependencies and download data
2. `source activate music_rnn` to activate the conda environment
3. Run `python nottingham_util.py` to generate the sequences and chord mapping file to `data/nottingham.pickle`
4. Run `python rnn.py --run_name YOUR_RUN_NAME_HERE` to start training the model. Use the grid object in `rnn.py` to edit hyperparameter
   configurations.
5. `source deactivate` to deactivate the conda environment


================================================
FILE: css/style.css
================================================
body {
  font-family: 'Maven Pro', sans-serif;
  margin: 20px;
}

.audio-wrapper {
  padding: 10px 0px;
}

.audio-wrapper p {
  font-size: 14px;
  text-align: center;
  color: gray;
  margin-top: 5px;
  margin-bottom: 0;
}

.img-wrapper {
  text-align: center;
}

.img-wrapper > img {
  max-width: 560px;
  width: 100%;
}

.img-wrapper > p {
  font-size: 14px;
  text-align: center;
  color: gray;
  margin-top: 0;
}

a {
  color:gray;
  text-decoration:none;
}

a:hover {
  color:lightgray;
}

.container {
  max-width: 700px;
  margin: 0 auto;
}

h2:after {
  margin-top: 5px;
  content: ' ';
  display: block;
  border: 1px solid black;
}

/* graph titles */
h3 {
  text-align: center;
}

/* LEGEND styling */

.legend {
  overflow: auto;
}

.legend .legend-title {
  text-align: left;
  margin-bottom: 8px;
  font-weight: bold;
  font-size: 90%;
}

.legend .legend-scale ul {
  margin-top: 10px;
  margin-bottom: 0px;
  padding: 0;
  float: left;
  list-style: none;
}

.legend .legend-scale ul li {
  display: block;
  float: left;
  width: 100px;
  text-align: center;
  font-size: 80%;
  list-style: none;
}

.legend ul.pie-legend li span {
  display: inline-block;
  height: 15px;
  width: 75px;
}

.legend ul.bar-legend li span {
  display: inline-block;
  height: 15px;
  width: 75px;
}


================================================
FILE: index.html
================================================
<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <meta name="description" content="Music Language Modeling with Recurrent Neural Networks">
    <meta name="author" content="Yoav Zimmerman">

    <title>Music Language Modeling with RNN's</title>

    <link href='http://fonts.googleapis.com/css?family=Maven+Pro:400,700' rel='stylesheet' type='text/css'>
    <link rel="stylesheet" type="text/css" href="css/mediaelementplayer.min.css" />
    <link rel="stylesheet" type="text/css" href="css/style.css">

    <!-- load and configure LaTeX -->
    <script type="text/javascript" src="http://cdn.mathjax.org/mathjax/latest/MathJax.js">
    MathJax.Hub.Config({
     extensions: ["tex2jax.js","TeX/AMSmath.js","TeX/AMSsymbols.js"],
     jax: ["input/TeX", "output/HTML-CSS"],
     tex2jax: {
         inlineMath: [ ['$','$'], ["\\(","\\)"] ],
         displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
     },
     "HTML-CSS": { availableFonts: ["TeX"] }
    });
    </script> 

  </head>

  <body>

    <div class="container">

      <h1>Music Language Modeling with Recurrent Neural Networks</h1>

      <div class="section">
        <h2>TL;DR</h2>
        <p>I trained a <a href="https://en.wikipedia.org/wiki/Long_short-term_memory">Long Short-Term Memory (LSTM) Recurrent Neural Network</a> on a dataset of around 650 jigs and folk tunes. I sampled from this model to generate the following musical pieces: </p>
        <div class="audio-wrapper">
          <audio src="generated_music/mp3/5.mp3" width="100%" preload="none"></audio>
        </div>
        <div class="audio-wrapper">
          <audio src="generated_music/mp3/7.mp3" width="100%" preload="none"></audio>
        </div>
        <p>You can find 8 more pieces <a href="#track-marker">here</a>
      </div>

      <div class="section">
        <h2>Introduction</h2>
        <p>Neural Networks are all the rage these days, and with good reason. Microsoft Research's <a href="http://image-net.org/challenges/LSVRC/2015/results">winning model</a> on the 2015 ImageNet competition is classifying images with 3.57% error rate (human performance is 5.1%). Google used a variant to <a href="https://deepmind.com/alpha-go.html">crush</a> one of the world's best Go players 4-1. Crazy things are happening in the field, with no sign of slowing down. In this project, I've applied recurrent neural nets to learn a predictive model over symbolic sequences of music.</p>
        <i>Disclaimer: This post assumes a familiarity with machine learning and neural networks. For a good overview of RNN's, I highly recommend reading <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness/">Andrej Karpathy's excellent blog post here</a> for an in-depth explanation</i>.
      </div>
      
      <div class="section">
        <h2>Music Language Modeling</h2>
        <p>Music Language Modeling is the problem of modeling symbolic sequences of polyphonic music in a completely general piano roll representation. <i>Piano roll representation</i> is a key distinction here, meaning we're going to use the symbolic note sequences as represented by sheet music, as opposed to more complex, acoustically rich audio signals. MIDI files are perfect for this, as they encode all the note information exactly to how it would be displayed on a piano roll.</p> 
        <div class="img-wrapper">
          <img src="img/encoding.png"></img>
        </div>
        <p>The most straightforward way to learn this way is to discretize a piece of music into uniform time steps. There are 88 possible pitches from A0 to C8 in a MIDI file, so every time step is encoded into an 88-dimensional binary vector as shown above. A value of 1 at index <i>i</i> indicates that pitch <i>i</i> is playing at a given time step. Then we plug this sequence of input vectors into an RNN architecture, where at each step the target is to predict the next time step of the sequence. A trained model outputs the conditional distribution of notes at a time step, given the all the time steps that have occured before it.</p>

        <p>One problem with this naive formulation is that the amount of potential note configurations is too high ($2^{N}$ for $N$ possible notes) to take the softmax classification approach normally used in image classification and language modeling. Instead, we need to use a sigmoid cross-entropy loss function to predict the probability of whether each note class is active or not  <i>separately</i>. However, this approach does not make much sense for the complex joint distribution of notes usually found in a time step. For example, <i>C</i> is much more likely than <i>C#</i> to be playing when <i>E</i> and <i>G</i> are also active, but separate classification targets implicity assumes independence between note probabilities at the same time step. <a href="http://www-etud.iro.umontreal.ca/~boulanni/ICML2012.pdf">Modeling Temporal Dependencies in High-Dimensional Sequences (Boulanger-Lewandowski, 2012)</a>, perhaps the most succesful research paper on MLM so far, attempts to solve this problem using energy based generative models such as the Restricted Boltzmann Machine (RBM). They propose the combined RNN-RBM architecture, which achieves state-of-the-art performance on several music datasets.</p>
      </div>

      <div class="section">
        <h2>Model</h2>
        <p> For my model, I decided to take the approach of introducing more musical structure into learning. Many musical pieces can be separated into two parts: a melody and harmony. I make the two following assumptions about a piece of music for my model: First, the melody is <i>monophonic</i> (only one note at most playing at every time step). Second, the harmony at each time step can be classified into a chord class. For example, a <i>C, E, G</i> active during a time step would is classified as <i>C Major</i>. These are strong assumptions, but they lead to the nice property of exactly one active melody class and one active harmony class at every time step. This allows us to take the sum of two softmax functions as the loss function for our model.</p>
        <div class="img-wrapper">
          <img src="img/dual_softmax_fig.png"></img>
        </div>
        <p>My model works in the following way: for every time step I encode the melody note into a one-hot-encoding binary vector. I then use the notes playing in the harmony to infer the chord class, and turn that into a one-hot-encoding binary vector as well. The full input vector is a concatentation of the melody and harmony vectors. This input vector then passes through hidden layer(s) of LSTM cells. The loss function is the sum of two separate softmax loss functions over the respective melody and harmony parts of the output layer.</p>

        <p style="text-align: center">
          $L(z, m, h) =  \alpha \, log \bigg( \frac{ e^{z_m} }{ \sum_{n=0}^{M-1}{ e^{z_n}}} \bigg) + (1 - \alpha) \, log \bigg( \frac{ e^{z_{M+h}} }{ \sum_{n=M}^{M+H}{ e^{z_n}}} \bigg)$
        </p>

        <p>If we have $M$ melody classes and $H$ harmony classes, the function above describes the log-likelihood loss at a time step given the output layer $z \in \mathbb{R}^{M+H}$, a target melody class $m$, and target harmony class $h$. $\alpha$ is what I call the <i>melody coefficient</i> that controls how much the loss function is affected by it's respective melody and harmony loss terms.</p>
      </div>

      <div class="section">
        <h2>Experiments</h2>
        <p>The <a href="http://ifdo.ca/~seymour/nottingham/nottingham.html">Nottingham dataset</a> is a collection of 1200 jigs and folk tunes, most of which fit the assumptions specified above: they have a simple monophonic melody on top of recognizable chords. You can download the all the nottingham tunes as MIDI files <a href="http://www-etud.iro.umontreal.ca/~boulanni/icml2012">here</a>. I discretized each of these sequences into time steps of sixteenth notes (1/4 of a quarter note), and used the <a href="https://github.com/bspaans/python-mingus">mingus</a> python package to detect the chord classes in the harmonies. After some filtering out some sequences that didn't fit the assumptions, I ended up with 32 chord classes and 34 possible melody notes (1 class from each of these represented a rest) for a total input dimension of 66 over 997 sequences. The average length of a sequence was 516 (roughly 32 measures in 4/4). Finally, all the sequences were split up into 65% training, 15% validation, and 15% testing.</p>

        <div class="audio-wrapper">
          <audio src="generated_music/mp3/nottingham_sample.mp3" width="100%" preload="none"></audio>
          <p>An example musical sequence from the Nottingham dataset</p>
        </div>

        <p>I used Google's <a href="https://www.tensorflow.org/">TensorFlow</a> library to implement my model. The architecture that I found worked best was 2 stacked hidden layers of 200 LSTM units each. I batched sequences by length, and used an unrolling length of 128 (8 measures in 4/4 time signature) for <a href="https://en.wikipedia.org/wiki/Backpropagation_through_time">Backpropagation through time (BPTT)</a>. I used RMSProp with a learning rate of 0.005 and decay rate of 0.9 for minibatch gradient descent. When searching over the hyperparameter space, I trained each model for 250 epochs, and saved the model with the lowest validation loss.</p>
        
        <!-- TODO: overfitting image -->
        <div class="img-wrapper">
          <img src="img/overfitting.png"></img>
          <p>Training and validation loss plotted over num epochs for a model with 2 stacked layers of 200 LSTM units, with 50% dropout on hidden layers and 80% dropout on input layers. Overfitting issues start showing up after about 20 epochs.</p>
        </div>

        <p>One big issue I ran into when training was extreme overfitting. Adding dropout on the non-recurrent connections helped some, but did not completely eliminate the issue. The best dropout configuration I found and ended up using was 50% on the hidden layers and 80% on the input layers.</p>
      </div>

      <div class="section">
        <h2>Results</h2>
        <p>The best model I found achieved on overall accuracy of 77.84% on the test set. One nice consequence of my model is I can evaluate the melody and harmony accuracies separately, which ended up being 64.15% and 91.57% for the melody and harmony respectively. The higher harmony accuracy makes sense, because most of the pieces in the dataset hold out chords for 8 or 16 time steps (a half or whole note in 4/4 time).</p>
        <p>Alright, enough numbers, let's get to the fun stuff. Once the model is trained, generating music from it is just a matter of sampling a melody and harmony from the probability distribution at each time step and plugging it back into the network. Rinse and repeat. I present to you 8 more pieces generated by my model below. I "primed" each with the starting 16 time steps (1 measure in 4/4 time) from a random test sequence, and then let them do their thing for 2048 time steps.</p>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/3.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper" id="track-marker">
        <audio src="generated_music/mp3/1.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/2.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/4.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/6.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/8.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/9.mp3" width="100%" preload="none"></audio>
      </div>
      <div class="audio-wrapper">
        <audio src="generated_music/mp3/10.mp3" width="100%" preload="none"></audio>
      </div>
      <p>Some turned out sounding better than others to my ears, but overall the model clearly does not produce human-level compositions. The lack of long-term structure such as repeated phrases and themes is especially revealing. However, for the most part, the model seems to play a melody in key with the harmony that it chooses. The melody also tends to stay in the same key signature for short-term phrases, and sometimes the harmony accompanies it with short chord progressions in that same key. There does also seem to be small pieces of coherent rhythmic structure, although the "time signature" overall throughout a piece is sporadic.</p>
    
    <div class="section">
      <hr>
      <p>Many thanks go out to <a href="http://web.cs.ucla.edu/~feisha/">Fei Sha</a> for providing valuable advice; this work was completed as part of my final project for his research seminar. If you're interested in learning more, <a href="http://yoavz.com/music_rnn_paper.pdf">the final report</a> contains more about the model and a few more experimental results. The source code is also <a href="http://github.com/yoavz/music_rnn">available on github here</a> if you'd like to use my code to train your own models! (warning: messy code)</p>. 
    </div>

    <!-- Bootstrap core JavaScript
    ================================================== -->
    <!-- Placed at the end of the document so the pages load faster -->
    <script src="js/jquery-1.11.2.min.js"></script>
    <script src="js/mediaelement-and-player.min.js"></script>
    <script>
      $('audio').mediaelementplayer();
    </script>
  </body>
</html>


================================================
FILE: install.sh
================================================
conda create -n music_rnn python=2.7
source activate music_rnn

pip install -r requirements.txt

mkdir models

mkdir data
# http://www-etud.iro.umontreal.ca/~boulanni/icml2012
wget http://www-etud.iro.umontreal.ca/~boulanni/Nottingham.zip -O data/Nottingham.zip
unzip data/Nottingham.zip -d data/


================================================
FILE: midi_util.py
================================================
import sys, os
from collections import defaultdict
import numpy as np
import midi

RANGE = 128

def round_tick(tick, time_step):
    return int(round(tick/float(time_step)) * time_step)

def ingest_notes(track, verbose=False):

    notes = { n: [] for n in range(RANGE) }
    current_tick = 0

    for msg in track:
        # ignore all end of track events
        if isinstance(msg, midi.EndOfTrackEvent):
            continue

        if msg.tick > 0: 
            current_tick += msg.tick

        # velocity of 0 is equivalent to note off, so treat as such
        if isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() != 0:
            if len(notes[msg.get_pitch()]) > 0 and \
               len(notes[msg.get_pitch()][-1]) != 2:
                if verbose:
                    print "Warning: double NoteOn encountered, deleting the first"
                    print msg
            else:
                notes[msg.get_pitch()] += [[current_tick]]
        elif isinstance(msg, midi.NoteOffEvent) or \
            (isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() == 0):
            # sanity check: no notes end without being started
            if len(notes[msg.get_pitch()][-1]) != 1:
                if verbose:
                    print "Warning: skipping NoteOff Event with no corresponding NoteOn"
                    print msg
            else: 
                notes[msg.get_pitch()][-1] += [current_tick]

    return notes, current_tick

def round_notes(notes, track_ticks, time_step, R=None, O=None):
    if not R:
        R = RANGE
    if not O:
        O = 0

    sequence = np.zeros((track_ticks/time_step, R))
    disputed = { t: defaultdict(int) for t in range(track_ticks/time_step) }
    for note in notes:
        for (start, end) in notes[note]:
            start_t = round_tick(start, time_step) / time_step
            end_t = round_tick(end, time_step) / time_step
            # normal case where note is long enough
            if end - start > time_step/2 and start_t != end_t:
                sequence[start_t:end_t, note - O] = 1
            # cases where note is within bounds of time step 
            elif start > start_t * time_step:
                disputed[start_t][note] += (end - start)
            elif end <= end_t * time_step:
                disputed[end_t-1][note] += (end - start)
            # case where a note is on the border 
            else:
                before_border = start_t * time_step - start
                if before_border > 0:
                    disputed[start_t-1][note] += before_border
                after_border = end - start_t * time_step
                if after_border > 0 and end < track_ticks:
                    disputed[start_t][note] += after_border

    # solve disputed
    for seq_idx in range(sequence.shape[0]):
        if np.count_nonzero(sequence[seq_idx, :]) == 0 and len(disputed[seq_idx]) > 0:
            # print seq_idx, disputed[seq_idx]
            sorted_notes = sorted(disputed[seq_idx].items(),
                                  key=lambda x: x[1])
            max_val = max(x[1] for x in sorted_notes)
            top_notes = filter(lambda x: x[1] >= max_val, sorted_notes)
            for note, _ in top_notes:
                sequence[seq_idx, note - O] = 1

    return sequence

def parse_midi_to_sequence(input_filename, time_step, verbose=False):
    sequence = []
    pattern = midi.read_midifile(input_filename)

    if len(pattern) < 1:
        raise Exception("No pattern found in midi file")

    if verbose:
        print "Track resolution: {}".format(pattern.resolution)
        print "Number of tracks: {}".format(len(pattern))
        print "Time step: {}".format(time_step)

    # Track ingestion stage
    notes = { n: [] for n in range(RANGE) }
    track_ticks = 0
    for track in pattern:
        current_tick = 0
        for msg in track:
            # ignore all end of track events
            if isinstance(msg, midi.EndOfTrackEvent):
                continue

            if msg.tick > 0: 
                current_tick += msg.tick

            # velocity of 0 is equivalent to note off, so treat as such
            if isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() != 0:
                if len(notes[msg.get_pitch()]) > 0 and \
                   len(notes[msg.get_pitch()][-1]) != 2:
                    if verbose:
                        print "Warning: double NoteOn encountered, deleting the first"
                        print msg
                else:
                    notes[msg.get_pitch()] += [[current_tick]]
            elif isinstance(msg, midi.NoteOffEvent) or \
                (isinstance(msg, midi.NoteOnEvent) and msg.get_velocity() == 0):
                # sanity check: no notes end without being started
                if len(notes[msg.get_pitch()][-1]) != 1:
                    if verbose:
                        print "Warning: skipping NoteOff Event with no corresponding NoteOn"
                        print msg
                else: 
                    notes[msg.get_pitch()][-1] += [current_tick]

        track_ticks = max(current_tick, track_ticks)

    track_ticks = round_tick(track_ticks, time_step)
    if verbose:
        print "Track ticks (rounded): {} ({} time steps)".format(track_ticks, track_ticks/time_step)

    sequence = round_notes(notes, track_ticks, time_step)

    return sequence

class MidiWriter(object):

    def __init__(self, verbose=False):
        self.verbose = verbose
        self.note_range = RANGE

    def note_off(self, val, tick):
        self.track.append(midi.NoteOffEvent(tick=tick, pitch=val))
        return 0

    def note_on(self, val, tick):
        self.track.append(midi.NoteOnEvent(tick=tick, pitch=val, velocity=70))
        return 0

    def dump_sequence_to_midi(self, sequence, output_filename, time_step, 
                              resolution, metronome=24):
        if self.verbose:
            print "Dumping sequence to MIDI file: {}".format(output_filename)
            print "Resolution: {}".format(resolution)
            print "Time Step: {}".format(time_step)

        pattern = midi.Pattern(resolution=resolution)
        self.track = midi.Track()

        # metadata track
        meta_track = midi.Track()
        time_sig = midi.TimeSignatureEvent()
        time_sig.set_numerator(4)
        time_sig.set_denominator(4)
        time_sig.set_metronome(metronome)
        time_sig.set_thirtyseconds(8)
        meta_track.append(time_sig)
        pattern.append(meta_track)

        # reshape to (SEQ_LENGTH X NUM_DIMS)
        sequence = np.reshape(sequence, [-1, self.note_range])

        time_steps = sequence.shape[0]
        if self.verbose:
            print "Total number of time steps: {}".format(time_steps)

        tick = time_step
        self.notes_on = { n: False for n in range(self.note_range) }
        # for seq_idx in range(188, 220):
        for seq_idx in range(time_steps):
            notes = np.nonzero(sequence[seq_idx, :])[0].tolist()

            # this tick will only be assigned to first NoteOn/NoteOff in
            # this time_step

            # NoteOffEvents come first so they'll have the tick value
            # go through all notes that are currently on and see if any
            # turned off
            for n in self.notes_on:
                if self.notes_on[n] and n not in notes:
                    tick = self.note_off(n, tick)
                    self.notes_on[n] = False

            # Turn on any notes that weren't previously on
            for note in notes:
                if not self.notes_on[note]:
                    tick = self.note_on(note, tick)
                    self.notes_on[note] = True

            tick += time_step

        # flush out notes
        for n in self.notes_on:
            if self.notes_on[n]:
                self.note_off(n, tick)
                tick = 0
                self.notes_on[n] = False

        pattern.append(self.track)
        midi.write_midifile(output_filename, pattern)

if __name__ == '__main__':
    pass


================================================
FILE: model.py
================================================
import os
import logging
import numpy as np
import tensorflow as tf    
from tensorflow.models.rnn import rnn_cell
from tensorflow.models.rnn import rnn, seq2seq

import nottingham_util

class Model(object):
    """ 
    Cross-Entropy Naive Formulation
    A single time step may have multiple notes active, so a sigmoid cross entropy loss
    is used to match targets.

    seq_input: a [ T x B x D ] matrix, where T is the time steps in the batch, B is the
               batch size, and D is the amount of dimensions
    """
    
    def __init__(self, config, training=False):
        self.config = config
        self.time_batch_len = time_batch_len = config.time_batch_len
        self.input_dim = input_dim = config.input_dim
        hidden_size = config.hidden_size
        num_layers = config.num_layers
        dropout_prob = config.dropout_prob
        input_dropout_prob = config.input_dropout_prob
        cell_type = config.cell_type

        self.seq_input = \
            tf.placeholder(tf.float32, shape=[self.time_batch_len, None, input_dim])

        if (dropout_prob <= 0.0 or dropout_prob > 1.0):
            raise Exception("Invalid dropout probability: {}".format(dropout_prob))

        if (input_dropout_prob <= 0.0 or input_dropout_prob > 1.0):
            raise Exception("Invalid input dropout probability: {}".format(input_dropout_prob))

        # setup variables
        with tf.variable_scope("rnnlstm"):
            output_W = tf.get_variable("output_w", [hidden_size, input_dim])
            output_b = tf.get_variable("output_b", [input_dim])
            self.lr = tf.constant(config.learning_rate, name="learning_rate")
            self.lr_decay = tf.constant(config.learning_rate_decay, name="learning_rate_decay")

        def create_cell(input_size):
            if cell_type == "vanilla":
                cell_class = rnn_cell.BasicRNNCell
            elif cell_type == "gru":
                cell_class = rnn_cell.BasicGRUCell
            elif cell_type == "lstm":
                cell_class = rnn_cell.BasicLSTMCell
            else:
                raise Exception("Invalid cell type: {}".format(cell_type))

            cell = cell_class(hidden_size, input_size = input_size)
            if training:
                return rnn_cell.DropoutWrapper(cell, output_keep_prob = dropout_prob)
            else:
                return cell

        if training:
            self.seq_input_dropout = tf.nn.dropout(self.seq_input, keep_prob = input_dropout_prob)
        else:
            self.seq_input_dropout = self.seq_input

        self.cell = rnn_cell.MultiRNNCell(
            [create_cell(input_dim)] + [create_cell(hidden_size) for i in range(1, num_layers)])

        batch_size = tf.shape(self.seq_input_dropout)[0]
        self.initial_state = self.cell.zero_state(batch_size, tf.float32)
        inputs_list = tf.unpack(self.seq_input_dropout)

        # rnn outputs a list of [batch_size x H] outputs
        outputs_list, self.final_state = rnn.rnn(self.cell, inputs_list, 
                                                 initial_state=self.initial_state)

        outputs = tf.pack(outputs_list)
        outputs_concat = tf.reshape(outputs, [-1, hidden_size])
        logits_concat = tf.matmul(outputs_concat, output_W) + output_b
        logits = tf.reshape(logits_concat, [self.time_batch_len, -1, input_dim])

        # probabilities of each note
        self.probs = self.calculate_probs(logits)
        self.loss = self.init_loss(logits, logits_concat)
        self.train_step = tf.train.RMSPropOptimizer(self.lr, decay = self.lr_decay) \
                            .minimize(self.loss)

    def init_loss(self, outputs, _):
        self.seq_targets = \
            tf.placeholder(tf.float32, [self.time_batch_len, None, self.input_dim])

        batch_size = tf.shape(self.seq_input_dropout)
        cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(outputs, self.seq_targets)
        return tf.reduce_sum(cross_ent) / self.time_batch_len / tf.to_float(batch_size)

    def calculate_probs(self, logits):
        return tf.sigmoid(logits)

    def get_cell_zero_state(self, session, batch_size):
        return self.cell.zero_state(batch_size, tf.float32).eval(session=session)

class NottinghamModel(Model):
    """ 
    Dual softmax formulation 

    A single time step should be a concatenation of two one-hot-encoding binary vectors.
    Loss function is a sum of two softmax loss functions over [:r] and [r:] respectively,
    where r is the number of melody classes
    """

    def init_loss(self, outputs, outputs_concat):
        self.seq_targets = \
            tf.placeholder(tf.int64, [self.time_batch_len, None, 2])
        batch_size = tf.shape(self.seq_targets)[1]

        with tf.variable_scope("rnnlstm"):
            self.melody_coeff = tf.constant(self.config.melody_coeff)

        r = nottingham_util.NOTTINGHAM_MELODY_RANGE
        targets_concat = tf.reshape(self.seq_targets, [-1, 2])

        melody_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( \
            outputs_concat[:, :r], \
            targets_concat[:, 0])
        harmony_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( \
            outputs_concat[:, r:], \
            targets_concat[:, 1])
        losses = tf.add(self.melody_coeff * melody_loss, (1 - self.melody_coeff) * harmony_loss)
        return tf.reduce_sum(losses) / self.time_batch_len / tf.to_float(batch_size)

    def calculate_probs(self, logits):
        steps = []
        for t in range(self.time_batch_len):
            melody_softmax = tf.nn.softmax(logits[t, :, :nottingham_util.NOTTINGHAM_MELODY_RANGE])
            harmony_softmax = tf.nn.softmax(logits[t, :, nottingham_util.NOTTINGHAM_MELODY_RANGE:])
            steps.append(tf.concat(1, [melody_softmax, harmony_softmax]))
        return tf.pack(steps)

    def assign_melody_coeff(self, session, melody_coeff):
        if melody_coeff < 0.0 or melody_coeff > 1.0:
            raise Exception("Invalid melody coeffecient")

        session.run(tf.assign(self.melody_coeff, melody_coeff))

class NottinghamSeparate(Model):
    """ 
    Single softmax formulation 
    
    Regular single classification formulation, used to train baseline models
    where the melody and harmony are trained separately
    """

    def init_loss(self, outputs, outputs_concat):
        self.seq_targets = \
            tf.placeholder(tf.int64, [self.time_batch_len, None])
        batch_size = tf.shape(self.seq_targets)[1]

        with tf.variable_scope("rnnlstm"):
            self.melody_coeff = tf.constant(self.config.melody_coeff)

        targets_concat = tf.reshape(self.seq_targets, [-1])
        losses = tf.nn.sparse_softmax_cross_entropy_with_logits( \
            outputs_concat, targets_concat)

        return tf.reduce_sum(losses) / self.time_batch_len / tf.to_float(batch_size)

    def calculate_probs(self, logits):
        steps = []
        for t in range(self.time_batch_len):
            softmax = tf.nn.softmax(logits[t, :, :])
            steps.append(softmax)
        return tf.pack(steps)


================================================
FILE: nottingham_util.py
================================================
import numpy as np
import os
import midi
import cPickle
from pprint import pprint

import midi_util
import mingus
import mingus.core.chords
import sampling

PICKLE_LOC = 'data/nottingham.pickle'
NOTTINGHAM_MELODY_MAX = 88
NOTTINGHAM_MELODY_MIN = 55
# add one to the range for silence in melody
NOTTINGHAM_MELODY_RANGE = NOTTINGHAM_MELODY_MAX - NOTTINGHAM_MELODY_MIN + 1 + 1
CHORD_BASE = 48
CHORD_BLACKLIST = ['major third', 'minor third', 'perfect fifth']
NO_CHORD = 'NONE'
SHARPS_TO_FLATS = {
    "A#": "Bb",
    "B#": "C",
    "C#": "Db",
    "D#": "Eb",
    "E#": "F",
    "F#": "Gb",
    "G#": "Ab",
}

def resolve_chord(chord):
    """
    Resolves rare chords to their closest common chord, to limit the total
    amount of chord classes.
    """
    if chord in CHORD_BLACKLIST:
        return None
    # take the first of dual chords
    if "|" in chord:
        chord = chord.split("|")[0]
    # remove 7ths, 11ths, 9s, 6th,
    if chord.endswith("11"):
        chord = chord[:-2]
    if chord.endswith("7") or chord.endswith("9") or chord.endswith("6"):
        chord = chord[:-1]
    # replace 'dim' with minor
    if chord.endswith("dim"):
        chord = chord[:-3] + "m"
    return chord

def prepare_nottingham_pickle(time_step, chord_cutoff=64, filename=PICKLE_LOC, verbose=False):
    """
    time_step: the time step to discretize all notes into
    chord_cutoff: if chords are seen less than this cutoff, they are ignored and marked as
                  as rests in the resulting dataset
    filename: the location where the pickle will be saved to
    """

    data = {}
    store = {}
    chords = {}
    max_seq = 0
    seq_lens = []
    
    for d in ["train", "test", "valid"]:
        print "Parsing {}...".format(d)
        parsed = parse_nottingham_directory("data/Nottingham/{}".format(d), time_step, verbose=False)
        metadata = [s[0] for s in parsed]
        seqs = [s[1] for s in parsed]
        data[d] = seqs
        data[d + '_metadata'] = metadata
        lens = [len(s[1]) for s in seqs]
        seq_lens += lens
        max_seq = max(max_seq, max(lens))
        
        for _, harmony in seqs:
            for h in harmony:
                if h not in chords:
                    chords[h] = 1
                else:
                    chords[h] += 1

    avg_seq = float(sum(seq_lens)) / len(seq_lens)

    chords = { c: i for c, i in chords.iteritems() if chords[c] >= chord_cutoff }
    chord_mapping = { c: i for i, c in enumerate(chords.keys()) }
    num_chords = len(chord_mapping)
    store['chord_to_idx'] = chord_mapping
    if verbose:
        pprint(chords)
        print "Number of chords: {}".format(num_chords)
        print "Max Sequence length: {}".format(max_seq)
        print "Avg Sequence length: {}".format(avg_seq)
        print "Num Sequences: {}".format(len(seq_lens))

    def combine(melody, harmony):
        full = np.zeros((melody.shape[0], NOTTINGHAM_MELODY_RANGE + num_chords))

        assert melody.shape[0] == len(harmony)

        # for all melody sequences that don't have any notes, add the empty melody marker (last one)
        for i in range(melody.shape[0]):
            if np.count_nonzero(melody[i, :]) == 0:
                melody[i, NOTTINGHAM_MELODY_RANGE-1] = 1

        # all melody encodings should now have exactly one 1
        for i in range(melody.shape[0]):
            assert np.count_nonzero(melody[i, :]) == 1

        # add all the melodies
        full[:, :melody.shape[1]] += melody

        harmony_idxs = [ chord_mapping[h] if h in chord_mapping else chord_mapping[NO_CHORD] \
                         for h in harmony ]
        harmony_idxs = [ NOTTINGHAM_MELODY_RANGE + h for h in harmony_idxs ]
        full[np.arange(len(harmony)), harmony_idxs] = 1

        # all full encodings should have exactly two 1's
        for i in range(full.shape[0]):
            assert np.count_nonzero(full[i, :]) == 2

        return full

    for d in ["train", "test", "valid"]:
        print "Combining {}".format(d)
        store[d] = [ combine(m, h) for m, h in data[d] ]
        store[d + '_metadata'] = data[d + '_metadata']

    with open(filename, 'w') as f:
        cPickle.dump(store, f, protocol=-1)

    return True

def parse_nottingham_directory(input_dir, time_step, verbose=False):
    """
    input_dir: a directory containing MIDI files

    returns a list of [T x D] matrices, where each matrix represents a 
    a sequence with T time steps over D dimensions
    """

    files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir)
              if os.path.isfile(os.path.join(input_dir, f)) ] 
    sequences = [ \
        parse_nottingham_to_sequence(f, time_step=time_step, verbose=verbose) \
        for f in files ]

    if verbose:
        print "Total sequences: {}".format(len(sequences))
    
    # filter out the non 2-track MIDI's
    sequences = filter(lambda x: x[1] != None, sequences)

    if verbose:
        print "Total sequences left: {}".format(len(sequences))

    return sequences

def parse_nottingham_to_sequence(input_filename, time_step, verbose=False):
    """
    input_filename: a MIDI filename

    returns a [T x D] matrix representing a sequence with T time steps over
    D dimensions
    """
    sequence = []
    pattern = midi.read_midifile(input_filename)

    metadata = {
        "path": input_filename,
        "name": input_filename.split("/")[-1].split(".")[0]
    }

    # Most nottingham midi's have 3 tracks. metadata info, melody, harmony
    # throw away any tracks that don't fit this
    if len(pattern) != 3:
        if verbose:
            "Skipping track with {} tracks".format(len(pattern))
        return (metadata, None)

    # ticks_per_quarter = -1
    for msg in pattern[0]:
        if isinstance(msg, midi.TimeSignatureEvent):
            metadata["ticks_per_quarter"] = msg.get_metronome()
            ticks_per_quarter = msg.get_metronome()

    if verbose:
        print "{}".format(input_filename)
        print "Track resolution: {}".format(pattern.resolution)
        print "Number of tracks: {}".format(len(pattern))
        print "Time step: {}".format(time_step)
        print "Ticks per quarter: {}".format(ticks_per_quarter)

    # Track ingestion stage
    track_ticks = 0

    melody_notes, melody_ticks = midi_util.ingest_notes(pattern[1])
    harmony_notes, harmony_ticks = midi_util.ingest_notes(pattern[2])

    track_ticks = midi_util.round_tick(max(melody_ticks, harmony_ticks), time_step)
    if verbose:
        print "Track ticks (rounded): {} ({} time steps)".format(track_ticks, track_ticks/time_step)
    
    melody_sequence = midi_util.round_notes(melody_notes, track_ticks, time_step, 
                                  R=NOTTINGHAM_MELODY_RANGE, O=NOTTINGHAM_MELODY_MIN)

    for i in range(melody_sequence.shape[0]):
        if np.count_nonzero(melody_sequence[i, :]) > 1:
            if verbose:
                print "Double note found: {}: {} ({})".format(i, np.nonzero(melody_sequence[i, :]), input_filename)
            return (metadata, None)

    harmony_sequence = midi_util.round_notes(harmony_notes, track_ticks, time_step)

    harmonies = []
    for i in range(harmony_sequence.shape[0]):
        notes = np.where(harmony_sequence[i] == 1)[0]
        if len(notes) > 0:
            notes_shift = [ mingus.core.notes.int_to_note(h%12) for h in notes]
            chord = mingus.core.chords.determine(notes_shift, shorthand=True)
            if len(chord) == 0:
                # try flat combinations
                notes_shift = [ SHARPS_TO_FLATS[n] if n in SHARPS_TO_FLATS else n for n in notes_shift]
                chord = mingus.core.chords.determine(notes_shift, shorthand=True)
            if len(chord) == 0:
                if verbose:
                    print "Could not determine chord: {} ({}, {}), defaulting to last steps chord" \
                          .format(notes_shift, input_filename, i)
                if len(harmonies) > 0:
                    harmonies.append(harmonies[-1])
                else:
                    harmonies.append(NO_CHORD)
            else:
                resolved = resolve_chord(chord[0])
                if resolved:
                    harmonies.append(resolved)
                else:
                    harmonies.append(NO_CHORD)
        else:
            harmonies.append(NO_CHORD)

    return (metadata, (melody_sequence, harmonies))

class NottinghamMidiWriter(midi_util.MidiWriter):

    def __init__(self, chord_to_idx, verbose=False):
        super(NottinghamMidiWriter, self).__init__(verbose)
        self.idx_to_chord = { i: c for c, i in chord_to_idx.items() }
        self.note_range = NOTTINGHAM_MELODY_RANGE + len(self.idx_to_chord)

    def dereference_chord(self, idx):
        if idx not in self.idx_to_chord:
            raise Exception("No chord index found: {}".format(idx))
        shorthand = self.idx_to_chord[idx]
        if shorthand == NO_CHORD:
            return []
        chord = mingus.core.chords.from_shorthand(shorthand)
        return [ CHORD_BASE + mingus.core.notes.note_to_int(n) for n in chord ]

    def note_on(self, val, tick):
        if val >= NOTTINGHAM_MELODY_RANGE:
            notes = self.dereference_chord(val - NOTTINGHAM_MELODY_RANGE)
        else:
            # if note is the top of the range, then it stands for gap in melody
            if val == NOTTINGHAM_MELODY_RANGE - 1:
                notes = []
            else:
                notes = [NOTTINGHAM_MELODY_MIN + val]

        # print 'turning on {}'.format(notes)
        for note in notes:
            self.track.append(midi.NoteOnEvent(tick=tick, pitch=note, velocity=70))
            tick = 0 # notes that come right after each other should have zero tick

        return tick

    def note_off(self, val, tick):
        if val >= NOTTINGHAM_MELODY_RANGE:
            notes = self.dereference_chord(val - NOTTINGHAM_MELODY_RANGE)
        else:
            notes = [NOTTINGHAM_MELODY_MIN + val]

        # print 'turning off {}'.format(notes)
        for note in notes:
            self.track.append(midi.NoteOffEvent(tick=tick, pitch=note))
            tick = 0

        return tick

class NottinghamSampler(object):

    def __init__(self, chord_to_idx, method = 'sample', harmony_repeat_max = 16, melody_repeat_max = 16, verbose=False):
        self.verbose = verbose 
        self.idx_to_chord = { i: c for c, i in chord_to_idx.items() }
        self.method = method

        self.hlast = 0
        self.hcount = 0
        self.hrepeat = harmony_repeat_max

        self.mlast = 0
        self.mcount = 0
        self.mrepeat = melody_repeat_max 

    def visualize_probs(self, probs):
        if not self.verbose:
            return

        melodies = sorted(list(enumerate(probs[:NOTTINGHAM_MELODY_RANGE])), 
                     key=lambda x: x[1], reverse=True)[:4]
        harmonies = sorted(list(enumerate(probs[NOTTINGHAM_MELODY_RANGE:])), 
                     key=lambda x: x[1], reverse=True)[:4]
        harmonies = [(self.idx_to_chord[i], j) for i, j in harmonies]
        print 'Top Melody Notes: '
        pprint(melodies)
        print 'Top Harmony Notes: '
        pprint(harmonies)

    def sample_notes_static(self, probs):
        top_m = probs[:NOTTINGHAM_MELODY_RANGE].argsort()
        if top_m[-1] == self.mlast and self.mcount >= self.mrepeat:
            top_m = top_m[:-1]
            self.mcount = 0
        elif top_m[-1] == self.mlast:
            self.mcount += 1
        else:
            self.mcount = 0
        self.mlast = top_m[-1]
        top_melody = top_m[-1]

        top_h = probs[NOTTINGHAM_MELODY_RANGE:].argsort()
        if top_h[-1] == self.hlast and self.hcount >= self.hrepeat:
            top_h = top_h[:-1]
            self.hcount = 0
        elif top_h[-1] == self.hlast:
            self.hcount += 1
        else:
            self.hcount = 0
        self.hlast = top_h[-1]
        top_chord = top_h[-1] + NOTTINGHAM_MELODY_RANGE

        chord = np.zeros([len(probs)], dtype=np.int32)
        chord[top_melody] = 1.0
        chord[top_chord] = 1.0
        return chord

    def sample_notes_dist(self, probs):
        idxed = [(i, p) for i, p in enumerate(probs)]

        notes = [n[0] for n in idxed]
        ps = np.array([n[1] for n in idxed])
        r = NOTTINGHAM_MELODY_RANGE

        assert np.allclose(np.sum(ps[:r]), 1.0)
        assert np.allclose(np.sum(ps[r:]), 1.0)

        # renormalize so numpy doesn't complain
        ps[:r] = ps[:r] / ps[:r].sum()
        ps[r:] = ps[r:] / ps[r:].sum()

        melody = np.random.choice(notes[:r], p=ps[:r])
        harmony = np.random.choice(notes[r:], p=ps[r:])

        chord = np.zeros([len(probs)], dtype=np.int32)
        chord[melody] = 1.0
        chord[harmony] = 1.0
        return chord


    def sample_notes(self, probs):
        self.visualize_probs(probs)
        if self.method == 'static':
            return self.sample_notes_static(probs)
        elif self.method == 'sample':
            return self.sample_notes_dist(probs)

def accuracy(batch_probs, data, num_samples=1):
    """
    Batch Probs: { num_time_steps: [ time_step_1, time_step_2, ... ] }
    Data: [ 
        [ [ data ], [ target ] ], # batch with one time step
        [ [ data1, data2 ], [ target1, target2 ] ], # batch with two time steps
        ...
    ]
    """

    def calc_accuracy():
        total = 0
        melody_correct, harmony_correct = 0, 0
        melody_incorrect, harmony_incorrect = 0, 0
        for _, batch_targets in data:
            num_time_steps = len(batch_targets)
            for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]):

                assert ts_targets.shape == ts_targets.shape

                for seq_idx in range(ts_targets.shape[1]):
                    for step_idx in range(ts_targets.shape[0]):
                        idxed = [(n, p) for n, p in \
                                 enumerate(ts_probs[step_idx, seq_idx, :])]
                        notes = [n[0] for n in idxed]
                        ps = np.array([n[1] for n in idxed])
                        r = NOTTINGHAM_MELODY_RANGE

                        assert np.allclose(np.sum(ps[:r]), 1.0)
                        assert np.allclose(np.sum(ps[r:]), 1.0)

                        # renormalize so numpy doesn't complain
                        ps[:r] = ps[:r] / ps[:r].sum()
                        ps[r:] = ps[r:] / ps[r:].sum()

                        melody = np.random.choice(notes[:r], p=ps[:r])
                        harmony = np.random.choice(notes[r:], p=ps[r:])

                        melody_target = ts_targets[step_idx, seq_idx, 0]
                        if melody_target == melody:
                            melody_correct += 1
                        else:
                            melody_incorrect += 1

                        harmony_target = ts_targets[step_idx, seq_idx, 1] + r
                        if harmony_target == harmony:
                            harmony_correct += 1
                        else:
                            harmony_incorrect += 1

        return (melody_correct, melody_incorrect, harmony_correct, harmony_incorrect)

    maccs, haccs, taccs = [], [], []
    for i in range(num_samples):
        print "Sample {}".format(i)
        m, mi, h, hi = calc_accuracy()
        maccs.append( float(m) / float(m + mi))
        haccs.append( float(h) / float(h + hi))
        taccs.append( float(m + h) / float(m + h + mi + hi) )

    print "Melody Precision/Recall: {}".format(sum(maccs)/len(maccs))
    print "Harmony Precision/Recall: {}".format(sum(haccs)/len(haccs))
    print "Total Precision/Recall: {}".format(sum(taccs)/len(taccs))

def seperate_accuracy(batch_probs, data, num_samples=1):

    def calc_accuracy():
        total = 0
        total_correct, total_incorrect = 0, 0
        for _, batch_targets in data:
            num_time_steps = len(batch_targets)
            for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]):

                assert ts_targets.shape == ts_targets.shape

                for seq_idx in range(ts_targets.shape[1]):
                    for step_idx in range(ts_targets.shape[0]):

                        idxed = [(n, p) for n, p in \
                                 enumerate(ts_probs[step_idx, seq_idx, :])]
                        notes = [n[0] for n in idxed]
                        ps = np.array([n[1] for n in idxed])
                        r = NOTTINGHAM_MELODY_RANGE

                        assert np.allclose(np.sum(ps), 1.0)
                        ps = ps / ps.sum()
                        note = np.random.choice(notes, p=ps)

                        target = ts_targets[step_idx, seq_idx]
                        if target == note:
                            total_correct += 1
                        else:
                            total_incorrect += 1

        return (total_correct, total_incorrect)

    taccs = []
    for i in range(num_samples):
        print "Sample {}".format(i)
        c, ic = calc_accuracy()
        taccs.append( float(c) / float(c + ic))

    print "Precision/Recall: {}".format(sum(taccs)/len(taccs))

def i_vi_iv_v(chord_to_idx, repeats, input_dim):
    r = NOTTINGHAM_MELODY_RANGE

    i = np.zeros(input_dim)
    i[r + chord_to_idx['CM']] = 1

    vi = np.zeros(input_dim)
    vi[r + chord_to_idx['Am']] = 1

    iv = np.zeros(input_dim)
    iv[r + chord_to_idx['FM']] = 1

    v = np.zeros(input_dim)
    v[r + chord_to_idx['GM']] = 1

    full_seq = [i] * 16 + [vi] * 16 + [iv] * 16 + [v] * 16
    full_seq = full_seq * repeats
    
    return full_seq

if __name__ == '__main__':

    resolution = 480
    time_step = 120

    assert resolve_chord("GM7") == "GM"
    assert resolve_chord("G#dim|AM7") == "G#m"
    assert resolve_chord("Dm9") == "Dm"
    assert resolve_chord("AM11") == "AM"

    prepare_nottingham_pickle(time_step, verbose=True)


================================================
FILE: requirements.txt
================================================
matplotlib
mingus
numpy
git+https://github.com/vishnubob/python-midi#egg=midi
# Linux, Python 2.7, GPU
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl


================================================
FILE: rnn.py
================================================
import os, sys
import argparse
import time
import itertools
import cPickle
import logging
import random
import string

import numpy as np
import tensorflow as tf    
import matplotlib.pyplot as plt

import nottingham_util
import util
from model import Model, NottinghamModel

def get_config_name(config):
    def replace_dot(s): return s.replace(".", "p")
    return "nl_" + str(config.num_layers) + "_hs_" + str(config.hidden_size) + \
            replace_dot("_mc_{}".format(config.melody_coeff)) + \
            replace_dot("_dp_{}".format(config.dropout_prob)) + \
            replace_dot("_idp_{}".format(config.input_dropout_prob)) + \
            replace_dot("_tb_{}".format(config.time_batch_len)) 

class DefaultConfig(object):
    # model parameters
    num_layers = 2
    hidden_size = 200
    melody_coeff = 0.5
    dropout_prob = 0.5
    input_dropout_prob = 0.8
    cell_type = 'lstm'

    # learning parameters
    max_time_batches = 9 
    time_batch_len = 128
    learning_rate = 5e-3
    learning_rate_decay = 0.9
    num_epochs = 250

    # metadata
    dataset = 'softmax'
    model_file = ''

    def __repr__(self):
        return """Num Layers: {}, Hidden Size: {}, Melody Coeff: {}, Dropout Prob: {}, Input Dropout Prob: {}, Cell Type: {}, Time Batch Len: {}, Learning Rate: {}, Decay: {}""".format(self.num_layers, self.hidden_size, self.melody_coeff, self.dropout_prob, self.input_dropout_prob, self.cell_type, self.time_batch_len, self.learning_rate, self.learning_rate_decay)
    
if __name__ == '__main__':
    np.random.seed()      

    parser = argparse.ArgumentParser(description='Script to train and save a model.')
    parser.add_argument('--dataset', type=str, default='softmax',
                        # choices = ['bach', 'nottingham', 'softmax'],
                        choices = ['softmax'])
    parser.add_argument('--model_dir', type=str, default='models')
    parser.add_argument('--run_name', type=str, default=time.strftime("%m%d_%H%M"))

    args = parser.parse_args()

    if args.dataset == 'softmax':
        resolution = 480
        time_step = 120
        model_class = NottinghamModel
        with open(nottingham_util.PICKLE_LOC, 'r') as f:
            pickle = cPickle.load(f)
            chord_to_idx = pickle['chord_to_idx']

        input_dim = pickle["train"][0].shape[1]
        print 'Finished loading data, input dim: {}'.format(input_dim)
    else:
        raise Exception("Other datasets not yet implemented")

    initializer = tf.random_uniform_initializer(-0.1, 0.1)

    best_config = None
    best_valid_loss = None

    # set up run dir
    run_folder = os.path.join(args.model_dir, args.run_name)
    if os.path.exists(run_folder):
        raise Exception("Run name {} already exists, choose a different one", format(run_folder))
    os.makedirs(run_folder)

    logger = logging.getLogger(__name__) 
    logger.setLevel(logging.INFO)
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(logging.FileHandler(os.path.join(run_folder, "training.log")))

    grid = {
        "dropout_prob": [0.5],
        "input_dropout_prob": [0.8],
        "melody_coeff": [0.5],
        "num_layers": [2],
        "hidden_size": [200],
        "num_epochs": [250],
        "learning_rate": [5e-3],
        "learning_rate_decay": [0.9],
        "time_batch_len": [128],
    }

    # Generate product of hyperparams
    runs = list(list(itertools.izip(grid, x)) for x in itertools.product(*grid.itervalues()))
    logger.info("{} runs detected".format(len(runs)))

    for combination in runs:

        config = DefaultConfig()
        config.dataset = args.dataset
        config.model_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(12)) + '.model'
        for attr, value in combination:
            setattr(config, attr, value)

        if config.dataset == 'softmax':
            data = util.load_data('', time_step, config.time_batch_len, config.max_time_batches, nottingham=pickle)
            config.input_dim = data["input_dim"]
        else:
            raise Exception("Other datasets not yet implemented")

        logger.info(config)
        config_file_path = os.path.join(run_folder, get_config_name(config) + '.config')
        with open(config_file_path, 'w') as f: 
            cPickle.dump(config, f)

        with tf.Graph().as_default(), tf.Session() as session:
            with tf.variable_scope("model", reuse=None):
                train_model = model_class(config, training=True)
            with tf.variable_scope("model", reuse=True):
                valid_model = model_class(config, training=False)

            saver = tf.train.Saver(tf.all_variables(), max_to_keep=40)
            tf.initialize_all_variables().run()

            # training
            early_stop_best_loss = None
            start_saving = False
            saved_flag = False
            train_losses, valid_losses = [], []
            start_time = time.time()
            for i in range(config.num_epochs):
                loss = util.run_epoch(session, train_model, 
                    data["train"]["data"], training=True, testing=False)
                train_losses.append((i, loss))
                if i == 0:
                    continue

                logger.info('Epoch: {}, Train Loss: {}, Time Per Epoch: {}'.format(\
                        i, loss, (time.time() - start_time)/i))
                valid_loss = util.run_epoch(session, valid_model, data["valid"]["data"], training=False, testing=False)
                valid_losses.append((i, valid_loss))
                logger.info('Valid Loss: {}'.format(valid_loss))

                if early_stop_best_loss == None:
                    early_stop_best_loss = valid_loss
                elif valid_loss < early_stop_best_loss:
                    early_stop_best_loss = valid_loss
                    if start_saving:
                        logger.info('Best loss so far encountered, saving model.')
                        saver.save(session, os.path.join(run_folder, config.model_name))
                        saved_flag = True
                elif not start_saving:
                    start_saving = True 
                    logger.info('Valid loss increased for the first time, will start saving models')
                    saver.save(session, os.path.join(run_folder, config.model_name))
                    saved_flag = True

            if not saved_flag:
                saver.save(session, os.path.join(run_folder, config.model_name))

            # set loss axis max to 20
            axes = plt.gca()
            if config.dataset == 'softmax':
                axes.set_ylim([0, 2])
            else:
                axes.set_ylim([0, 100])
            plt.plot([t[0] for t in train_losses], [t[1] for t in train_losses])
            plt.plot([t[0] for t in valid_losses], [t[1] for t in valid_losses])
            plt.legend(['Train Loss', 'Validation Loss'])
            chart_file_path = os.path.join(run_folder, get_config_name(config) + '.png')
            plt.savefig(chart_file_path)
            plt.clf()

            logger.info("Config {}, Loss: {}".format(config, early_stop_best_loss))
            if best_valid_loss == None or early_stop_best_loss < best_valid_loss:
                logger.info("Found best new model!")
                best_valid_loss = early_stop_best_loss
                best_config = config

    logger.info("Best Config: {}, Loss: {}".format(best_config, best_valid_loss))


================================================
FILE: rnn_sample.py
================================================
import os, sys
import argparse
import time
import itertools
import cPickle

import numpy as np
import tensorflow as tf    

import util
import nottingham_util
from model import Model, NottinghamModel
from rnn import DefaultConfig

if __name__ == '__main__':
    np.random.seed()      

    parser = argparse.ArgumentParser(description='Script to generated a MIDI file sample from a trained model.')
    parser.add_argument('--config_file', type=str, required=True)
    parser.add_argument('--sample_melody', action='store_true', default=False)
    parser.add_argument('--sample_harmony', action='store_true', default=False)
    parser.add_argument('--sample_seq', type=str, default='random',
        choices = ['random', 'chords'])
    parser.add_argument('--conditioning', type=int, default=-1)
    parser.add_argument('--sample_length', type=int, default=512)

    args = parser.parse_args()

    with open(args.config_file, 'r') as f: 
        config = cPickle.load(f)

    if config.dataset == 'softmax':
        config.time_batch_len = 1
        config.max_time_batches = -1
        model_class = NottinghamModel
        with open(nottingham_util.PICKLE_LOC, 'r') as f:
            pickle = cPickle.load(f)
        chord_to_idx = pickle['chord_to_idx']

        time_step = 120
        resolution = 480

        # use time batch len of 1 so that every target is covered
        test_data = util.batch_data(pickle['test'], time_batch_len = 1, 
            max_time_batches = -1, softmax = True)
    else:
        raise Exception("Other datasets not yet implemented")

    print config

    with tf.Graph().as_default(), tf.Session() as session:
        with tf.variable_scope("model", reuse=None):
            sampling_model = model_class(config)

        saver = tf.train.Saver(tf.all_variables())
        model_path = os.path.join(os.path.dirname(args.config_file), 
            config.model_name)
        saver.restore(session, model_path)

        state = sampling_model.get_cell_zero_state(session, 1)
        if args.sample_seq == 'chords':
            # 16 - one measure, 64 - chord progression
            repeats = args.sample_length / 64
            sample_seq = nottingham_util.i_vi_iv_v(chord_to_idx, repeats, config.input_dim)
            print 'Sampling melody using a I, VI, IV, V progression'

        elif args.sample_seq == 'random':
            sample_index = np.random.choice(np.arange(len(pickle['test'])))
            sample_seq = [ pickle['test'][sample_index][i, :] 
                for i in range(pickle['test'][sample_index].shape[0]) ]

        chord = sample_seq[0]
        seq = [chord]

        if args.conditioning > 0:
            for i in range(1, args.conditioning):
                seq_input = np.reshape(chord, [1, 1, config.input_dim])
                feed = {
                    sampling_model.seq_input: seq_input,
                    sampling_model.initial_state: state,
                }
                state = session.run(sampling_model.final_state, feed_dict=feed)
                chord = sample_seq[i]
                seq.append(chord)

        if config.dataset == 'softmax':
            writer = nottingham_util.NottinghamMidiWriter(chord_to_idx, verbose=False)
            sampler = nottingham_util.NottinghamSampler(chord_to_idx, verbose=False)
        else:
            # writer = midi_util.MidiWriter()
            # sampler = sampling.Sampler(verbose=False)
            raise Exception("Other datasets not yet implemented")

        for i in range(max(args.sample_length - len(seq), 0)):
            seq_input = np.reshape(chord, [1, 1, config.input_dim])
            feed = {
                sampling_model.seq_input: seq_input,
                sampling_model.initial_state: state,
            }
            [probs, state] = session.run(
                [sampling_model.probs, sampling_model.final_state],
                feed_dict=feed)
            probs = np.reshape(probs, [config.input_dim])
            chord = sampler.sample_notes(probs)

            if config.dataset == 'softmax':
                r = nottingham_util.NOTTINGHAM_MELODY_RANGE
                if args.sample_melody:
                    chord[r:] = 0
                    chord[r:] = sample_seq[i][r:]
                elif args.sample_harmony:
                    chord[:r] = 0
                    chord[:r] = sample_seq[i][:r]

            seq.append(chord)

        writer.dump_sequence_to_midi(seq, "best.midi", 
            time_step=time_step, resolution=resolution)


================================================
FILE: rnn_separate.py
================================================
import os, sys
import argparse
import time
import itertools
import cPickle
import logging
import random
import string
import pprint
 
import numpy as np
import tensorflow as tf    
import matplotlib.pyplot as plt

import midi_util
import nottingham_util
import sampling
import util
from rnn import get_config_name, DefaultConfig
from model import Model, NottinghamSeparate

if __name__ == '__main__':
    np.random.seed()      

    parser = argparse.ArgumentParser(description='Music RNN')
    parser.add_argument('--choice', type=str, default='melody',
                        choices = ['melody', 'harmony'])
    parser.add_argument('--dataset', type=str, default='softmax',
                        choices = ['bach', 'nottingham', 'softmax'])
    parser.add_argument('--model_dir', type=str, default='models')
    parser.add_argument('--run_name', type=str, default=time.strftime("%m%d_%H%M"))

    args = parser.parse_args()

    if args.dataset == 'softmax':
        resolution = 480
        time_step = 120
        model_class = NottinghamSeparate
        with open(nottingham_util.PICKLE_LOC, 'r') as f:
            pickle = cPickle.load(f)
            chord_to_idx = pickle['chord_to_idx']

        input_dim = pickle["train"][0].shape[1]
        print 'Finished loading data, input dim: {}'.format(input_dim)
    else:
        raise Exception("Other datasets not yet implemented")


    initializer = tf.random_uniform_initializer(-0.1, 0.1)

    best_config = None
    best_valid_loss = None

    # set up run dir
    run_folder = os.path.join(args.model_dir, args.run_name)
    if os.path.exists(run_folder):
        raise Exception("Run name {} already exists, choose a different one", format(run_folder))
    os.makedirs(run_folder)

    logger = logging.getLogger(__name__) 
    logger.setLevel(logging.INFO)
    logger.addHandler(logging.StreamHandler())
    logger.addHandler(logging.FileHandler(os.path.join(run_folder, "training.log")))

    # grid
    grid = {
        "dropout_prob": [0.65],
        "input_dropout_prob": [0.9],
        "num_layers": [1],
        "hidden_size": [100]
    }

    # Generate product of hyperparams
    runs = list(list(itertools.izip(grid, x)) for x in itertools.product(*grid.itervalues()))
    logger.info("{} runs detected".format(len(runs)))

    for combination in runs:

        config = DefaultConfig()
        config.dataset = args.dataset
        config.model_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(12)) + '.model'
        for attr, value in combination:
            setattr(config, attr, value)

        if config.dataset == 'softmax':
            data = util.load_data('', time_step, config.time_batch_len, config.max_time_batches, nottingham=pickle)
            config.input_dim = data["input_dim"]
        else:
            raise Exception("Other datasets not yet implemented")

        # cut away unnecessary parts
        r = nottingham_util.NOTTINGHAM_MELODY_RANGE
        if args.choice == 'melody':
            print "Using only melody"
            for d in ['train', 'test', 'valid']:
                new_data = []
                for batch_data, batch_targets in data[d]["data"]:
                    new_data.append(([tb[:, :, :r] for tb in batch_data],
                                     [tb[:, :, 0] for tb in batch_targets]))
                data[d]["data"] = new_data
        else:
            print "Using only harmony"
            for d in ['train', 'test', 'valid']:
                new_data = []
                for batch_data, batch_targets in data[d]["data"]:
                    new_data.append(([tb[:, :, r:] for tb in batch_data],
                                     [tb[:, :, 1] for tb in batch_targets]))
                data[d]["data"] = new_data

        input_dim = data["input_dim"] = data["train"]["data"][0][0][0].shape[2]
        config.input_dim = input_dim
        print "New input dim: {}".format(input_dim)

        logger.info(config)
        config_file_path = os.path.join(run_folder, get_config_name(config) + '.config')
        with open(config_file_path, 'w') as f: 
            cPickle.dump(config, f)

        with tf.Graph().as_default(), tf.Session() as session:
            with tf.variable_scope("model", reuse=None):
                train_model = model_class(config, training=True)
            with tf.variable_scope("model", reuse=True):
                valid_model = model_class(config, training=False)

            saver = tf.train.Saver(tf.all_variables())
            tf.initialize_all_variables().run()

            # training
            early_stop_best_loss = None
            start_saving = False
            saved_flag = False
            train_losses, valid_losses = [], []
            start_time = time.time()
            for i in range(config.num_epochs):
                loss = util.run_epoch(session, train_model, data["train"]["data"], training=True, testing=False)
                train_losses.append((i, loss))
                if i == 0:
                    continue

                valid_loss = util.run_epoch(session, valid_model, data["valid"]["data"], training=False, testing=False)
                valid_losses.append((i, valid_loss))

                logger.info('Epoch: {}, Train Loss: {}, Valid Loss: {}, Time Per Epoch: {}'.format(\
                        i, loss, valid_loss, (time.time() - start_time)/i))

                # if it's best validation loss so far, save it
                if early_stop_best_loss == None:
                    early_stop_best_loss = valid_loss
                elif valid_loss < early_stop_best_loss:
                    early_stop_best_loss = valid_loss
                    if start_saving:
                        logger.info('Best loss so far encountered, saving model.')
                        saver.save(session, os.path.join(run_folder, config.model_name))
                        saved_flag = True
                elif not start_saving:
                    start_saving = True 
                    logger.info('Valid loss increased for the first time, will start saving models')
                    saver.save(session, os.path.join(run_folder, config.model_name))
                    saved_flag = True

            if not saved_flag:
                saver.save(session, os.path.join(run_folder, config.model_name))

            # set loss axis max to 20
            axes = plt.gca()
            if config.dataset == 'softmax':
                axes.set_ylim([0, 2])
            else:
                axes.set_ylim([0, 100])
            plt.plot([t[0] for t in train_losses], [t[1] for t in train_losses])
            plt.plot([t[0] for t in valid_losses], [t[1] for t in valid_losses])
            plt.legend(['Train Loss', 'Validation Loss'])
            chart_file_path = os.path.join(run_folder, get_config_name(config) + '.png')
            plt.savefig(chart_file_path)
            plt.clf()

            logger.info("Config {}, Loss: {}".format(config, early_stop_best_loss))
            if best_valid_loss == None or early_stop_best_loss < best_valid_loss:
                logger.info("Found best new model!")
                best_valid_loss = early_stop_best_loss
                best_config = config

    logger.info("Best Config: {}, Loss: {}".format(best_config, best_valid_loss))


================================================
FILE: rnn_test.py
================================================
import os, sys
import argparse
import cPickle

import numpy as np
import tensorflow as tf    

import util
import nottingham_util
from model import Model, NottinghamModel, NottinghamSeparate
from rnn import DefaultConfig

if __name__ == '__main__':
    np.random.seed()      

    parser = argparse.ArgumentParser(description='Script to test a models performance against the test set')
    parser.add_argument('--config_file', type=str, required=True)
    parser.add_argument('--num_samples', type=int, default=1)
    parser.add_argument('--seperate', action='store_true', default=False)
    parser.add_argument('--choice', type=str, default='melody',
                        choices = ['melody', 'harmony'])
    args = parser.parse_args()

    with open(args.config_file, 'r') as f: 
        config = cPickle.load(f)

    if config.dataset == 'softmax':
        config.time_batch_len = 1
        config.max_time_batches = -1
        with open(nottingham_util.PICKLE_LOC, 'r') as f:
            pickle = cPickle.load(f)
        if args.seperate:
            model_class = NottinghamSeparate
            test_data = util.batch_data(pickle['test'], time_batch_len = 1, 
                max_time_batches = -1, softmax = True)
            r = nottingham_util.NOTTINGHAM_MELODY_RANGE
            if args.choice == 'melody':
                print "Using only melody"
                new_data = []
                for batch_data, batch_targets in test_data:
                    new_data.append(([tb[:, :, :r] for tb in batch_data],
                                     [tb[:, :, 0] for tb in batch_targets]))
                test_data = new_data
            else:
                print "Using only harmony"
                new_data = []
                for batch_data, batch_targets in test_data:
                    new_data.append(([tb[:, :, r:] for tb in batch_data],
                                     [tb[:, :, 1] for tb in batch_targets]))
                test_data = new_data
        else:
            model_class = NottinghamModel
            # use time batch len of 1 so that every target is covered
            test_data = util.batch_data(pickle['test'], time_batch_len = 1, 
                max_time_batches = -1, softmax = True)
    else:
        raise Exception("Other datasets not yet implemented")
        
    print config

    with tf.Graph().as_default(), tf.Session() as session:
        with tf.variable_scope("model", reuse=None):
            test_model = model_class(config, training=False)

        saver = tf.train.Saver(tf.all_variables())
        model_path = os.path.join(os.path.dirname(args.config_file), 
            config.model_name)
        saver.restore(session, model_path)
        
        test_loss, test_probs = util.run_epoch(session, test_model, test_data, 
            training=False, testing=True)
        print 'Testing Loss: {}'.format(test_loss)

        if config.dataset == 'softmax':
            if args.seperate:
                nottingham_util.seperate_accuracy(test_probs, test_data, num_samples=args.num_samples)
            else:
                nottingham_util.accuracy(test_probs, test_data, num_samples=args.num_samples)

        else:
            util.accuracy(test_probs, test_data, num_samples=50)

    sys.exit(1)


================================================
FILE: sampling.py
================================================
import numpy as np
from pprint import pprint

import midi_util


class Sampler(object):

    def __init__(self, min_prob=0.5, num_notes = 4, method = 'sample', verbose=False):
        self.min_prob = min_prob
        self.num_notes = num_notes
        self.method = method
        self.verbose = verbose

    def visualize_probs(self, probs):
        if not self.verbose:
            return
        print 'Highest four probs: '
        pprint(sorted(list(enumerate(probs)), key=lambda x: x[1], 
               reverse=True)[:4])

    def sample_notes_prob(self, probs, max_notes=-1):
        """ Samples all notes that are over a certain probability"""
        self.visualize_probs(probs)
        top_idxs = list()
        for idx in probs.argsort()[::-1]:
            if max_notes > 0 and len(top_idxs) >= max_notes:
                break
            if probs[idx] < self.min_prob:
                break
            top_idxs.append(idx)
        chord = np.zeros([len(probs)], dtype=np.int32)
        chord[top_idxs] = 1.0
        return chord

    def sample_notes_static(self, probs):
        top_idxs = probs.argsort()[-self.num_notes:][::-1]
        chord = np.zeros([len(probs)], dtype=np.int32)
        chord[top_idxs] = 1.0
        return chord

    def sample_notes_bernoulli(self, probs):
        chord = np.zeros([len(probs)], dtype=np.int32)
        for note, prob in enumerate(probs):
            if np.random.binomial(1, prob) > 0:
                chord[note] = 1
        return chord

    def sample_notes(self, probs):
        """ Samples a static amount of notes from probabilities by highest prob """
        self.visualize_probs(probs)
        if self.method == 'sample':
            return self.sample_notes_bernoulli(probs)
        elif self.method == 'static':
            return self.sample_notes_static(probs)
        elif self.method == 'min_prob':
            return self.sample_notes_prob(probs)
        else:
            raise Exception("Unrecognized method: {}".format(self.method))


================================================
FILE: util.py
================================================
import os
import math
import cPickle
from collections import defaultdict
from random import shuffle

import numpy as np
import tensorflow as tf    

import midi_util
import nottingham_util

def parse_midi_directory(input_dir, time_step):
    """ 
    input_dir: data directory full of midi files
    time_step: the number of ticks to use as a time step for discretization

    Returns a list of [T x D] matrices, where T is the amount of time steps
    and D is the range of notes.
    """
    files = [ os.path.join(input_dir, f) for f in os.listdir(input_dir)
              if os.path.isfile(os.path.join(input_dir, f)) ] 
    sequences = [ \
        (f, midi_util.parse_midi_to_sequence(f, time_step=time_step)) \
        for f in files ]

    return sequences

def batch_data(sequences, time_batch_len=128, max_time_batches=10,
               softmax=False, verbose=False):
    """
    sequences: a list of [T x D] matrices, each matrix representing a sequencey
    time_batch_len: the unrolling length that will be used by BPTT. 
    max_time_batches: the max amount of time batches to consider. Any sequences 
                      longert than max_time_batches * time_batch_len will be ignored
                      Can be set to -1 to all time batches needed.
    softmax: Flag should be set to true if using the dual-softmax formualtion

    returns [
        [ [ data ], [ target ] ], # batch with one time step
        [ [ data1, data2 ], [ target1, target2 ] ], # batch with two time steps
        ...
    ]
    """

    assert time_batch_len > 0

    dims = sequences[0].shape[1]
    sequence_lens = [s.shape[0] for s in sequences]

    if verbose:
        avg_seq_len = sum(sequence_lens) / len(sequences)
        print "Average Sequence Length: {}".format(avg_seq_len)
        print "Max Sequence Length: {}".format(time_batch_len)
        print "Number of sequences: {}".format(len(sequences))

    batches = defaultdict(list)
    for sequence in sequences:
        # -1 because we can't predict the first step
        num_time_steps = ((sequence.shape[0]-1) // time_batch_len) 
        if num_time_steps < 1:
            continue
        if max_time_batches > 0 and num_time_steps > max_time_batches:
            continue
        batches[num_time_steps].append(sequence)

    if verbose:
        print "Batch distribution:"
        print [(k, len(v)) for (k, v) in batches.iteritems()]

    def arrange_batch(sequences, num_time_steps):
        sequences = [s[:(num_time_steps*time_batch_len)+1, :] for s in sequences]
        stacked = np.dstack(sequences)
        # swap axes so that shape is (SEQ_LENGTH X BATCH_SIZE X INPUT_DIM)
        data = np.swapaxes(stacked, 1, 2)
        targets = np.roll(data, -1, axis=0)
        # cutoff final time step
        data = data[:-1, :, :]
        targets = targets[:-1, :, :]
        assert data.shape == targets.shape

        if softmax:
            r = nottingham_util.NOTTINGHAM_MELODY_RANGE
            labels = np.ones((targets.shape[0], targets.shape[1], 2), dtype=np.int32)
            assert np.all(np.sum(targets[:, :, :r], axis=2) == 1)
            assert np.all(np.sum(targets[:, :, r:], axis=2) == 1)
            labels[:, :, 0] = np.argmax(targets[:, :, :r], axis=2)
            labels[:, :, 1] = np.argmax(targets[:, :, r:], axis=2)
            targets = labels
            assert targets.shape[:2] == data.shape[:2]

        assert data.shape[0] == num_time_steps * time_batch_len

        # split them up into time batches
        tb_data = np.split(data, num_time_steps, axis=0)
        tb_targets = np.split(targets, num_time_steps, axis=0)

        assert len(tb_data) == len(tb_targets) == num_time_steps
        for i in range(len(tb_data)):
            assert tb_data[i].shape[0] == time_batch_len
            assert tb_targets[i].shape[0] == time_batch_len
            if softmax:
                assert np.all(np.sum(tb_data[i], axis=2) == 2)

        return (tb_data, tb_targets)

    return [ arrange_batch(b, n) for n, b in batches.iteritems() ]
        
def load_data(data_dir, time_step, time_batch_len, max_time_batches, nottingham=None):
    """
    nottingham: The sequences object as created in prepare_nottingham_pickle
                (see nottingham_util for more). If None, parse all the MIDI
                files from data_dir
    time_step: the time_step used to parse midi files (only used if data_dir
               is provided)
    time_batch_len and max_time_batches: see batch_data()

    returns { 
        "train": {
            "data": [ batch_data() ],
            "metadata: { ... }
        },
        "valid": { ... }
        "test": { ... }
    }
    """

    data = {}
    for dataset in ['train', 'test', 'valid']:

        # For testing, use ALL the sequences
        if dataset == 'test':
            max_time_batches = -1

        # Softmax formualation preparsed into sequences
        if nottingham:
            sequences = nottingham[dataset]
            metadata = nottingham[dataset + '_metadata']
        # Cross-entropy formulation needs to be parsed
        else:
            sf = parse_midi_directory(os.path.join(data_dir, dataset), time_step)
            sequences = [s[1] for s in sf]
            files = [s[0] for s in sf]
            metadata = [{
                'path': f,
                'name': f.split("/")[-1].split(".")[0]
            } for f in files]

        dataset_data = batch_data(sequences, time_batch_len, max_time_batches, softmax = True if nottingham else False)

        data[dataset] = {
            "data": dataset_data,
            "metadata": metadata,
        }

        data["input_dim"] = dataset_data[0][0][0].shape[2]

    return data


def run_epoch(session, model, batches, training=False, testing=False):
    """
    session: Tensorflow session object
    model: model object (see model.py)
    batches: data object loaded from util_data()

    training: A backpropagation iteration will be performed on the dataset
    if this flag is active

    returns average loss per time step over all batches.
    if testing flag is active: returns [ loss, probs ] where is the probability
        values for each note
    """

    # shuffle batches
    shuffle(batches)

    target_tensors = [model.loss, model.final_state]
    if testing:
        target_tensors.append(model.probs)
        batch_probs = defaultdict(list)
    if training:
        target_tensors.append(model.train_step)

    losses = []
    for data, targets in batches:
        # save state over unrolling time steps
        batch_size = data[0].shape[1]
        num_time_steps = len(data)
        state = model.get_cell_zero_state(session, batch_size) 
        probs = list()

        for tb_data, tb_targets in zip(data, targets):
            if testing:
                tbd = tb_data
                tbt = tb_targets
            else:
                # shuffle all the batches of input, state, and target
                batches = tb_data.shape[1]
                permutations = np.random.permutation(batches)
                tbd = np.zeros_like(tb_data)
                tbd[:, np.arange(batches), :] = tb_data[:, permutations, :]
                tbt = np.zeros_like(tb_targets)
                tbt[:, np.arange(batches), :] = tb_targets[:, permutations, :]
                state[np.arange(batches)] = state[permutations]

            feed_dict = {
                model.initial_state: state,
                model.seq_input: tbd,
                model.seq_targets: tbt,
            }
            results = session.run(target_tensors, feed_dict=feed_dict)

            losses.append(results[0])
            state = results[1]
            if testing:
                batch_probs[num_time_steps].append(results[2])

    loss = sum(losses) / len(losses)

    if testing:
        return [loss, batch_probs]
    else:
        return loss

def accuracy(batch_probs, data, num_samples=20):
    """
    batch_probs: probs object returned from run_epoch
    data: data object passed into run_epoch
    num_samples: the number of times to sample each note (an average over all
    these samples will be used)

    returns the accuracy metric according to
    http://ismir2009.ismir.net/proceedings/PS2-21.pdf
    """

    false_positives, false_negatives, true_positives = 0, 0, 0 
    for _, batch_targets in data:
        num_time_steps = len(batch_data)
        for ts_targets, ts_probs in zip(batch_targets, batch_probs[num_time_steps]):

            assert ts_targets.shape == ts_targets.shape

            for seq_idx in range(ts_targets.shape[1]):
                for step_idx in range(ts_targets.shape[0]):
                    for note_idx, prob in enumerate(ts_probs[step_idx, seq_idx, :]):
                        num_occurrences = np.random.binomial(num_samples, prob)
                        if ts_targets[step_idx, seq_idx, note_idx] == 0.0:
                            false_positives += num_occurrences
                        else:
                            false_negatives += (num_samples - num_occurrences)
                            true_positives += num_occurrences
                
    accuracy = (float(true_positives) / float(true_positives + false_positives + false_negatives)) 

    print "Precision: {}".format(float(true_positives) / (float(true_positives + false_positives)))
    print "Recall: {}".format(float(true_positives) / (float(true_positives + false_negatives)))
    print "Accuracy: {}".format(accuracy)
Download .txt
gitextract_o25hts11/

├── .gitignore
├── README.md
├── css/
│   └── style.css
├── data_samples/
│   ├── alb_esp1.mid
│   ├── ashover_simple_chords_1.mid
│   ├── bach_chorale.mid
│   ├── koopa_troopa_beach.mid
│   └── reels_simple_chords_157.mid
├── index.html
├── install.sh
├── midi_util.py
├── model.py
├── nottingham_util.py
├── requirements.txt
├── rnn.py
├── rnn_sample.py
├── rnn_separate.py
├── rnn_test.py
├── sampling.py
└── util.py
Download .txt
SYMBOL INDEX (54 symbols across 6 files)

FILE: midi_util.py
  function round_tick (line 8) | def round_tick(tick, time_step):
  function ingest_notes (line 11) | def ingest_notes(track, verbose=False):
  function round_notes (line 45) | def round_notes(notes, track_ticks, time_step, R=None, O=None):
  function parse_midi_to_sequence (line 87) | def parse_midi_to_sequence(input_filename, time_step, verbose=False):
  class MidiWriter (line 141) | class MidiWriter(object):
    method __init__ (line 143) | def __init__(self, verbose=False):
    method note_off (line 147) | def note_off(self, val, tick):
    method note_on (line 151) | def note_on(self, val, tick):
    method dump_sequence_to_midi (line 155) | def dump_sequence_to_midi(self, sequence, output_filename, time_step,

FILE: model.py
  class Model (line 10) | class Model(object):
    method __init__ (line 20) | def __init__(self, config, training=False):
    method init_loss (line 89) | def init_loss(self, outputs, _):
    method calculate_probs (line 97) | def calculate_probs(self, logits):
    method get_cell_zero_state (line 100) | def get_cell_zero_state(self, session, batch_size):
  class NottinghamModel (line 103) | class NottinghamModel(Model):
    method init_loss (line 112) | def init_loss(self, outputs, outputs_concat):
    method calculate_probs (line 132) | def calculate_probs(self, logits):
    method assign_melody_coeff (line 140) | def assign_melody_coeff(self, session, melody_coeff):
  class NottinghamSeparate (line 146) | class NottinghamSeparate(Model):
    method init_loss (line 154) | def init_loss(self, outputs, outputs_concat):
    method calculate_probs (line 168) | def calculate_probs(self, logits):

FILE: nottingham_util.py
  function resolve_chord (line 30) | def resolve_chord(chord):
  function prepare_nottingham_pickle (line 50) | def prepare_nottingham_pickle(time_step, chord_cutoff=64, filename=PICKL...
  function parse_nottingham_directory (line 133) | def parse_nottingham_directory(input_dir, time_step, verbose=False):
  function parse_nottingham_to_sequence (line 158) | def parse_nottingham_to_sequence(input_filename, time_step, verbose=False):
  class NottinghamMidiWriter (line 243) | class NottinghamMidiWriter(midi_util.MidiWriter):
    method __init__ (line 245) | def __init__(self, chord_to_idx, verbose=False):
    method dereference_chord (line 250) | def dereference_chord(self, idx):
    method note_on (line 259) | def note_on(self, val, tick):
    method note_off (line 276) | def note_off(self, val, tick):
  class NottinghamSampler (line 289) | class NottinghamSampler(object):
    method __init__ (line 291) | def __init__(self, chord_to_idx, method = 'sample', harmony_repeat_max...
    method visualize_probs (line 304) | def visualize_probs(self, probs):
    method sample_notes_static (line 318) | def sample_notes_static(self, probs):
    method sample_notes_dist (line 346) | def sample_notes_dist(self, probs):
    method sample_notes (line 369) | def sample_notes(self, probs):
  function accuracy (line 376) | def accuracy(batch_probs, data, num_samples=1):
  function seperate_accuracy (line 440) | def seperate_accuracy(batch_probs, data, num_samples=1):
  function i_vi_iv_v (line 480) | def i_vi_iv_v(chord_to_idx, repeats, input_dim):

FILE: rnn.py
  function get_config_name (line 18) | def get_config_name(config):
  class DefaultConfig (line 26) | class DefaultConfig(object):
    method __repr__ (line 46) | def __repr__(self):

FILE: sampling.py
  class Sampler (line 7) | class Sampler(object):
    method __init__ (line 9) | def __init__(self, min_prob=0.5, num_notes = 4, method = 'sample', ver...
    method visualize_probs (line 15) | def visualize_probs(self, probs):
    method sample_notes_prob (line 22) | def sample_notes_prob(self, probs, max_notes=-1):
    method sample_notes_static (line 36) | def sample_notes_static(self, probs):
    method sample_notes_bernoulli (line 42) | def sample_notes_bernoulli(self, probs):
    method sample_notes (line 49) | def sample_notes(self, probs):

FILE: util.py
  function parse_midi_directory (line 13) | def parse_midi_directory(input_dir, time_step):
  function batch_data (line 29) | def batch_data(sequences, time_batch_len=128, max_time_batches=10,
  function load_data (line 109) | def load_data(data_dir, time_step, time_batch_len, max_time_batches, not...
  function run_epoch (line 161) | def run_epoch(session, model, batches, training=False, testing=False):
  function accuracy (line 226) | def accuracy(batch_probs, data, num_samples=20):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (87K chars).
[
  {
    "path": ".gitignore",
    "chars": 81,
    "preview": "data\n2012code\nmodels\npython-midi\nresearch\ntensorflow\n\n*.midi\n*.pyc\nimg/.DS_Store\n"
  },
  {
    "path": "README.md",
    "chars": 1122,
    "preview": "Overview\n============\nA project that trains a LSTM recurrent neural network over a dataset of MIDI files. More informati"
  },
  {
    "path": "css/style.css",
    "chars": 1297,
    "preview": "body {\n  font-family: 'Maven Pro', sans-serif;\n  margin: 20px;\n}\n\n.audio-wrapper {\n  padding: 10px 0px;\n}\n\n.audio-wrappe"
  },
  {
    "path": "index.html",
    "chars": 13850,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n  <head>\n    <meta charset=\"utf-8\">\n    <meta name=\"viewport\" content=\"width=device-wid"
  },
  {
    "path": "install.sh",
    "chars": 297,
    "preview": "conda create -n music_rnn python=2.7\nsource activate music_rnn\n\npip install -r requirements.txt\n\nmkdir models\n\nmkdir dat"
  },
  {
    "path": "midi_util.py",
    "chars": 8057,
    "preview": "import sys, os\nfrom collections import defaultdict\nimport numpy as np\nimport midi\n\nRANGE = 128\n\ndef round_tick(tick, tim"
  },
  {
    "path": "model.py",
    "chars": 7095,
    "preview": "import os\nimport logging\nimport numpy as np\nimport tensorflow as tf    \nfrom tensorflow.models.rnn import rnn_cell\nfrom "
  },
  {
    "path": "nottingham_util.py",
    "chars": 17957,
    "preview": "import numpy as np\nimport os\nimport midi\nimport cPickle\nfrom pprint import pprint\n\nimport midi_util\nimport mingus\nimport"
  },
  {
    "path": "requirements.txt",
    "chars": 199,
    "preview": "matplotlib\nmingus\nnumpy\ngit+https://github.com/vishnubob/python-midi#egg=midi\n# Linux, Python 2.7, GPU\nhttps://storage.g"
  },
  {
    "path": "rnn.py",
    "chars": 7498,
    "preview": "import os, sys\nimport argparse\nimport time\nimport itertools\nimport cPickle\nimport logging\nimport random\nimport string\n\ni"
  },
  {
    "path": "rnn_sample.py",
    "chars": 4500,
    "preview": "import os, sys\nimport argparse\nimport time\nimport itertools\nimport cPickle\n\nimport numpy as np\nimport tensorflow as tf  "
  },
  {
    "path": "rnn_separate.py",
    "chars": 7317,
    "preview": "import os, sys\nimport argparse\nimport time\nimport itertools\nimport cPickle\nimport logging\nimport random\nimport string\nim"
  },
  {
    "path": "rnn_test.py",
    "chars": 3269,
    "preview": "import os, sys\nimport argparse\nimport cPickle\n\nimport numpy as np\nimport tensorflow as tf    \n\nimport util\nimport nottin"
  },
  {
    "path": "sampling.py",
    "chars": 2011,
    "preview": "import numpy as np\nfrom pprint import pprint\n\nimport midi_util\n\n\nclass Sampler(object):\n\n    def __init__(self, min_prob"
  },
  {
    "path": "util.py",
    "chars": 9464,
    "preview": "import os\nimport math\nimport cPickle\nfrom collections import defaultdict\nfrom random import shuffle\n\nimport numpy as np\n"
  }
]

// ... and 5 more files (download for full content)

About this extraction

This page contains the full source code of the yoavz/music_rnn GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (82.0 KB), approximately 20.2k tokens, and a symbol index with 54 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!