Full Code of MarcusOlivecrona/REINVENT for AI

master 752935c29d46 cached
20 files
103.0 MB
13.0k tokens
77 symbols
1 requests
Download .txt
Repository: MarcusOlivecrona/REINVENT
Branch: master
Commit: 752935c29d46
Files: 20
Total size: 103.0 MB

Directory structure:
gitextract_jbrxke6h/

├── LICENSE
├── README.md
├── Vizard/
│   ├── main.py
│   ├── run.sh
│   ├── templates/
│   │   ├── index.html
│   │   └── styles.css
│   └── theme.yaml
├── data/
│   ├── ChEMBL_filtered
│   ├── Prior.ckpt
│   ├── Voc
│   └── clf.pkl
├── data_structs.py
├── main.py
├── model.py
├── multiprocess.py
├── scoring_functions.py
├── train_agent.py
├── train_prior.py
├── utils.py
└── vizard_logger.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
Copyright <2017> <Marcus Olivecrona>

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.


================================================
FILE: README.md
================================================

# REINVENT
## Molecular De Novo design using Recurrent Neural Networks and Reinforcement Learning

Searching chemical space as described in:

[Molecular De Novo Design through Deep Reinforcement Learning](https://arxiv.org/abs/1704.07555)

![Video demonstrating an Agent trained to generate analogues to Celecoxib](https://github.com/MarcusOlivecrona/REINVENT/blob/master/images/celecoxib_analogues.gif "Training an Agent to generate analogues of Celecoxib")


## Notes
The current version is a PyTorch implementation that differs in several ways from the original implementation described in the paper. This version works better in most situations and is better documented, but for the purpose of reproducing results from the paper refer to [Release v1.0.1](https://github.com/MarcusOlivecrona/REINVENT/releases/tag/v1.0.1)

Differences from implmentation in the paper:
* Written in PyTorch/Python3.6 rather than TF/Python2.7
* SMILES are encoded with token index rather than as a onehot of the index. An embedding matrix is then used to transform the token index to a feature vector.
* Scores are in the range (0,1).
* A regularizer that penalizes high values of total episodic likelihood is included.
* Sequences are only considered once, ie if the same sequence is generated twice in a batch only the first instance contributes to the loss.
* These changes makes the algorithm more robust towards local minima, means much higher values of sigma can be used if needed.

## Requirements

This package requires:
* Python 3.6
* PyTorch 0.1.12 
* [RDkit](http://www.rdkit.org/docs/Install.html)
* Scikit-Learn (for QSAR scoring function)
* tqdm (for training Prior)
* pexpect

## Usage

To train a Prior starting with a SMILES file called mols.smi:

* First filter the SMILES and construct a vocabulary from the remaining sequences. `./data_structs.py mols.smi`   - Will generate data/mols_filtered.smi and data/Voc. A filtered file containing around 1.1 million SMILES and the corresponding Voc is contained in "data".

* Then use `./train_prior.py` to train the Prior. A pretrained Prior is included.

To train an Agent using our Prior, use the main.py script. For example:

* `./main.py --scoring-function activity_model --num-steps 1000`

Training can be visualized using the Vizard bokeh app. The vizard_logger.py is used to log information (by default to data/logs) such as structures generated, average score, and network weights.

* `cd Vizard`
* `./run.sh ../data/logs`
* Open the browser at http://localhost:5006/Vizard




================================================
FILE: Vizard/main.py
================================================
from bokeh.plotting import figure, ColumnDataSource, curdoc
from bokeh.models import CustomJS, Range1d
from bokeh.models.glyphs import Text
from bokeh.layouts import row, column, widgetbox, layout
from bokeh.models.widgets import Div
import bokeh.palettes
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit import rdBase

import sys
import os.path
import numpy as np
import math

"""Bokeh app that visualizes training progress for the De Novo design reinforcement learning.
   The app is updated dynamically using information that the train_agent.py script writes to a
   logging directory."""

rdBase.DisableLog('rdApp.error')

error_msg = """Need to provide valid log directory as first argument.
                     'bokeh serve . --args [log_dir]'"""
try:
    path = sys.argv[1]
except IndexError:
    raise IndexError(error_msg)
if not os.path.isdir(path):
    raise ValueError(error_msg)

score_source = ColumnDataSource(data=dict(x=[], y=[], y_mean=[]))
score_fig = figure(title="Scores", plot_width=600, plot_height=600)
score_fig.line('x', 'y', legend='Average score', source=score_source)
score_fig.line('x', 'y_mean', legend='Running average of average score', line_width=2, 
               color="firebrick", source=score_source)

score_fig.xaxis.axis_label = "Step"
score_fig.yaxis.axis_label = "Average Score"
score_fig.title.text_font_size = "20pt"
score_fig.legend.location = "bottom_right"
score_fig.css_classes = ["score_fig"]

img_fig = Div(text="", width=850, height=590)
img_fig.css_classes = ["img_outside"]

def downsample(data, max_len):
    np.random.seed(0)
    if len(data)>max_len:
        data = np.random.choice(data, size=max_len, replace=False)
    return data

def running_average(data, length):
    early_cumsum = np.cumsum(data[:length]) / np.arange(1, min(len(data), length) + 1)
    if len(data)>length:
        cumsum = np.cumsum(data) 
        cumsum =  (cumsum[length:] - cumsum[:-length]) / length
        cumsum = np.concatenate((early_cumsum, cumsum))
        return cumsum
    return early_cumsum

def create_bar_plot(init_data, title):
    init_data = downsample(init_data, 50)
    x = range(len(init_data))
    source = ColumnDataSource(data=dict(x= [], y=[]))
    fig = figure(title=title, plot_width=300, plot_height=300)
    fig.vbar(x=x, width=1, top=init_data, fill_alpha=0.05)
    fig.vbar('x', width=1, top='y', fill_alpha=0.3, source=source)
    fig.y_range = Range1d(min(0, 1.2 * min(init_data)), 1.2 * max(init_data))
    return fig, source

def create_hist_plot(init_data, title):
    source = ColumnDataSource(data=dict(hist=[], left_edge=[], right_edge=[]))
    init_hist, init_edge = np.histogram(init_data, density=True, bins=50)
    fig = figure(title=title, plot_width=300, plot_height=300)
    fig.quad(top=init_hist, bottom=0, left=init_edge[:-1], right=init_edge[1:],
            fill_alpha=0.05)
    fig.quad(top='hist', bottom=0, left='left_edge', right='right_edge',
            fill_alpha=0.3, source=source)
    return fig, source


weights = [f for f in os.listdir(path) if f.startswith("weight")]
weights = {w:{'init_weight': np.load(os.path.join(path, "init_" + w)).reshape(-1)} for w in weights}

for name, w in weights.items():
    w['bar_fig'], w['bar_source'] = create_bar_plot(w['init_weight'], name)
    w['hist_fig'], w['hist_source'] = create_hist_plot(w['init_weight'], name + "_histogram")

bar_plots = [w['bar_fig'] for name, w in weights.items()]
hist_plots = [w['hist_fig'] for name, w in weights.items()]

layout = layout([[img_fig, score_fig], bar_plots, hist_plots], sizing_mode="fixed")
curdoc().add_root(layout)

def update():
    score = np.load(os.path.join(path, "Scores.npy"))
    with open(os.path.join(path, "SMILES"), "r") as f:
        mols = []
        scores = []
        for line in f:
                line = line.split()
                mol = Chem.MolFromSmiles(line[0])
                if mol and len(mols)<6:
                    mols.append(mol)
                    scores.append(line[1])
    img = Draw.MolsToGridImage(mols, molsPerRow=3, legends=scores, subImgSize=(250,250), useSVG=True)
    img = img.replace("FFFFFF", "EDEDED")
    img_fig.text = '<h2>Generated Molecules</h2>' + '<div class="img_inside">' + img + '</div>'
    score_source.data = dict(x=score[0], y=score[1], y_mean=running_average(score[1], 50))

    for name, w in weights.items():
        current_weights = np.load(os.path.join(path, name)).reshape(-1)
        hist, edge = np.histogram(current_weights, density=True, bins=50)
        w['hist_source'].data = dict(hist=hist, left_edge=edge[:-1], right_edge=edge[1:])
        current_weights = downsample(current_weights, 50)
        w['bar_source'].data = dict(x=range(len(current_weights)), y=current_weights)

update()
curdoc().add_periodic_callback(update, 1000)



================================================
FILE: Vizard/run.sh
================================================
#!/bin/bash
if [ -z "$1" ];
    then echo "Must supply path to a directory where vizard_logger is saving its information";
    exit 0
fi
bokeh serve . --args $1


================================================
FILE: Vizard/templates/index.html
================================================
<!DOCTYPE html>
<html lang="en">
    <head>
        {{ bokeh_css }}
        {{ bokeh_js }}
        <style>
             {% include 'styles.css' %}
        </style>
        <meta charset="utf-8">
        <title>MolExplorer</title>
    </head>
    <body>
    <div>
        <h1>Vizard</h1>
        {{ plot_div|indent(8) }}
    </div>
        {{ plot_script|indent(8) }}
    </body>
</html>


================================================
FILE: Vizard/templates/styles.css
================================================
html {
    background-color: #2F2F2F;
    display: table;
    margin: auto;
}

body {
    display: table-cell;
    vertical-align: middle;
    color: #fff;
}

.img_outside {
    position: relative;
}

.img_inside {
    background-color: #EDEDED;
    border: 7px solid #656565;
    position:absolute;
    left:50% ;
    margin-left: -375px;
    top:50% ;
    margin-top: -250px;
}

.score_fig{
    position: absolute;
    top: 10px;
}

h1 {
    margin: 0.5em 0 0.5em 0;
    color: #fff;
    font-family: 'Julius Sans One', sans-serif;
    font-size: 3em;
    text-transform: uppercase;
    text-align: center;
}

h2 {
    margin: 0 0 0 0;
    color: #fff;
    font-size: 20pt;
    text-align: center;
}

a:link {
    font-weight: bold;
    text-decoration: none;
    color: #0d8ba1;
}
a:visited {
    font-weight: bold;
    text-decoration: none;
    color: #1a5952;
}
a:hover, a:focus, a:active {
    text-decoration: underline;
    color: #9685BA;
}


================================================
FILE: Vizard/theme.yaml
================================================
attrs:
    Figure:
        background_fill_color: '#2F2F2F'
        border_fill_color: '#2F2F2F'
        outline_line_color: '#444444'
        min_border_top: 0
    Axis:
        axis_line_color: "#FFFFFF"
        axis_label_text_color: "#FFFFFF"
        axis_label_text_font_size: "10pt"
        axis_label_text_font_style: "normal"
        axis_label_standoff: 10
        major_label_text_color: "#FFFFFF"
        major_tick_line_color: "#FFFFFF"
        minor_tick_line_color: "#FFFFFF"
        minor_tick_line_color: "#FFFFFF"
    Grid:
        grid_line_dash: [6, 4]
        grid_line_alpha: .3
    Title:
        text_color: "#FFFFFF"
        align: "center"


================================================
FILE: data/ChEMBL_filtered
================================================
[File too large to display: 53.3 MB]

================================================
FILE: data/Prior.ckpt
================================================
[File too large to display: 15.9 MB]

================================================
FILE: data/Voc
================================================
[S-]
9
(
S
c
[NH+]
3
[CH]
o
[NH3+]
[nH]
7
6
[N]
1
O
%
[N-]
5
-
[O+]
[n+]
[o+]
[nH+]
[NH2+]
[N+]
[O-]
[S+]
R
F
[n-]
[s+]
L
s
8
4
[SH]
2
=
n
)
[O]
N
#
[NH-]
C
[SH+]
0


================================================
FILE: data/clf.pkl
================================================
[File too large to display: 33.8 MB]

================================================
FILE: data_structs.py
================================================
import numpy as np
import random
import re
import pickle
from rdkit import Chem
import sys
import time
import torch
from torch.utils.data import Dataset

from utils import Variable

class Vocabulary(object):
    """A class for handling encoding/decoding from SMILES to an array of indices"""
    def __init__(self, init_from_file=None, max_length=140):
        self.special_tokens = ['EOS', 'GO']
        self.additional_chars = set()
        self.chars = self.special_tokens
        self.vocab_size = len(self.chars)
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}
        self.max_length = max_length
        if init_from_file: self.init_from_file(init_from_file)

    def encode(self, char_list):
        """Takes a list of characters (eg '[NH]') and encodes to array of indices"""
        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
        for i, char in enumerate(char_list):
            smiles_matrix[i] = self.vocab[char]
        return smiles_matrix

    def decode(self, matrix):
        """Takes an array of indices and returns the corresponding SMILES"""
        chars = []
        for i in matrix:
            if i == self.vocab['EOS']: break
            chars.append(self.reversed_vocab[i])
        smiles = "".join(chars)
        smiles = smiles.replace("L", "Cl").replace("R", "Br")
        return smiles

    def tokenize(self, smiles):
        """Takes a SMILES and return a list of characters/tokens"""
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        tokenized = []
        for char in char_list:
            if char.startswith('['):
                tokenized.append(char)
            else:
                chars = [unit for unit in char]
                [tokenized.append(unit) for unit in chars]
        tokenized.append('EOS')
        return tokenized

    def add_characters(self, chars):
        """Adds characters to the vocabulary"""
        for char in chars:
            self.additional_chars.add(char)
        char_list = list(self.additional_chars)
        char_list.sort()
        self.chars = char_list + self.special_tokens
        self.vocab_size = len(self.chars)
        self.vocab = dict(zip(self.chars, range(len(self.chars))))
        self.reversed_vocab = {v: k for k, v in self.vocab.items()}

    def init_from_file(self, file):
        """Takes a file containing \n separated characters to initialize the vocabulary"""
        with open(file, 'r') as f:
            chars = f.read().split()
        self.add_characters(chars)

    def __len__(self):
        return len(self.chars)

    def __str__(self):
        return "Vocabulary containing {} tokens: {}".format(len(self), self.chars)

class MolData(Dataset):
    """Custom PyTorch Dataset that takes a file containing SMILES.

        Args:
                fname : path to a file containing \n separated SMILES.
                voc   : a Vocabulary instance

        Returns:
                A custom PyTorch dataset for training the Prior.
    """
    def __init__(self, fname, voc):
        self.voc = voc
        self.smiles = []
        with open(fname, 'r') as f:
            for line in f:
                self.smiles.append(line.split()[0])

    def __getitem__(self, i):
        mol = self.smiles[i]
        tokenized = self.voc.tokenize(mol)
        encoded = self.voc.encode(tokenized)
        return Variable(encoded)

    def __len__(self):
        return len(self.smiles)

    def __str__(self):
        return "Dataset containing {} structures.".format(len(self))

    @classmethod
    def collate_fn(cls, arr):
        """Function to take a list of encoded sequences and turn them into a batch"""
        max_length = max([seq.size(0) for seq in arr])
        collated_arr = Variable(torch.zeros(len(arr), max_length))
        for i, seq in enumerate(arr):
            collated_arr[i, :seq.size(0)] = seq
        return collated_arr

class Experience(object):
    """Class for prioritized experience replay that remembers the highest scored sequences
       seen and samples from them with probabilities relative to their scores."""
    def __init__(self, voc, max_size=100):
        self.memory = []
        self.max_size = max_size
        self.voc = voc

    def add_experience(self, experience):
        """Experience should be a list of (smiles, score, prior likelihood) tuples"""
        self.memory.extend(experience)
        if len(self.memory)>self.max_size:
            # Remove duplicates
            idxs, smiles = [], []
            for i, exp in enumerate(self.memory):
                if exp[0] not in smiles:
                    idxs.append(i)
                    smiles.append(exp[0])
            self.memory = [self.memory[idx] for idx in idxs]
            # Retain highest scores
            self.memory.sort(key = lambda x: x[1], reverse=True)
            self.memory = self.memory[:self.max_size]
            print("\nBest score in memory: {:.2f}".format(self.memory[0][1]))

    def sample(self, n):
        """Sample a batch size n of experience"""
        if len(self.memory)<n:
            raise IndexError('Size of memory ({}) is less than requested sample ({})'.format(len(self), n))
        else:
            scores = [x[1] for x in self.memory]
            sample = np.random.choice(len(self), size=n, replace=False, p=scores/np.sum(scores))
            sample = [self.memory[i] for i in sample]
            smiles = [x[0] for x in sample]
            scores = [x[1] for x in sample]
            prior_likelihood = [x[2] for x in sample]
        tokenized = [self.voc.tokenize(smile) for smile in smiles]
        encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]
        encoded = MolData.collate_fn(encoded)
        return encoded, np.array(scores), np.array(prior_likelihood)

    def initiate_from_file(self, fname, scoring_function, Prior):
        """Adds experience from a file with SMILES
           Needs a scoring function and an RNN to score the sequences.
           Using this feature means that the learning can be very biased
           and is typically advised against."""
        with open(fname, 'r') as f:
            smiles = []
            for line in f:
                smile = line.split()[0]
                if Chem.MolFromSmiles(smile):
                    smiles.append(smile)
        scores = scoring_function(smiles)
        tokenized = [self.voc.tokenize(smile) for smile in smiles]
        encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]
        encoded = MolData.collate_fn(encoded)
        prior_likelihood, _ = Prior.likelihood(encoded.long())
        prior_likelihood = prior_likelihood.data.cpu().numpy()
        new_experience = zip(smiles, scores, prior_likelihood)
        self.add_experience(new_experience)

    def print_memory(self, path):
        """Prints the memory."""
        print("\n" + "*" * 80 + "\n")
        print("         Best recorded SMILES: \n")
        print("Score     Prior log P     SMILES\n")
        with open(path, 'w') as f:
            f.write("SMILES Score PriorLogP\n")
            for i, exp in enumerate(self.memory[:100]):
                if i < 50:
                    print("{:4.2f}   {:6.2f}        {}".format(exp[1], exp[2], exp[0]))
                    f.write("{} {:4.2f} {:6.2f}\n".format(*exp))
        print("\n" + "*" * 80 + "\n")

    def __len__(self):
        return len(self.memory)

def replace_halogen(string):
    """Regex to replace Br and Cl with single letters"""
    br = re.compile('Br')
    cl = re.compile('Cl')
    string = br.sub('R', string)
    string = cl.sub('L', string)

    return string

def tokenize(smiles):
    """Takes a SMILES string and returns a list of tokens.
    This will swap 'Cl' and 'Br' to 'L' and 'R' and treat
    '[xx]' as one token."""
    regex = '(\[[^\[\]]{1,6}\])'
    smiles = replace_halogen(smiles)
    char_list = re.split(regex, smiles)
    tokenized = []
    for char in char_list:
        if char.startswith('['):
            tokenized.append(char)
        else:
            chars = [unit for unit in char]
            [tokenized.append(unit) for unit in chars]
    tokenized.append('EOS')
    return tokenized

def canonicalize_smiles_from_file(fname):
    """Reads a SMILES file and returns a list of RDKIT SMILES"""
    with open(fname, 'r') as f:
        smiles_list = []
        for i, line in enumerate(f):
            if i % 100000 == 0:
                print("{} lines processed.".format(i))
            smiles = line.split(" ")[0]
            mol = Chem.MolFromSmiles(smiles)
            if filter_mol(mol):
                smiles_list.append(Chem.MolToSmiles(mol))
        print("{} SMILES retrieved".format(len(smiles_list)))
        return smiles_list

def filter_mol(mol, max_heavy_atoms=50, min_heavy_atoms=10, element_list=[6,7,8,9,16,17,35]):
    """Filters molecules on number of heavy atoms and atom types"""
    if mol is not None:
        num_heavy = min_heavy_atoms<mol.GetNumHeavyAtoms()<max_heavy_atoms
        elements = all([atom.GetAtomicNum() in element_list for atom in mol.GetAtoms()])
        if num_heavy and elements:
            return True
        else:
            return False

def write_smiles_to_file(smiles_list, fname):
    """Write a list of SMILES to a file."""
    with open(fname, 'w') as f:
        for smiles in smiles_list:
            f.write(smiles + "\n")

def filter_on_chars(smiles_list, chars):
    """Filters SMILES on the characters they contain.
       Used to remove SMILES containing very rare/undesirable
       characters."""
    smiles_list_valid = []
    for smiles in smiles_list:
        tokenized = tokenize(smiles)
        if all([char in chars for char in tokenized][:-1]):
            smiles_list_valid.append(smiles)
    return smiles_list_valid

def filter_file_on_chars(smiles_fname, voc_fname):
    """Filters a SMILES file using a vocabulary file.
       Only SMILES containing nothing but the characters
       in the vocabulary will be retained."""
    smiles = []
    with open(smiles_fname, 'r') as f:
        for line in f:
            smiles.append(line.split()[0])
    print(smiles[:10])
    chars = []
    with open(voc_fname, 'r') as f:
        for line in f:
            chars.append(line.split()[0])
    print(chars)
    valid_smiles = filter_on_chars(smiles, chars)
    with open(smiles_fname + "_filtered", 'w') as f:
        for smiles in valid_smiles:
            f.write(smiles + "\n")

def combine_voc_from_files(fnames):
    """Combine two vocabularies"""
    chars = set()
    for fname in fnames:
        with open(fname, 'r') as f:
            for line in f:
                chars.add(line.split()[0])
    with open("_".join(fnames) + '_combined', 'w') as f:
        for char in chars:
            f.write(char + "\n")

def construct_vocabulary(smiles_list):
    """Returns all the characters present in a SMILES file.
       Uses regex to find characters/tokens of the format '[x]'."""
    add_chars = set()
    for i, smiles in enumerate(smiles_list):
        regex = '(\[[^\[\]]{1,6}\])'
        smiles = replace_halogen(smiles)
        char_list = re.split(regex, smiles)
        for char in char_list:
            if char.startswith('['):
                add_chars.add(char)
            else:
                chars = [unit for unit in char]
                [add_chars.add(unit) for unit in chars]

    print("Number of characters: {}".format(len(add_chars)))
    with open('data/Voc', 'w') as f:
        for char in add_chars:
            f.write(char + "\n")
    return add_chars

if __name__ == "__main__":
    smiles_file = sys.argv[1]
    print("Reading smiles...")
    smiles_list = canonicalize_smiles_from_file(smiles_file)
    print("Constructing vocabulary...")
    voc_chars = construct_vocabulary(smiles_list)
    write_smiles_to_file(smiles_list, "data/mols_filtered.smi")


================================================
FILE: main.py
================================================
#!/usr/bin/env python
import argparse
import time
import os
from train_agent import train_agent


parser = argparse.ArgumentParser(description="Main script for running the model")
parser.add_argument('--scoring-function', action='store', dest='scoring_function',
                    choices=['activity_model', 'tanimoto', 'no_sulphur'],
                    default='tanimoto',
                    help='What type of scoring function to use.')
parser.add_argument('--scoring-function-kwargs', action='store', dest='scoring_function_kwargs',
                    nargs="*",
                    help='Additional arguments for the scoring function. Should be supplied with a '\
                    'list of "keyword_name argument". For pharmacophoric and tanimoto '\
                    'the keyword is "query_structure" and requires a SMILES. ' \
                    'For activity_model it is "clf_path " '\
                    'pointing to a sklearn classifier. '\
                    'For example: "--scoring-function-kwargs query_structure COc1ccccc1".')
parser.add_argument('--learning-rate', action='store', dest='learning_rate',
                    type=float, default=0.0005)
parser.add_argument('--num-steps', action='store', dest='n_steps', type=int,
                    default=3000)
parser.add_argument('--batch-size', action='store', dest='batch_size', type=int,
                    default=64)
parser.add_argument('--sigma', action='store', dest='sigma', type=int,
                    default=20)
parser.add_argument('--experience', action='store', dest='experience_replay', type=int,
                    default=0, help='Number of experience sequences to sample each step. '\
                    '0 means no experience replay.')
parser.add_argument('--num-processes', action='store', dest='num_processes',
                    type=int, default=0,
                    help='Number of processes used to run the scoring function. "0" means ' \
                    'that the scoring function will be run in the main process.')
parser.add_argument('--prior', action='store', dest='restore_prior_from',
                    default='data/Prior.ckpt',
                    help='Path to an RNN checkpoint file to use as a Prior')
parser.add_argument('--agent', action='store', dest='restore_agent_from',
                    default='data/Prior.ckpt',
                    help='Path to an RNN checkpoint file to use as a Agent.')
parser.add_argument('--save-dir', action='store', dest='save_dir',
                    help='Path where results and model are saved. Default is data/results/run_<datetime>.')

if __name__ == "__main__":

    arg_dict = vars(parser.parse_args())

    if arg_dict['scoring_function_kwargs']:
        kwarg_list = arg_dict.pop('scoring_function_kwargs')
        if not len(kwarg_list) % 2 == 0:
            raise ValueError("Scoring function kwargs must be given as pairs, "\
                             "but got a list with odd length.")
        kwarg_dict = {i:j for i, j in zip(kwarg_list[::2], kwarg_list[1::2])}
        arg_dict['scoring_function_kwargs'] = kwarg_dict
    else:
        arg_dict['scoring_function_kwargs'] = dict()

    train_agent(**arg_dict)


================================================
FILE: model.py
================================================
#!/usr/bin/env python

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import Variable

class MultiGRU(nn.Module):
    """ Implements a three layer GRU cell including an embedding layer
       and an output linear layer back to the size of the vocabulary"""
    def __init__(self, voc_size):
        super(MultiGRU, self).__init__()
        self.embedding = nn.Embedding(voc_size, 128)
        self.gru_1 = nn.GRUCell(128, 512)
        self.gru_2 = nn.GRUCell(512, 512)
        self.gru_3 = nn.GRUCell(512, 512)
        self.linear = nn.Linear(512, voc_size)

    def forward(self, x, h):
        x = self.embedding(x)
        h_out = Variable(torch.zeros(h.size()))
        x = h_out[0] = self.gru_1(x, h[0])
        x = h_out[1] = self.gru_2(x, h[1])
        x = h_out[2] = self.gru_3(x, h[2])
        x = self.linear(x)
        return x, h_out

    def init_h(self, batch_size):
        # Initial cell state is zero
        return Variable(torch.zeros(3, batch_size, 512))

class RNN():
    """Implements the Prior and Agent RNN. Needs a Vocabulary instance in
    order to determine size of the vocabulary and index of the END token"""
    def __init__(self, voc):
        self.rnn = MultiGRU(voc.vocab_size)
        if torch.cuda.is_available():
            self.rnn.cuda()
        self.voc = voc

    def likelihood(self, target):
        """
            Retrieves the likelihood of a given sequence

            Args:
                target: (batch_size * sequence_lenght) A batch of sequences

            Outputs:
                log_probs : (batch_size) Log likelihood for each example*
                entropy: (batch_size) The entropies for the sequences. Not
                                      currently used.
        """
        batch_size, seq_length = target.size()
        start_token = Variable(torch.zeros(batch_size, 1).long())
        start_token[:] = self.voc.vocab['GO']
        x = torch.cat((start_token, target[:, :-1]), 1)
        h = self.rnn.init_h(batch_size)

        log_probs = Variable(torch.zeros(batch_size))
        entropy = Variable(torch.zeros(batch_size))
        for step in range(seq_length):
            logits, h = self.rnn(x[:, step], h)
            log_prob = F.log_softmax(logits)
            prob = F.softmax(logits)
            log_probs += NLLLoss(log_prob, target[:, step])
            entropy += -torch.sum((log_prob * prob), 1)
        return log_probs, entropy

    def sample(self, batch_size, max_length=140):
        """
            Sample a batch of sequences

            Args:
                batch_size : Number of sequences to sample 
                max_length:  Maximum length of the sequences

            Outputs:
            seqs: (batch_size, seq_length) The sampled sequences.
            log_probs : (batch_size) Log likelihood for each sequence.
            entropy: (batch_size) The entropies for the sequences. Not
                                    currently used.
        """
        start_token = Variable(torch.zeros(batch_size).long())
        start_token[:] = self.voc.vocab['GO']
        h = self.rnn.init_h(batch_size)
        x = start_token

        sequences = []
        log_probs = Variable(torch.zeros(batch_size))
        finished = torch.zeros(batch_size).byte()
        entropy = Variable(torch.zeros(batch_size))
        if torch.cuda.is_available():
            finished = finished.cuda()

        for step in range(max_length):
            logits, h = self.rnn(x, h)
            prob = F.softmax(logits)
            log_prob = F.log_softmax(logits)
            x = torch.multinomial(prob).view(-1)
            sequences.append(x.view(-1, 1))
            log_probs +=  NLLLoss(log_prob, x)
            entropy += -torch.sum((log_prob * prob), 1)

            x = Variable(x.data)
            EOS_sampled = (x == self.voc.vocab['EOS']).data
            finished = torch.ge(finished + EOS_sampled, 1)
            if torch.prod(finished) == 1: break

        sequences = torch.cat(sequences, 1)
        return sequences.data, log_probs, entropy

def NLLLoss(inputs, targets):
    """
        Custom Negative Log Likelihood loss that returns loss per example,
        rather than for the entire batch.

        Args:
            inputs : (batch_size, num_classes) *Log probabilities of each class*
            targets: (batch_size) *Target class index*

        Outputs:
            loss : (batch_size) *Loss for each example*
    """

    if torch.cuda.is_available():
        target_expanded = torch.zeros(inputs.size()).cuda()
    else:
        target_expanded = torch.zeros(inputs.size())

    target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)
    loss = Variable(target_expanded) * inputs
    loss = torch.sum(loss, 1)
    return loss


