[
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2017 Alexis David Jacq\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Pytorch-Sketch-RNN\nA pytorch implementation of https://arxiv.org/abs/1704.03477\n\nIn order to draw other things than cats, you will find more drawing data here: https://github.com/googlecreativelab/quickdraw-dataset\n\nepoch 1900:\n\n![epoch_1900](images/1900_output_.jpg)\n\nepoch 2400:\n\n![epoch_2600](images/2600_output_.jpg)\n\nepoch 3400\n\n![epoch_3400](images/3400_output_.jpg)\n\nDefault hyperparameters for training has been found here: https://github.com/tensorflow/magenta/blob/master/magenta/models/sketch_rnn/README.md\n"
  },
  {
    "path": "sketch_rnn.py",
    "content": "import numpy as np\nimport matplotlib.pyplot as plt\nimport PIL\n\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nimport torch.nn.functional as F\n\nuse_cuda = torch.cuda.is_available()\n\n###################################### hyperparameters\nclass HParams():\n    def __init__(self):\n        self.data_location = 'cat.npz'\n        self.enc_hidden_size = 256\n        self.dec_hidden_size = 512\n        self.Nz = 128\n        self.M = 20\n        self.dropout = 0.9\n        self.batch_size = 100\n        self.eta_min = 0.01\n        self.R = 0.99995\n        self.KL_min = 0.2\n        self.wKL = 0.5\n        self.lr = 0.001\n        self.lr_decay = 0.9999\n        self.min_lr = 0.00001\n        self.grad_clip = 1.\n        self.temperature = 0.4\n        self.max_seq_length = 200\n\nhp = HParams()\n\n################################# load and prepare data\ndef max_size(data):\n    \"\"\"larger sequence length in the data set\"\"\"\n    sizes = [len(seq) for seq in data]\n    return max(sizes)\n\ndef purify(strokes):\n    \"\"\"removes to small or too long sequences + removes large gaps\"\"\"\n    data = []\n    for seq in strokes:\n        if seq.shape[0] <= hp.max_seq_length and seq.shape[0] > 10:\n            seq = np.minimum(seq, 1000)\n            seq = np.maximum(seq, -1000)\n            seq = np.array(seq, dtype=np.float32)\n            data.append(seq)\n    return data\n\ndef calculate_normalizing_scale_factor(strokes):\n    \"\"\"Calculate the normalizing factor explained in appendix of sketch-rnn.\"\"\"\n    data = []\n    for i in range(len(strokes)):\n        for j in range(len(strokes[i])):\n            data.append(strokes[i][j, 0])\n            data.append(strokes[i][j, 1])\n    data = np.array(data)\n    return np.std(data)\n\ndef normalize(strokes):\n    \"\"\"Normalize entire dataset (delta_x, delta_y) by the scaling factor.\"\"\"\n    data = []\n    scale_factor = calculate_normalizing_scale_factor(strokes)\n    for seq in strokes:\n        seq[:, 0:2] /= scale_factor\n        data.append(seq)\n    return data\n\ndataset = np.load(hp.data_location, encoding='latin1')\ndata = dataset['train']\ndata = purify(data)\ndata = normalize(data)\nNmax = max_size(data)\n\n############################## function to generate a batch:\ndef make_batch(batch_size):\n    batch_idx = np.random.choice(len(data),batch_size)\n    batch_sequences = [data[idx] for idx in batch_idx]\n    strokes = []\n    lengths = []\n    indice = 0\n    for seq in batch_sequences:\n        len_seq = len(seq[:,0])\n        new_seq = np.zeros((Nmax,5))\n        new_seq[:len_seq,:2] = seq[:,:2]\n        new_seq[:len_seq-1,2] = 1-seq[:-1,2]\n        new_seq[:len_seq,3] = seq[:,2]\n        new_seq[(len_seq-1):,4] = 1\n        new_seq[len_seq-1,2:4] = 0\n        lengths.append(len(seq[:,0]))\n        strokes.append(new_seq)\n        indice += 1\n\n    if use_cuda:\n        batch = Variable(torch.from_numpy(np.stack(strokes,1)).cuda().float())\n    else:\n        batch = Variable(torch.from_numpy(np.stack(strokes,1)).float())\n    return batch, lengths\n\n################################ adaptive lr\ndef lr_decay(optimizer):\n    \"\"\"Decay learning rate by a factor of lr_decay\"\"\"\n    for param_group in optimizer.param_groups:\n        if param_group['lr']>hp.min_lr:\n            param_group['lr'] *= hp.lr_decay\n    return optimizer\n\n################################# encoder and decoder modules\nclass EncoderRNN(nn.Module):\n    def __init__(self):\n        super(EncoderRNN, self).__init__()\n        # bidirectional lstm:\n        self.lstm = nn.LSTM(5, hp.enc_hidden_size, \\\n            dropout=hp.dropout, bidirectional=True)\n        # create mu and sigma from lstm's last output:\n        self.fc_mu = nn.Linear(2*hp.enc_hidden_size, hp.Nz)\n        self.fc_sigma = nn.Linear(2*hp.enc_hidden_size, hp.Nz)\n        # active dropout:\n        self.train()\n\n    def forward(self, inputs, batch_size, hidden_cell=None):\n        if hidden_cell is None:\n            # then must init with zeros\n            if use_cuda:\n                hidden = torch.zeros(2, batch_size, hp.enc_hidden_size).cuda()\n                cell = torch.zeros(2, batch_size, hp.enc_hidden_size.cuda()\n            else:\n                hidden = torch.zeros(2, batch_size, hp.enc_hidden_size)\n                cell = torch.zeros(2, batch_size, hp.enc_hidden_size)\n            hidden_cell = (hidden, cell)\n        _, (hidden,cell) = self.lstm(inputs.float(), hidden_cell)\n        # hidden is (2, batch_size, hidden_size), we want (batch_size, 2*hidden_size):\n        hidden_forward, hidden_backward = torch.split(hidden,1,0)\n        hidden_cat = torch.cat([hidden_forward.squeeze(0), hidden_backward.squeeze(0)],1)\n        # mu and sigma:\n        mu = self.fc_mu(hidden_cat)\n        sigma_hat = self.fc_sigma(hidden_cat)\n        sigma = torch.exp(sigma_hat/2.)\n        # N ~ N(0,1)\n        z_size = mu.size()\n        if use_cuda:\n            N = torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda()\n        else:\n            N = torch.normal(torch.zeros(z_size),torch.ones(z_size))\n        z = mu + sigma*N\n        # mu and sigma_hat are needed for LKL loss\n        return z, mu, sigma_hat\n\nclass DecoderRNN(nn.Module):\n    def __init__(self):\n        super(DecoderRNN, self).__init__()\n        # to init hidden and cell from z:\n        self.fc_hc = nn.Linear(hp.Nz, 2*hp.dec_hidden_size)\n        # unidirectional lstm:\n        self.lstm = nn.LSTM(hp.Nz+5, hp.dec_hidden_size, dropout=hp.dropout)\n        # create proba distribution parameters from hiddens:\n        self.fc_params = nn.Linear(hp.dec_hidden_size,6*hp.M+3)\n\n    def forward(self, inputs, z, hidden_cell=None):\n        if hidden_cell is None:\n            # then we must init from z\n            hidden,cell = torch.split(F.tanh(self.fc_hc(z)),hp.dec_hidden_size,1)\n            hidden_cell = (hidden.unsqueeze(0).contiguous(), cell.unsqueeze(0).contiguous())\n        outputs,(hidden,cell) = self.lstm(inputs, hidden_cell)\n        # in training we feed the lstm with the whole input in one shot\n        # and use all outputs contained in 'outputs', while in generate\n        # mode we just feed with the last generated sample:\n        if self.training:\n            y = self.fc_params(outputs.view(-1, hp.dec_hidden_size))\n        else:\n            y = self.fc_params(hidden.view(-1, hp.dec_hidden_size))\n        # separate pen and mixture params:\n        params = torch.split(y,6,1)\n        params_mixture = torch.stack(params[:-1]) # trajectory\n        params_pen = params[-1] # pen up/down\n        # identify mixture params:\n        pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy = torch.split(params_mixture,1,2)\n        # preprocess params::\n        if self.training:\n            len_out = Nmax+1\n        else:\n            len_out = 1\n                                   \n        pi = F.softmax(pi.transpose(0,1).squeeze()).view(len_out,-1,hp.M)\n        sigma_x = torch.exp(sigma_x.transpose(0,1).squeeze()).view(len_out,-1,hp.M)\n        sigma_y = torch.exp(sigma_y.transpose(0,1).squeeze()).view(len_out,-1,hp.M)\n        rho_xy = torch.tanh(rho_xy.transpose(0,1).squeeze()).view(len_out,-1,hp.M)\n        mu_x = mu_x.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)\n        mu_y = mu_y.transpose(0,1).squeeze().contiguous().view(len_out,-1,hp.M)\n        q = F.softmax(params_pen).view(len_out,-1,3)\n        return pi,mu_x,mu_y,sigma_x,sigma_y,rho_xy,q,hidden,cell\n\nclass Model():\n    def __init__(self):\n        if use_cuda:\n            self.encoder = EncoderRNN().cuda()\n            self.decoder = DecoderRNN().cuda()\n        else:\n            self.encoder = EncoderRNN()\n            self.decoder = DecoderRNN()\n        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), hp.lr)\n        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), hp.lr)\n        self.eta_step = hp.eta_min\n\n    def make_target(self, batch, lengths):\n        if use_cuda:\n            eos = torch.stack([torch.Tensor([0,0,0,0,1])]*batch.size()[1]).cuda().unsqueeze(0)\n        else:\n            eos = torch.stack([torch.Tensor([0,0,0,0,1])]*batch.size()[1]).unsqueeze(0)\n        batch = torch.cat([batch, eos], 0)\n        mask = torch.zeros(Nmax+1, batch.size()[1])\n        for indice,length in enumerate(lengths):\n            mask[:length,indice] = 1\n        if use_cuda:\n            mask = mask.cuda()\n        dx = torch.stack([batch.data[:,:,0]]*hp.M,2)\n        dy = torch.stack([batch.data[:,:,1]]*hp.M,2)\n        p1 = batch.data[:,:,2]\n        p2 = batch.data[:,:,3]\n        p3 = batch.data[:,:,4]\n        p = torch.stack([p1,p2,p3],2)\n        return mask,dx,dy,p\n\n    def train(self, epoch):\n        self.encoder.train()\n        self.decoder.train()\n        batch, lengths = make_batch(hp.batch_size)\n        # encode:\n        z, self.mu, self.sigma = self.encoder(batch, hp.batch_size)\n        # create start of sequence:\n        if use_cuda:\n            sos = torch.stack([torch.Tensor([0,0,1,0,0])]*hp.batch_size).cuda().unsqueeze(0)\n        else:\n            sos = torch.stack([torch.Tensor([0,0,1,0,0])]*hp.batch_size).unsqueeze(0)\n        # had sos at the begining of the batch:\n        batch_init = torch.cat([sos, batch],0)\n        # expend z to be ready to concatenate with inputs:\n        z_stack = torch.stack([z]*(Nmax+1))\n        # inputs is concatenation of z and batch_inputs\n        inputs = torch.cat([batch_init, z_stack],2)\n        # decode:\n        self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \\\n            self.rho_xy, self.q, _, _ = self.decoder(inputs, z)\n        # prepare targets:\n        mask,dx,dy,p = self.make_target(batch, lengths)\n        # prepare optimizers:\n        self.encoder_optimizer.zero_grad()\n        self.decoder_optimizer.zero_grad()\n        # update eta for LKL:\n        self.eta_step = 1-(1-hp.eta_min)*hp.R\n        # compute losses:\n        LKL = self.kullback_leibler_loss()\n        LR = self.reconstruction_loss(mask,dx,dy,p,epoch)\n        loss = LR + LKL\n        # gradient step\n        loss.backward()\n        # gradient cliping\n        nn.utils.clip_grad_norm(self.encoder.parameters(), hp.grad_clip)\n        nn.utils.clip_grad_norm(self.decoder.parameters(), hp.grad_clip)\n        # optim step\n        self.encoder_optimizer.step()\n        self.decoder_optimizer.step()\n        # some print and save:\n        if epoch%1==0:\n            print('epoch',epoch,'loss',loss.data[0],'LR',LR.data[0],'LKL',LKL.data[0])\n            self.encoder_optimizer = lr_decay(self.encoder_optimizer)\n            self.decoder_optimizer = lr_decay(self.decoder_optimizer)\n        if epoch%100==0:\n            #self.save(epoch)\n            self.conditional_generation(epoch)\n\n    def bivariate_normal_pdf(self, dx, dy):\n        z_x = ((dx-self.mu_x)/self.sigma_x)**2\n        z_y = ((dy-self.mu_y)/self.sigma_y)**2\n        z_xy = (dx-self.mu_x)*(dy-self.mu_y)/(self.sigma_x*self.sigma_y)\n        z = z_x + z_y -2*self.rho_xy*z_xy\n        exp = torch.exp(-z/(2*(1-self.rho_xy**2)))\n        norm = 2*np.pi*self.sigma_x*self.sigma_y*torch.sqrt(1-self.rho_xy**2)\n        return exp/norm\n\n    def reconstruction_loss(self, mask, dx, dy, p, epoch):\n        pdf = self.bivariate_normal_pdf(dx, dy)\n        LS = -torch.sum(mask*torch.log(1e-5+torch.sum(self.pi * pdf, 2)))\\\n            /float(Nmax*hp.batch_size)\n        LP = -torch.sum(p*torch.log(self.q))/float(Nmax*hp.batch_size)\n        return LS+LP\n\n    def kullback_leibler_loss(self):\n        LKL = -0.5*torch.sum(1+self.sigma-self.mu**2-torch.exp(self.sigma))\\\n            /float(hp.Nz*hp.batch_size)\n        if use_cuda:\n            KL_min = Variable(torch.Tensor([hp.KL_min]).cuda()).detach()\n        else:\n            KL_min = Variable(torch.Tensor([hp.KL_min])).detach()\n        return hp.wKL*self.eta_step * torch.max(LKL,KL_min)\n\n    def save(self, epoch):\n        sel = np.random.rand()\n        torch.save(self.encoder.state_dict(), \\\n            'encoderRNN_sel_%3f_epoch_%d.pth' % (sel,epoch))\n        torch.save(self.decoder.state_dict(), \\\n            'decoderRNN_sel_%3f_epoch_%d.pth' % (sel,epoch))\n\n    def load(self, encoder_name, decoder_name):\n        saved_encoder = torch.load(encoder_name)\n        saved_decoder = torch.load(decoder_name)\n        self.encoder.load_state_dict(saved_encoder)\n        self.decoder.load_state_dict(saved_decoder)\n\n    def conditional_generation(self, epoch):\n        batch,lengths = make_batch(1)\n        # should remove dropouts:\n        self.encoder.train(False)\n        self.decoder.train(False)\n        # encode:\n        z, _, _ = self.encoder(batch, 1)\n        if use_cuda:\n            sos = Variable(torch.Tensor([0,0,1,0,0]).view(1,1,-1).cuda())\n        else:\n            sos = Variable(torch.Tensor([0,0,1,0,0]).view(1,1,-1))\n        s = sos\n        seq_x = []\n        seq_y = []\n        seq_z = []\n        hidden_cell = None\n        for i in range(Nmax):\n            input = torch.cat([s,z.unsqueeze(0)],2)\n            # decode:\n            self.pi, self.mu_x, self.mu_y, self.sigma_x, self.sigma_y, \\\n                self.rho_xy, self.q, hidden, cell = \\\n                    self.decoder(input, z, hidden_cell)\n            hidden_cell = (hidden, cell)\n            # sample from parameters:\n            s, dx, dy, pen_down, eos = self.sample_next_state()\n            #------\n            seq_x.append(dx)\n            seq_y.append(dy)\n            seq_z.append(pen_down)\n            if eos:\n                print(i)\n                break\n        # visualize result:\n        x_sample = np.cumsum(seq_x, 0)\n        y_sample = np.cumsum(seq_y, 0)\n        z_sample = np.array(seq_z)\n        sequence = np.stack([x_sample,y_sample,z_sample]).T\n        make_image(sequence, epoch)\n\n    def sample_next_state(self):\n\n        def adjust_temp(pi_pdf):\n            pi_pdf = np.log(pi_pdf)/hp.temperature\n            pi_pdf -= pi_pdf.max()\n            pi_pdf = np.exp(pi_pdf)\n            pi_pdf /= pi_pdf.sum()\n            return pi_pdf\n\n        # get mixture indice:\n        pi = self.pi.data[0,0,:].cpu().numpy()\n        pi = adjust_temp(pi)\n        pi_idx = np.random.choice(hp.M, p=pi)\n        # get pen state:\n        q = self.q.data[0,0,:].cpu().numpy()\n        q = adjust_temp(q)\n        q_idx = np.random.choice(3, p=q)\n        # get mixture params:\n        mu_x = self.mu_x.data[0,0,pi_idx]\n        mu_y = self.mu_y.data[0,0,pi_idx]\n        sigma_x = self.sigma_x.data[0,0,pi_idx]\n        sigma_y = self.sigma_y.data[0,0,pi_idx]\n        rho_xy = self.rho_xy.data[0,0,pi_idx]\n        x,y = sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy,greedy=False)\n        next_state = torch.zeros(5)\n        next_state[0] = x\n        next_state[1] = y\n        next_state[q_idx+2] = 1\n        if use_cuda:\n            return Variable(next_state.cuda()).view(1,1,-1),x,y,q_idx==1,q_idx==2\n        else:\n            return Variable(next_state).view(1,1,-1),x,y,q_idx==1,q_idx==2\n\ndef sample_bivariate_normal(mu_x,mu_y,sigma_x,sigma_y,rho_xy, greedy=False):\n    # inputs must be floats\n    if greedy:\n      return mu_x,mu_y\n    mean = [mu_x, mu_y]\n    sigma_x *= np.sqrt(hp.temperature)\n    sigma_y *= np.sqrt(hp.temperature)\n    cov = [[sigma_x * sigma_x, rho_xy * sigma_x * sigma_y],\\\n        [rho_xy * sigma_x * sigma_y, sigma_y * sigma_y]]\n    x = np.random.multivariate_normal(mean, cov, 1)\n    return x[0][0], x[0][1]\n\ndef make_image(sequence, epoch, name='_output_'):\n    \"\"\"plot drawing with separated strokes\"\"\"\n    strokes = np.split(sequence, np.where(sequence[:,2]>0)[0]+1)\n    fig = plt.figure()\n    ax1 = fig.add_subplot(111)\n    for s in strokes:\n        plt.plot(s[:,0],-s[:,1])\n    canvas = plt.get_current_fig_manager().canvas\n    canvas.draw()\n    pil_image = PIL.Image.frombytes('RGB', canvas.get_width_height(),\n                 canvas.tostring_rgb())\n    name = str(epoch)+name+'.jpg'\n    pil_image.save(name,\"JPEG\")\n    plt.close(\"all\")\n\nif __name__==\"__main__\":\n    model = Model()\n    for epoch in range(50001):\n        model.train(epoch)\n\n    '''\n    model.load('encoder.pth','decoder.pth')\n    model.conditional_generation(0)\n    #'''\n\n"
  }
]