[
  {
    "path": ".gitignore",
    "content": "*.pyc\n*.log\nckpt\n/data*/*\n/model*/*\n/ckpt*/*\n/result*/*\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2021 Piji Li\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# SongNet\nSongNet: SongCi + Song (Lyrics) + Sonnet + etc.\n\n```\n@inproceedings{li-etal-2020-rigid,\n    title = \"Rigid Formats Controlled Text Generation\",\n    author = \"Li, Piji and Zhang, Haisong and Liu, Xiaojiang and Shi, Shuming\",\n    booktitle = \"Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics\",\n    month = jul,\n    year = \"2020\",\n    address = \"Online\",\n    publisher = \"Association for Computational Linguistics\",\n    url = \"https://www.aclweb.org/anthology/2020.acl-main.68\",\n    doi = \"10.18653/v1/2020.acl-main.68\",\n    pages = \"742--751\"\n}\n```\n\n### Run\n- python prepare_data.py\n- ./train.sh \n\n### Evaluation\n- Modify test.py: m_path = the best dev model\n- ./test.sh\n- python metrics.py\n\n### Polish\n- ./polish.sh\n\n### Download\n- The pretrained Chinese Language Model: https://drive.google.com/file/d/1g2tGyUwPe86vPn2nub1vkQva5lwtZ6Rd/view \n- The finetuned SongCi model: https://drive.google.com/file/d/16A2AzuU7slf7xj2QdLcBAorUCCaCk650/view\n\n#### Reference\n\n- Guyu: https://github.com/lipiji/Guyu\n- Pretraining：https://github.com/lipiji/big_tpl_zh_10_base\n"
  },
  {
    "path": "adam.py",
    "content": "# coding=utf-8\nimport torch\nfrom torch.optim import Optimizer\n\nclass AdamWeightDecayOptimizer(Optimizer):\n    \"\"\"A basic Adam optimizer that includes \"correct\" L2 weight decay.\n    https://github.com/google-research/bert/blob/master/optimization.py\n    https://raw.githubusercontent.com/pytorch/pytorch/v1.0.0/torch/optim/adam.py\"\"\"\n    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,\n                 weight_decay=0, amsgrad=False):\n        if not 0.0 <= lr:\n            raise ValueError(\"Invalid learning rate: {}\".format(lr))\n        if not 0.0 <= eps:\n            raise ValueError(\"Invalid epsilon value: {}\".format(eps))\n        if not 0.0 <= betas[0] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 0: {}\".format(betas[0]))\n        if not 0.0 <= betas[1] < 1.0:\n            raise ValueError(\"Invalid beta parameter at index 1: {}\".format(betas[1]))\n        defaults = dict(lr=lr, betas=betas, eps=eps,\n                        weight_decay=weight_decay, amsgrad=amsgrad)\n        super(AdamWeightDecayOptimizer, self).__init__(params, defaults)\n\n    def __setstate__(self, state):\n        super(AdamWeightDecayOptimizer, self).__setstate__(state)\n        for group in self.param_groups:\n            group.setdefault('amsgrad', False)\n\n    def step(self, closure=None):\n        \"\"\"Performs a single optimization step.\n\n        Arguments:\n            closure (callable, optional): A closure that reevaluates the model\n                and returns the loss.\n        \"\"\"\n        loss = None\n        if closure is not None:\n            loss = closure()\n\n        for group in self.param_groups:\n            for p in group['params']:\n                if p.grad is None:\n                    continue\n                grad = p.grad.data\n                if grad.is_sparse:\n                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')\n                amsgrad = group['amsgrad']\n\n                state = self.state[p]\n\n                # State initialization\n                if len(state) == 0:\n                    state['step'] = 0\n                    # Exponential moving average of gradient values\n                    state['exp_avg'] = torch.zeros_like(p.data)\n                    # Exponential moving average of squared gradient values\n                    state['exp_avg_sq'] = torch.zeros_like(p.data)\n                    if amsgrad:\n                        # Maintains max of all exp. moving avg. of sq. grad. values\n                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)\n\n                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']\n                if amsgrad:\n                    max_exp_avg_sq = state['max_exp_avg_sq']\n                beta1, beta2 = group['betas']\n\n                state['step'] += 1\n\n                # Decay the first and second moment running average coefficient\n                exp_avg.mul_(beta1).add_(1 - beta1, grad)\n                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)\n                if amsgrad:\n                    # Maintains the maximum of all 2nd moment running avg. till now\n                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)\n                    # Use the max. for normalizing running avg. of gradient\n                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])\n                else:\n                    denom = exp_avg_sq.sqrt().add_(group['eps'])\n\n                # Just adding the square of the weights to the loss function is *not*\n                # the correct way of using L2 regularization/weight decay with Adam,\n                # since that will interact with the m and v parameters in strange ways.\n                #\n                # Instead we want ot decay the weights in a manner that doesn't interact\n                # with the m/v parameters. This is equivalent to adding the square\n                # of the weights to the loss with plain (non-momentum) SGD.\n                update = (exp_avg/denom).add_(group['weight_decay'], p.data)\n                p.data.add_(-group['lr'], update)\n        return loss"
  },
  {
    "path": "biglm.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom utils import gelu, LayerNorm\nfrom transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, SelfAttentionMask\nfrom label_smoothing import LabelSmoothing \n\nclass BIGLM(nn.Module):\n    def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx=None):\n        super(BIGLM, self).__init__()\n        self.vocab = vocab\n        self.embed_dim = embed_dim\n\n        self.tok_embed = Embedding(self.vocab.size, embed_dim, self.vocab.padding_idx)\n        self.pos_embed = LearnedPositionalEmbedding(embed_dim, device=local_rank)\n        \n        self.layers = nn.ModuleList()\n        for i in range(layers):\n            self.layers.append(TransformerLayer(embed_dim, ff_embed_dim, num_heads, dropout, with_external=True))\n        self.emb_layer_norm = LayerNorm(embed_dim)\n        self.one_more = nn.Linear(embed_dim, embed_dim)\n        self.one_more_layer_norm = LayerNorm(embed_dim)\n        self.out_proj = nn.Linear(embed_dim, self.vocab.size)\n        \n        self.attn_mask = SelfAttentionMask(device=local_rank)\n        self.smoothing = LabelSmoothing(local_rank, self.vocab.size, self.vocab.padding_idx, smoothing_factor)\n       \n        self.dropout = dropout\n        self.device = local_rank\n\n        self.approx = approx\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.constant_(self.one_more.bias, 0.)\n        nn.init.normal_(self.one_more.weight, std=0.02)\n        nn.init.constant_(self.out_proj.bias, 0.)\n        nn.init.normal_(self.out_proj.weight, std=0.02)\n    \n    def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):\n        seq_len, bsz = y.size()\n\n        y_pred = torch.log(y_pred.clamp(min=1e-8))\n        loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1))\n        if avg:\n            return loss / torch.sum(y_mask)\n        else:\n            return loss / bsz\n\n    def nll_loss(self, y_pred, y, y_mask, avg=True):\n        cost = -torch.log(torch.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1)))\n        cost = cost.view(y.shape)\n        y_mask = y_mask.view(y.shape)\n        if avg:\n            cost = torch.sum(cost * y_mask, 0) / torch.sum(y_mask, 0)\n        else:\n            cost = torch.sum(cost * y_mask, 0)\n        cost = cost.view((y.size(1), -1))\n        ppl = 2 ** cost\n        return cost.sum().item(), ppl.sum().item()\n\n    \n    def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos, incremental_state=None):\n        seq_len, bsz = ys_inp.size()\n        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)\n        x = self.emb_layer_norm(x)\n        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)\n        if not padding_mask.any():\n            padding_mask = None\n\n        if incremental_state is None:\n            self_attn_mask = self.attn_mask(seq_len)\n            incremental_state = {}\n        else:\n            x = x[-1, :, :].unsqueeze(0)\n            self_attn_mask = None\n\n        for layer in self.layers:\n            x, _ ,_ = layer.work_incremental(x, self_padding_mask=padding_mask, \\\n                                             self_attn_mask=self_attn_mask, \\\n                                             external_memories = enc, \\\n                                             external_padding_mask = src_padding_mask, \\\n                                             incremental_state = incremental_state)\n\n        x = self.one_more_layer_norm(gelu(self.one_more(x)))\n        probs = torch.softmax(self.out_proj(x), -1)\n\n        _, pred_y = probs.max(-1)\n        return probs, pred_y, incremental_state\n \n    def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):\n        seq_len, bsz = ys_inp.size()\n        self_attn_mask = self.attn_mask(seq_len)\n        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)\n        x = self.emb_layer_norm(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)\n        if not padding_mask.any():\n            padding_mask = None\n        for layer in self.layers:\n            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \\\n                               self_attn_mask = self_attn_mask, \\\n                               external_memories = enc, \\\n                               external_padding_mask = src_padding_mask,)\n\n        x = self.one_more_layer_norm(gelu(self.one_more(x)))\n        probs = torch.softmax(self.out_proj(x), -1)\n        \n        _, pred_y = probs.max(-1)\n        \n        return probs, pred_y\n    \n    def encode(self, xs_tpl, xs_seg, xs_pos):\n        padding_mask = torch.eq(xs_tpl, self.vocab.padding_idx)\n        x = self.tok_embed(xs_tpl)  + self.tok_embed(xs_seg) + self.tok_embed(xs_pos)\n        x = self.emb_layer_norm(x)\n        return x, padding_mask\n    \n    def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):\n        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)\n        seq_len, bsz = ys_inp.size()\n        self_attn_mask = self.attn_mask(seq_len)\n        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)\n        x = self.emb_layer_norm(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)\n        if not padding_mask.any():\n            padding_mask = None\n        for layer in self.layers:\n            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \\\n                               self_attn_mask = self_attn_mask, \\\n                               external_memories = enc, \\\n                               external_padding_mask = src_padding_mask,)\n\n        x = self.one_more_layer_norm(gelu(self.one_more(x)))\n        pred = torch.softmax(self.out_proj(x), -1)\n        nll, ppl = self.nll_loss(pred, ys_truth, msk)\n        return nll, ppl, bsz\n    \n    def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):\n        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)\n        seq_len, bsz = ys_inp.size()\n        self_attn_mask = self.attn_mask(seq_len)\n        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)\n        x = self.emb_layer_norm(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)\n        if not padding_mask.any():\n            padding_mask = None\n        for layer in self.layers:\n            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \\\n                               self_attn_mask = self_attn_mask, \\\n                               external_memories = enc, \\\n                               external_padding_mask = src_padding_mask,)\n\n        x = self.one_more_layer_norm(gelu(self.one_more(x)))\n        pred = torch.softmax(self.out_proj(x), -1)\n\n        loss = self.label_smotthing_loss(pred, ys_truth, msk)\n        \n        _, pred_y = pred.max(-1)\n        tot_tokens = msk.float().sum().item()\n        acc = (torch.eq(pred_y, ys_truth).float() * msk).sum().item()\n       \n        nll, ppl = self.nll_loss(pred, ys_truth, msk)\n        return (pred_y, ys_truth), loss, acc, nll, ppl, tot_tokens, bsz\n"
  },
  {
    "path": "data.py",
    "content": "import random\nimport torch\nimport numpy as np\n\nPAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'\nBOC, EOC = '<boc>', '<eoc>'\nLS, RS, SP = '<s>', '</s>', ' '\nCS = ['<c-1>'] + ['<c' + str(i) + '>' for i in range(32)] # content\nSS = ['<s-1>'] + ['<s' + str(i) + '>' for i in range(512)] # segment\nPS = ['<p-1>'] + ['<p' + str(i) + '>' for i in range(512)] # position\nTS = ['<t-1>'] + ['<t' + str(i) + '>' for i in range(32)] # other types\n\nPUNCS = set([\",\", \".\", \"?\", \"!\", \":\", \"，\", \"。\", \"？\", \"！\", \"：\"])\n\nBUFSIZE = 4096000\n\ndef ListsToTensor(xs, vocab=None):\n    max_len = max(len(x) for x in xs)\n    ys = []\n    for x in xs:\n        if vocab is not None:\n            y = vocab.token2idx(x) + [vocab.padding_idx]*(max_len - len(x))\n        else:\n            y = x + [0]*(max_len - len(x))\n        ys.append(y)\n    return ys\n\ndef _back_to_text_for_check(x, vocab):\n    w = x.t().tolist()\n    for sent in vocab.idx2token(w):\n        print (' '.join(sent))\n    \ndef batchify(data, vocab):\n    xs_tpl, xs_seg, xs_pos, \\\n    ys_truth, ys_inp, \\\n    ys_tpl, ys_seg, ys_pos, msk = [], [], [], [], [], [], [], [], []\n    for xs_tpl_i, xs_seg_i, xs_pos_i, ys_i, ys_tpl_i, ys_seg_i, ys_pos_i in data:\n        xs_tpl.append(xs_tpl_i)\n        xs_seg.append(xs_seg_i)\n        xs_pos.append(xs_pos_i)\n        \n        ys_truth.append(ys_i)\n        ys_inp.append([BOS] + ys_i[:-1])\n        ys_tpl.append(ys_tpl_i)\n        ys_seg.append(ys_seg_i)\n        ys_pos.append(ys_pos_i)\n        \n        msk.append([1 for i in range(len(ys_i))])\n\n    xs_tpl = torch.LongTensor(ListsToTensor(xs_tpl, vocab)).t_().contiguous()\n    xs_seg = torch.LongTensor(ListsToTensor(xs_seg, vocab)).t_().contiguous()\n    xs_pos = torch.LongTensor(ListsToTensor(xs_pos, vocab)).t_().contiguous()\n    ys_truth = torch.LongTensor(ListsToTensor(ys_truth, vocab)).t_().contiguous()\n    ys_inp = torch.LongTensor(ListsToTensor(ys_inp, vocab)).t_().contiguous()\n    ys_tpl = torch.LongTensor(ListsToTensor(ys_tpl, vocab)).t_().contiguous()\n    ys_seg = torch.LongTensor(ListsToTensor(ys_seg, vocab)).t_().contiguous()\n    ys_pos = torch.LongTensor(ListsToTensor(ys_pos, vocab)).t_().contiguous()\n    msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()\n    return xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk\n\ndef s2t(strs, vocab):\n    inp, msk = [], []\n    for x in strs:\n        inp.append(x)\n        msk.append([1 for i in range(len(x))])\n\n    inp = torch.LongTensor(ListsToTensor(inp, vocab)).t_().contiguous()\n    msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()\n    return inp, msk\n\ndef s2xy(lines, vocab, max_len, min_len):\n    data = []\n    for line in lines:\n        res = parse_line(line, max_len, min_len)\n        if not res:\n            continue\n        data.append(res)\n    return  batchify(data, vocab)\n\ndef parse_line(line, max_len, min_len):\n    line = line.strip()\n    if not line:\n        return None\n    fs = line.split(\"<s2>\")\n    author, cipai = fs[0].split(\"<s1>\")\n    sents = fs[1].strip()\n    if len(sents) > max_len:\n        sents = sents[:max_len]\n    if len(sents) < min_len:\n        return None\n    sents = sents.split(\"</s>\")\n\n    ys = []\n    xs_tpl = []\n    xs_seg = []\n    xs_pos = []\n\n    ctx = cipai\n    ws = [w for w in ctx]\n    xs_tpl = ws + [EOC]\n    xs_seg = [SS[0] for w in ws] + [EOC]\n    xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]\n\n    ys_tpl = []\n    ys_seg = []\n    ys_pos = []\n    for si, sent in enumerate(sents):\n        ws = []\n        sent = sent.strip()\n        if not sent:\n            continue\n        for w in sent:\n            ws.append(w)\n            if w.strip() and w not in PUNCS:\n                ys_tpl.append(CS[2])\n            else:\n                ys_tpl.append(CS[1])\n        ys += ws + [RS]\n        if ws[-1] in PUNCS:\n            ys_tpl[-2] = CS[3]\n        else:\n            ys_tpl[-1] = CS[3]\n        ys_tpl += [RS]\n        ys_seg += [SS[si + 1] for w in ws] + [RS]\n        ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]\n\n    ys += [EOS]\n    ys_tpl += [EOS]\n    ys_seg += [EOS]\n    ys_pos += [EOS]\n\n    xs_tpl += ys_tpl\n    xs_seg += ys_seg\n    xs_pos += ys_pos\n    \n    if len(ys) < min_len:\n        return None\n    return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos\n\ndef s2xy_polish(lines, vocab, max_len, min_len):\n    data = []\n    for line in lines:\n        res = parse_line_polish(line, max_len, min_len)\n        data.append(res)\n    return  batchify(data, vocab)\n\ndef parse_line_polish(line, max_len, min_len):\n    line = line.strip()\n    if not line:\n        return None\n    fs = line.split(\"<s2>\")\n    author, cipai = fs[0].split(\"<s1>\")\n    sents = fs[1].strip()\n    if len(sents) > max_len:\n        sents = sents[:max_len]\n    if len(sents) < min_len:\n        return None\n    sents = sents.split(\"</s>\")\n\n    ys = []\n    xs_tpl = []\n    xs_seg = []\n    xs_pos = []\n\n    ctx = cipai\n    ws = [w for w in ctx]\n    xs_tpl = ws + [EOC]\n    xs_seg = [SS[0] for w in ws] + [EOC]\n    xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]\n\n    ys_tpl = []\n    ys_seg = []\n    ys_pos = []\n    for si, sent in enumerate(sents):\n        ws = []\n        sent = sent.strip()\n        if not sent:\n            continue\n        for w in sent:\n            ws.append(w)\n            if w == \"_\":\n                ys_tpl.append(CS[2])\n            else:\n                ys_tpl.append(w)\n        ys += ws + [RS]\n        ys_tpl += [RS]\n        ys_seg += [SS[si + 1] for w in ws] + [RS]\n        ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]\n\n    ys += [EOS]\n    ys_tpl += [EOS]\n    ys_seg += [EOS]\n    ys_pos += [EOS]\n\n    xs_tpl += ys_tpl\n    xs_seg += ys_seg\n    xs_pos += ys_pos\n    \n    if len(ys) < min_len:\n        return None\n\n    return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos\n\nclass DataLoader(object):\n    def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):\n        self.batch_size = batch_size\n        self.vocab = vocab\n        self.max_len_y = max_len_y\n        self.min_len_y = min_len_y\n        self.filename = filename\n        self.stream = open(self.filename, encoding='utf8')\n        self.epoch_id = 0\n\n    def __iter__(self):\n        \n        lines = self.stream.readlines(BUFSIZE)\n\n        if not lines:\n            self.epoch_id += 1\n            self.stream.close()\n            self.stream = open(self.filename, encoding='utf8')\n            lines = self.stream.readlines(BUFSIZE)\n\n        data = []\n        for line in lines[:-1]: # the last sent may be imcomplete\n            res = parse_line(line, self.max_len_y, self.min_len_y)\n            if not res:\n                continue\n            data.append(res)\n        \n        random.shuffle(data)\n        \n        idx = 0\n        while idx < len(data):\n            yield batchify(data[idx:idx+self.batch_size], self.vocab)\n            idx += self.batch_size\n\nclass Vocab(object):\n    def __init__(self, filename, min_occur_cnt, specials = None):\n        idx2token = [PAD, UNK, BOS, EOS] + [BOC, EOC, LS, RS, SP] + CS + SS + PS + TS \\\n                    +  (specials if specials is not None else [])\n        for line in open(filename, encoding='utf8').readlines():\n            try: \n                token, cnt = line.strip().split()\n            except:\n                continue\n            if int(cnt) >= min_occur_cnt:\n                idx2token.append(token)\n        self._token2idx = dict(zip(idx2token, range(len(idx2token))))\n        self._idx2token = idx2token\n        self._padding_idx = self._token2idx[PAD]\n        self._unk_idx = self._token2idx[UNK]\n\n    @property\n    def size(self):\n        return len(self._idx2token)\n    \n    @property\n    def unk_idx(self):\n        return self._unk_idx\n    \n    @property\n    def padding_idx(self):\n        return self._padding_idx\n    \n    def random_token(self):\n        return self.idx2token(1 + np.random.randint(self.size-1))\n\n    def idx2token(self, x):\n        if isinstance(x, list):\n            return [self.idx2token(i) for i in x]\n        return self._idx2token[x]\n\n    def token2idx(self, x):\n        if isinstance(x, list):\n            return [self.token2idx(i) for i in x]\n        return self._token2idx.get(x, self.unk_idx)\n"
  },
  {
    "path": "eval.py",
    "content": "import sys\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport time\n\nfrom biglm import BIGLM\nfrom data import Vocab, DataLoader, s2t, s2xy\n\ngpu = int(sys.argv[2]) if len(sys.argv) > 1 else 0\ndef init_model(m_path, device, vocab):\n    ckpt= torch.load(m_path, map_location='cpu')\n    lm_args = ckpt['args']\n    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])\n    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1, lm_args.approx)\n    lm_model.load_state_dict(ckpt['model'])\n    lm_model = lm_model.cuda(device)\n    lm_model.eval()\n    return lm_model, lm_vocab, lm_args\n\n#m_path = \"./ckpt_d101_6/epoch5_batch_139999\"\nm_path = sys.argv[1] if len(sys.argv) > 1 else None\nlm_model, lm_vocab, lm_args = init_model(m_path, gpu, \"./data/vocab.txt\")\n\n\nds = []\nwith open(\"./data/dev.txt\", \"r\") as f:\n    for line in f:\n        line = line.strip()\n        if line:\n            ds.append(line)\nprint(len(ds))\n\nlocal_rank = gpu\nbatch_size = 10\nbatches = round(len(ds) / batch_size)\nidx = 0\n\navg_nll = 0.\navg_ppl = 0.\ncount = 0.\nwhile idx < len(ds):\n    \n    cplb = ds[idx:idx + batch_size]\n    xs_tpl, xs_seg, xs_pos, \\\n    ys_truth, ys_inp, \\\n    ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)\n\n    xs_tpl = xs_tpl.cuda(local_rank)\n    xs_seg = xs_seg.cuda(local_rank)\n    xs_pos = xs_pos.cuda(local_rank)\n    ys_truth = ys_truth.cuda(local_rank)\n    ys_inp = ys_inp.cuda(local_rank)\n    ys_tpl = ys_tpl.cuda(local_rank)\n    ys_seg = ys_seg.cuda(local_rank)\n    ys_pos = ys_pos.cuda(local_rank)\n    msk = msk.cuda(local_rank)\n\n    nll, ppl, bsz = lm_model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)\n    \n    avg_nll += nll\n    avg_ppl += ppl\n    count += bsz\n\n    idx += batch_size\n    if count % 200 == 0:\n        print(\"nll=\", avg_nll/count, \"ppl=\", avg_ppl/count, \"count=\", count)\n    \nprint(\"nll=\", avg_nll/count, \"ppl=\", avg_ppl/count, \"count=\", count)\n"
  },
  {
    "path": "eval.sh",
    "content": "#!/bin/bash\npath=./ckpt/\nFILES=$path/*\nfor f in $FILES; do\n    echo \"==========================\" ${f##*/}\n    python -u eval.py $path${f##*/} 1\ndone\n"
  },
  {
    "path": "label_smoothing.py",
    "content": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nclass LabelSmoothing(nn.Module):\r\n    \"Implement label smoothing.\"\r\n    def __init__(self, device, size, padding_idx, label_smoothing=0.0):\r\n        super(LabelSmoothing, self).__init__()\r\n        assert 0.0 < label_smoothing <= 1.0\r\n        self.padding_idx = padding_idx\r\n        self.size = size\r\n        self.device = device\r\n\r\n        self.smoothing_value = label_smoothing / (size - 2)\r\n        self.one_hot = torch.full((1, size), self.smoothing_value).to(device)\r\n        self.one_hot[0, self.padding_idx] = 0\r\n        \r\n        self.confidence = 1.0 - label_smoothing\r\n\r\n    def forward(self, output, target):\r\n        real_size = output.size(1)\r\n        if real_size > self.size:\r\n            real_size -= self.size\r\n        else:\r\n            real_size = 0\r\n\r\n        model_prob = self.one_hot.repeat(target.size(0), 1)\r\n        if real_size > 0:\r\n            ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device)\r\n            model_prob = torch.cat((model_prob, ext_zeros), -1)\r\n        model_prob.scatter_(1, target, self.confidence)\r\n        model_prob.masked_fill_((target == self.padding_idx), 0.)\r\n\r\n        return F.kl_div(output, model_prob, reduction='sum')\r\n"
  },
  {
    "path": "metrics.py",
    "content": "import os\nimport sys\nimport numpy as np\nfrom pypinyin import Style, lazy_pinyin\n\nfrom data import PUNCS\n\nyunjiaos = {\n            \"0\":[\"a\", \"ia\", \"ua\", \"va\", \"üa\"],\n            \"1\":[\"e\", \"o\", \"uo\", \"ie\", \"ue\", \"üe\", \"ve\"],\n            \"2\":[\"u\"],\n            \"3\":[\"i\", \"ü\", \"v\"],\n            \"4\":[\"ai\", \"uai\"],\n            \"5\":[\"ao\", \"iao\"],\n            \"6\":[\"ou\", \"iu\", \"iou\"],\n            \"7\":[\"an\", \"ian\", \"uan\", \"üan\", \"van\"],\n            \"8\":[\"en\", \"in\", \"un\", \"ün\", \"vn\"],\n            \"9\":[\"ang\", \"iang\", \"uang\"],\n            \"10\":[\"eng\", \"ing\", \"ueng\", \"ong\", \"iong\"],\n            \"11\":[\"er\"],\n            \"12\":[\"ei\", \"ui\", \"uei\", \"vei\"],\n           }\n\nyun2id = {}\nfor yid, yws in yunjiaos.items():\n    for w in yws:\n        yun2id[w] = yid\n\ndef eval_tpl(sents1, sents2):\n    n = 0.\n    if len(sents1) > len(sents2):\n        sents1 = sents1[:len(sents2)]\n    for i, x in enumerate(sents1):\n        y = sents2[i]\n        if len(x) != len(y):\n            continue\n        px, py = [], []\n        for w in x:\n            if w in PUNCS:\n                px.append(w)\n        for w in y:\n            if w in PUNCS:\n                py.append(w)\n        if px == py:\n            n += 1\n    p = n / len(sents2)\n    r = n / len(sents1)\n    f = 2 * p * r / (p + r + 1e-16)\n\n    return p, r, f, n, len(sents1), len(sents2)\n\n\ndef rhythm_labellig(sents):\n    rhys = []\n    for sent in sents:\n        w = sent[-1]\n        if w in PUNCS and len(sent) > 1:\n            w = sent[-2]\n        yunmu = lazy_pinyin(w, style=Style.FINALS)\n        rhys.append(yunmu[0])\n    assert len(rhys) == len(sents)\n    rhy_map = {}\n    for i, r in enumerate(rhys):\n        if r in yun2id:\n            rid = yun2id[r]\n            if rid in rhy_map:\n                rhy_map[rid] += [i]\n            else:\n                rhy_map[rid] = [i]\n        else:\n            pass\n    max_len_yuns = -1\n    max_rid = \"\"\n    for rid, yuns in rhy_map.items():\n        if len(yuns) > max_len_yuns:\n            max_len_yuns = len(yuns)\n            max_rid = rid\n    res = []\n    for i in range(len(sents)):\n        if max_rid in rhy_map and i in rhy_map[max_rid]:\n            res.append(1)\n        else:\n            res.append(-1)\n    return res\n\ndef eval_rhythm(sents1, sents2):\n    n = 0.\n    if len(sents1) > len(sents2):\n        sents1 = sents1[:len(sents2)]\n    rhys1 = rhythm_labellig(sents1)\n    rhys2 = rhythm_labellig(sents2)\n    \n    n1, n2 = 0., 0.\n    for v in rhys1:\n        if v == 1:\n            n1 += 1\n    for v in rhys2:\n        if v == 1:\n            n2 += 1\n    for i, v1 in enumerate(rhys1):\n        v2 = rhys2[i]\n        if v1 == 1 and v1 == v2:\n            n += 1\n    p = n / (n2 + 1e-16)\n    r = n / (n1 + 1e-16)\n    f1 = 2 * p * r / (p + r + 1e-16)\n    return p, r, f1, n, n1, n2\n\ndef eval_endings(sents1, sents2):\n    n = 0.\n    if len(sents1) > len(sents2):\n        sents1 = sents1[:len(sents2)]\n   \n    sents0 = []\n    for si, sent1 in enumerate(sents1):\n        sent2 = sents2[si]\n        if len(sent2) <= len(sent1):\n            sents0.append(sent2)\n        else:\n            sents0.append(sent2[:len(sent1) - 1] + sent1[-1])\n\n    sent = \"</s>\".join(sents0)\n    return sent\n\n\ndef eval(res_file, fid):\n    docs = []\n    with open(res_file) as f:\n        for line in f:\n            line = line.strip()\n            if not line:\n                continue\n            fs = line.split(\"\\t\")\n            if len(fs) != 2:\n                print(\"error\", line)\n                continue\n            x, y = fs\n            docs.append((x, y))\n\n\n    print(len(docs))\n\n    ugrams_ = []\n    bigrams_ = []\n    p_, r_, f1_ = 0., 0., 0.\n    n0_, n1_, n2_ = 0., 0., 0.\n\n    p__, r__, f1__ = 0., 0., 0.\n    n0__, n1__, n2__ = 0., 0., 0.\n    d1_, d2_ = 0., 0.\n    d4ends = []\n\n    for x, y in docs:\n        topic, content = x.split(\"<s2>\")\n        author, topic = topic.split(\"<s1>\")\n        sents1 = content.split(\"</s>\")\n        y = y.replace(\"<bos>\", \"\")\n        sents2 = y.split(\"</s>\")\n        sents1_ = []\n        for sent in sents1:\n            sent = sent.strip()\n            if sent:\n                sents1_.append(sent)\n        sents1 = sents1_\n        sents2_ = []\n        for sent in sents2:\n            sent = sent.strip()\n            if sent:\n                sents2_.append(sent)\n        sents2 = sents2_\n\n        p, r, f1, n0, n1, n2 = eval_tpl(sents1, sents2)\n        p_ += p\n        r_ += r\n        f1_ += f1\n        n0_ += n0\n        n1_ += n1\n        n2_ += n2\n\n        ugrams = [w for w in ''.join(sents2)]\n        bigrams = []\n        for bi in range(len(ugrams) - 1):\n            bigrams.append(ugrams[bi] + ugrams[bi+1])\n        d1_ += len(set(ugrams)) / float(len(ugrams))\n        d2_ += len(set(bigrams)) / float(len(bigrams))\n        ugrams_ += ugrams\n        bigrams_ += bigrams\n\n        p, r, f1, n0, n1, n2 = eval_rhythm(sents1, sents2)\n        p__ += p\n        r__ += r\n        f1__ += f1\n        n0__ += n0\n        n1__ += n1\n        n2__ += n2\n\n        d4end = eval_endings(sents1, sents2)\n        d4ends.append(author + \"<s1>\" + topic + \"<s2>\" + d4end)\n\n    tpl_macro_p = p_ / len(docs)\n    tpl_macro_r = r_ / len(docs)\n    tpl_macro_f1 = 2 * tpl_macro_p * tpl_macro_r / (tpl_macro_p + tpl_macro_r)\n    tpl_micro_p = n0_ / n2_\n    tpl_micro_r = n0_ / n1_\n    tpl_micro_f1 = 2 * tpl_micro_p * tpl_micro_r / (tpl_micro_p + tpl_micro_r)\n    \n    rhy_macro_p = p__ / len(docs)\n    rhy_macro_r = r__ / len(docs)\n    rhy_macro_f1 = 2 * rhy_macro_p * rhy_macro_r / (rhy_macro_p + rhy_macro_r)\n    rhy_micro_p = n0__ / n2__\n    rhy_micro_r = n0__ / n1__\n    rhy_micro_f1 = 2 * rhy_micro_p * rhy_micro_r / (rhy_micro_p + rhy_micro_r)\n    \n\n    macro_dist1 = d1_ / len(docs)\n    macro_dist2 = d2_ / len(docs)\n    micro_dist1 = len(set(ugrams_)) / float(len(ugrams_))\n    micro_dist2 = len(set(bigrams_)) / float(len(bigrams_))\n\n    with open(\"./results_4ending/res4end\" + str(fid) + \".txt\", \"w\") as fo:\n        for line in d4ends:\n            fo.write(line + \"\\n\")\n    return tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2\n\ntpl_macro_f1_, tpl_micro_f1_, rhy_macro_f1_, rhy_micro_f1_,  \\\nmacro_dist1_, micro_dist1_, macro_dist2_, micro_dist2_ = [], [], [], [], [], [], [], []\nabalation = \"top-32\"\nfor i in range(5):\n    f_name = \"./results/\"+abalation+\"/out\" +str(i+1)+\".txt\"\n    if not os.path.exists(f_name):\n        continue\n    tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2 = eval(f_name, i + 1)\n    print(tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2)\n    tpl_macro_f1_.append(tpl_macro_f1)\n    tpl_micro_f1_.append(tpl_micro_f1)\n    rhy_macro_f1_.append(rhy_macro_f1)\n    rhy_micro_f1_.append(rhy_micro_f1)\n    macro_dist1_.append(macro_dist1)\n    micro_dist1_.append(micro_dist1)\n    macro_dist2_.append(macro_dist2)\n    micro_dist2_.append(micro_dist2)\n\nprint()\nprint(\"tpl_macro_f1\", np.mean(tpl_macro_f1_), np.std(tpl_macro_f1_, ddof=1))\nprint(\"tpl_micro_f1\", np.mean(tpl_micro_f1_), np.std(tpl_micro_f1_, ddof=1))\nprint(\"rhy_macro_f1\", np.mean(rhy_macro_f1_), np.std(rhy_macro_f1_, ddof=1))\nprint(\"rhy_micro_f1\", np.mean(rhy_micro_f1_), np.std(rhy_micro_f1_, ddof=1))\nprint(\"macro_dist1\", np.mean(macro_dist1_), np.std(macro_dist1_, ddof=1))\nprint(\"micro_dist1\", np.mean(micro_dist1_), np.std(micro_dist1_, ddof=1))\nprint(\"macro_dist2\", np.mean(macro_dist2_), np.std(macro_dist2_, ddof=1))\nprint(\"micro_dist2\", np.mean(micro_dist2_), np.std(micro_dist2_, ddof=1))\n\n\n\n\n\n"
  },
  {
    "path": "optim.py",
    "content": "# -*- coding: utf-8 -*-\n\nclass Optim:\n    \"Optim wrapper that implements rate.\"\n    def __init__(self, model_size, factor, warmup, optimizer):\n        self.optimizer = optimizer\n        self._step = 0\n        self.warmup = warmup\n        self.factor = factor\n        self.model_size = model_size\n        self._rate = 0\n\n    def step(self):\n        \"Update parameters and rate\"\n        self._step += 1\n        rate = self.rate()\n        for p in self.optimizer.param_groups:\n            p['lr'] = rate\n        self._rate = rate\n        self.optimizer.step()\n\n    def rate(self, step = None):\n        \"Implement `lrate` above\"\n        if step is None:\n            step = self._step\n        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))\n\n    def state_dict(self):\n        return self.optimizer.state_dict()\n\n    def load_state_dict(self, m):\n        self.optimizer.load_state_dict(m)\n"
  },
  {
    "path": "polish.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport time\n\nfrom biglm import BIGLM\nfrom data import Vocab, DataLoader, s2t, s2xy_polish\n\ngpu = 0\ndef init_model(m_path, device, vocab):\n    ckpt= torch.load(m_path, map_location='cpu')\n    lm_args = ckpt['args']\n    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])\n    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)\n    lm_model.load_state_dict(ckpt['model'])\n    lm_model = lm_model.cuda(device)\n    lm_model.eval()\n    return lm_model, lm_vocab, lm_args\n\nm_path = \"./model/songci.ckpt\"\nlm_model, lm_vocab, lm_args = init_model(m_path, gpu, \"./model/vocab.txt\")\n\n\nMAX_LEN = 300\nk = 32\ndef top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):\n    start = time.time()\n    incremental_state = None\n    inp_y, m = s2t(s, lm_vocab)\n    inp_y = inp_y.cuda(gpu)\n    res = []\n    for l in range(inp_ys_tpl.size(0)):\n        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \\\n                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\\\n                                         incremental_state)\n        next_tk = []\n        for i in range(len(s)):\n            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())\n            if ctk != \"<c1>\" and ctk != \"<c2>\" and ctk != \"<c0>\":\n                next_tk.append(ctk)\n                continue\n            \n            if l == 0:\n                logits = probs[len(s[i]) - 1, i]\n            else:\n                logits = probs[0, i]\n            ps, idx = torch.topk(logits, k=k)\n            ps = ps / torch.sum(ps)\n            sampled = torch.multinomial(ps, num_samples = 1)\n            sampled_idx = idx[sampled]\n            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))\n        \n        s_ = []\n        bidx = [1] * len(s)\n        for idx, (sent, t) in enumerate(zip(s, next_tk)):\n            if t == \"<eos>\":\n                res.append(sent)\n                bidx[idx] = 0\n            else:\n                s_.append(sent + [t])\n        if not s_:\n            break\n        s = s_\n        inp_y, m = s2t(s, lm_vocab)\n        inp_y = inp_y.cuda(gpu)\n        bidx = torch.BoolTensor(bidx).cuda(gpu)\n        incremental_state[\"bidx\"] = bidx\n    res += s_\n        \n    #for i in res:\n    #    print(''.join(i))\n    print(time.time()-start)\n    return res\n\ndef top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):\n    inp_y, m = s2t(s, lm_vocab)\n    inp_y = inp_y.cuda(gpu)\n\n    start = time.time()\n    res = []\n    for l in range(inp_ys_tpl.size(0)):\n        probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])\n        next_tk = []\n        for i in range(len(s)):\n            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())\n            if ctk != \"<c1>\" and ctk != \"<c2>\" and ctk != \"<c0>\":\n                next_tk.append(ctk)\n                continue\n            logits = probs[len(s[i]) - 1, i]\n            ps, idx = torch.topk(logits, k=k)\n            ps = ps / torch.sum(ps)\n            sampled = torch.multinomial(ps, num_samples = 1)\n            sampled_idx = idx[sampled]\n            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))\n        \n        s_ = []\n        for sent, t in zip(s, next_tk):\n            if t == \"<eos>\":\n                res.append(sent)\n            else:\n                s_.append(sent + [t])\n        if not s_:\n            break\n        s = s_\n        inp_y, m = s2t(s, lm_vocab)\n        inp_y = inp_y.cuda(gpu)\n\n    res += s_\n        \n    #for i in res:\n    #    print(''.join(i))\n\n    #print(time.time()-start)\n    return res\n   \n\n\n\nds = []\nwith open(\"./data/polish_tpl.txt\", \"r\") as f:\n    for line in f:\n        line = line.strip()\n        if line:\n            ds.append(line)\nprint(len(ds))\n\nlocal_rank = gpu\nbatch_size = 1\ncp_size = 1\nbatches = round(len(ds) / batch_size)\n\nfor i in range(5):\n    fo = open(\"./results/out\"+str(i+1)+\".txt\", \"w\")     \n    idx = 0\n    while idx < len(ds):\n        lb = ds[idx:idx + batch_size]\n        cplb = []\n        for line in lb:\n            cplb += [line for i in range(cp_size)]\n        print(cplb) \n        xs_tpl, xs_seg, xs_pos, \\\n        ys_truth, ys_inp, \\\n        ys_tpl, ys_seg, ys_pos, msk = s2xy_polish(cplb, lm_vocab, lm_args.max_len,2)\n\n        xs_tpl = xs_tpl.cuda(local_rank)\n        xs_seg = xs_seg.cuda(local_rank)\n        xs_pos = xs_pos.cuda(local_rank)\n        ys_tpl = ys_tpl.cuda(local_rank)\n        ys_seg = ys_seg.cuda(local_rank)\n        ys_pos = ys_pos.cuda(local_rank)\n\n        enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)\n        s = [['<bos>']] * batch_size * cp_size   \n        res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)\n\n        for i, line in enumerate(cplb):\n            r = ''.join(res[i])\n            print(line)\n            print(r)\n            fo.write(line + \"\\t\" + r + \"\\n\")  \n\n        idx += batch_size\n    fo.close()\n"
  },
  {
    "path": "polish.sh",
    "content": "python3 -u polish.py\n"
  },
  {
    "path": "prepare_data.py",
    "content": "import sys, re\nfrom collections import Counter\nimport random\ncnt = Counter()\n\nf_ci = \"./data/ci.txt\"\nf_cipai = \"./data/cipai.txt\"\ncipai = Counter()\nwith open(f_cipai) as f:\n    for line in f:\n        line = line.strip()\n        fs = line.split()\n        cipai.update(fs)\n\ncipai = cipai.keys()\n\ndocs = {}\nwith open(f_ci) as f:\n    for line in f:\n        line = line.strip()\n        fs = line.split(\"<s1>\")\n        author = fs[0]\n        topic, content = fs[1].split(\"<s2>\")\n        if \"・\" in topic:\n            t1, t2 = topic.split(\"・\")\n            if t1 == t2:\n                topic = t1\n            else:\n                if t1 in cipai:\n                    topic = t1\n                elif t2 in cipai:\n                    topic = t2\n                else:\n                    topic = t1\n        content = content.replace(\"、\", \"，\")\n        sents = content.split(\"</s>\")\n        ws = [w for w in author + topic + ''.join(sents)]\n        cnt.update(ws)\n        if topic not in docs:\n            docs[topic] = []\n        docs[topic].append(author + \"<s1>\" + topic + \"<s2>\" + '</s>'.join(sents))\n\n\ntopics = list(docs.keys())\n\nprint(len(topics))\nrandom.shuffle(topics)\n\ntopics_train = topics[:len(topics)-50]\ntopics_dev_test = topics[-50:]\ntopics_dev = topics_dev_test[:25]\ntopics_test = topics_dev_test[-25:]\n\ndocs_train = []\ndocs_dev = []\ndocs_test = []\n\nfor t in topics_train:\n    docs_train.extend(docs[t])\n\nfor t in topics_dev:\n    docs_dev.extend(docs[t])\n\nfor t in topics_test:\n    docs_test.extend(docs[t])\n\nrandom.shuffle(docs_train)\nrandom.shuffle(docs_dev)\nrandom.shuffle(docs_test)\n\nprint(len(docs_train), len(docs_dev), len(docs_test))\ntrain_cps = []\ndev_cps = []\ntest_cps = []\n\n\nwith open('./data/train.txt', 'w', encoding ='utf8') as f:\n    for x in docs_train:\n        s = x.split(\"<s2>\")[0]\n        train_cps.append(s.split(\"<s1>\")[1])\n        f.write(x + '\\n')\n    print(len(set(train_cps)))\nwith open('./data/dev.txt', 'w', encoding ='utf8') as f:\n    for x in docs_dev:\n        s = x.split(\"<s2>\")[0]\n        dev_cps.append(s.split(\"<s1>\")[1])\n        f.write(x + '\\n')\n    print(len(set(dev_cps)))\nwith open('./data/test.txt', 'w', encoding ='utf8') as f:\n    for x in docs_test:\n        s = x.split(\"<s2>\")[0]\n        test_cps.append(s.split(\"<s1>\")[1])\n        f.write(x + '\\n')\n    print(len(set(test_cps)))\n\nprint(\"vocab\")\nwith open('./data/vocab.txt', 'w', encoding ='utf8') as f:\n    for x, y in cnt.most_common():\n        f.write(x + '\\t' + str(y) + '\\n')\nprint(\"done\")\n"
  },
  {
    "path": "test.py",
    "content": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport time\n\nfrom biglm import BIGLM\nfrom data import Vocab, DataLoader, s2t, s2xy\n\n\n\ndef init_seeds():\n    random.seed(123)\n    torch.manual_seed(123)\n    if torch.cuda.is_available():\n        torch.cuda.manual_seed_all(123)\n\n#init_seeds()\n\ngpu = 1\ndef init_model(m_path, device, vocab):\n    ckpt= torch.load(m_path, map_location='cpu')\n    lm_args = ckpt['args']\n    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])\n    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)\n    lm_model.load_state_dict(ckpt['model'])\n    lm_model = lm_model.cuda(device)\n    lm_model.eval()\n    return lm_model, lm_vocab, lm_args\n\nm_path = \"./model/songci.ckpt\"\nlm_model, lm_vocab, lm_args = init_model(m_path, gpu, \"./model/vocab.txt\")\n\n\nk = 32\ndef top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):\n    start = time.time()\n    incremental_state = None\n    inp_y, m = s2t(s, lm_vocab)\n    inp_y = inp_y.cuda(gpu)\n    res = []\n    for l in range(inp_ys_tpl.size(0)):\n        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \\\n                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\\\n                                         incremental_state)\n        next_tk = []\n        for i in range(len(s)):\n            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())\n            if ctk != \"<c1>\" and ctk != \"<c2>\" and ctk != \"<c0>\":\n                next_tk.append(ctk)\n                continue\n            \n            if l == 0:\n                logits = probs[len(s[i]) - 1, i]\n            else:\n                logits = probs[0, i]\n            ps, idx = torch.topk(logits, k=k)\n            ps = ps / torch.sum(ps)\n            sampled = torch.multinomial(ps, num_samples = 1)\n            sampled_idx = idx[sampled]\n            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))\n        \n        s_ = []\n        bidx = [1] * len(s)\n        for idx, (sent, t) in enumerate(zip(s, next_tk)):\n            if t == \"<eos>\":\n                res.append(sent)\n                bidx[idx] = 0\n            else:\n                s_.append(sent + [t])\n        if not s_:\n            break\n        s = s_\n        inp_y, m = s2t(s, lm_vocab)\n        inp_y = inp_y.cuda(gpu)\n        bidx = torch.BoolTensor(bidx).cuda(gpu)\n        incremental_state[\"bidx\"] = bidx\n    res += s_\n        \n    #for i in res:\n    #    print(''.join(i))\n    print(time.time()-start)\n    return res\n\ndef top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):\n    inp_y, m = s2t(s, lm_vocab)\n    inp_y = inp_y.cuda(gpu)\n\n    start = time.time()\n    res = []\n    for l in range(inp_ys_tpl.size(0)):\n        probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])\n        next_tk = []\n        for i in range(len(s)):\n            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())\n            if ctk != \"<c1>\" and ctk != \"<c2>\" and ctk != \"<c0>\":\n                next_tk.append(ctk)\n                continue\n            logits = probs[len(s[i]) - 1, i]\n            ps, idx = torch.topk(logits, k=k)\n            ps = ps / torch.sum(ps)\n            sampled = torch.multinomial(ps, num_samples = 1)\n            sampled_idx = idx[sampled]\n            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))\n        \n        s_ = []\n        for sent, t in zip(s, next_tk):\n            if t == \"<eos>\":\n                res.append(sent)\n            else:\n                s_.append(sent + [t])\n        if not s_:\n            break\n        s = s_\n        inp_y, m = s2t(s, lm_vocab)\n        inp_y = inp_y.cuda(gpu)\n\n    res += s_\n        \n    #for i in res:\n    #    print(''.join(i))\n\n    #print(time.time()-start)\n    return res\n \n    \ndef greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):\n    start = time.time()\n    incremental_state = None\n    inp_y, m = s2t(s, lm_vocab)\n    inp_y = inp_y.cuda(gpu)\n    res = []\n    for l in range(inp_ys_tpl.size(0)):\n        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \\\n                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\\\n                                         incremental_state)\n        next_tk = []\n        for i in range(len(s)):\n            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())\n            if ctk != \"<c1>\" and ctk != \"<c2>\" and ctk != \"<c0>\":\n                next_tk.append(ctk)\n                continue\n            \n            if l == 0:\n                pred = pred[len(s[i]) - 1, i]\n            else:\n                pred = pred[0, i]\n            next_tk.append(lm_vocab.idx2token(pred.item()))\n        \n        s_ = []\n        bidx = [1] * len(s)\n        for idx, (sent, t) in enumerate(zip(s, next_tk)):\n            if t == \"<eos>\":\n                res.append(sent)\n                bidx[idx] = 0\n            else:\n                s_.append(sent + [t])\n        if not s_:\n            break\n        s = s_\n        inp_y, m = s2t(s, lm_vocab)\n        inp_y = inp_y.cuda(gpu)\n        bidx = torch.BoolTensor(bidx).cuda(gpu)\n        incremental_state[\"bidx\"] = bidx\n    res += s_\n        \n    #for i in res:\n    #    print(''.join(i))\n    print(time.time()-start)\n    return res\n\n\ndef beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos):\n    beam_size = 5\n    \n    num_live = 1\n    num_dead = 0 \n    samples = []\n    sample_scores = np.zeros(beam_size)\n\n    last_traces = [[]]\n    last_scores = torch.FloatTensor(np.zeros(1)).to(gpu)\n\n    x = x.to(gpu)\n    ys = x\n\n    for l in range(inp_ys_tpl.size(0)):\n        seq_len, bsz = ys.size()\n        enc_ = enc.repeat(1, bsz, 1)\n        src_padding_mask_ = src_padding_mask.repeat(1, bsz)\n        inp_ys_tpl_ = inp_ys_tpl.repeat(1, bsz)\n        inp_ys_seg_ = inp_ys_seg.repeat(1, bsz)\n        inp_ys_pos_ = inp_ys_pos.repeat(1, bsz)\n\n        y_pred, _ = lm_model.work(enc_, src_padding_mask_, ys, inp_ys_tpl_[0:l+1,:], inp_ys_seg_[0:l+1,:], inp_ys_pos_[0:l+1,:])\n\n        dict_size = y_pred.shape[-1]\n        y_pred = y_pred[-1, :, :] \n\n        cand_y_scores = last_scores + torch.log(y_pred) # larger is better\n        cand_scores = cand_y_scores.flatten()\n        idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1]\n        \n        '''\n        ps, idx_top_joint_scores = torch.topk(cand_scores, 100)\n        ps = F.softmax(ps)\n        sampled = torch.multinomial(ps, num_samples = beam_size - num_dead)\n        idx_top_joint_scores = idx_top_joint_scores[sampled]\n        '''\n\n        idx_last_traces = idx_top_joint_scores / dict_size\n        idx_word_now = idx_top_joint_scores % dict_size\n        top_joint_scores = cand_scores[idx_top_joint_scores]\n\n        traces_now = []\n        scores_now = np.zeros((beam_size - num_dead))\n        ys_now = []\n        for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)):\n            traces_now.append(last_traces[j] + [k])\n            scores_now[i] = copy.copy(top_joint_scores[i])\n            ys_now.append(copy.copy(ys[:,j]))\n\n\n        num_live = 0  \n        last_traces = []\n        last_scores = []\n        ys = []\n        for i in range(len(traces_now)):\n            w = lm_vocab.idx2token(traces_now[i][-1].item())\n            if w == \"<eos>\":\n                samples.append([str(e.item()) for e in traces_now[i][:-1]])\n                sample_scores[num_dead] = scores_now[i] \n                num_dead += 1\n            else:\n                last_traces.append(traces_now[i])\n                last_scores.append(scores_now[i])\n                ys.append(ys_now[i])\n                num_live += 1\n        \n        if num_live == 0 or num_dead >= beam_size:\n            break\n        ys = torch.stack(ys, dim = 1) \n\n        last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(gpu)\n        next_y = []\n        for e in last_traces:\n            eid = e[-1].item()\n            next_y.append(eid)\n        next_y = np.array(next_y).reshape((1, num_live))\n        next_y = torch.LongTensor(next_y).to(gpu)\n        \n        ys = torch.cat([ys, next_y], dim=0)\n       \n        assert num_live + num_dead == beam_size \n        # end for loop\n\n    if num_live > 0:\n        for i in range(num_live):\n            samples.append([str(e.item()) for e in last_traces[i]])\n            sample_scores[num_dead] = last_scores[i]\n            num_dead += 1  \n\n    idx_sorted_scores = np.argsort(sample_scores) # ascending order\n\n    sorted_samples = []\n    sorted_scores = []\n    filter_idx = []\n    for e in idx_sorted_scores:\n        if len(samples[e]) > 0:\n            filter_idx.append(e)\n    if len(filter_idx) == 0:\n        filter_idx = idx_sorted_scores\n    for e in filter_idx:\n        sorted_samples.append(samples[e])\n        sorted_scores.append(sample_scores[e])\n\n    res = []\n    dec_words = []\n    for sample in sorted_samples[::-1]:\n        for e in sample:\n            e = int(e)\n            dec_words.append(lm_vocab.idx2token(e))\n        #r = ''.join(dec_words)\n        #print(r)\n        res.append(dec_words)\n        dec_words = []\n\n    return res\n\n\ndef beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):\n    x, m = s2t(s, lm_vocab)\n    return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos)\n\n\nds = []\nwith open(\"./data/test.txt\", \"r\") as f:\n    for line in f:\n        line = line.strip()\n        if line:\n            ds.append(line)\nprint(len(ds))\n\nlocal_rank = gpu\nbatch_size = 1\ncp_size = 1\nbatches = round(len(ds) / batch_size)\n\nfor i in range(5): \n    idx = 0\n    fo = open(\"./results/top-\"+str(k)+\"/out\"+str(i+1)+\".txt\", \"w\")\n    while idx < len(ds):\n        lb = ds[idx:idx + batch_size]\n        cplb = []\n        for line in lb:\n            cplb += [line for i in range(cp_size)]\n        print(cplb) \n        xs_tpl, xs_seg, xs_pos, \\\n        ys_truth, ys_inp, \\\n        ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)\n\n        xs_tpl = xs_tpl.cuda(local_rank)\n        xs_seg = xs_seg.cuda(local_rank)\n        xs_pos = xs_pos.cuda(local_rank)\n        ys_tpl = ys_tpl.cuda(local_rank)\n        ys_seg = ys_seg.cuda(local_rank)\n        ys_pos = ys_pos.cuda(local_rank)\n\n        enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)\n        s = [['<bos>']] * batch_size * cp_size   \n        res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)\n\n        for i, line in enumerate(cplb):\n            r = ''.join(res[i])\n            print(line)\n            print(r)\n    \n            fo.write(line + \"\\t\" + r + \"\\n\")\n    \n        idx += batch_size\n    \n    fo.close()\n"
  },
  {
    "path": "test.sh",
    "content": "python3 -u test.py\n"
  },
  {
    "path": "train.py",
    "content": "# coding=utf-8\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.multiprocessing as mp\n\nfrom biglm import BIGLM\nfrom data import Vocab, DataLoader, s2xy\nfrom optim import Optim\n\nimport argparse, os\nimport random\n\ndef parse_config():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--embed_dim', type=int)\n    parser.add_argument('--ff_embed_dim', type=int)\n    parser.add_argument('--num_heads', type=int)\n    parser.add_argument('--layers', type=int)\n    parser.add_argument('--dropout', type=float)\n\n    parser.add_argument('--train_data', type=str)\n    parser.add_argument('--dev_data', type=str)\n    parser.add_argument('--vocab', type=str)\n    parser.add_argument('--min_occur_cnt', type=int)\n    parser.add_argument('--batch_size', type=int)\n    parser.add_argument('--warmup_steps', type=int)\n    parser.add_argument('--lr', type=float)\n    parser.add_argument('--smoothing', type=float)\n    parser.add_argument('--weight_decay', type=float)\n    parser.add_argument('--max_len', type=int)\n    parser.add_argument('--min_len', type=int)\n    parser.add_argument('--print_every', type=int)\n    parser.add_argument('--save_every', type=int)\n    parser.add_argument('--start_from', type=str, default=None)\n    parser.add_argument('--save_dir', type=str)\n\n    parser.add_argument('--world_size', type=int)\n    parser.add_argument('--gpus', type=int)\n    parser.add_argument('--MASTER_ADDR', type=str)\n    parser.add_argument('--MASTER_PORT', type=str)\n    parser.add_argument('--start_rank', type=int)\n    parser.add_argument('--backend', type=str)\n\n    return parser.parse_args()\n\ndef update_lr(optimizer, lr):\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n \ndef average_gradients(model):\n    \"\"\" Gradient averaging. \"\"\"\n    normal = True\n    size = float(dist.get_world_size())\n    for param in model.parameters():\n        if param.grad is not None:\n            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)\n            param.grad.data /= size\n        else:\n            normal = False\n            break\n    return normal\n\ndef eval_epoch(lm_args, model, lm_vocab, local_rank, label):\n    print(\"validating...\", flush=True)\n    ds = []\n    with open(lm_args.dev_data, \"r\") as f:\n        for line in f:\n            line = line.strip()\n            if line:\n                ds.append(line)\n\n    batch_size = 10\n    batches = round(len(ds) / batch_size)\n    idx = 0\n    avg_nll = 0.\n    avg_ppl = 0.\n    count = 0.\n    while idx < len(ds):\n        cplb = ds[idx:idx + batch_size]\n        xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, lm_args.min_len)\n\n        xs_tpl = xs_tpl.cuda(local_rank)\n        xs_seg = xs_seg.cuda(local_rank)\n        xs_pos = xs_pos.cuda(local_rank)\n        ys_truth = ys_truth.cuda(local_rank)\n        ys_inp = ys_inp.cuda(local_rank)\n        ys_tpl = ys_tpl.cuda(local_rank)\n        ys_seg = ys_seg.cuda(local_rank)\n        ys_pos = ys_pos.cuda(local_rank)\n        msk = msk.cuda(local_rank)\n\n        nll, ppl, bsz = model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)\n    \n        avg_nll += nll\n        avg_ppl += ppl\n        count += bsz\n\n        idx += batch_size\n    \n    print(label, \"nll=\", avg_nll/count, \"ppl=\", avg_ppl/count, \"count=\", count, flush=True)\n\ndef run(args, local_rank):\n    \"\"\" Distributed Synchronous \"\"\"\n    torch.manual_seed(1234)\n    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])\n    if (args.world_size == 1 or dist.get_rank() == 0):\n        print (\"vocab.size = \" + str(vocab.size), flush=True)\n    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\\\n                  args.num_heads, args.dropout, args.layers, args.smoothing)\n    if args.start_from is not None:\n        ckpt = torch.load(args.start_from, map_location='cpu')\n        model.load_state_dict(ckpt['model'])\n    model = model.cuda(local_rank)\n   \n    optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))\n\n    if args.start_from is not None:\n        optimizer.load_state_dict(ckpt['optimizer'])\n\n    train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len)\n    batch_acm = 0\n    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.\n    while True:\n        model.train()\n        if train_data.epoch_id > 30:\n            break\n        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:\n            batch_acm += 1\n            xs_tpl = xs_tpl.cuda(local_rank)\n            xs_seg = xs_seg.cuda(local_rank)\n            xs_pos = xs_pos.cuda(local_rank)\n            ys_truth = ys_truth.cuda(local_rank)\n            ys_inp = ys_inp.cuda(local_rank)\n            ys_tpl = ys_tpl.cuda(local_rank)\n            ys_seg = ys_seg.cuda(local_rank)\n            ys_pos = ys_pos.cuda(local_rank)\n            msk = msk.cuda(local_rank)\n\n            model.zero_grad()\n            res, loss, acc, nll, ppl, ntokens, npairs = model(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)\n            loss_acm += loss.item()\n            acc_acm += acc\n            nll_acm += nll\n            ppl_acm += ppl\n            ntokens_acm += ntokens\n            npairs_acm += npairs\n            nxs += npairs\n            \n            loss.backward()\n            if args.world_size > 1:\n                is_normal = average_gradients(model)\n            else:\n                is_normal = True\n            if is_normal:\n                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n                optimizer.step()\n            else:\n                print(\"gradient: none, gpu: \" + str(local_rank), flush=True)\n                continue\n            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:\n                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\\\n                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \\\n                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)\n                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.\n            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:\n                if not os.path.exists(args.save_dir):\n                    os.mkdir(args.save_dir)\n                \n                model.eval()\n                eval_epoch(args, model, vocab, local_rank, \"epoch-\" + str(train_data.epoch_id) + \"-acm-\" + str(batch_acm))\n                model.train()\n\n                torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))\n\ndef init_processes(args, local_rank, fn, backend='nccl'):\n    \"\"\" Initialize the distributed environment. \"\"\"\n    os.environ['MASTER_ADDR'] = args.MASTER_ADDR\n    os.environ['MASTER_PORT'] = args.MASTER_PORT\n    dist.init_process_group(backend, rank=args.start_rank + local_rank, world_size=args.world_size)\n    fn(args, local_rank)\n\nif __name__ == \"__main__\":\n    mp.set_start_method('spawn')\n    args = parse_config()\n\n    if args.world_size == 1:\n        run(args, 0)\n        exit(0)\n    processes = []\n    for rank in range(args.gpus):\n        p = mp.Process(target=init_processes, args=(args, rank, run, args.backend))\n        p.start()\n        processes.append(p)\n\n    for p in processes:\n        p.join()\n"
  },
  {
    "path": "train.sh",
    "content": "CUDA_VISIBLE_DEVICES=1 \\\npython3 -u train.py --embed_dim 768 \\\n                      --ff_embed_dim 3072 \\\n                      --num_heads 12 \\\n                      --layers 12 \\\n                      --dropout 0.2 \\\n                      --train_data ./data/train.txt \\\n                      --dev_data ./data/dev.txt \\\n                      --vocab ./data/vocab.txt \\\n                      --min_occur_cnt 1 \\\n                      --batch_size 32 \\\n                      --warmup_steps 8000 \\\n                      --lr 0.5 \\\n                      --weight_decay 0 \\\n                      --smoothing 0.1 \\\n                      --max_len 300 \\\n                      --min_len 10 \\\n                      --world_size 1 \\\n                      --gpus 1 \\\n                      --start_rank 0 \\\n                      --MASTER_ADDR localhost \\\n                      --MASTER_PORT 28512 \\\n                      --print_every 100 \\\n                      --save_every 1000 \\\n                      --save_dir ckpt \\\n                      --backend nccl\n"
  },
  {
    "path": "transformer.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom utils import gelu, LayerNorm, get_incremental_state, set_incremental_state\nimport math\n\nclass TransformerLayer(nn.Module):\n    \n    def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True):\n        super(TransformerLayer, self).__init__()\n        self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)\n        self.fc1 = nn.Linear(embed_dim, ff_embed_dim)\n        self.fc2 = nn.Linear(ff_embed_dim, embed_dim)\n        self.attn_layer_norm = LayerNorm(embed_dim)\n        self.ff_layer_norm = LayerNorm(embed_dim)\n        self.with_external = with_external\n        self.dropout = dropout\n        if self.with_external:\n            self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)\n            self.external_layer_norm = LayerNorm(embed_dim)\n        self.reset_parameters()\n    \n    def reset_parameters(self):\n        nn.init.normal_(self.fc1.weight, std=0.02)\n        nn.init.normal_(self.fc2.weight, std=0.02)\n        nn.init.constant_(self.fc1.bias, 0.)\n        nn.init.constant_(self.fc2.bias, 0.)\n\n    def forward(self, x, kv = None,\n                self_padding_mask = None, self_attn_mask = None,\n                external_memories = None, external_padding_mask=None,\n                need_weights = False):\n        # x: seq_len x bsz x embed_dim\n        residual = x\n        if kv is None:\n            x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)\n        else:\n            x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)\n\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.attn_layer_norm(residual + x)\n\n        if self.with_external:\n            residual = x\n            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)\n            x = F.dropout(x, p=self.dropout, training=self.training)\n            x = self.external_layer_norm(residual + x)\n        else:\n            external_attn = None\n\n        residual = x\n        x = gelu(self.fc1(x))\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.fc2(x)\n        x = F.dropout(x, p=self.dropout, training=self.training)\n        x = self.ff_layer_norm(residual + x)\n\n        return x, self_attn, external_attn\n    \n    def work_incremental(self, x, self_padding_mask = None, self_attn_mask = None,\n                         external_memories = None, external_padding_mask = None, incremental_state = None):\n        # x: seq_len x bsz x embed_dim\n        residual = x\n        x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state)\n        x = self.attn_layer_norm(residual + x)\n\n        if self.with_external:\n            residual = x\n            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask)\n            x = self.external_layer_norm(residual + x)\n        else:\n            external_attn = None\n        residual = x\n        x = gelu(self.fc1(x))\n        x = self.fc2(x)\n        x = self.ff_layer_norm(residual + x)\n\n        return x, self_attn, external_attn\n\nclass MultiheadAttention(nn.Module):\n\n    def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True):\n        super(MultiheadAttention, self).__init__()\n        self.embed_dim = embed_dim\n        self.num_heads = num_heads\n        self.dropout = dropout\n        self.head_dim = embed_dim // num_heads\n        assert self.head_dim * num_heads == self.embed_dim, \"embed_dim must be divisible by num_heads\"\n        self.scaling = self.head_dim ** -0.5\n\n        self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))\n        self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))\n\n        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)\n        self.weights_dropout = weights_dropout\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.normal_(self.in_proj_weight, std=0.02)\n        nn.init.normal_(self.out_proj.weight, std=0.02)\n        nn.init.constant_(self.in_proj_bias, 0.)\n        nn.init.constant_(self.out_proj.bias, 0.)\n\n    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, incremental_state = None):\n        \"\"\" Input shape: Time x Batch x Channel\n            key_padding_mask: Time x batch\n            attn_mask:  tgt_len x src_len\n        \"\"\"\n        if incremental_state is not None: \n            saved_state = self._get_input_buffer(incremental_state)\n            bidx = self._get_bidx(incremental_state)\n        else:\n            saved_state = None\n            bidx = None\n    \n        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()\n        kv_same = key.data_ptr() == value.data_ptr()\n\n        tgt_len, bsz, embed_dim = query.size()\n        assert key.size() == value.size()\n\n        if qkv_same:\n            # self-attention\n            q, k, v = self.in_proj_qkv(query)\n        elif kv_same:\n            # encoder-decoder attention\n            q = self.in_proj_q(query)\n            k, v = self.in_proj_kv(key)\n        else:\n            q = self.in_proj_q(query)\n            k = self.in_proj_k(key)\n            v = self.in_proj_v(value)\n        q = q * self.scaling\n        \n        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)\n        k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)\n        v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)\n\n        if saved_state is not None:\n            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)\n            if 'prev_key' in saved_state:\n                prev_key = saved_state['prev_key']\n                if bidx is not None:\n                    prev_key = prev_key[bidx]\n                prev_key = prev_key.contiguous().view(bsz * self.num_heads, -1, self.head_dim)\n                k = torch.cat((prev_key, k), dim=1)\n            if 'prev_value' in saved_state:\n                prev_value = saved_state['prev_value']\n                if bidx is not None:\n                    prev_value = prev_value[bidx]\n                prev_value = prev_value.contiguous().view(bsz * self.num_heads, -1, self.head_dim)\n                v = torch.cat((prev_value, v), dim=1)\n            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)\n            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)\n            self._set_input_buffer(incremental_state, saved_state)\t\n        \n        src_len = k.size(1)\n        # k,v: bsz*heads x src_len x dim\n        # q: bsz*heads x tgt_len x dim \n\n        attn_weights = torch.bmm(q, k.transpose(1, 2))\n        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]\n\n        if attn_mask is not None:\n            attn_weights.masked_fill_(\n                attn_mask.unsqueeze(0),\n                float('-inf')\n            )\n\n        if key_padding_mask is not None:\n            # don't attend to padding symbols\n            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n            attn_weights.masked_fill_(\n                key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2),\n                float('-inf')\n            )\n            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)\n\n        \n        attn_weights = F.softmax(attn_weights, dim=-1)\n        \n        if self.weights_dropout:\n            attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)\n\n        attn = torch.bmm(attn_weights, v)\n        if not self.weights_dropout:\n            attn = F.dropout(attn, p=self.dropout, training=self.training)\n\n        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]\n\n        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)\n        attn = self.out_proj(attn)\n        if need_weights:\n            # maximum attention weight over heads \n            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)\n            \n            attn_weights, _ = attn_weights.max(dim=1)\n            attn_weights = attn_weights.transpose(0, 1)\n        else:\n            attn_weights = None\n\n        return attn, attn_weights\n\n    def in_proj_qkv(self, query):\n        return self._in_proj(query).chunk(3, dim=-1)\n\n    def in_proj_kv(self, key):\n        return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)\n\n    def in_proj_q(self, query):\n        return self._in_proj(query, end=self.embed_dim)\n\n    def in_proj_k(self, key):\n        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)\n\n    def in_proj_v(self, value):\n        return self._in_proj(value, start=2 * self.embed_dim)\n\n    def _in_proj(self, input, start=0, end=None):\n        weight = self.in_proj_weight\n        bias = self.in_proj_bias\n        weight = weight[start:end, :]\n        if bias is not None:\n            bias = bias[start:end]\n        return F.linear(input, weight, bias)\n\n    def _get_input_buffer(self, incremental_state):\n       return get_incremental_state(\n                self,\n                incremental_state,\n                'attn_state',\n                ) or {}\n\n    def _set_input_buffer(self, incremental_state, buffer):\n        set_incremental_state(\n                self,\n                incremental_state,\n                'attn_state',\n                buffer,)\n\n    def _get_bidx(self, incremental_state):\n        if \"bidx\" in incremental_state:\n            return incremental_state[\"bidx\"]\n        else:\n            return None\n\ndef Embedding(num_embeddings, embedding_dim, padding_idx):\n    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)\n    nn.init.normal_(m.weight, std=0.02)\n    nn.init.constant_(m.weight[padding_idx], 0)\n    return m\n\nclass SelfAttentionMask(nn.Module):\n    def __init__(self, init_size = 100, device = 0):\n        super(SelfAttentionMask, self).__init__()\n        self.weights = SelfAttentionMask.get_mask(init_size)\n        self.device = device\n    \n    @staticmethod\n    def get_mask(size):\n        weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1)\n        return weights\n\n    def forward(self, size):\n        if self.weights is None or size > self.weights.size(0):\n            self.weights = SelfAttentionMask.get_mask(size)\n        res = self.weights[:size,:size].cuda(self.device).detach()\n        return res\n\nclass LearnedPositionalEmbedding(nn.Module):\n    \"\"\"This module produces LearnedPositionalEmbedding.\n    \"\"\"\n    def __init__(self, embedding_dim, init_size=1024, device=0):\n        super(LearnedPositionalEmbedding, self).__init__()\n        self.weights = nn.Embedding(init_size, embedding_dim)\n        self.device= device\n        self.reset_parameters()\n    \n    def reset_parameters(self):\n        nn.init.normal_(self.weights.weight, std=0.02)\n\n    def forward(self, input, offset=0):\n        \"\"\"Input is expected to be of size [seq_len x bsz].\"\"\"\n        seq_len, bsz = input.size()\n        positions = (offset + torch.arange(seq_len)).cuda(self.device)\n        res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1)\n        return res\n\nclass SinusoidalPositionalEmbedding(nn.Module):\n    \"\"\"This module produces sinusoidal positional embeddings of any length.\n    \"\"\"\n    def __init__(self, embedding_dim, init_size=1024, device=0):\n        super(SinusoidalPositionalEmbedding, self).__init__()\n        self.embedding_dim = embedding_dim\n        self.weights = SinusoidalPositionalEmbedding.get_embedding(\n            init_size,\n            embedding_dim\n        )\n        self.device= device\n\n    @staticmethod\n    def get_embedding(num_embeddings, embedding_dim):\n        \"\"\"Build sinusoidal embeddings.\n        This matches the implementation in tensor2tensor, but differs slightly\n        from the description in Section 3.5 of \"Attention Is All You Need\".\n        \"\"\"\n        half_dim = embedding_dim // 2\n        emb = math.log(10000) / (half_dim - 1)\n        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)\n        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)\n        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)\n        if embedding_dim % 2 == 1:\n            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)\n        return emb\n\n    def forward(self, input, offset=0):\n        \"\"\"Input is expected to be of size [seq_len x bsz].\"\"\"\n        seq_len, bsz = input.size()\n        mx_position = seq_len + offset\n        if self.weights is None or mx_position > self.weights.size(0):\n            # recompute/expand embeddings if needed\n            self.weights = SinusoidalPositionalEmbedding.get_embedding(\n                mx_position,\n                self.embedding_dim,\n            )\n\n        positions = offset + torch.arange(seq_len)\n        res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach()\n        return res\n"
  },
  {
    "path": "utils.py",
    "content": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nfrom collections import defaultdict\n\nimport math\n\ndef gelu(x):\n    cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))\n    return cdf*x\n\nclass LayerNorm(nn.Module):\n    def __init__(self, hidden_size, eps=1e-12):\n        super(LayerNorm, self).__init__()\n        self.weight = nn.Parameter(torch.Tensor(hidden_size))\n        self.bias = nn.Parameter(torch.Tensor(hidden_size))\n        self.eps = eps\n        self.reset_parameters()\n    def reset_parameters(self):\n        nn.init.constant_(self.weight, 1.)\n        nn.init.constant_(self.bias, 0.)\n\n    def forward(self, x):\n        u = x.mean(-1, keepdim=True)\n        s = (x - u).pow(2).mean(-1, keepdim=True)\n        x = (x - u) / torch.sqrt(s + self.eps)\n        return self.weight * x + self.bias\n\n\nINCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)\n\ndef _get_full_incremental_state_key(module_instance, key):\n    module_name = module_instance.__class__.__name__\n\n    # assign a unique ID to each module instance, so that incremental state is\n    # not shared across module instances\n    if not hasattr(module_instance, '_guyu_instance_id'):\n        INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1\n        module_instance._guyu_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]\n\n    return '{}.{}.{}'.format(module_name, module_instance._guyu_instance_id, key)\n\ndef get_incremental_state(module, incremental_state, key):\n    \"\"\"Helper for getting incremental state for an nn.Module.\"\"\"\n    full_key = _get_full_incremental_state_key(module, key)\n    if incremental_state is None or full_key not in incremental_state:\n        return None\n    return incremental_state[full_key]\n\ndef set_incremental_state(module, incremental_state, key, value):\n    \"\"\"Helper for setting incremental state for an nn.Module.\"\"\"\n    if incremental_state is not None:\n        full_key = _get_full_incremental_state_key(module, key)\n        incremental_state[full_key] = value\n\n\n"
  }
]