================================================
FILE: multiprocess.py
================================================
#!/usr/bin/env python

import importlib
import sys

scoring_function = sys.argv[1]
func = getattr(importlib.import_module("scoring_functions"), scoring_function)()

while True:
    smile = sys.stdin.readline().rstrip()
    try:
        score = float(func(smile))
    except:
        score = 0.0
    sys.stdout.write(" ".join([smile, str(score), "\n"]))
    sys.stdout.flush()





================================================
FILE: scoring_functions.py
================================================
#!/usr/bin/env python
from __future__ import print_function, division
import numpy as np
from rdkit import Chem
from rdkit import rdBase
from rdkit.Chem import AllChem
from rdkit import DataStructs
from sklearn import svm
import time
import pickle
import re
import threading
import pexpect
rdBase.DisableLog('rdApp.error')

"""Scoring function should be a class where some tasks that are shared for every call
   can be reallocated to the __init__, and has a __call__ method which takes a single SMILES of
   argument and returns a float. A multiprocessing class will then spawn workers and divide the
   list of SMILES given between them.

   Passing *args and **kwargs through a subprocess call is slightly tricky because we need to know
   their types - everything will be a string once we have passed it. Therefor, we instead use class
   attributes which we can modify in place before any subprocess is created. Any **kwarg left over in
   the call to get_scoring_function will be checked against a list of (allowed) kwargs for the class
   and if a match is found the value of the item will be the new value for the class.

   If num_processes == 0, the scoring function will be run in the main process. Depending on how
   demanding the scoring function is and how well the OS handles the multiprocessing, this might
   be faster than multiprocessing in some cases."""

