[
  {
    "path": "LICENSE",
    "content": "Copyright <2017> <Marcus Olivecrona>\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "\n# REINVENT\n## Molecular De Novo design using Recurrent Neural Networks and Reinforcement Learning\n\nSearching chemical space as described in:\n\n[Molecular De Novo Design through Deep Reinforcement Learning](https://arxiv.org/abs/1704.07555)\n\n![Video demonstrating an Agent trained to generate analogues to Celecoxib](https://github.com/MarcusOlivecrona/REINVENT/blob/master/images/celecoxib_analogues.gif \"Training an Agent to generate analogues of Celecoxib\")\n\n\n## Notes\nThe current version is a PyTorch implementation that differs in several ways from the original implementation described in the paper. This version works better in most situations and is better documented, but for the purpose of reproducing results from the paper refer to [Release v1.0.1](https://github.com/MarcusOlivecrona/REINVENT/releases/tag/v1.0.1)\n\nDifferences from implmentation in the paper:\n* Written in PyTorch/Python3.6 rather than TF/Python2.7\n* SMILES are encoded with token index rather than as a onehot of the index. An embedding matrix is then used to transform the token index to a feature vector.\n* Scores are in the range (0,1).\n* A regularizer that penalizes high values of total episodic likelihood is included.\n* Sequences are only considered once, ie if the same sequence is generated twice in a batch only the first instance contributes to the loss.\n* These changes makes the algorithm more robust towards local minima, means much higher values of sigma can be used if needed.\n\n## Requirements\n\nThis package requires:\n* Python 3.6\n* PyTorch 0.1.12 \n* [RDkit](http://www.rdkit.org/docs/Install.html)\n* Scikit-Learn (for QSAR scoring function)\n* tqdm (for training Prior)\n* pexpect\n\n## Usage\n\nTo train a Prior starting with a SMILES file called mols.smi:\n\n* First filter the SMILES and construct a vocabulary from the remaining sequences. `./data_structs.py mols.smi`   - Will generate data/mols_filtered.smi and data/Voc. A filtered file containing around 1.1 million SMILES and the corresponding Voc is contained in \"data\".\n\n* Then use `./train_prior.py` to train the Prior. A pretrained Prior is included.\n\nTo train an Agent using our Prior, use the main.py script. For example:\n\n* `./main.py --scoring-function activity_model --num-steps 1000`\n\nTraining can be visualized using the Vizard bokeh app. The vizard_logger.py is used to log information (by default to data/logs) such as structures generated, average score, and network weights.\n\n* `cd Vizard`\n* `./run.sh ../data/logs`\n* Open the browser at http://localhost:5006/Vizard\n\n\n"
  },
  {
    "path": "Vizard/main.py",
    "content": "from bokeh.plotting import figure, ColumnDataSource, curdoc\nfrom bokeh.models import CustomJS, Range1d\nfrom bokeh.models.glyphs import Text\nfrom bokeh.layouts import row, column, widgetbox, layout\nfrom bokeh.models.widgets import Div\nimport bokeh.palettes\nfrom rdkit import Chem\nfrom rdkit.Chem import Draw\nfrom rdkit import rdBase\n\nimport sys\nimport os.path\nimport numpy as np\nimport math\n\n\"\"\"Bokeh app that visualizes training progress for the De Novo design reinforcement learning.\n   The app is updated dynamically using information that the train_agent.py script writes to a\n   logging directory.\"\"\"\n\nrdBase.DisableLog('rdApp.error')\n\nerror_msg = \"\"\"Need to provide valid log directory as first argument.\n                     'bokeh serve . --args [log_dir]'\"\"\"\ntry:\n    path = sys.argv[1]\nexcept IndexError:\n    raise IndexError(error_msg)\nif not os.path.isdir(path):\n    raise ValueError(error_msg)\n\nscore_source = ColumnDataSource(data=dict(x=[], y=[], y_mean=[]))\nscore_fig = figure(title=\"Scores\", plot_width=600, plot_height=600)\nscore_fig.line('x', 'y', legend='Average score', source=score_source)\nscore_fig.line('x', 'y_mean', legend='Running average of average score', line_width=2, \n               color=\"firebrick\", source=score_source)\n\nscore_fig.xaxis.axis_label = \"Step\"\nscore_fig.yaxis.axis_label = \"Average Score\"\nscore_fig.title.text_font_size = \"20pt\"\nscore_fig.legend.location = \"bottom_right\"\nscore_fig.css_classes = [\"score_fig\"]\n\nimg_fig = Div(text=\"\", width=850, height=590)\nimg_fig.css_classes = [\"img_outside\"]\n\ndef downsample(data, max_len):\n    np.random.seed(0)\n    if len(data)>max_len:\n        data = np.random.choice(data, size=max_len, replace=False)\n    return data\n\ndef running_average(data, length):\n    early_cumsum = np.cumsum(data[:length]) / np.arange(1, min(len(data), length) + 1)\n    if len(data)>length:\n        cumsum = np.cumsum(data) \n        cumsum =  (cumsum[length:] - cumsum[:-length]) / length\n        cumsum = np.concatenate((early_cumsum, cumsum))\n        return cumsum\n    return early_cumsum\n\ndef create_bar_plot(init_data, title):\n    init_data = downsample(init_data, 50)\n    x = range(len(init_data))\n    source = ColumnDataSource(data=dict(x= [], y=[]))\n    fig = figure(title=title, plot_width=300, plot_height=300)\n    fig.vbar(x=x, width=1, top=init_data, fill_alpha=0.05)\n    fig.vbar('x', width=1, top='y', fill_alpha=0.3, source=source)\n    fig.y_range = Range1d(min(0, 1.2 * min(init_data)), 1.2 * max(init_data))\n    return fig, source\n\ndef create_hist_plot(init_data, title):\n    source = ColumnDataSource(data=dict(hist=[], left_edge=[], right_edge=[]))\n    init_hist, init_edge = np.histogram(init_data, density=True, bins=50)\n    fig = figure(title=title, plot_width=300, plot_height=300)\n    fig.quad(top=init_hist, bottom=0, left=init_edge[:-1], right=init_edge[1:],\n            fill_alpha=0.05)\n    fig.quad(top='hist', bottom=0, left='left_edge', right='right_edge',\n            fill_alpha=0.3, source=source)\n    return fig, source\n\n\nweights = [f for f in os.listdir(path) if f.startswith(\"weight\")]\nweights = {w:{'init_weight': np.load(os.path.join(path, \"init_\" + w)).reshape(-1)} for w in weights}\n\nfor name, w in weights.items():\n    w['bar_fig'], w['bar_source'] = create_bar_plot(w['init_weight'], name)\n    w['hist_fig'], w['hist_source'] = create_hist_plot(w['init_weight'], name + \"_histogram\")\n\nbar_plots = [w['bar_fig'] for name, w in weights.items()]\nhist_plots = [w['hist_fig'] for name, w in weights.items()]\n\nlayout = layout([[img_fig, score_fig], bar_plots, hist_plots], sizing_mode=\"fixed\")\ncurdoc().add_root(layout)\n\ndef update():\n    score = np.load(os.path.join(path, \"Scores.npy\"))\n    with open(os.path.join(path, \"SMILES\"), \"r\") as f:\n        mols = []\n        scores = []\n        for line in f:\n                line = line.split()\n                mol = Chem.MolFromSmiles(line[0])\n                if mol and len(mols)<6:\n                    mols.append(mol)\n                    scores.append(line[1])\n    img = Draw.MolsToGridImage(mols, molsPerRow=3, legends=scores, subImgSize=(250,250), useSVG=True)\n    img = img.replace(\"FFFFFF\", \"EDEDED\")\n    img_fig.text = '<h2>Generated Molecules</h2>' + '<div class=\"img_inside\">' + img + '</div>'\n    score_source.data = dict(x=score[0], y=score[1], y_mean=running_average(score[1], 50))\n\n    for name, w in weights.items():\n        current_weights = np.load(os.path.join(path, name)).reshape(-1)\n        hist, edge = np.histogram(current_weights, density=True, bins=50)\n        w['hist_source'].data = dict(hist=hist, left_edge=edge[:-1], right_edge=edge[1:])\n        current_weights = downsample(current_weights, 50)\n        w['bar_source'].data = dict(x=range(len(current_weights)), y=current_weights)\n\nupdate()\ncurdoc().add_periodic_callback(update, 1000)\n\n"
  },
  {
    "path": "Vizard/run.sh",
    "content": "#!/bin/bash\nif [ -z \"$1\" ];\n    then echo \"Must supply path to a directory where vizard_logger is saving its information\";\n    exit 0\nfi\nbokeh serve . --args $1\n"
  },
  {
    "path": "Vizard/templates/index.html",
    "content": "<!DOCTYPE html>\n<html lang=\"en\">\n    <head>\n        {{ bokeh_css }}\n        {{ bokeh_js }}\n        <style>\n             {% include 'styles.css' %}\n        </style>\n        <meta charset=\"utf-8\">\n        <title>MolExplorer</title>\n    </head>\n    <body>\n    <div>\n        <h1>Vizard</h1>\n        {{ plot_div|indent(8) }}\n    </div>\n        {{ plot_script|indent(8) }}\n    </body>\n</html>\n"
  },
  {
    "path": "Vizard/templates/styles.css",
    "content": "html {\n    background-color: #2F2F2F;\n    display: table;\n    margin: auto;\n}\n\nbody {\n    display: table-cell;\n    vertical-align: middle;\n    color: #fff;\n}\n\n.img_outside {\n    position: relative;\n}\n\n.img_inside {\n    background-color: #EDEDED;\n    border: 7px solid #656565;\n    position:absolute;\n    left:50% ;\n    margin-left: -375px;\n    top:50% ;\n    margin-top: -250px;\n}\n\n.score_fig{\n    position: absolute;\n    top: 10px;\n}\n\nh1 {\n    margin: 0.5em 0 0.5em 0;\n    color: #fff;\n    font-family: 'Julius Sans One', sans-serif;\n    font-size: 3em;\n    text-transform: uppercase;\n    text-align: center;\n}\n\nh2 {\n    margin: 0 0 0 0;\n    color: #fff;\n    font-size: 20pt;\n    text-align: center;\n}\n\na:link {\n    font-weight: bold;\n    text-decoration: none;\n    color: #0d8ba1;\n}\na:visited {\n    font-weight: bold;\n    text-decoration: none;\n    color: #1a5952;\n}\na:hover, a:focus, a:active {\n    text-decoration: underline;\n    color: #9685BA;\n}\n"
  },
  {
    "path": "Vizard/theme.yaml",
    "content": "attrs:\n    Figure:\n        background_fill_color: '#2F2F2F'\n        border_fill_color: '#2F2F2F'\n        outline_line_color: '#444444'\n        min_border_top: 0\n    Axis:\n        axis_line_color: \"#FFFFFF\"\n        axis_label_text_color: \"#FFFFFF\"\n        axis_label_text_font_size: \"10pt\"\n        axis_label_text_font_style: \"normal\"\n        axis_label_standoff: 10\n        major_label_text_color: \"#FFFFFF\"\n        major_tick_line_color: \"#FFFFFF\"\n        minor_tick_line_color: \"#FFFFFF\"\n        minor_tick_line_color: \"#FFFFFF\"\n    Grid:\n        grid_line_dash: [6, 4]\n        grid_line_alpha: .3\n    Title:\n        text_color: \"#FFFFFF\"\n        align: \"center\"\n"
  },
  {
    "path": "data/Voc",
    "content": "[S-]\n9\n(\nS\nc\n[NH+]\n3\n[CH]\no\n[NH3+]\n[nH]\n7\n6\n[N]\n1\nO\n%\n[N-]\n5\n-\n[O+]\n[n+]\n[o+]\n[nH+]\n[NH2+]\n[N+]\n[O-]\n[S+]\nR\nF\n[n-]\n[s+]\nL\ns\n8\n4\n[SH]\n2\n=\nn\n)\n[O]\nN\n#\n[NH-]\nC\n[SH+]\n0\n"
  },
  {
    "path": "data_structs.py",
    "content": "import numpy as np\nimport random\nimport re\nimport pickle\nfrom rdkit import Chem\nimport sys\nimport time\nimport torch\nfrom torch.utils.data import Dataset\n\nfrom utils import Variable\n\nclass Vocabulary(object):\n    \"\"\"A class for handling encoding/decoding from SMILES to an array of indices\"\"\"\n    def __init__(self, init_from_file=None, max_length=140):\n        self.special_tokens = ['EOS', 'GO']\n        self.additional_chars = set()\n        self.chars = self.special_tokens\n        self.vocab_size = len(self.chars)\n        self.vocab = dict(zip(self.chars, range(len(self.chars))))\n        self.reversed_vocab = {v: k for k, v in self.vocab.items()}\n        self.max_length = max_length\n        if init_from_file: self.init_from_file(init_from_file)\n\n    def encode(self, char_list):\n        \"\"\"Takes a list of characters (eg '[NH]') and encodes to array of indices\"\"\"\n        smiles_matrix = np.zeros(len(char_list), dtype=np.float32)\n        for i, char in enumerate(char_list):\n            smiles_matrix[i] = self.vocab[char]\n        return smiles_matrix\n\n    def decode(self, matrix):\n        \"\"\"Takes an array of indices and returns the corresponding SMILES\"\"\"\n        chars = []\n        for i in matrix:\n            if i == self.vocab['EOS']: break\n            chars.append(self.reversed_vocab[i])\n        smiles = \"\".join(chars)\n        smiles = smiles.replace(\"L\", \"Cl\").replace(\"R\", \"Br\")\n        return smiles\n\n    def tokenize(self, smiles):\n        \"\"\"Takes a SMILES and return a list of characters/tokens\"\"\"\n        regex = '(\\[[^\\[\\]]{1,6}\\])'\n        smiles = replace_halogen(smiles)\n        char_list = re.split(regex, smiles)\n        tokenized = []\n        for char in char_list:\n            if char.startswith('['):\n                tokenized.append(char)\n            else:\n                chars = [unit for unit in char]\n                [tokenized.append(unit) for unit in chars]\n        tokenized.append('EOS')\n        return tokenized\n\n    def add_characters(self, chars):\n        \"\"\"Adds characters to the vocabulary\"\"\"\n        for char in chars:\n            self.additional_chars.add(char)\n        char_list = list(self.additional_chars)\n        char_list.sort()\n        self.chars = char_list + self.special_tokens\n        self.vocab_size = len(self.chars)\n        self.vocab = dict(zip(self.chars, range(len(self.chars))))\n        self.reversed_vocab = {v: k for k, v in self.vocab.items()}\n\n    def init_from_file(self, file):\n        \"\"\"Takes a file containing \\n separated characters to initialize the vocabulary\"\"\"\n        with open(file, 'r') as f:\n            chars = f.read().split()\n        self.add_characters(chars)\n\n    def __len__(self):\n        return len(self.chars)\n\n    def __str__(self):\n        return \"Vocabulary containing {} tokens: {}\".format(len(self), self.chars)\n\nclass MolData(Dataset):\n    \"\"\"Custom PyTorch Dataset that takes a file containing SMILES.\n\n        Args:\n                fname : path to a file containing \\n separated SMILES.\n                voc   : a Vocabulary instance\n\n        Returns:\n                A custom PyTorch dataset for training the Prior.\n    \"\"\"\n    def __init__(self, fname, voc):\n        self.voc = voc\n        self.smiles = []\n        with open(fname, 'r') as f:\n            for line in f:\n                self.smiles.append(line.split()[0])\n\n    def __getitem__(self, i):\n        mol = self.smiles[i]\n        tokenized = self.voc.tokenize(mol)\n        encoded = self.voc.encode(tokenized)\n        return Variable(encoded)\n\n    def __len__(self):\n        return len(self.smiles)\n\n    def __str__(self):\n        return \"Dataset containing {} structures.\".format(len(self))\n\n    @classmethod\n    def collate_fn(cls, arr):\n        \"\"\"Function to take a list of encoded sequences and turn them into a batch\"\"\"\n        max_length = max([seq.size(0) for seq in arr])\n        collated_arr = Variable(torch.zeros(len(arr), max_length))\n        for i, seq in enumerate(arr):\n            collated_arr[i, :seq.size(0)] = seq\n        return collated_arr\n\nclass Experience(object):\n    \"\"\"Class for prioritized experience replay that remembers the highest scored sequences\n       seen and samples from them with probabilities relative to their scores.\"\"\"\n    def __init__(self, voc, max_size=100):\n        self.memory = []\n        self.max_size = max_size\n        self.voc = voc\n\n    def add_experience(self, experience):\n        \"\"\"Experience should be a list of (smiles, score, prior likelihood) tuples\"\"\"\n        self.memory.extend(experience)\n        if len(self.memory)>self.max_size:\n            # Remove duplicates\n            idxs, smiles = [], []\n            for i, exp in enumerate(self.memory):\n                if exp[0] not in smiles:\n                    idxs.append(i)\n                    smiles.append(exp[0])\n            self.memory = [self.memory[idx] for idx in idxs]\n            # Retain highest scores\n            self.memory.sort(key = lambda x: x[1], reverse=True)\n            self.memory = self.memory[:self.max_size]\n            print(\"\\nBest score in memory: {:.2f}\".format(self.memory[0][1]))\n\n    def sample(self, n):\n        \"\"\"Sample a batch size n of experience\"\"\"\n        if len(self.memory)<n:\n            raise IndexError('Size of memory ({}) is less than requested sample ({})'.format(len(self), n))\n        else:\n            scores = [x[1] for x in self.memory]\n            sample = np.random.choice(len(self), size=n, replace=False, p=scores/np.sum(scores))\n            sample = [self.memory[i] for i in sample]\n            smiles = [x[0] for x in sample]\n            scores = [x[1] for x in sample]\n            prior_likelihood = [x[2] for x in sample]\n        tokenized = [self.voc.tokenize(smile) for smile in smiles]\n        encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]\n        encoded = MolData.collate_fn(encoded)\n        return encoded, np.array(scores), np.array(prior_likelihood)\n\n    def initiate_from_file(self, fname, scoring_function, Prior):\n        \"\"\"Adds experience from a file with SMILES\n           Needs a scoring function and an RNN to score the sequences.\n           Using this feature means that the learning can be very biased\n           and is typically advised against.\"\"\"\n        with open(fname, 'r') as f:\n            smiles = []\n            for line in f:\n                smile = line.split()[0]\n                if Chem.MolFromSmiles(smile):\n                    smiles.append(smile)\n        scores = scoring_function(smiles)\n        tokenized = [self.voc.tokenize(smile) for smile in smiles]\n        encoded = [Variable(self.voc.encode(tokenized_i)) for tokenized_i in tokenized]\n        encoded = MolData.collate_fn(encoded)\n        prior_likelihood, _ = Prior.likelihood(encoded.long())\n        prior_likelihood = prior_likelihood.data.cpu().numpy()\n        new_experience = zip(smiles, scores, prior_likelihood)\n        self.add_experience(new_experience)\n\n    def print_memory(self, path):\n        \"\"\"Prints the memory.\"\"\"\n        print(\"\\n\" + \"*\" * 80 + \"\\n\")\n        print(\"         Best recorded SMILES: \\n\")\n        print(\"Score     Prior log P     SMILES\\n\")\n        with open(path, 'w') as f:\n            f.write(\"SMILES Score PriorLogP\\n\")\n            for i, exp in enumerate(self.memory[:100]):\n                if i < 50:\n                    print(\"{:4.2f}   {:6.2f}        {}\".format(exp[1], exp[2], exp[0]))\n                    f.write(\"{} {:4.2f} {:6.2f}\\n\".format(*exp))\n        print(\"\\n\" + \"*\" * 80 + \"\\n\")\n\n    def __len__(self):\n        return len(self.memory)\n\ndef replace_halogen(string):\n    \"\"\"Regex to replace Br and Cl with single letters\"\"\"\n    br = re.compile('Br')\n    cl = re.compile('Cl')\n    string = br.sub('R', string)\n    string = cl.sub('L', string)\n\n    return string\n\ndef tokenize(smiles):\n    \"\"\"Takes a SMILES string and returns a list of tokens.\n    This will swap 'Cl' and 'Br' to 'L' and 'R' and treat\n    '[xx]' as one token.\"\"\"\n    regex = '(\\[[^\\[\\]]{1,6}\\])'\n    smiles = replace_halogen(smiles)\n    char_list = re.split(regex, smiles)\n    tokenized = []\n    for char in char_list:\n        if char.startswith('['):\n            tokenized.append(char)\n        else:\n            chars = [unit for unit in char]\n            [tokenized.append(unit) for unit in chars]\n    tokenized.append('EOS')\n    return tokenized\n\ndef canonicalize_smiles_from_file(fname):\n    \"\"\"Reads a SMILES file and returns a list of RDKIT SMILES\"\"\"\n    with open(fname, 'r') as f:\n        smiles_list = []\n        for i, line in enumerate(f):\n            if i % 100000 == 0:\n                print(\"{} lines processed.\".format(i))\n            smiles = line.split(\" \")[0]\n            mol = Chem.MolFromSmiles(smiles)\n            if filter_mol(mol):\n                smiles_list.append(Chem.MolToSmiles(mol))\n        print(\"{} SMILES retrieved\".format(len(smiles_list)))\n        return smiles_list\n\ndef filter_mol(mol, max_heavy_atoms=50, min_heavy_atoms=10, element_list=[6,7,8,9,16,17,35]):\n    \"\"\"Filters molecules on number of heavy atoms and atom types\"\"\"\n    if mol is not None:\n        num_heavy = min_heavy_atoms<mol.GetNumHeavyAtoms()<max_heavy_atoms\n        elements = all([atom.GetAtomicNum() in element_list for atom in mol.GetAtoms()])\n        if num_heavy and elements:\n            return True\n        else:\n            return False\n\ndef write_smiles_to_file(smiles_list, fname):\n    \"\"\"Write a list of SMILES to a file.\"\"\"\n    with open(fname, 'w') as f:\n        for smiles in smiles_list:\n            f.write(smiles + \"\\n\")\n\ndef filter_on_chars(smiles_list, chars):\n    \"\"\"Filters SMILES on the characters they contain.\n       Used to remove SMILES containing very rare/undesirable\n       characters.\"\"\"\n    smiles_list_valid = []\n    for smiles in smiles_list:\n        tokenized = tokenize(smiles)\n        if all([char in chars for char in tokenized][:-1]):\n            smiles_list_valid.append(smiles)\n    return smiles_list_valid\n\ndef filter_file_on_chars(smiles_fname, voc_fname):\n    \"\"\"Filters a SMILES file using a vocabulary file.\n       Only SMILES containing nothing but the characters\n       in the vocabulary will be retained.\"\"\"\n    smiles = []\n    with open(smiles_fname, 'r') as f:\n        for line in f:\n            smiles.append(line.split()[0])\n    print(smiles[:10])\n    chars = []\n    with open(voc_fname, 'r') as f:\n        for line in f:\n            chars.append(line.split()[0])\n    print(chars)\n    valid_smiles = filter_on_chars(smiles, chars)\n    with open(smiles_fname + \"_filtered\", 'w') as f:\n        for smiles in valid_smiles:\n            f.write(smiles + \"\\n\")\n\ndef combine_voc_from_files(fnames):\n    \"\"\"Combine two vocabularies\"\"\"\n    chars = set()\n    for fname in fnames:\n        with open(fname, 'r') as f:\n            for line in f:\n                chars.add(line.split()[0])\n    with open(\"_\".join(fnames) + '_combined', 'w') as f:\n        for char in chars:\n            f.write(char + \"\\n\")\n\ndef construct_vocabulary(smiles_list):\n    \"\"\"Returns all the characters present in a SMILES file.\n       Uses regex to find characters/tokens of the format '[x]'.\"\"\"\n    add_chars = set()\n    for i, smiles in enumerate(smiles_list):\n        regex = '(\\[[^\\[\\]]{1,6}\\])'\n        smiles = replace_halogen(smiles)\n        char_list = re.split(regex, smiles)\n        for char in char_list:\n            if char.startswith('['):\n                add_chars.add(char)\n            else:\n                chars = [unit for unit in char]\n                [add_chars.add(unit) for unit in chars]\n\n    print(\"Number of characters: {}\".format(len(add_chars)))\n    with open('data/Voc', 'w') as f:\n        for char in add_chars:\n            f.write(char + \"\\n\")\n    return add_chars\n\nif __name__ == \"__main__\":\n    smiles_file = sys.argv[1]\n    print(\"Reading smiles...\")\n    smiles_list = canonicalize_smiles_from_file(smiles_file)\n    print(\"Constructing vocabulary...\")\n    voc_chars = construct_vocabulary(smiles_list)\n    write_smiles_to_file(smiles_list, \"data/mols_filtered.smi\")\n"
  },
  {
    "path": "main.py",
    "content": "#!/usr/bin/env python\nimport argparse\nimport time\nimport os\nfrom train_agent import train_agent\n\n\nparser = argparse.ArgumentParser(description=\"Main script for running the model\")\nparser.add_argument('--scoring-function', action='store', dest='scoring_function',\n                    choices=['activity_model', 'tanimoto', 'no_sulphur'],\n                    default='tanimoto',\n                    help='What type of scoring function to use.')\nparser.add_argument('--scoring-function-kwargs', action='store', dest='scoring_function_kwargs',\n                    nargs=\"*\",\n                    help='Additional arguments for the scoring function. Should be supplied with a '\\\n                    'list of \"keyword_name argument\". For pharmacophoric and tanimoto '\\\n                    'the keyword is \"query_structure\" and requires a SMILES. ' \\\n                    'For activity_model it is \"clf_path \" '\\\n                    'pointing to a sklearn classifier. '\\\n                    'For example: \"--scoring-function-kwargs query_structure COc1ccccc1\".')\nparser.add_argument('--learning-rate', action='store', dest='learning_rate',\n                    type=float, default=0.0005)\nparser.add_argument('--num-steps', action='store', dest='n_steps', type=int,\n                    default=3000)\nparser.add_argument('--batch-size', action='store', dest='batch_size', type=int,\n                    default=64)\nparser.add_argument('--sigma', action='store', dest='sigma', type=int,\n                    default=20)\nparser.add_argument('--experience', action='store', dest='experience_replay', type=int,\n                    default=0, help='Number of experience sequences to sample each step. '\\\n                    '0 means no experience replay.')\nparser.add_argument('--num-processes', action='store', dest='num_processes',\n                    type=int, default=0,\n                    help='Number of processes used to run the scoring function. \"0\" means ' \\\n                    'that the scoring function will be run in the main process.')\nparser.add_argument('--prior', action='store', dest='restore_prior_from',\n                    default='data/Prior.ckpt',\n                    help='Path to an RNN checkpoint file to use as a Prior')\nparser.add_argument('--agent', action='store', dest='restore_agent_from',\n                    default='data/Prior.ckpt',\n                    help='Path to an RNN checkpoint file to use as a Agent.')\nparser.add_argument('--save-dir', action='store', dest='save_dir',\n                    help='Path where results and model are saved. Default is data/results/run_<datetime>.')\n\nif __name__ == \"__main__\":\n\n    arg_dict = vars(parser.parse_args())\n\n    if arg_dict['scoring_function_kwargs']:\n        kwarg_list = arg_dict.pop('scoring_function_kwargs')\n        if not len(kwarg_list) % 2 == 0:\n            raise ValueError(\"Scoring function kwargs must be given as pairs, \"\\\n                             \"but got a list with odd length.\")\n        kwarg_dict = {i:j for i, j in zip(kwarg_list[::2], kwarg_list[1::2])}\n        arg_dict['scoring_function_kwargs'] = kwarg_dict\n    else:\n        arg_dict['scoring_function_kwargs'] = dict()\n\n    train_agent(**arg_dict)\n"
  },
  {
    "path": "model.py",
    "content": "#!/usr/bin/env python\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom utils import Variable\n\nclass MultiGRU(nn.Module):\n    \"\"\" Implements a three layer GRU cell including an embedding layer\n       and an output linear layer back to the size of the vocabulary\"\"\"\n    def __init__(self, voc_size):\n        super(MultiGRU, self).__init__()\n        self.embedding = nn.Embedding(voc_size, 128)\n        self.gru_1 = nn.GRUCell(128, 512)\n        self.gru_2 = nn.GRUCell(512, 512)\n        self.gru_3 = nn.GRUCell(512, 512)\n        self.linear = nn.Linear(512, voc_size)\n\n    def forward(self, x, h):\n        x = self.embedding(x)\n        h_out = Variable(torch.zeros(h.size()))\n        x = h_out[0] = self.gru_1(x, h[0])\n        x = h_out[1] = self.gru_2(x, h[1])\n        x = h_out[2] = self.gru_3(x, h[2])\n        x = self.linear(x)\n        return x, h_out\n\n    def init_h(self, batch_size):\n        # Initial cell state is zero\n        return Variable(torch.zeros(3, batch_size, 512))\n\nclass RNN():\n    \"\"\"Implements the Prior and Agent RNN. Needs a Vocabulary instance in\n    order to determine size of the vocabulary and index of the END token\"\"\"\n    def __init__(self, voc):\n        self.rnn = MultiGRU(voc.vocab_size)\n        if torch.cuda.is_available():\n            self.rnn.cuda()\n        self.voc = voc\n\n    def likelihood(self, target):\n        \"\"\"\n            Retrieves the likelihood of a given sequence\n\n            Args:\n                target: (batch_size * sequence_lenght) A batch of sequences\n\n            Outputs:\n                log_probs : (batch_size) Log likelihood for each example*\n                entropy: (batch_size) The entropies for the sequences. Not\n                                      currently used.\n        \"\"\"\n        batch_size, seq_length = target.size()\n        start_token = Variable(torch.zeros(batch_size, 1).long())\n        start_token[:] = self.voc.vocab['GO']\n        x = torch.cat((start_token, target[:, :-1]), 1)\n        h = self.rnn.init_h(batch_size)\n\n        log_probs = Variable(torch.zeros(batch_size))\n        entropy = Variable(torch.zeros(batch_size))\n        for step in range(seq_length):\n            logits, h = self.rnn(x[:, step], h)\n            log_prob = F.log_softmax(logits)\n            prob = F.softmax(logits)\n            log_probs += NLLLoss(log_prob, target[:, step])\n            entropy += -torch.sum((log_prob * prob), 1)\n        return log_probs, entropy\n\n    def sample(self, batch_size, max_length=140):\n        \"\"\"\n            Sample a batch of sequences\n\n            Args:\n                batch_size : Number of sequences to sample \n                max_length:  Maximum length of the sequences\n\n            Outputs:\n            seqs: (batch_size, seq_length) The sampled sequences.\n            log_probs : (batch_size) Log likelihood for each sequence.\n            entropy: (batch_size) The entropies for the sequences. Not\n                                    currently used.\n        \"\"\"\n        start_token = Variable(torch.zeros(batch_size).long())\n        start_token[:] = self.voc.vocab['GO']\n        h = self.rnn.init_h(batch_size)\n        x = start_token\n\n        sequences = []\n        log_probs = Variable(torch.zeros(batch_size))\n        finished = torch.zeros(batch_size).byte()\n        entropy = Variable(torch.zeros(batch_size))\n        if torch.cuda.is_available():\n            finished = finished.cuda()\n\n        for step in range(max_length):\n            logits, h = self.rnn(x, h)\n            prob = F.softmax(logits)\n            log_prob = F.log_softmax(logits)\n            x = torch.multinomial(prob).view(-1)\n            sequences.append(x.view(-1, 1))\n            log_probs +=  NLLLoss(log_prob, x)\n            entropy += -torch.sum((log_prob * prob), 1)\n\n            x = Variable(x.data)\n            EOS_sampled = (x == self.voc.vocab['EOS']).data\n            finished = torch.ge(finished + EOS_sampled, 1)\n            if torch.prod(finished) == 1: break\n\n        sequences = torch.cat(sequences, 1)\n        return sequences.data, log_probs, entropy\n\ndef NLLLoss(inputs, targets):\n    \"\"\"\n        Custom Negative Log Likelihood loss that returns loss per example,\n        rather than for the entire batch.\n\n        Args:\n            inputs : (batch_size, num_classes) *Log probabilities of each class*\n            targets: (batch_size) *Target class index*\n\n        Outputs:\n            loss : (batch_size) *Loss for each example*\n    \"\"\"\n\n    if torch.cuda.is_available():\n        target_expanded = torch.zeros(inputs.size()).cuda()\n    else:\n        target_expanded = torch.zeros(inputs.size())\n\n    target_expanded.scatter_(1, targets.contiguous().view(-1, 1).data, 1.0)\n    loss = Variable(target_expanded) * inputs\n    loss = torch.sum(loss, 1)\n    return loss\n"
  },
  {
    "path": "multiprocess.py",
    "content": "#!/usr/bin/env python\n\nimport importlib\nimport sys\n\nscoring_function = sys.argv[1]\nfunc = getattr(importlib.import_module(\"scoring_functions\"), scoring_function)()\n\nwhile True:\n    smile = sys.stdin.readline().rstrip()\n    try:\n        score = float(func(smile))\n    except:\n        score = 0.0\n    sys.stdout.write(\" \".join([smile, str(score), \"\\n\"]))\n    sys.stdout.flush()\n\n\n\n"
  },
  {
    "path": "scoring_functions.py",
    "content": "#!/usr/bin/env python\nfrom __future__ import print_function, division\nimport numpy as np\nfrom rdkit import Chem\nfrom rdkit import rdBase\nfrom rdkit.Chem import AllChem\nfrom rdkit import DataStructs\nfrom sklearn import svm\nimport time\nimport pickle\nimport re\nimport threading\nimport pexpect\nrdBase.DisableLog('rdApp.error')\n\n\"\"\"Scoring function should be a class where some tasks that are shared for every call\n   can be reallocated to the __init__, and has a __call__ method which takes a single SMILES of\n   argument and returns a float. A multiprocessing class will then spawn workers and divide the\n   list of SMILES given between them.\n\n   Passing *args and **kwargs through a subprocess call is slightly tricky because we need to know\n   their types - everything will be a string once we have passed it. Therefor, we instead use class\n   attributes which we can modify in place before any subprocess is created. Any **kwarg left over in\n   the call to get_scoring_function will be checked against a list of (allowed) kwargs for the class\n   and if a match is found the value of the item will be the new value for the class.\n\n   If num_processes == 0, the scoring function will be run in the main process. Depending on how\n   demanding the scoring function is and how well the OS handles the multiprocessing, this might\n   be faster than multiprocessing in some cases.\"\"\"\n\nclass no_sulphur():\n    \"\"\"Scores structures based on not containing sulphur.\"\"\"\n\n    kwargs = []\n\n    def __init__(self):\n        pass\n    def __call__(self, smile):\n        mol = Chem.MolFromSmiles(smile)\n        if mol:\n            has_sulphur = any(atom.GetAtomicNum() == 16 for atom in mol.GetAtoms())\n            return float(not has_sulphur)\n        return 0.0\n\nclass tanimoto():\n    \"\"\"Scores structures based on Tanimoto similarity to a query structure.\n       Scores are only scaled up to k=(0,1), after which no more reward is given.\"\"\"\n\n    kwargs = [\"k\", \"query_structure\"]\n    k = 0.7\n    query_structure = \"Cc1ccc(cc1)c2cc(nn2c3ccc(cc3)S(=O)(=O)N)C(F)(F)F\"\n\n    def __init__(self):\n        query_mol = Chem.MolFromSmiles(self.query_structure)\n        self.query_fp = AllChem.GetMorganFingerprint(query_mol, 2, useCounts=True, useFeatures=True)\n\n    def __call__(self, smile):\n        mol = Chem.MolFromSmiles(smile)\n        if mol:\n            fp = AllChem.GetMorganFingerprint(mol, 2, useCounts=True, useFeatures=True)\n            score = DataStructs.TanimotoSimilarity(self.query_fp, fp)\n            score = min(score, self.k) / self.k\n            return float(score)\n        return 0.0\n\nclass activity_model():\n    \"\"\"Scores based on an ECFP classifier for activity.\"\"\"\n\n    kwargs = [\"clf_path\"]\n    clf_path = 'data/clf.pkl'\n\n    def __init__(self):\n        with open(self.clf_path, \"rb\") as f:\n            self.clf = pickle.load(f)\n\n    def __call__(self, smile):\n        mol = Chem.MolFromSmiles(smile)\n        if mol:\n            fp = activity_model.fingerprints_from_mol(mol)\n            score = self.clf.predict_proba(fp)[:, 1]\n            return float(score)\n        return 0.0\n\n    @classmethod\n    def fingerprints_from_mol(cls, mol):\n        fp = AllChem.GetMorganFingerprint(mol, 3, useCounts=True, useFeatures=True)\n        size = 2048\n        nfp = np.zeros((1, size), np.int32)\n        for idx,v in fp.GetNonzeroElements().items():\n            nidx = idx%size\n            nfp[0, nidx] += int(v)\n        return nfp\n\nclass Worker():\n    \"\"\"A worker class for the Multiprocessing functionality. Spawns a subprocess\n       that is listening for input SMILES and inserts the score into the given\n       index in the given list.\"\"\"\n    def __init__(self, scoring_function=None):\n        \"\"\"The score_re is a regular expression that extracts the score from the\n           stdout of the subprocess. This means only scoring functions with range\n           0.0-1.0 will work, for other ranges this re has to be modified.\"\"\"\n\n        self.proc = pexpect.spawn('./multiprocess.py ' + scoring_function,\n                                  encoding='utf-8')\n\n        print(self.is_alive())\n\n    def __call__(self, smile, index, result_list):\n        self.proc.sendline(smile)\n        output = self.proc.expect([re.escape(smile) + \" 1\\.0+|[0]\\.[0-9]+\", 'None', pexpect.TIMEOUT])\n        if output is 0:\n            score = float(self.proc.after.lstrip(smile + \" \"))\n        elif output in [1, 2]:\n            score = 0.0\n        result_list[index] = score\n\n    def is_alive(self):\n        return self.proc.isalive()\n\nclass Multiprocessing():\n    \"\"\"Class for handling multiprocessing of scoring functions. OEtoolkits cant be used with\n       native multiprocessing (cant be pickled), so instead we spawn threads that create\n       subprocesses.\"\"\"\n    def __init__(self, num_processes=None, scoring_function=None):\n        self.n = num_processes\n        self.workers = [Worker(scoring_function=scoring_function) for _ in range(num_processes)]\n\n    def alive_workers(self):\n        return [i for i, worker in enumerate(self.workers) if worker.is_alive()]\n\n    def __call__(self, smiles):\n        scores = [0 for _ in range(len(smiles))]\n        smiles_copy = [smile for smile in smiles]\n        while smiles_copy:\n            alive_procs = self.alive_workers()\n            if not alive_procs:\n               raise RuntimeError(\"All subprocesses are dead, exiting.\")\n            # As long as we still have SMILES to score\n            used_threads = []\n            # Threads name corresponds to the index of the worker, so here\n            # we are actually checking which workers are busy\n            for t in threading.enumerate():\n                # Workers have numbers as names, while the main thread cant\n                # be converted to an integer\n                try:\n                    n = int(t.name)\n                    used_threads.append(n)\n                except ValueError:\n                    continue\n            free_threads = [i for i in alive_procs if i not in used_threads]\n            for n in free_threads:\n                if smiles_copy:\n                    # Send SMILES and what index in the result list the score should be inserted at\n                    smile = smiles_copy.pop()\n                    idx = len(smiles_copy)\n                    t = threading.Thread(target=self.workers[n], name=str(n), args=(smile, idx, scores))\n                    t.start()\n            time.sleep(0.01)\n        for t in threading.enumerate():\n            try:\n                n = int(t.name)\n                t.join()\n            except ValueError:\n                continue\n        return np.array(scores, dtype=np.float32)\n\nclass Singleprocessing():\n    \"\"\"Adds an option to not spawn new processes for the scoring functions, but rather\n       run them in the main process.\"\"\"\n    def __init__(self, scoring_function=None):\n        self.scoring_function = scoring_function()\n    def __call__(self, smiles):\n        scores = [self.scoring_function(smile) for smile in smiles]\n        return np.array(scores, dtype=np.float32)\n\ndef get_scoring_function(scoring_function, num_processes=None, **kwargs):\n    \"\"\"Function that initializes and returns a scoring function by name\"\"\"\n    scoring_function_classes = [no_sulphur, tanimoto, activity_model]\n    scoring_functions = [f.__name__ for f in scoring_function_classes]\n    scoring_function_class = [f for f in scoring_function_classes if f.__name__ == scoring_function][0]\n\n    if scoring_function not in scoring_functions:\n        raise ValueError(\"Scoring function must be one of {}\".format([f for f in scoring_functions]))\n\n    for k, v in kwargs.items():\n        if k in scoring_function_class.kwargs:\n            setattr(scoring_function_class, k, v)\n\n    if num_processes == 0:\n        return Singleprocessing(scoring_function=scoring_function_class)\n    return Multiprocessing(scoring_function=scoring_function, num_processes=num_processes)\n"
  },
  {
    "path": "train_agent.py",
    "content": "#!/usr/bin/env python\n\nimport torch\nimport pickle\nimport numpy as np\nimport time\nimport os\nfrom shutil import copyfile\n\nfrom model import RNN\nfrom data_structs import Vocabulary, Experience\nfrom scoring_functions import get_scoring_function\nfrom utils import Variable, seq_to_smiles, fraction_valid_smiles, unique\nfrom vizard_logger import VizardLog\n\ndef train_agent(restore_prior_from='data/Prior.ckpt',\n                restore_agent_from='data/Prior.ckpt',\n                scoring_function='tanimoto',\n                scoring_function_kwargs=None,\n                save_dir=None, learning_rate=0.0005,\n                batch_size=64, n_steps=3000,\n                num_processes=0, sigma=60,\n                experience_replay=0):\n\n    voc = Vocabulary(init_from_file=\"data/Voc\")\n\n    start_time = time.time()\n\n    Prior = RNN(voc)\n    Agent = RNN(voc)\n\n    logger = VizardLog('data/logs')\n\n    # By default restore Agent to same model as Prior, but can restore from already trained Agent too.\n    # Saved models are partially on the GPU, but if we dont have cuda enabled we can remap these\n    # to the CPU.\n    if torch.cuda.is_available():\n        Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt'))\n        Agent.rnn.load_state_dict(torch.load(restore_agent_from))\n    else:\n        Prior.rnn.load_state_dict(torch.load('data/Prior.ckpt', map_location=lambda storage, loc: storage))\n        Agent.rnn.load_state_dict(torch.load(restore_agent_from, map_location=lambda storage, loc: storage))\n\n    # We dont need gradients with respect to Prior\n    for param in Prior.rnn.parameters():\n        param.requires_grad = False\n\n    optimizer = torch.optim.Adam(Agent.rnn.parameters(), lr=0.0005)\n\n    # Scoring_function\n    scoring_function = get_scoring_function(scoring_function=scoring_function, num_processes=num_processes,\n                                            **scoring_function_kwargs)\n\n    # For policy based RL, we normally train on-policy and correct for the fact that more likely actions\n    # occur more often (which means the agent can get biased towards them). Using experience replay is\n    # therefor not as theoretically sound as it is for value based RL, but it seems to work well.\n    experience = Experience(voc)\n\n    # Log some network weights that can be dynamically plotted with the Vizard bokeh app\n    logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], \"init_weight_GRU_layer_2_w_ih\")\n    logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], \"init_weight_GRU_layer_2_w_hh\")\n    logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], \"init_weight_GRU_embedding\")\n    logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), \"init_weight_GRU_layer_2_b_ih\")\n    logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), \"init_weight_GRU_layer_2_b_hh\")\n\n    # Information for the logger\n    step_score = [[], []]\n\n    print(\"Model initialized, starting training...\")\n\n    for step in range(n_steps):\n\n        # Sample from Agent\n        seqs, agent_likelihood, entropy = Agent.sample(batch_size)\n\n        # Remove duplicates, ie only consider unique seqs\n        unique_idxs = unique(seqs)\n        seqs = seqs[unique_idxs]\n        agent_likelihood = agent_likelihood[unique_idxs]\n        entropy = entropy[unique_idxs]\n\n        # Get prior likelihood and score\n        prior_likelihood, _ = Prior.likelihood(Variable(seqs))\n        smiles = seq_to_smiles(seqs, voc)\n        score = scoring_function(smiles)\n\n        # Calculate augmented likelihood\n        augmented_likelihood = prior_likelihood + sigma * Variable(score)\n        loss = torch.pow((augmented_likelihood - agent_likelihood), 2)\n\n        # Experience Replay\n        # First sample\n        if experience_replay and len(experience)>4:\n            exp_seqs, exp_score, exp_prior_likelihood = experience.sample(4)\n            exp_agent_likelihood, exp_entropy = Agent.likelihood(exp_seqs.long())\n            exp_augmented_likelihood = exp_prior_likelihood + sigma * exp_score\n            exp_loss = torch.pow((Variable(exp_augmented_likelihood) - exp_agent_likelihood), 2)\n            loss = torch.cat((loss, exp_loss), 0)\n            agent_likelihood = torch.cat((agent_likelihood, exp_agent_likelihood), 0)\n\n        # Then add new experience\n        prior_likelihood = prior_likelihood.data.cpu().numpy()\n        new_experience = zip(smiles, score, prior_likelihood)\n        experience.add_experience(new_experience)\n\n        # Calculate loss\n        loss = loss.mean()\n\n        # Add regularizer that penalizes high likelihood for the entire sequence\n        loss_p = - (1 / agent_likelihood).mean()\n        loss += 5 * 1e3 * loss_p\n\n        # Calculate gradients and make an update to the network weights\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Convert to numpy arrays so that we can print them\n        augmented_likelihood = augmented_likelihood.data.cpu().numpy()\n        agent_likelihood = agent_likelihood.data.cpu().numpy()\n\n        # Print some information for this step\n        time_elapsed = (time.time() - start_time) / 3600\n        time_left = (time_elapsed * ((n_steps - step) / (step + 1)))\n        print(\"\\n       Step {}   Fraction valid SMILES: {:4.1f}  Time elapsed: {:.2f}h Time left: {:.2f}h\".format(\n              step, fraction_valid_smiles(smiles) * 100, time_elapsed, time_left))\n        print(\"  Agent    Prior   Target   Score             SMILES\")\n        for i in range(10):\n            print(\" {:6.2f}   {:6.2f}  {:6.2f}  {:6.2f}     {}\".format(agent_likelihood[i],\n                                                                       prior_likelihood[i],\n                                                                       augmented_likelihood[i],\n                                                                       score[i],\n                                                                       smiles[i]))\n        # Need this for Vizard plotting\n        step_score[0].append(step + 1)\n        step_score[1].append(np.mean(score))\n\n        # Log some weights\n        logger.log(Agent.rnn.gru_2.weight_ih.cpu().data.numpy()[::100], \"weight_GRU_layer_2_w_ih\")\n        logger.log(Agent.rnn.gru_2.weight_hh.cpu().data.numpy()[::100], \"weight_GRU_layer_2_w_hh\")\n        logger.log(Agent.rnn.embedding.weight.cpu().data.numpy()[::30], \"weight_GRU_embedding\")\n        logger.log(Agent.rnn.gru_2.bias_ih.cpu().data.numpy(), \"weight_GRU_layer_2_b_ih\")\n        logger.log(Agent.rnn.gru_2.bias_hh.cpu().data.numpy(), \"weight_GRU_layer_2_b_hh\")\n        logger.log(\"\\n\".join([smiles + \"\\t\" + str(round(score, 2)) for smiles, score in zip \\\n                            (smiles[:12], score[:12])]), \"SMILES\", dtype=\"text\", overwrite=True)\n        logger.log(np.array(step_score), \"Scores\")\n\n    # If the entire training finishes, we create a new folder where we save this python file\n    # as well as some sampled sequences and the contents of the experinence (which are the highest\n    # scored sequences seen during training)\n    if not save_dir:\n        save_dir = 'data/results/run_' + time.strftime(\"%Y-%m-%d-%H_%M_%S\", time.localtime())\n    os.makedirs(save_dir)\n    copyfile('train_agent.py', os.path.join(save_dir, \"train_agent.py\"))\n\n    experience.print_memory(os.path.join(save_dir, \"memory\"))\n    torch.save(Agent.rnn.state_dict(), os.path.join(save_dir, 'Agent.ckpt'))\n\n    seqs, agent_likelihood, entropy = Agent.sample(256)\n    prior_likelihood, _ = Prior.likelihood(Variable(seqs))\n    prior_likelihood = prior_likelihood.data.cpu().numpy()\n    smiles = seq_to_smiles(seqs, voc)\n    score = scoring_function(smiles)\n    with open(os.path.join(save_dir, \"sampled\"), 'w') as f:\n        f.write(\"SMILES Score PriorLogP\\n\")\n        for smiles, score, prior_likelihood in zip(smiles, score, prior_likelihood):\n            f.write(\"{} {:5.2f} {:6.2f}\\n\".format(smiles, score, prior_likelihood))\n\nif __name__ == \"__main__\":\n    train_agent()\n"
  },
  {
    "path": "train_prior.py",
    "content": "#!/usr/bin/env python\n\nimport torch\nfrom torch.utils.data import DataLoader\nimport pickle\nfrom rdkit import Chem\nfrom rdkit import rdBase\nfrom tqdm import tqdm\n\nfrom data_structs import MolData, Vocabulary\nfrom model import RNN\nfrom utils import Variable, decrease_learning_rate\nrdBase.DisableLog('rdApp.error')\n\ndef pretrain(restore_from=None):\n    \"\"\"Trains the Prior RNN\"\"\"\n\n    # Read vocabulary from a file\n    voc = Vocabulary(init_from_file=\"data/Voc\")\n\n    # Create a Dataset from a SMILES file\n    moldata = MolData(\"data/mols_filtered.smi\", voc)\n    data = DataLoader(moldata, batch_size=128, shuffle=True, drop_last=True,\n                      collate_fn=MolData.collate_fn)\n\n    Prior = RNN(voc)\n\n    # Can restore from a saved RNN\n    if restore_from:\n        Prior.rnn.load_state_dict(torch.load(restore_from))\n\n    optimizer = torch.optim.Adam(Prior.rnn.parameters(), lr = 0.001)\n    for epoch in range(1, 6):\n        # When training on a few million compounds, this model converges\n        # in a few of epochs or even faster. If model sized is increased\n        # its probably a good idea to check loss against an external set of\n        # validation SMILES to make sure we dont overfit too much.\n        for step, batch in tqdm(enumerate(data), total=len(data)):\n\n            # Sample from DataLoader\n            seqs = batch.long()\n\n            # Calculate loss\n            log_p, _ = Prior.likelihood(seqs)\n            loss = - log_p.mean()\n\n            # Calculate gradients and take a step\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            # Every 500 steps we decrease learning rate and print some information\n            if step % 500 == 0 and step != 0:\n                decrease_learning_rate(optimizer, decrease_by=0.03)\n                tqdm.write(\"*\" * 50)\n                tqdm.write(\"Epoch {:3d}   step {:3d}    loss: {:5.2f}\\n\".format(epoch, step, loss.data[0]))\n                seqs, likelihood, _ = Prior.sample(128)\n                valid = 0\n                for i, seq in enumerate(seqs.cpu().numpy()):\n                    smile = voc.decode(seq)\n                    if Chem.MolFromSmiles(smile):\n                        valid += 1\n                    if i < 5:\n                        tqdm.write(smile)\n                tqdm.write(\"\\n{:>4.1f}% valid SMILES\".format(100 * valid / len(seqs)))\n                tqdm.write(\"*\" * 50 + \"\\n\")\n                torch.save(Prior.rnn.state_dict(), \"data/Prior.ckpt\")\n\n        # Save the Prior\n        torch.save(Prior.rnn.state_dict(), \"data/Prior.ckpt\")\n\nif __name__ == \"__main__\":\n    pretrain()\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\nimport numpy as np\nfrom rdkit import Chem\n\ndef Variable(tensor):\n    \"\"\"Wrapper for torch.autograd.Variable that also accepts\n       numpy arrays directly and automatically assigns it to\n       the GPU. Be aware in case some operations are better\n       left to the CPU.\"\"\"\n    if isinstance(tensor, np.ndarray):\n        tensor = torch.from_numpy(tensor)\n    if torch.cuda.is_available():\n        return torch.autograd.Variable(tensor).cuda()\n    return torch.autograd.Variable(tensor)\n\ndef decrease_learning_rate(optimizer, decrease_by=0.01):\n    \"\"\"Multiplies the learning rate of the optimizer by 1 - decrease_by\"\"\"\n    for param_group in optimizer.param_groups:\n        param_group['lr'] *= (1 - decrease_by)\n\ndef seq_to_smiles(seqs, voc):\n    \"\"\"Takes an output sequence from the RNN and returns the\n       corresponding SMILES.\"\"\"\n    smiles = []\n    for seq in seqs.cpu().numpy():\n        smiles.append(voc.decode(seq))\n    return smiles\n\ndef fraction_valid_smiles(smiles):\n    \"\"\"Takes a list of SMILES and returns fraction valid.\"\"\"\n    i = 0\n    for smile in smiles:\n        if Chem.MolFromSmiles(smile):\n            i += 1\n    return i / len(smiles)\n\ndef unique(arr):\n    # Finds unique rows in arr and return their indices\n    arr = arr.cpu().numpy()\n    arr_ = np.ascontiguousarray(arr).view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[1])))\n    _, idxs = np.unique(arr_, return_index=True)\n    if torch.cuda.is_available():\n        return torch.LongTensor(np.sort(idxs)).cuda()\n    return torch.LongTensor(np.sort(idxs))\n"
  },
  {
    "path": "vizard_logger.py",
    "content": "import numpy as np\nimport os\n\nclass VizardLog():\n    def __init__(self, log_dir):\n        self.log_dir = log_dir\n        if not os.path.exists(log_dir):\n            os.makedirs(log_dir)\n\n        # List of variables to log\n        self.logged_vars = []\n        # Dict of {name_of_variable : time_since_last_logged}\n        self.last_logged = {}\n        # Dict of [name_of_variable : log_every}\n        self.log_every = {}\n        self.overwrite = {}\n\n    def log(self, data, name, dtype=\"array\", log_every=1, overwrite=False):\n        if name not in self.logged_vars:\n            self.logged_vars.append(name)\n            self.last_logged[name] = 1\n            self.log_every[name] = log_every\n            if overwrite:\n                self.overwrite[name] = 'w'\n            else:\n                self.overwrite[name] = 'a'\n\n        if self.last_logged[name] == self.log_every[name]:\n            out_f = os.path.join(self.log_dir, name)\n            if dtype==\"text\":\n                with open(out_f, self.overwrite[name]) as f:\n                    f.write(data)\n            elif dtype==\"array\":\n                np.save(out_f, data)\n            elif dtype==\"hist\":\n                np.save(out_f, np.histogram(data, density=True, bins=50))\n"
  }
]