class no_sulphur():
    """Scores structures based on not containing sulphur."""

    kwargs = []

    def __init__(self):
        pass
    def __call__(self, smile):
        mol = Chem.MolFromSmiles(smile)
        if mol:
            has_sulphur = any(atom.GetAtomicNum() == 16 for atom in mol.GetAtoms())
            return float(not has_sulphur)
        return 0.0

class tanimoto():
    """Scores structures based on Tanimoto similarity to a query structure.
       Scores are only scaled up to k=(0,1), after which no more reward is given."""

    kwargs = ["k", "query_structure"]
    k = 0.7
    query_structure = "Cc1ccc(cc1)c2cc(nn2c3ccc(cc3)S(=O)(=O)N)C(F)(F)F"

    def __init__(self):
        query_mol = Chem.MolFromSmiles(self.query_structure)
        self.query_fp = AllChem.GetMorganFingerprint(query_mol, 2, useCounts=True, useFeatures=True)

    def __call__(self, smile):
        mol = Chem.MolFromSmiles(smile)
        if mol:
            fp = AllChem.GetMorganFingerprint(mol, 2, useCounts=True, useFeatures=True)
            score = DataStructs.TanimotoSimilarity(self.query_fp, fp)
            score = min(score, self.k) / self.k
            return float(score)
        return 0.0

class activity_model():
    """Scores based on an ECFP classifier for activity."""

    kwargs = ["clf_path"]
    clf_path = 'data/clf.pkl'

    def __init__(self):
        with open(self.clf_path, "rb") as f:
            self.clf = pickle.load(f)

    def __call__(self, smile):
        mol = Chem.MolFromSmiles(smile)
        if mol:
            fp = activity_model.fingerprints_from_mol(mol)
            score = self.clf.predict_proba(fp)[:, 1]
            return float(score)
        return 0.0

    @classmethod
    def fingerprints_from_mol(cls, mol):
        fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)
        size = 2048
        nfp = np.zeros((1, size), np.int32)
        for idx,v in fp.GetNonzeroElements().items():
            nidx = idx%size
            nfp[0, nidx] += int(v)
        return nfp

class Worker():
    """A worker class for the Multiprocessing functionality. Spawns a subprocess
       that is listening for input SMILES and inserts the score into the given
       index in the given list."""
    def __init__(self, scoring_function=None):
        """The score_re is a regular expression that extracts the score from the
           stdout of the subprocess. This means only scoring functions with range
           0.0-1.0 will work, for other ranges this re has to be modified."""

        self.proc = pexpect.spawn('./multiprocess.py ' + scoring_function,
                                  encoding='utf-8')

        print(self.is_alive())

    def __call__(self, smile, index, result_list):
        self.proc.sendline(smile)
        output = self.proc.expect([re.escape(smile) + " 1\.0+|[0]\.[0-9]+", 'None', pexpect.TIMEOUT])
        if output is 0:
            score = float(self.proc.after.lstrip(smile + " "))
        elif output in [1, 2]:
            score = 0.0
        result_list[index] = score

    def is_alive(self):
        return self.proc.isalive()

class Multiprocessing():
    """Class for handling multiprocessing of scoring functions. OEtoolkits cant be used with
       native multiprocessing (cant be pickled), so instead we spawn threads that create
       subprocesses."""
    def __init__(self, num_processes=None, scoring_function=None):
        self.n = num_processes
        self.workers = [Worker(scoring_function=scoring_function) for _ in range(num_processes)]

    def alive_workers(self):
        return [i for i, worker in enumerate(self.workers) if worker.is_alive()]

    def __call__(self, smiles):
        scores = [0 for _ in range(len(smiles))]
        smiles_copy = [smile for smile in smiles]
        while smiles_copy:
            alive_procs = self.alive_workers()
            if not alive_procs:
               raise RuntimeError("All subprocesses are dead, exiting.")
            # As long as we still have SMILES to score
            used_threads = []
            # Threads name corresponds to the index of the worker, so here
            # we are actually checking which workers are busy
            for t in threading.enumerate():
                # Workers have numbers as names, while the main thread cant
                # be converted to an integer
                try:
                    n = int(t.name)
                    used_threads.append(n)
                except ValueError:
                    continue
            free_threads = [i for i in alive_procs if i not in used_threads]
            for n in free_threads:
                if smiles_copy:
                    # Send SMILES and what index in the result list the score should be inserted at
                    smile = smiles_copy.pop()
                    idx = len(smiles_copy)
                    t = threading.Thread(target=self.workers[n], name=str(n), args=(smile, idx, scores))
                    t.start()
            time.sleep(0.01)
        for t in threading.enumerate():
            try:
                n = int(t.name)
                t.join()
            except ValueError:
                continue
        return np.array(scores, dtype=np.float32)

class Singleprocessing():
    """Adds an option to not spawn new processes for the scoring functions, but rather
       run them in the main process."""
    def __init__(self, scoring_function=None):
        self.scoring_function = scoring_function()
    def __call__(self, smiles):
        scores = [self.scoring_function(smile) for smile in smiles]
        return np.array(scores, dtype=np.float32)

def get_scoring_function(scoring_function, num_processes=None, **kwargs):
    """Function that initializes and returns a scoring function by name"""
    scoring_function_classes = [no_sulphur, tanimoto, activity_model]
    scoring_functions = [f.__name__ for f in scoring_function_classes]
    scoring_function_class = [f for f in scoring_function_classes if f.__name__ == scoring_function][0]

    if scoring_function not in scoring_functions:
        raise ValueError("Scoring function must be one of {}".format([f for f in scoring_functions]))

    for k, v in kwargs.items():
        if k in scoring_function_class.kwargs:
            setattr(scoring_function_class, k, v)

    if num_processes == 0:
        return Singleprocessing(scoring_function=scoring_function_class)
    return Multiprocessing(scoring_function=scoring_function, num_processes=num_processes)


================================================
FILE: train_agent.py
================================================
#!/usr/bin/env python

import torch
import pickle
import numpy as np
import time
import os
from shutil import copyfile

from model import RNN
from data_structs import Vocabulary, Experience
from scoring_functions import get_scoring_function
from utils import Variable, seq_to_smiles, fraction_valid_smiles, unique
from vizard_logger import VizardLog

def train_agent(restore_prior_from='data/Prior.ckpt',
                restore_agent_from='data/Prior.ckpt',
                scoring_function='tanimoto',
                scoring_function_kwargs=None,
                save_dir=None, learning_rate=0.0005,
                batch_size=64, n_steps=3000,
                num_processes=0, sigma=60,
                experience_replay=0):

    voc = Vocabulary(init_from_file="data/Voc")

    start_time = time.time()

    Prior = RNN(voc)
    Agent = RNN(voc)

    logger = VizardLog('data/logs')

    # By default restore Agent to same model as Prior, but can restore from already trained Agent too.
    # Saved models are partially on the GPU, but if we dont have cuda enabled we can remap these
    # to the CPU.
    if torch.cuda.is_available():
        Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt'))
        Agent.rnn.load_state_dict(torch.load(restore_agent_from))
    else:
        Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt', map_location=lambda storage, loc: storage))
        Agent.rnn.load_state_dict(torch.load(restore_agent_from, map_location=lambda storage, loc: storage))

    # We dont need gradients with respect to Prior
    for param in Prior.rnn.parameters():
        param.requires_grad = False

    optimizer = torch.optim.Adam(Agent.rnn.parameters(), lr=0.0005)

    # Scoring_function
    scoring_function = get_scoring_function(scoring_function=scoring_function, num_processes=num_processes,
                                            **scoring_function_kwargs)

    # For policy based RL, we normally train on-policy and correct for the fact that more likely actions
    # occur more often (which means the agent can get biased towards them). Using experience replay is
    # therefor not as theoretically sound as it is for value based RL, but it seems to work well.
    experience = Experience(voc)

    # Log some network weights that can be dynamically plotted with the Vizard bokeh app
    logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], "init_weight_GRU_layer_2_w_ih")
    logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], "init_weight_GRU_layer_2_w_hh")
    logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], "init_weight_GRU_embedding")
    logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), "init_weight_GRU_layer_2_b_ih")
    logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), "init_weight_GRU_layer_2_b_hh")

    # Information for the logger
    step_score = [[], []]

    print("Model initialized, starting training...")

    for step in range(n_steps):

        # Sample from Agent
        seqs, agent_likelihood, entropy = Agent.sample(batch_size)

        # Remove duplicates, ie only consider unique seqs
        unique_idxs = unique(seqs)
        seqs = seqs[unique_idxs]
        agent_likelihood = agent_likelihood[unique_idxs]
        entropy = entropy[unique_idxs]

        # Get prior likelihood and score
        prior_likelihood, _ = Prior.likelihood(Variable(seqs))
        smiles = seq_to_smiles(seqs, voc)
        score = scoring_function(smiles)

        # Calculate augmented likelihood
        augmented_likelihood = prior_likelihood + sigma * Variable(score)
        loss = torch.pow((augmented_likelihood - agent_likelihood), 2)

        # Experience Replay
        # First sample
        if experience_replay and len(experience)>4:
            exp_seqs, exp_score, exp_prior_likelihood = experience.sample(4)
            exp_agent_likelihood, exp_entropy = Agent.likelihood(exp_seqs.long())
            exp_augmented_likelihood = exp_prior_likelihood + sigma * exp_score
            exp_loss = torch.pow((Variable(exp_augmented_likelihood) - exp_agent_likelihood), 2)
            loss = torch.cat((loss, exp_loss), 0)
            agent_likelihood = torch.cat((agent_likelihood, exp_agent_likelihood), 0)

        # Then add new experience
        prior_likelihood = prior_likelihood.data.cpu().numpy()
        new_experience = zip(smiles, score, prior_likelihood)
        experience.add_experience(new_experience)

        # Calculate loss
        loss = loss.mean()

        # Add regularizer that penalizes high likelihood for the entire sequence
        loss_p = - (1 / agent_likelihood).mean()
        loss += 5 * 1e3 * loss_p

        # Calculate gradients and make an update to the network weights
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Convert to numpy arrays so that we can print them
        augmented_likelihood = augmented_likelihood.data.cpu().numpy()
        agent_likelihood = agent_likelihood.data.cpu().numpy()

        # Print some information for this step
        time_elapsed = (time.time() - start_time) / 3600
        time_left = (time_elapsed * ((n_steps - step) / (step + 1)))
        print("\n       Step {}   Fraction valid SMILES: {:4.1f}  Time elapsed: {:.2f}h Time left: {:.2f}h".format(
              step, fraction_valid_smiles(smiles) * 100, time_elapsed, time_left))
        print("  Agent    Prior   Target   Score             SMILES")
        for i in range(10):
            print(" {:6.2f}   {:6.2f}  {:6.2f}  {:6.2f}     {}".format(agent_likelihood[i],
                                                                       prior_likelihood[i],
                                                                       augmented_likelihood[i],
                                                                       score[i],
                                                                       smiles[i]))
        # Need this for Vizard plotting
        step_score[0].append(step + 1)
        step_score[1].append(np.mean(score))

        # Log some weights
        logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], "weight_GRU_layer_2_w_ih")
        logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], "weight_GRU_layer_2_w_hh")
        logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], "weight_GRU_embedding")
        logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), "weight_GRU_layer_2_b_ih")
        logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), "weight_GRU_layer_2_b_hh")
        logger.log("\n".join([smiles + "\t" + str(round(score, 2)) for smiles, score in zip \
                            (smiles[:12], score[:12])]), "SMILES", dtype="text", overwrite=True)
        logger.log(np.array(step_score), "Scores")

    # If the entire training finishes, we create a new folder where we save this python file
    # as well as some sampled sequences and the contents of the experinence (which are the highest
    # scored sequences seen during training)
    if not save_dir:
        save_dir = 'data/results/run_' + time.strftime("%Y-%m-%d-%H_%M_%S", time.localtime())
    os.makedirs(save_dir)
    copyfile('train_agent.py', os.path.join(save_dir, "train_agent.py"))

    experience.print_memory(os.path.join(save_dir, "memory"))
    torch.save(Agent.rnn.state_dict(), os.path.join(save_dir, 'Agent.ckpt'))

    seqs, agent_likelihood, entropy = Agent.sample(256)
    prior_likelihood, _ = Prior.likelihood(Variable(seqs))
    prior_likelihood = prior_likelihood.data.cpu().numpy()
    smiles = seq_to_smiles(seqs, voc)
    score = scoring_function(smiles)
    with open(os.path.join(save_dir, "sampled"), 'w') as f:
        f.write("SMILES Score PriorLogP\n")
        for smiles, score, prior_likelihood in zip(smiles, score, prior_likelihood):
            f.write("{} {:5.2f} {:6.2f}\n".format(smiles, score, prior_likelihood))

if __name__ == "__main__":
    train_agent()


================================================
FILE: train_prior.py
================================================
#!/usr/bin/env python

import torch
from torch.utils.data import DataLoader
import pickle
from rdkit import Chem
from rdkit import rdBase
from tqdm import tqdm

from data_structs import MolData, Vocabulary
from model import RNN
from utils import Variable, decrease_learning_rate
rdBase.DisableLog('rdApp.error')

def pretrain(restore_from=None):
    """Trains the Prior RNN"""

    # Read vocabulary from a file
    voc = Vocabulary(init_from_file="data/Voc")

    # Create a Dataset from a SMILES file
    moldata = MolData("data/mols_filtered.smi", voc)
    data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True,
                      collate_fn=MolData.collate_fn)

    Prior = RNN(voc)

    # Can restore from a saved RNN
    if restore_from:
        Prior.rnn.load_state_dict(torch.load(restore_from))

    optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr = 0.001)
    for epoch in range(1, 6):
        # When training on a few million compounds, this model converges
        # in a few of epochs or even faster. If model sized is increased
        # its probably a good idea to check loss against an external set of
        # validation SMILES to make sure we dont overfit too much.
        for step, batch in tqdm(enumerate(data), total=len(data)):

            # Sample from DataLoader
            seqs = batch.long()

            # Calculate loss
            log_p, _ = Prior.likelihood(seqs)
            loss = - log_p.mean()

            # Calculate gradients and take a step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Every 500 steps we decrease learning rate and print some information
            if step % 500 == 0 and step != 0:
                decrease_learning_rate(optimizer, decrease_by=0.03)
                tqdm.write("*" * 50)
                tqdm.write("Epoch {:3d}   step {:3d}    loss: {:5.2f}\n".format(epoch, step, loss.data[0]))
                seqs, likelihood, _ = Prior.sample(128)
                valid = 0
                for i, seq in enumerate(seqs.cpu().numpy()):
                    smile = voc.decode(seq)
                    if Chem.MolFromSmiles(smile):
                        valid += 1
                    if i < 5:
                        tqdm.write(smile)
                tqdm.write("\n{:>4.1f}% valid SMILES".format(100 * valid / len(seqs)))
                tqdm.write("*" * 50 + "\n")
                torch.save(Prior.rnn.state_dict(), "data/Prior.ckpt")

        # Save the Prior
        torch.save(Prior.rnn.state_dict(), "data/Prior.ckpt")

if __name__ == "__main__":
    pretrain()


================================================
FILE: utils.py
================================================
import torch
import numpy as np
from rdkit import Chem

def Variable(tensor):
    """Wrapper for torch.autograd.Variable that also accepts
       numpy arrays directly and automatically assigns it to
       the GPU. Be aware in case some operations are better
       left to the CPU."""
    if isinstance(tensor, np.ndarray):
        tensor = torch.from_numpy(tensor)
    if torch.cuda.is_available():
        return torch.autograd.Variable(tensor).cuda()
    return torch.autograd.Variable(tensor)

def decrease_learning_rate(optimizer, decrease_by=0.01):
    """Multiplies the learning rate of the optimizer by 1 - decrease_by"""
    for param_group in optimizer.param_groups:
        param_group['lr'] *= (1 - decrease_by)

def seq_to_smiles(seqs, voc):
    """Takes an output sequence from the RNN and returns the
       corresponding SMILES."""
    smiles = []
    for seq in seqs.cpu().numpy():
        smiles.append(voc.decode(seq))
    return smiles

def fraction_valid_smiles(smiles):
    """Takes a list of SMILES and returns fraction valid."""
    i = 0
    for smile in smiles:
        if Chem.MolFromSmiles(smile):
            i += 1
    return i / len(smiles)

def unique(arr):
    # Finds unique rows in arr and return their indices
    arr = arr.cpu().numpy()
    arr_ = np.ascontiguousarray(arr).view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[1])))
    _, idxs = np.unique(arr_, return_index=True)
    if torch.cuda.is_available():
        return torch.LongTensor(np.sort(idxs)).cuda()
    return torch.LongTensor(np.sort(idxs))


================================================
FILE: vizard_logger.py
================================================
import numpy as np
import os

class VizardLog():
    def __init__(self, log_dir):
        self.log_dir = log_dir
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)

        # List of variables to log
        self.logged_vars = []
        # Dict of {name_of_variable : time_since_last_logged}
        self.last_logged = {}
        # Dict of [name_of_variable : log_every}
        self.log_every = {}
        self.overwrite = {}

    def log(self, data, name, dtype="array", log_every=1, overwrite=False):
        if name not in self.logged_vars:
            self.logged_vars.append(name)
            self.last_logged[name] = 1
            self.log_every[name] = log_every
            if overwrite:
                self.overwrite[name] = 'w'
            else:
                self.overwrite[name] = 'a'

        if self.last_logged[name] == self.log_every[name]:
            out_f = os.path.join(self.log_dir, name)
            if dtype=="text":
                with open(out_f, self.overwrite[name]) as f:
                    f.write(data)
            elif dtype=="array":
                np.save(out_f, data)
            elif dtype=="hist":
                np.save(out_f, np.histogram(data, density=True, bins=50))
Download .txt
gitextract_jbrxke6h/

├── LICENSE
├── README.md
├── Vizard/
│   ├── main.py
│   ├── run.sh
│   ├── templates/
│   │   ├── index.html
│   │   └── styles.css
│   └── theme.yaml
├── data/
│   ├── ChEMBL_filtered
│   ├── Prior.ckpt
│   ├── Voc
│   └── clf.pkl
├── data_structs.py
├── main.py
├── model.py
├── multiprocess.py
├── scoring_functions.py
├── train_agent.py
├── train_prior.py
├── utils.py
└── vizard_logger.py
Download .txt
SYMBOL INDEX (77 symbols across 8 files)

FILE: Vizard/main.py
  function downsample (line 46) | def downsample(data, max_len):
  function running_average (line 52) | def running_average(data, length):
  function create_bar_plot (line 61) | def create_bar_plot(init_data, title):
  function create_hist_plot (line 71) | def create_hist_plot(init_data, title):
  function update (line 95) | def update():

FILE: data_structs.py
  class Vocabulary (line 13) | class Vocabulary(object):
    method __init__ (line 15) | def __init__(self, init_from_file=None, max_length=140):
    method encode (line 25) | def encode(self, char_list):
    method decode (line 32) | def decode(self, matrix):
    method tokenize (line 42) | def tokenize(self, smiles):
    method add_characters (line 57) | def add_characters(self, chars):
    method init_from_file (line 68) | def init_from_file(self, file):
    method __len__ (line 74) | def __len__(self):
    method __str__ (line 77) | def __str__(self):
  class MolData (line 80) | class MolData(Dataset):
    method __init__ (line 90) | def __init__(self, fname, voc):
    method __getitem__ (line 97) | def __getitem__(self, i):
    method __len__ (line 103) | def __len__(self):
    method __str__ (line 106) | def __str__(self):
    method collate_fn (line 110) | def collate_fn(cls, arr):
  class Experience (line 118) | class Experience(object):
    method __init__ (line 121) | def __init__(self, voc, max_size=100):
    method add_experience (line 126) | def add_experience(self, experience):
    method sample (line 142) | def sample(self, n):
    method initiate_from_file (line 158) | def initiate_from_file(self, fname, scoring_function, Prior):
    method print_memory (line 178) | def print_memory(self, path):
    method __len__ (line 191) | def __len__(self):
  function replace_halogen (line 194) | def replace_halogen(string):
  function tokenize (line 203) | def tokenize(smiles):
  function canonicalize_smiles_from_file (line 220) | def canonicalize_smiles_from_file(fname):
  function filter_mol (line 234) | def filter_mol(mol, max_heavy_atoms=50, min_heavy_atoms=10, element_list...
  function write_smiles_to_file (line 244) | def write_smiles_to_file(smiles_list, fname):
  function filter_on_chars (line 250) | def filter_on_chars(smiles_list, chars):
  function filter_file_on_chars (line 261) | def filter_file_on_chars(smiles_fname, voc_fname):
  function combine_voc_from_files (line 280) | def combine_voc_from_files(fnames):
  function construct_vocabulary (line 291) | def construct_vocabulary(smiles_list):

FILE: model.py
  class MultiGRU (line 10) | class MultiGRU(nn.Module):
    method __init__ (line 13) | def __init__(self, voc_size):
    method forward (line 21) | def forward(self, x, h):
    method init_h (line 30) | def init_h(self, batch_size):
  class RNN (line 34) | class RNN():
    method __init__ (line 37) | def __init__(self, voc):
    method likelihood (line 43) | def likelihood(self, target):
    method sample (line 71) | def sample(self, batch_size, max_length=140):
  function NLLLoss (line 114) | def NLLLoss(inputs, targets):

FILE: scoring_functions.py
  class no_sulphur (line 31) | class no_sulphur():
    method __init__ (line 36) | def __init__(self):
    method __call__ (line 38) | def __call__(self, smile):
  class tanimoto (line 45) | class tanimoto():
    method __init__ (line 53) | def __init__(self):
    method __call__ (line 57) | def __call__(self, smile):
  class activity_model (line 66) | class activity_model():
    method __init__ (line 72) | def __init__(self):
    method __call__ (line 76) | def __call__(self, smile):
    method fingerprints_from_mol (line 85) | def fingerprints_from_mol(cls, mol):
  class Worker (line 94) | class Worker():
    method __init__ (line 98) | def __init__(self, scoring_function=None):
    method __call__ (line 108) | def __call__(self, smile, index, result_list):
    method is_alive (line 117) | def is_alive(self):
  class Multiprocessing (line 120) | class Multiprocessing():
    method __init__ (line 124) | def __init__(self, num_processes=None, scoring_function=None):
    method alive_workers (line 128) | def alive_workers(self):
    method __call__ (line 131) | def __call__(self, smiles):
  class Singleprocessing (line 167) | class Singleprocessing():
    method __init__ (line 170) | def __init__(self, scoring_function=None):
    method __call__ (line 172) | def __call__(self, smiles):
  function get_scoring_function (line 176) | def get_scoring_function(scoring_function, num_processes=None, **kwargs):

FILE: train_agent.py
  function train_agent (line 16) | def train_agent(restore_prior_from='data/Prior.ckpt',

FILE: train_prior.py
  function pretrain (line 15) | def pretrain(restore_from=None):

FILE: utils.py
  function Variable (line 5) | def Variable(tensor):
  function decrease_learning_rate (line 16) | def decrease_learning_rate(optimizer, decrease_by=0.01):
  function seq_to_smiles (line 21) | def seq_to_smiles(seqs, voc):
  function fraction_valid_smiles (line 29) | def fraction_valid_smiles(smiles):
  function unique (line 37) | def unique(arr):

FILE: vizard_logger.py
  class VizardLog (line 4) | class VizardLog():
    method __init__ (line 5) | def __init__(self, log_dir):
    method log (line 18) | def log(self, data, name, dtype="array", log_every=1, overwrite=False):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (55K chars).
[
  {
    "path": "LICENSE",
    "chars": 1061,
    "preview": "Copyright <2017> <Marcus Olivecrona>\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of th"
  },
  {
    "path": "README.md",
    "chars": 2532,
    "preview": "\n# REINVENT\n## Molecular De Novo design using Recurrent Neural Networks and Reinforcement Learning\n\nSearching chemical s"
  },
  {
    "path": "Vizard/main.py",
    "chars": 4811,
    "preview": "from bokeh.plotting import figure, ColumnDataSource, curdoc\nfrom bokeh.models import CustomJS, Range1d\nfrom bokeh.models"
  },
  {
    "path": "Vizard/run.sh",
    "chars": 161,
    "preview": "#!/bin/bash\nif [ -z \"$1\" ];\n    then echo \"Must supply path to a directory where vizard_logger is saving its information"
  },
  {
    "path": "Vizard/templates/index.html",
    "chars": 387,
    "preview": "<!DOCTYPE html>\n<html lang=\"en\">\n    <head>\n        {{ bokeh_css }}\n        {{ bokeh_js }}\n        <style>\n             "
  },
  {
    "path": "Vizard/templates/styles.css",
    "chars": 951,
    "preview": "html {\n    background-color: #2F2F2F;\n    display: table;\n    margin: auto;\n}\n\nbody {\n    display: table-cell;\n    verti"
  },
  {
    "path": "Vizard/theme.yaml",
    "chars": 665,
    "preview": "attrs:\n    Figure:\n        background_fill_color: '#2F2F2F'\n        border_fill_color: '#2F2F2F'\n        outline_line_co"
  },
  {
    "path": "data/Voc",
    "chars": 165,
    "preview": "[S-]\n9\n(\nS\nc\n[NH+]\n3\n[CH]\no\n[NH3+]\n[nH]\n7\n6\n[N]\n1\nO\n%\n[N-]\n5\n-\n[O+]\n[n+]\n[o+]\n[nH+]\n[NH2+]\n[N+]\n[O-]\n[S+]\nR\nF\n[n-]\n[s+]\n"
  },
  {
    "path": "data_structs.py",
    "chars": 12006,
    "preview": "import numpy as np\nimport random\nimport re\nimport pickle\nfrom rdkit import Chem\nimport sys\nimport time\nimport torch\nfrom"
  },
  {
    "path": "main.py",
    "chars": 3194,
    "preview": "#!/usr/bin/env python\nimport argparse\nimport time\nimport os\nfrom train_agent import train_agent\n\n\nparser = argparse.Argu"
  },
  {
    "path": "model.py",
    "chars": 4805,
    "preview": "#!/usr/bin/env python\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils"
  },
  {
    "path": "multiprocess.py",
    "chars": 379,
    "preview": "#!/usr/bin/env python\n\nimport importlib\nimport sys\n\nscoring_function = sys.argv[1]\nfunc = getattr(importlib.import_modul"
  },
  {
    "path": "scoring_functions.py",
    "chars": 7898,
    "preview": "#!/usr/bin/env python\nfrom __future__ import print_function, division\nimport numpy as np\nfrom rdkit import Chem\nfrom rdk"
  },
  {
    "path": "train_agent.py",
    "chars": 7944,
    "preview": "#!/usr/bin/env python\n\nimport torch\nimport pickle\nimport numpy as np\nimport time\nimport os\nfrom shutil import copyfile\n\n"
  },
  {
    "path": "train_prior.py",
    "chars": 2622,
    "preview": "#!/usr/bin/env python\n\nimport torch\nfrom torch.utils.data import DataLoader\nimport pickle\nfrom rdkit import Chem\nfrom rd"
  },
  {
    "path": "utils.py",
    "chars": 1554,
    "preview": "import torch\nimport numpy as np\nfrom rdkit import Chem\n\ndef Variable(tensor):\n    \"\"\"Wrapper for torch.autograd.Variable"
  },
  {
    "path": "vizard_logger.py",
    "chars": 1237,
    "preview": "import numpy as np\nimport os\n\nclass VizardLog():\n    def __init__(self, log_dir):\n        self.log_dir = log_dir\n       "
  }
]

// ... and 3 more files (download for full content)

About this extraction

This page contains the full source code of the MarcusOlivecrona/REINVENT GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (103.0 MB), approximately 13.0k tokens, and a symbol index with 77 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!