Full Code of lipiji/SongNet for AI

master f8c7064c21c7 cached
20 files
75.1 KB
21.1k tokens
103 symbols
1 requests
Download .txt
Repository: lipiji/SongNet
Branch: master
Commit: f8c7064c21c7
Files: 20
Total size: 75.1 KB

Directory structure:
gitextract_oqyy3n7z/

├── .gitignore
├── LICENSE
├── README.md
├── adam.py
├── biglm.py
├── data.py
├── eval.py
├── eval.sh
├── label_smoothing.py
├── metrics.py
├── optim.py
├── polish.py
├── polish.sh
├── prepare_data.py
├── test.py
├── test.sh
├── train.py
├── train.sh
├── transformer.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
*.log
ckpt
/data*/*
/model*/*
/ckpt*/*
/result*/*


================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2021 Piji Li

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# SongNet
SongNet: SongCi + Song (Lyrics) + Sonnet + etc.

```
@inproceedings{li-etal-2020-rigid,
    title = "Rigid Formats Controlled Text Generation",
    author = "Li, Piji and Zhang, Haisong and Liu, Xiaojiang and Shi, Shuming",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
    month = jul,
    year = "2020",
    address = "Online",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-main.68",
    doi = "10.18653/v1/2020.acl-main.68",
    pages = "742--751"
}
```

### Run
- python prepare_data.py
- ./train.sh 

### Evaluation
- Modify test.py: m_path = the best dev model
- ./test.sh
- python metrics.py

### Polish
- ./polish.sh

### Download
- The pretrained Chinese Language Model: https://drive.google.com/file/d/1g2tGyUwPe86vPn2nub1vkQva5lwtZ6Rd/view 
- The finetuned SongCi model: https://drive.google.com/file/d/16A2AzuU7slf7xj2QdLcBAorUCCaCk650/view

#### Reference

- Guyu: https://github.com/lipiji/Guyu
- Pretraining:https://github.com/lipiji/big_tpl_zh_10_base


================================================
FILE: adam.py
================================================
# coding=utf-8
import torch
from torch.optim import Optimizer

class AdamWeightDecayOptimizer(Optimizer):
    """A basic Adam optimizer that includes "correct" L2 weight decay.
    https://github.com/google-research/bert/blob/master/optimization.py
    https://raw.githubusercontent.com/pytorch/pytorch/v1.0.0/torch/optim/adam.py"""
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
                 weight_decay=0, amsgrad=False):
        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {}".format(eps))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
        defaults = dict(lr=lr, betas=betas, eps=eps,
                        weight_decay=weight_decay, amsgrad=amsgrad)
        super(AdamWeightDecayOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(AdamWeightDecayOptimizer, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('amsgrad', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
                amsgrad = group['amsgrad']

                state = self.state[p]

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = torch.zeros_like(p.data)
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = torch.zeros_like(p.data)
                    if amsgrad:
                        # Maintains max of all exp. moving avg. of sq. grad. values
                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                if amsgrad:
                    max_exp_avg_sq = state['max_exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                if amsgrad:
                    # Maintains the maximum of all 2nd moment running avg. till now
                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
                    # Use the max. for normalizing running avg. of gradient
                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
                else:
                    denom = exp_avg_sq.sqrt().add_(group['eps'])

                # Just adding the square of the weights to the loss function is *not*
                # the correct way of using L2 regularization/weight decay with Adam,
                # since that will interact with the m and v parameters in strange ways.
                #
                # Instead we want ot decay the weights in a manner that doesn't interact
                # with the m/v parameters. This is equivalent to adding the square
                # of the weights to the loss with plain (non-momentum) SGD.
                update = (exp_avg/denom).add_(group['weight_decay'], p.data)
                p.data.add_(-group['lr'], update)
        return loss

================================================
FILE: biglm.py
================================================
import torch
from torch import nn
import torch.nn.functional as F

from utils import gelu, LayerNorm
from transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, SelfAttentionMask
from label_smoothing import LabelSmoothing 

class BIGLM(nn.Module):
    def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx=None):
        super(BIGLM, self).__init__()
        self.vocab = vocab
        self.embed_dim = embed_dim

        self.tok_embed = Embedding(self.vocab.size, embed_dim, self.vocab.padding_idx)
        self.pos_embed = LearnedPositionalEmbedding(embed_dim, device=local_rank)
        
        self.layers = nn.ModuleList()
        for i in range(layers):
            self.layers.append(TransformerLayer(embed_dim, ff_embed_dim, num_heads, dropout, with_external=True))
        self.emb_layer_norm = LayerNorm(embed_dim)
        self.one_more = nn.Linear(embed_dim, embed_dim)
        self.one_more_layer_norm = LayerNorm(embed_dim)
        self.out_proj = nn.Linear(embed_dim, self.vocab.size)
        
        self.attn_mask = SelfAttentionMask(device=local_rank)
        self.smoothing = LabelSmoothing(local_rank, self.vocab.size, self.vocab.padding_idx, smoothing_factor)
       
        self.dropout = dropout
        self.device = local_rank

        self.approx = approx
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.constant_(self.one_more.bias, 0.)
        nn.init.normal_(self.one_more.weight, std=0.02)
        nn.init.constant_(self.out_proj.bias, 0.)
        nn.init.normal_(self.out_proj.weight, std=0.02)
    
    def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):
        seq_len, bsz = y.size()

        y_pred = torch.log(y_pred.clamp(min=1e-8))
        loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1))
        if avg:
            return loss / torch.sum(y_mask)
        else:
            return loss / bsz

    def nll_loss(self, y_pred, y, y_mask, avg=True):
        cost = -torch.log(torch.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1)))
        cost = cost.view(y.shape)
        y_mask = y_mask.view(y.shape)
        if avg:
            cost = torch.sum(cost * y_mask, 0) / torch.sum(y_mask, 0)
        else:
            cost = torch.sum(cost * y_mask, 0)
        cost = cost.view((y.size(1), -1))
        ppl = 2 ** cost
        return cost.sum().item(), ppl.sum().item()

    
    def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos, incremental_state=None):
        seq_len, bsz = ys_inp.size()
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None

        if incremental_state is None:
            self_attn_mask = self.attn_mask(seq_len)
            incremental_state = {}
        else:
            x = x[-1, :, :].unsqueeze(0)
            self_attn_mask = None

        for layer in self.layers:
            x, _ ,_ = layer.work_incremental(x, self_padding_mask=padding_mask, \
                                             self_attn_mask=self_attn_mask, \
                                             external_memories = enc, \
                                             external_padding_mask = src_padding_mask, \
                                             incremental_state = incremental_state)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        probs = torch.softmax(self.out_proj(x), -1)

        _, pred_y = probs.max(-1)
        return probs, pred_y, incremental_state
 
    def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
                               self_attn_mask = self_attn_mask, \
                               external_memories = enc, \
                               external_padding_mask = src_padding_mask,)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        probs = torch.softmax(self.out_proj(x), -1)
        
        _, pred_y = probs.max(-1)
        
        return probs, pred_y
    
    def encode(self, xs_tpl, xs_seg, xs_pos):
        padding_mask = torch.eq(xs_tpl, self.vocab.padding_idx)
        x = self.tok_embed(xs_tpl)  + self.tok_embed(xs_seg) + self.tok_embed(xs_pos)
        x = self.emb_layer_norm(x)
        return x, padding_mask
    
    def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):
        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
                               self_attn_mask = self_attn_mask, \
                               external_memories = enc, \
                               external_padding_mask = src_padding_mask,)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)
        nll, ppl = self.nll_loss(pred, ys_truth, msk)
        return nll, ppl, bsz
    
    def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):
        enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
        seq_len, bsz = ys_inp.size()
        self_attn_mask = self.attn_mask(seq_len)
        x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
        x = self.emb_layer_norm(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
        if not padding_mask.any():
            padding_mask = None
        for layer in self.layers:
            x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
                               self_attn_mask = self_attn_mask, \
                               external_memories = enc, \
                               external_padding_mask = src_padding_mask,)

        x = self.one_more_layer_norm(gelu(self.one_more(x)))
        pred = torch.softmax(self.out_proj(x), -1)

        loss = self.label_smotthing_loss(pred, ys_truth, msk)
        
        _, pred_y = pred.max(-1)
        tot_tokens = msk.float().sum().item()
        acc = (torch.eq(pred_y, ys_truth).float() * msk).sum().item()
       
        nll, ppl = self.nll_loss(pred, ys_truth, msk)
        return (pred_y, ys_truth), loss, acc, nll, ppl, tot_tokens, bsz


================================================
FILE: data.py
================================================
import random
import torch
import numpy as np

PAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'
BOC, EOC = '<boc>', '<eoc>'
LS, RS, SP = '<s>', '</s>', ' '
CS = ['<c-1>'] + ['<c' + str(i) + '>' for i in range(32)] # content
SS = ['<s-1>'] + ['<s' + str(i) + '>' for i in range(512)] # segment
PS = ['<p-1>'] + ['<p' + str(i) + '>' for i in range(512)] # position
TS = ['<t-1>'] + ['<t' + str(i) + '>' for i in range(32)] # other types

PUNCS = set([",", ".", "?", "!", ":", ",", "。", "?", "!", ":"])

BUFSIZE = 4096000

def ListsToTensor(xs, vocab=None):
    max_len = max(len(x) for x in xs)
    ys = []
    for x in xs:
        if vocab is not None:
            y = vocab.token2idx(x) + [vocab.padding_idx]*(max_len - len(x))
        else:
            y = x + [0]*(max_len - len(x))
        ys.append(y)
    return ys

def _back_to_text_for_check(x, vocab):
    w = x.t().tolist()
    for sent in vocab.idx2token(w):
        print (' '.join(sent))
    
def batchify(data, vocab):
    xs_tpl, xs_seg, xs_pos, \
    ys_truth, ys_inp, \
    ys_tpl, ys_seg, ys_pos, msk = [], [], [], [], [], [], [], [], []
    for xs_tpl_i, xs_seg_i, xs_pos_i, ys_i, ys_tpl_i, ys_seg_i, ys_pos_i in data:
        xs_tpl.append(xs_tpl_i)
        xs_seg.append(xs_seg_i)
        xs_pos.append(xs_pos_i)
        
        ys_truth.append(ys_i)
        ys_inp.append([BOS] + ys_i[:-1])
        ys_tpl.append(ys_tpl_i)
        ys_seg.append(ys_seg_i)
        ys_pos.append(ys_pos_i)
        
        msk.append([1 for i in range(len(ys_i))])

    xs_tpl = torch.LongTensor(ListsToTensor(xs_tpl, vocab)).t_().contiguous()
    xs_seg = torch.LongTensor(ListsToTensor(xs_seg, vocab)).t_().contiguous()
    xs_pos = torch.LongTensor(ListsToTensor(xs_pos, vocab)).t_().contiguous()
    ys_truth = torch.LongTensor(ListsToTensor(ys_truth, vocab)).t_().contiguous()
    ys_inp = torch.LongTensor(ListsToTensor(ys_inp, vocab)).t_().contiguous()
    ys_tpl = torch.LongTensor(ListsToTensor(ys_tpl, vocab)).t_().contiguous()
    ys_seg = torch.LongTensor(ListsToTensor(ys_seg, vocab)).t_().contiguous()
    ys_pos = torch.LongTensor(ListsToTensor(ys_pos, vocab)).t_().contiguous()
    msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()
    return xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk

def s2t(strs, vocab):
    inp, msk = [], []
    for x in strs:
        inp.append(x)
        msk.append([1 for i in range(len(x))])

    inp = torch.LongTensor(ListsToTensor(inp, vocab)).t_().contiguous()
    msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()
    return inp, msk

def s2xy(lines, vocab, max_len, min_len):
    data = []
    for line in lines:
        res = parse_line(line, max_len, min_len)
        if not res:
            continue
        data.append(res)
    return  batchify(data, vocab)

def parse_line(line, max_len, min_len):
    line = line.strip()
    if not line:
        return None
    fs = line.split("<s2>")
    author, cipai = fs[0].split("<s1>")
    sents = fs[1].strip()
    if len(sents) > max_len:
        sents = sents[:max_len]
    if len(sents) < min_len:
        return None
    sents = sents.split("</s>")

    ys = []
    xs_tpl = []
    xs_seg = []
    xs_pos = []

    ctx = cipai
    ws = [w for w in ctx]
    xs_tpl = ws + [EOC]
    xs_seg = [SS[0] for w in ws] + [EOC]
    xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]

    ys_tpl = []
    ys_seg = []
    ys_pos = []
    for si, sent in enumerate(sents):
        ws = []
        sent = sent.strip()
        if not sent:
            continue
        for w in sent:
            ws.append(w)
            if w.strip() and w not in PUNCS:
                ys_tpl.append(CS[2])
            else:
                ys_tpl.append(CS[1])
        ys += ws + [RS]
        if ws[-1] in PUNCS:
            ys_tpl[-2] = CS[3]
        else:
            ys_tpl[-1] = CS[3]
        ys_tpl += [RS]
        ys_seg += [SS[si + 1] for w in ws] + [RS]
        ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]

    ys += [EOS]
    ys_tpl += [EOS]
    ys_seg += [EOS]
    ys_pos += [EOS]

    xs_tpl += ys_tpl
    xs_seg += ys_seg
    xs_pos += ys_pos
    
    if len(ys) < min_len:
        return None
    return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos

def s2xy_polish(lines, vocab, max_len, min_len):
    data = []
    for line in lines:
        res = parse_line_polish(line, max_len, min_len)
        data.append(res)
    return  batchify(data, vocab)

def parse_line_polish(line, max_len, min_len):
    line = line.strip()
    if not line:
        return None
    fs = line.split("<s2>")
    author, cipai = fs[0].split("<s1>")
    sents = fs[1].strip()
    if len(sents) > max_len:
        sents = sents[:max_len]
    if len(sents) < min_len:
        return None
    sents = sents.split("</s>")

    ys = []
    xs_tpl = []
    xs_seg = []
    xs_pos = []

    ctx = cipai
    ws = [w for w in ctx]
    xs_tpl = ws + [EOC]
    xs_seg = [SS[0] for w in ws] + [EOC]
    xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]

    ys_tpl = []
    ys_seg = []
    ys_pos = []
    for si, sent in enumerate(sents):
        ws = []
        sent = sent.strip()
        if not sent:
            continue
        for w in sent:
            ws.append(w)
            if w == "_":
                ys_tpl.append(CS[2])
            else:
                ys_tpl.append(w)
        ys += ws + [RS]
        ys_tpl += [RS]
        ys_seg += [SS[si + 1] for w in ws] + [RS]
        ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]

    ys += [EOS]
    ys_tpl += [EOS]
    ys_seg += [EOS]
    ys_pos += [EOS]

    xs_tpl += ys_tpl
    xs_seg += ys_seg
    xs_pos += ys_pos
    
    if len(ys) < min_len:
        return None

    return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos

class DataLoader(object):
    def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):
        self.batch_size = batch_size
        self.vocab = vocab
        self.max_len_y = max_len_y
        self.min_len_y = min_len_y
        self.filename = filename
        self.stream = open(self.filename, encoding='utf8')
        self.epoch_id = 0

    def __iter__(self):
        
        lines = self.stream.readlines(BUFSIZE)

        if not lines:
            self.epoch_id += 1
            self.stream.close()
            self.stream = open(self.filename, encoding='utf8')
            lines = self.stream.readlines(BUFSIZE)

        data = []
        for line in lines[:-1]: # the last sent may be imcomplete
            res = parse_line(line, self.max_len_y, self.min_len_y)
            if not res:
                continue
            data.append(res)
        
        random.shuffle(data)
        
        idx = 0
        while idx < len(data):
            yield batchify(data[idx:idx+self.batch_size], self.vocab)
            idx += self.batch_size

class Vocab(object):
    def __init__(self, filename, min_occur_cnt, specials = None):
        idx2token = [PAD, UNK, BOS, EOS] + [BOC, EOC, LS, RS, SP] + CS + SS + PS + TS \
                    +  (specials if specials is not None else [])
        for line in open(filename, encoding='utf8').readlines():
            try: 
                token, cnt = line.strip().split()
            except:
                continue
            if int(cnt) >= min_occur_cnt:
                idx2token.append(token)
        self._token2idx = dict(zip(idx2token, range(len(idx2token))))
        self._idx2token = idx2token
        self._padding_idx = self._token2idx[PAD]
        self._unk_idx = self._token2idx[UNK]

    @property
    def size(self):
        return len(self._idx2token)
    
    @property
    def unk_idx(self):
        return self._unk_idx
    
    @property
    def padding_idx(self):
        return self._padding_idx
    
    def random_token(self):
        return self.idx2token(1 + np.random.randint(self.size-1))

    def idx2token(self, x):
        if isinstance(x, list):
            return [self.idx2token(i) for i in x]
        return self._idx2token[x]

    def token2idx(self, x):
        if isinstance(x, list):
            return [self.token2idx(i) for i in x]
        return self._token2idx.get(x, self.unk_idx)


================================================
FILE: eval.py
================================================
import sys
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy 
import time

from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy

gpu = int(sys.argv[2]) if len(sys.argv) > 1 else 0
def init_model(m_path, device, vocab):
    ckpt= torch.load(m_path, map_location='cpu')
    lm_args = ckpt['args']
    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1, lm_args.approx)
    lm_model.load_state_dict(ckpt['model'])
    lm_model = lm_model.cuda(device)
    lm_model.eval()
    return lm_model, lm_vocab, lm_args

#m_path = "./ckpt_d101_6/epoch5_batch_139999"
m_path = sys.argv[1] if len(sys.argv) > 1 else None
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./data/vocab.txt")


ds = []
with open("./data/dev.txt", "r") as f:
    for line in f:
        line = line.strip()
        if line:
            ds.append(line)
print(len(ds))

local_rank = gpu
batch_size = 10
batches = round(len(ds) / batch_size)
idx = 0

avg_nll = 0.
avg_ppl = 0.
count = 0.
while idx < len(ds):
    
    cplb = ds[idx:idx + batch_size]
    xs_tpl, xs_seg, xs_pos, \
    ys_truth, ys_inp, \
    ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)

    xs_tpl = xs_tpl.cuda(local_rank)
    xs_seg = xs_seg.cuda(local_rank)
    xs_pos = xs_pos.cuda(local_rank)
    ys_truth = ys_truth.cuda(local_rank)
    ys_inp = ys_inp.cuda(local_rank)
    ys_tpl = ys_tpl.cuda(local_rank)
    ys_seg = ys_seg.cuda(local_rank)
    ys_pos = ys_pos.cuda(local_rank)
    msk = msk.cuda(local_rank)

    nll, ppl, bsz = lm_model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
    
    avg_nll += nll
    avg_ppl += ppl
    count += bsz

    idx += batch_size
    if count % 200 == 0:
        print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count)
    
print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count)


================================================
FILE: eval.sh
================================================
#!/bin/bash
path=./ckpt/
FILES=$path/*
for f in $FILES; do
    echo "==========================" ${f##*/}
    python -u eval.py $path${f##*/} 1
done


================================================
FILE: label_smoothing.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class LabelSmoothing(nn.Module):
    "Implement label smoothing."
    def __init__(self, device, size, padding_idx, label_smoothing=0.0):
        super(LabelSmoothing, self).__init__()
        assert 0.0 < label_smoothing <= 1.0
        self.padding_idx = padding_idx
        self.size = size
        self.device = device

        self.smoothing_value = label_smoothing / (size - 2)
        self.one_hot = torch.full((1, size), self.smoothing_value).to(device)
        self.one_hot[0, self.padding_idx] = 0
        
        self.confidence = 1.0 - label_smoothing

    def forward(self, output, target):
        real_size = output.size(1)
        if real_size > self.size:
            real_size -= self.size
        else:
            real_size = 0

        model_prob = self.one_hot.repeat(target.size(0), 1)
        if real_size > 0:
            ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device)
            model_prob = torch.cat((model_prob, ext_zeros), -1)
        model_prob.scatter_(1, target, self.confidence)
        model_prob.masked_fill_((target == self.padding_idx), 0.)

        return F.kl_div(output, model_prob, reduction='sum')


================================================
FILE: metrics.py
================================================
import os
import sys
import numpy as np
from pypinyin import Style, lazy_pinyin

from data import PUNCS

yunjiaos = {
            "0":["a", "ia", "ua", "va", "üa"],
            "1":["e", "o", "uo", "ie", "ue", "üe", "ve"],
            "2":["u"],
            "3":["i", "ü", "v"],
            "4":["ai", "uai"],
            "5":["ao", "iao"],
            "6":["ou", "iu", "iou"],
            "7":["an", "ian", "uan", "üan", "van"],
            "8":["en", "in", "un", "ün", "vn"],
            "9":["ang", "iang", "uang"],
            "10":["eng", "ing", "ueng", "ong", "iong"],
            "11":["er"],
            "12":["ei", "ui", "uei", "vei"],
           }

yun2id = {}
for yid, yws in yunjiaos.items():
    for w in yws:
        yun2id[w] = yid

def eval_tpl(sents1, sents2):
    n = 0.
    if len(sents1) > len(sents2):
        sents1 = sents1[:len(sents2)]
    for i, x in enumerate(sents1):
        y = sents2[i]
        if len(x) != len(y):
            continue
        px, py = [], []
        for w in x:
            if w in PUNCS:
                px.append(w)
        for w in y:
            if w in PUNCS:
                py.append(w)
        if px == py:
            n += 1
    p = n / len(sents2)
    r = n / len(sents1)
    f = 2 * p * r / (p + r + 1e-16)

    return p, r, f, n, len(sents1), len(sents2)


def rhythm_labellig(sents):
    rhys = []
    for sent in sents:
        w = sent[-1]
        if w in PUNCS and len(sent) > 1:
            w = sent[-2]
        yunmu = lazy_pinyin(w, style=Style.FINALS)
        rhys.append(yunmu[0])
    assert len(rhys) == len(sents)
    rhy_map = {}
    for i, r in enumerate(rhys):
        if r in yun2id:
            rid = yun2id[r]
            if rid in rhy_map:
                rhy_map[rid] += [i]
            else:
                rhy_map[rid] = [i]
        else:
            pass
    max_len_yuns = -1
    max_rid = ""
    for rid, yuns in rhy_map.items():
        if len(yuns) > max_len_yuns:
            max_len_yuns = len(yuns)
            max_rid = rid
    res = []
    for i in range(len(sents)):
        if max_rid in rhy_map and i in rhy_map[max_rid]:
            res.append(1)
        else:
            res.append(-1)
    return res

def eval_rhythm(sents1, sents2):
    n = 0.
    if len(sents1) > len(sents2):
        sents1 = sents1[:len(sents2)]
    rhys1 = rhythm_labellig(sents1)
    rhys2 = rhythm_labellig(sents2)
    
    n1, n2 = 0., 0.
    for v in rhys1:
        if v == 1:
            n1 += 1
    for v in rhys2:
        if v == 1:
            n2 += 1
    for i, v1 in enumerate(rhys1):
        v2 = rhys2[i]
        if v1 == 1 and v1 == v2:
            n += 1
    p = n / (n2 + 1e-16)
    r = n / (n1 + 1e-16)
    f1 = 2 * p * r / (p + r + 1e-16)
    return p, r, f1, n, n1, n2

def eval_endings(sents1, sents2):
    n = 0.
    if len(sents1) > len(sents2):
        sents1 = sents1[:len(sents2)]
   
    sents0 = []
    for si, sent1 in enumerate(sents1):
        sent2 = sents2[si]
        if len(sent2) <= len(sent1):
            sents0.append(sent2)
        else:
            sents0.append(sent2[:len(sent1) - 1] + sent1[-1])

    sent = "</s>".join(sents0)
    return sent


def eval(res_file, fid):
    docs = []
    with open(res_file) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            fs = line.split("\t")
            if len(fs) != 2:
                print("error", line)
                continue
            x, y = fs
            docs.append((x, y))


    print(len(docs))

    ugrams_ = []
    bigrams_ = []
    p_, r_, f1_ = 0., 0., 0.
    n0_, n1_, n2_ = 0., 0., 0.

    p__, r__, f1__ = 0., 0., 0.
    n0__, n1__, n2__ = 0., 0., 0.
    d1_, d2_ = 0., 0.
    d4ends = []

    for x, y in docs:
        topic, content = x.split("<s2>")
        author, topic = topic.split("<s1>")
        sents1 = content.split("</s>")
        y = y.replace("<bos>", "")
        sents2 = y.split("</s>")
        sents1_ = []
        for sent in sents1:
            sent = sent.strip()
            if sent:
                sents1_.append(sent)
        sents1 = sents1_
        sents2_ = []
        for sent in sents2:
            sent = sent.strip()
            if sent:
                sents2_.append(sent)
        sents2 = sents2_

        p, r, f1, n0, n1, n2 = eval_tpl(sents1, sents2)
        p_ += p
        r_ += r
        f1_ += f1
        n0_ += n0
        n1_ += n1
        n2_ += n2

        ugrams = [w for w in ''.join(sents2)]
        bigrams = []
        for bi in range(len(ugrams) - 1):
            bigrams.append(ugrams[bi] + ugrams[bi+1])
        d1_ += len(set(ugrams)) / float(len(ugrams))
        d2_ += len(set(bigrams)) / float(len(bigrams))
        ugrams_ += ugrams
        bigrams_ += bigrams

        p, r, f1, n0, n1, n2 = eval_rhythm(sents1, sents2)
        p__ += p
        r__ += r
        f1__ += f1
        n0__ += n0
        n1__ += n1
        n2__ += n2

        d4end = eval_endings(sents1, sents2)
        d4ends.append(author + "<s1>" + topic + "<s2>" + d4end)

    tpl_macro_p = p_ / len(docs)
    tpl_macro_r = r_ / len(docs)
    tpl_macro_f1 = 2 * tpl_macro_p * tpl_macro_r / (tpl_macro_p + tpl_macro_r)
    tpl_micro_p = n0_ / n2_
    tpl_micro_r = n0_ / n1_
    tpl_micro_f1 = 2 * tpl_micro_p * tpl_micro_r / (tpl_micro_p + tpl_micro_r)
    
    rhy_macro_p = p__ / len(docs)
    rhy_macro_r = r__ / len(docs)
    rhy_macro_f1 = 2 * rhy_macro_p * rhy_macro_r / (rhy_macro_p + rhy_macro_r)
    rhy_micro_p = n0__ / n2__
    rhy_micro_r = n0__ / n1__
    rhy_micro_f1 = 2 * rhy_micro_p * rhy_micro_r / (rhy_micro_p + rhy_micro_r)
    

    macro_dist1 = d1_ / len(docs)
    macro_dist2 = d2_ / len(docs)
    micro_dist1 = len(set(ugrams_)) / float(len(ugrams_))
    micro_dist2 = len(set(bigrams_)) / float(len(bigrams_))

    with open("./results_4ending/res4end" + str(fid) + ".txt", "w") as fo:
        for line in d4ends:
            fo.write(line + "\n")
    return tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2

tpl_macro_f1_, tpl_micro_f1_, rhy_macro_f1_, rhy_micro_f1_,  \
macro_dist1_, micro_dist1_, macro_dist2_, micro_dist2_ = [], [], [], [], [], [], [], []
abalation = "top-32"
for i in range(5):
    f_name = "./results/"+abalation+"/out" +str(i+1)+".txt"
    if not os.path.exists(f_name):
        continue
    tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2 = eval(f_name, i + 1)
    print(tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2)
    tpl_macro_f1_.append(tpl_macro_f1)
    tpl_micro_f1_.append(tpl_micro_f1)
    rhy_macro_f1_.append(rhy_macro_f1)
    rhy_micro_f1_.append(rhy_micro_f1)
    macro_dist1_.append(macro_dist1)
    micro_dist1_.append(micro_dist1)
    macro_dist2_.append(macro_dist2)
    micro_dist2_.append(micro_dist2)

print()
print("tpl_macro_f1", np.mean(tpl_macro_f1_), np.std(tpl_macro_f1_, ddof=1))
print("tpl_micro_f1", np.mean(tpl_micro_f1_), np.std(tpl_micro_f1_, ddof=1))
print("rhy_macro_f1", np.mean(rhy_macro_f1_), np.std(rhy_macro_f1_, ddof=1))
print("rhy_micro_f1", np.mean(rhy_micro_f1_), np.std(rhy_micro_f1_, ddof=1))
print("macro_dist1", np.mean(macro_dist1_), np.std(macro_dist1_, ddof=1))
print("micro_dist1", np.mean(micro_dist1_), np.std(micro_dist1_, ddof=1))
print("macro_dist2", np.mean(macro_dist2_), np.std(macro_dist2_, ddof=1))
print("micro_dist2", np.mean(micro_dist2_), np.std(micro_dist2_, ddof=1))







================================================
FILE: optim.py
================================================
# -*- coding: utf-8 -*-

class Optim:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0

    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()

    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))

    def state_dict(self):
        return self.optimizer.state_dict()

    def load_state_dict(self, m):
        self.optimizer.load_state_dict(m)


================================================
FILE: polish.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy 
import time

from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy_polish

gpu = 0
def init_model(m_path, device, vocab):
    ckpt= torch.load(m_path, map_location='cpu')
    lm_args = ckpt['args']
    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)
    lm_model.load_state_dict(ckpt['model'])
    lm_model = lm_model.cuda(device)
    lm_model.eval()
    return lm_model, lm_vocab, lm_args

m_path = "./model/songci.ckpt"
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")


MAX_LEN = 300
k = 32
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    start = time.time()
    incremental_state = None
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
                                         incremental_state)
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            
            if l == 0:
                logits = probs[len(s[i]) - 1, i]
            else:
                logits = probs[0, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples = 1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        
        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)
        bidx = torch.BoolTensor(bidx).cuda(gpu)
        incremental_state["bidx"] = bidx
    res += s_
        
    #for i in res:
    #    print(''.join(i))
    print(time.time()-start)
    return res

def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)

    start = time.time()
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            logits = probs[len(s[i]) - 1, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples = 1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        
        s_ = []
        for sent, t in zip(s, next_tk):
            if t == "<eos>":
                res.append(sent)
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)

    res += s_
        
    #for i in res:
    #    print(''.join(i))

    #print(time.time()-start)
    return res
   



ds = []
with open("./data/polish_tpl.txt", "r") as f:
    for line in f:
        line = line.strip()
        if line:
            ds.append(line)
print(len(ds))

local_rank = gpu
batch_size = 1
cp_size = 1
batches = round(len(ds) / batch_size)

for i in range(5):
    fo = open("./results/out"+str(i+1)+".txt", "w")     
    idx = 0
    while idx < len(ds):
        lb = ds[idx:idx + batch_size]
        cplb = []
        for line in lb:
            cplb += [line for i in range(cp_size)]
        print(cplb) 
        xs_tpl, xs_seg, xs_pos, \
        ys_truth, ys_inp, \
        ys_tpl, ys_seg, ys_pos, msk = s2xy_polish(cplb, lm_vocab, lm_args.max_len,2)

        xs_tpl = xs_tpl.cuda(local_rank)
        xs_seg = xs_seg.cuda(local_rank)
        xs_pos = xs_pos.cuda(local_rank)
        ys_tpl = ys_tpl.cuda(local_rank)
        ys_seg = ys_seg.cuda(local_rank)
        ys_pos = ys_pos.cuda(local_rank)

        enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)
        s = [['<bos>']] * batch_size * cp_size   
        res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)

        for i, line in enumerate(cplb):
            r = ''.join(res[i])
            print(line)
            print(r)
            fo.write(line + "\t" + r + "\n")  

        idx += batch_size
    fo.close()


================================================
FILE: polish.sh
================================================
python3 -u polish.py


================================================
FILE: prepare_data.py
================================================
import sys, re
from collections import Counter
import random
cnt = Counter()

f_ci = "./data/ci.txt"
f_cipai = "./data/cipai.txt"
cipai = Counter()
with open(f_cipai) as f:
    for line in f:
        line = line.strip()
        fs = line.split()
        cipai.update(fs)

cipai = cipai.keys()

docs = {}
with open(f_ci) as f:
    for line in f:
        line = line.strip()
        fs = line.split("<s1>")
        author = fs[0]
        topic, content = fs[1].split("<s2>")
        if "・" in topic:
            t1, t2 = topic.split("・")
            if t1 == t2:
                topic = t1
            else:
                if t1 in cipai:
                    topic = t1
                elif t2 in cipai:
                    topic = t2
                else:
                    topic = t1
        content = content.replace("、", ",")
        sents = content.split("</s>")
        ws = [w for w in author + topic + ''.join(sents)]
        cnt.update(ws)
        if topic not in docs:
            docs[topic] = []
        docs[topic].append(author + "<s1>" + topic + "<s2>" + '</s>'.join(sents))


topics = list(docs.keys())

print(len(topics))
random.shuffle(topics)

topics_train = topics[:len(topics)-50]
topics_dev_test = topics[-50:]
topics_dev = topics_dev_test[:25]
topics_test = topics_dev_test[-25:]

docs_train = []
docs_dev = []
docs_test = []

for t in topics_train:
    docs_train.extend(docs[t])

for t in topics_dev:
    docs_dev.extend(docs[t])

for t in topics_test:
    docs_test.extend(docs[t])

random.shuffle(docs_train)
random.shuffle(docs_dev)
random.shuffle(docs_test)

print(len(docs_train), len(docs_dev), len(docs_test))
train_cps = []
dev_cps = []
test_cps = []


with open('./data/train.txt', 'w', encoding ='utf8') as f:
    for x in docs_train:
        s = x.split("<s2>")[0]
        train_cps.append(s.split("<s1>")[1])
        f.write(x + '\n')
    print(len(set(train_cps)))
with open('./data/dev.txt', 'w', encoding ='utf8') as f:
    for x in docs_dev:
        s = x.split("<s2>")[0]
        dev_cps.append(s.split("<s1>")[1])
        f.write(x + '\n')
    print(len(set(dev_cps)))
with open('./data/test.txt', 'w', encoding ='utf8') as f:
    for x in docs_test:
        s = x.split("<s2>")[0]
        test_cps.append(s.split("<s1>")[1])
        f.write(x + '\n')
    print(len(set(test_cps)))

print("vocab")
with open('./data/vocab.txt', 'w', encoding ='utf8') as f:
    for x, y in cnt.most_common():
        f.write(x + '\t' + str(y) + '\n')
print("done")


================================================
FILE: test.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy 
import time

from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy



def init_seeds():
    random.seed(123)
    torch.manual_seed(123)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(123)

#init_seeds()

gpu = 1
def init_model(m_path, device, vocab):
    ckpt= torch.load(m_path, map_location='cpu')
    lm_args = ckpt['args']
    lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
    lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)
    lm_model.load_state_dict(ckpt['model'])
    lm_model = lm_model.cuda(device)
    lm_model.eval()
    return lm_model, lm_vocab, lm_args

m_path = "./model/songci.ckpt"
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")


k = 32
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    start = time.time()
    incremental_state = None
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
                                         incremental_state)
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            
            if l == 0:
                logits = probs[len(s[i]) - 1, i]
            else:
                logits = probs[0, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples = 1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        
        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)
        bidx = torch.BoolTensor(bidx).cuda(gpu)
        incremental_state["bidx"] = bidx
    res += s_
        
    #for i in res:
    #    print(''.join(i))
    print(time.time()-start)
    return res

def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)

    start = time.time()
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            logits = probs[len(s[i]) - 1, i]
            ps, idx = torch.topk(logits, k=k)
            ps = ps / torch.sum(ps)
            sampled = torch.multinomial(ps, num_samples = 1)
            sampled_idx = idx[sampled]
            next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
        
        s_ = []
        for sent, t in zip(s, next_tk):
            if t == "<eos>":
                res.append(sent)
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)

    res += s_
        
    #for i in res:
    #    print(''.join(i))

    #print(time.time()-start)
    return res
 
    
def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
    start = time.time()
    incremental_state = None
    inp_y, m = s2t(s, lm_vocab)
    inp_y = inp_y.cuda(gpu)
    res = []
    for l in range(inp_ys_tpl.size(0)):
        probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
                                         inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
                                         incremental_state)
        next_tk = []
        for i in range(len(s)):
            ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
            if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
                next_tk.append(ctk)
                continue
            
            if l == 0:
                pred = pred[len(s[i]) - 1, i]
            else:
                pred = pred[0, i]
            next_tk.append(lm_vocab.idx2token(pred.item()))
        
        s_ = []
        bidx = [1] * len(s)
        for idx, (sent, t) in enumerate(zip(s, next_tk)):
            if t == "<eos>":
                res.append(sent)
                bidx[idx] = 0
            else:
                s_.append(sent + [t])
        if not s_:
            break
        s = s_
        inp_y, m = s2t(s, lm_vocab)
        inp_y = inp_y.cuda(gpu)
        bidx = torch.BoolTensor(bidx).cuda(gpu)
        incremental_state["bidx"] = bidx
    res += s_
        
    #for i in res:
    #    print(''.join(i))
    print(time.time()-start)
    return res


def beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos):
    beam_size = 5
    
    num_live = 1
    num_dead = 0 
    samples = []
    sample_scores = np.zeros(beam_size)

    last_traces = [[]]
    last_scores = torch.FloatTensor(np.zeros(1)).to(gpu)

    x = x.to(gpu)
    ys = x

    for l in range(inp_ys_tpl.size(0)):
        seq_len, bsz = ys.size()
        enc_ = enc.repeat(1, bsz, 1)
        src_padding_mask_ = src_padding_mask.repeat(1, bsz)
        inp_ys_tpl_ = inp_ys_tpl.repeat(1, bsz)
        inp_ys_seg_ = inp_ys_seg.repeat(1, bsz)
        inp_ys_pos_ = inp_ys_pos.repeat(1, bsz)

        y_pred, _ = lm_model.work(enc_, src_padding_mask_, ys, inp_ys_tpl_[0:l+1,:], inp_ys_seg_[0:l+1,:], inp_ys_pos_[0:l+1,:])

        dict_size = y_pred.shape[-1]
        y_pred = y_pred[-1, :, :] 

        cand_y_scores = last_scores + torch.log(y_pred) # larger is better
        cand_scores = cand_y_scores.flatten()
        idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1]
        
        '''
        ps, idx_top_joint_scores = torch.topk(cand_scores, 100)
        ps = F.softmax(ps)
        sampled = torch.multinomial(ps, num_samples = beam_size - num_dead)
        idx_top_joint_scores = idx_top_joint_scores[sampled]
        '''

        idx_last_traces = idx_top_joint_scores / dict_size
        idx_word_now = idx_top_joint_scores % dict_size
        top_joint_scores = cand_scores[idx_top_joint_scores]

        traces_now = []
        scores_now = np.zeros((beam_size - num_dead))
        ys_now = []
        for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)):
            traces_now.append(last_traces[j] + [k])
            scores_now[i] = copy.copy(top_joint_scores[i])
            ys_now.append(copy.copy(ys[:,j]))


        num_live = 0  
        last_traces = []
        last_scores = []
        ys = []
        for i in range(len(traces_now)):
            w = lm_vocab.idx2token(traces_now[i][-1].item())
            if w == "<eos>":
                samples.append([str(e.item()) for e in traces_now[i][:-1]])
                sample_scores[num_dead] = scores_now[i] 
                num_dead += 1
            else:
                last_traces.append(traces_now[i])
                last_scores.append(scores_now[i])
                ys.append(ys_now[i])
                num_live += 1
        
        if num_live == 0 or num_dead >= beam_size:
            break
        ys = torch.stack(ys, dim = 1) 

        last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(gpu)
        next_y = []
        for e in last_traces:
            eid = e[-1].item()
            next_y.append(eid)
        next_y = np.array(next_y).reshape((1, num_live))
        next_y = torch.LongTensor(next_y).to(gpu)
        
        ys = torch.cat([ys, next_y], dim=0)
       
        assert num_live + num_dead == beam_size 
        # end for loop

    if num_live > 0:
        for i in range(num_live):
            samples.append([str(e.item()) for e in last_traces[i]])
            sample_scores[num_dead] = last_scores[i]
            num_dead += 1  

    idx_sorted_scores = np.argsort(sample_scores) # ascending order

    sorted_samples = []
    sorted_scores = []
    filter_idx = []
    for e in idx_sorted_scores:
        if len(samples[e]) > 0:
            filter_idx.append(e)
    if len(filter_idx) == 0:
        filter_idx = idx_sorted_scores
    for e in filter_idx:
        sorted_samples.append(samples[e])
        sorted_scores.append(sample_scores[e])

    res = []
    dec_words = []
    for sample in sorted_samples[::-1]:
        for e in sample:
            e = int(e)
            dec_words.append(lm_vocab.idx2token(e))
        #r = ''.join(dec_words)
        #print(r)
        res.append(dec_words)
        dec_words = []

    return res


def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):
    x, m = s2t(s, lm_vocab)
    return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos)


ds = []
with open("./data/test.txt", "r") as f:
    for line in f:
        line = line.strip()
        if line:
            ds.append(line)
print(len(ds))

local_rank = gpu
batch_size = 1
cp_size = 1
batches = round(len(ds) / batch_size)

for i in range(5): 
    idx = 0
    fo = open("./results/top-"+str(k)+"/out"+str(i+1)+".txt", "w")
    while idx < len(ds):
        lb = ds[idx:idx + batch_size]
        cplb = []
        for line in lb:
            cplb += [line for i in range(cp_size)]
        print(cplb) 
        xs_tpl, xs_seg, xs_pos, \
        ys_truth, ys_inp, \
        ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)

        xs_tpl = xs_tpl.cuda(local_rank)
        xs_seg = xs_seg.cuda(local_rank)
        xs_pos = xs_pos.cuda(local_rank)
        ys_tpl = ys_tpl.cuda(local_rank)
        ys_seg = ys_seg.cuda(local_rank)
        ys_pos = ys_pos.cuda(local_rank)

        enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)
        s = [['<bos>']] * batch_size * cp_size   
        res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)

        for i, line in enumerate(cplb):
            r = ''.join(res[i])
            print(line)
            print(r)
    
            fo.write(line + "\t" + r + "\n")
    
        idx += batch_size
    
    fo.close()


================================================
FILE: test.sh
================================================
python3 -u test.py


================================================
FILE: train.py
================================================
# coding=utf-8
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp

from biglm import BIGLM
from data import Vocab, DataLoader, s2xy
from optim import Optim

import argparse, os
import random

def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('--embed_dim', type=int)
    parser.add_argument('--ff_embed_dim', type=int)
    parser.add_argument('--num_heads', type=int)
    parser.add_argument('--layers', type=int)
    parser.add_argument('--dropout', type=float)

    parser.add_argument('--train_data', type=str)
    parser.add_argument('--dev_data', type=str)
    parser.add_argument('--vocab', type=str)
    parser.add_argument('--min_occur_cnt', type=int)
    parser.add_argument('--batch_size', type=int)
    parser.add_argument('--warmup_steps', type=int)
    parser.add_argument('--lr', type=float)
    parser.add_argument('--smoothing', type=float)
    parser.add_argument('--weight_decay', type=float)
    parser.add_argument('--max_len', type=int)
    parser.add_argument('--min_len', type=int)
    parser.add_argument('--print_every', type=int)
    parser.add_argument('--save_every', type=int)
    parser.add_argument('--start_from', type=str, default=None)
    parser.add_argument('--save_dir', type=str)

    parser.add_argument('--world_size', type=int)
    parser.add_argument('--gpus', type=int)
    parser.add_argument('--MASTER_ADDR', type=str)
    parser.add_argument('--MASTER_PORT', type=str)
    parser.add_argument('--start_rank', type=int)
    parser.add_argument('--backend', type=str)

    return parser.parse_args()

def update_lr(optimizer, lr):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
 
def average_gradients(model):
    """ Gradient averaging. """
    normal = True
    size = float(dist.get_world_size())
    for param in model.parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
            param.grad.data /= size
        else:
            normal = False
            break
    return normal

def eval_epoch(lm_args, model, lm_vocab, local_rank, label):
    print("validating...", flush=True)
    ds = []
    with open(lm_args.dev_data, "r") as f:
        for line in f:
            line = line.strip()
            if line:
                ds.append(line)

    batch_size = 10
    batches = round(len(ds) / batch_size)
    idx = 0
    avg_nll = 0.
    avg_ppl = 0.
    count = 0.
    while idx < len(ds):
        cplb = ds[idx:idx + batch_size]
        xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, lm_args.min_len)

        xs_tpl = xs_tpl.cuda(local_rank)
        xs_seg = xs_seg.cuda(local_rank)
        xs_pos = xs_pos.cuda(local_rank)
        ys_truth = ys_truth.cuda(local_rank)
        ys_inp = ys_inp.cuda(local_rank)
        ys_tpl = ys_tpl.cuda(local_rank)
        ys_seg = ys_seg.cuda(local_rank)
        ys_pos = ys_pos.cuda(local_rank)
        msk = msk.cuda(local_rank)

        nll, ppl, bsz = model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
    
        avg_nll += nll
        avg_ppl += ppl
        count += bsz

        idx += batch_size
    
    print(label, "nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count, flush=True)

def run(args, local_rank):
    """ Distributed Synchronous """
    torch.manual_seed(1234)
    vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
    if (args.world_size == 1 or dist.get_rank() == 0):
        print ("vocab.size = " + str(vocab.size), flush=True)
    model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
                  args.num_heads, args.dropout, args.layers, args.smoothing)
    if args.start_from is not None:
        ckpt = torch.load(args.start_from, map_location='cpu')
        model.load_state_dict(ckpt['model'])
    model = model.cuda(local_rank)
   
    optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))

    if args.start_from is not None:
        optimizer.load_state_dict(ckpt['optimizer'])

    train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len)
    batch_acm = 0
    acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
    while True:
        model.train()
        if train_data.epoch_id > 30:
            break
        for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
            batch_acm += 1
            xs_tpl = xs_tpl.cuda(local_rank)
            xs_seg = xs_seg.cuda(local_rank)
            xs_pos = xs_pos.cuda(local_rank)
            ys_truth = ys_truth.cuda(local_rank)
            ys_inp = ys_inp.cuda(local_rank)
            ys_tpl = ys_tpl.cuda(local_rank)
            ys_seg = ys_seg.cuda(local_rank)
            ys_pos = ys_pos.cuda(local_rank)
            msk = msk.cuda(local_rank)

            model.zero_grad()
            res, loss, acc, nll, ppl, ntokens, npairs = model(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
            loss_acm += loss.item()
            acc_acm += acc
            nll_acm += nll
            ppl_acm += ppl
            ntokens_acm += ntokens
            npairs_acm += npairs
            nxs += npairs
            
            loss.backward()
            if args.world_size > 1:
                is_normal = average_gradients(model)
            else:
                is_normal = True
            if is_normal:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
            else:
                print("gradient: none, gpu: " + str(local_rank), flush=True)
                continue
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:
                print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
                        %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
                        nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
                acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
            if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
                if not os.path.exists(args.save_dir):
                    os.mkdir(args.save_dir)
                
                model.eval()
                eval_epoch(args, model, vocab, local_rank, "epoch-" + str(train_data.epoch_id) + "-acm-" + str(batch_acm))
                model.train()

                torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))

def init_processes(args, local_rank, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = args.MASTER_ADDR
    os.environ['MASTER_PORT'] = args.MASTER_PORT
    dist.init_process_group(backend, rank=args.start_rank + local_rank, world_size=args.world_size)
    fn(args, local_rank)

if __name__ == "__main__":
    mp.set_start_method('spawn')
    args = parse_config()

    if args.world_size == 1:
        run(args, 0)
        exit(0)
    processes = []
    for rank in range(args.gpus):
        p = mp.Process(target=init_processes, args=(args, rank, run, args.backend))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()


================================================
FILE: train.sh
================================================
CUDA_VISIBLE_DEVICES=1 \
python3 -u train.py --embed_dim 768 \
                      --ff_embed_dim 3072 \
                      --num_heads 12 \
                      --layers 12 \
                      --dropout 0.2 \
                      --train_data ./data/train.txt \
                      --dev_data ./data/dev.txt \
                      --vocab ./data/vocab.txt \
                      --min_occur_cnt 1 \
                      --batch_size 32 \
                      --warmup_steps 8000 \
                      --lr 0.5 \
                      --weight_decay 0 \
                      --smoothing 0.1 \
                      --max_len 300 \
                      --min_len 10 \
                      --world_size 1 \
                      --gpus 1 \
                      --start_rank 0 \
                      --MASTER_ADDR localhost \
                      --MASTER_PORT 28512 \
                      --print_every 100 \
                      --save_every 1000 \
                      --save_dir ckpt \
                      --backend nccl


================================================
FILE: transformer.py
================================================
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F

from utils import gelu, LayerNorm, get_incremental_state, set_incremental_state
import math

class TransformerLayer(nn.Module):
    
    def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True):
        super(TransformerLayer, self).__init__()
        self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
        self.fc1 = nn.Linear(embed_dim, ff_embed_dim)
        self.fc2 = nn.Linear(ff_embed_dim, embed_dim)
        self.attn_layer_norm = LayerNorm(embed_dim)
        self.ff_layer_norm = LayerNorm(embed_dim)
        self.with_external = with_external
        self.dropout = dropout
        if self.with_external:
            self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
            self.external_layer_norm = LayerNorm(embed_dim)
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.normal_(self.fc1.weight, std=0.02)
        nn.init.normal_(self.fc2.weight, std=0.02)
        nn.init.constant_(self.fc1.bias, 0.)
        nn.init.constant_(self.fc2.bias, 0.)

    def forward(self, x, kv = None,
                self_padding_mask = None, self_attn_mask = None,
                external_memories = None, external_padding_mask=None,
                need_weights = False):
        # x: seq_len x bsz x embed_dim
        residual = x
        if kv is None:
            x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
        else:
            x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)

        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.attn_layer_norm(residual + x)

        if self.with_external:
            residual = x
            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.external_layer_norm(residual + x)
        else:
            external_attn = None

        residual = x
        x = gelu(self.fc1(x))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.ff_layer_norm(residual + x)

        return x, self_attn, external_attn
    
    def work_incremental(self, x, self_padding_mask = None, self_attn_mask = None,
                         external_memories = None, external_padding_mask = None, incremental_state = None):
        # x: seq_len x bsz x embed_dim
        residual = x
        x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state)
        x = self.attn_layer_norm(residual + x)

        if self.with_external:
            residual = x
            x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask)
            x = self.external_layer_norm(residual + x)
        else:
            external_attn = None
        residual = x
        x = gelu(self.fc1(x))
        x = self.fc2(x)
        x = self.ff_layer_norm(residual + x)

        return x, self_attn, external_attn

class MultiheadAttention(nn.Module):

    def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True):
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.scaling = self.head_dim ** -0.5

        self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
        self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))

        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.weights_dropout = weights_dropout
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.in_proj_weight, std=0.02)
        nn.init.normal_(self.out_proj.weight, std=0.02)
        nn.init.constant_(self.in_proj_bias, 0.)
        nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, incremental_state = None):
        """ Input shape: Time x Batch x Channel
            key_padding_mask: Time x batch
            attn_mask:  tgt_len x src_len
        """
        if incremental_state is not None: 
            saved_state = self._get_input_buffer(incremental_state)
            bidx = self._get_bidx(incremental_state)
        else:
            saved_state = None
            bidx = None
    
        qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
        kv_same = key.data_ptr() == value.data_ptr()

        tgt_len, bsz, embed_dim = query.size()
        assert key.size() == value.size()

        if qkv_same:
            # self-attention
            q, k, v = self.in_proj_qkv(query)
        elif kv_same:
            # encoder-decoder attention
            q = self.in_proj_q(query)
            k, v = self.in_proj_kv(key)
        else:
            q = self.in_proj_q(query)
            k = self.in_proj_k(key)
            v = self.in_proj_v(value)
        q = q * self.scaling
        
        q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)

        if saved_state is not None:
            # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
            if 'prev_key' in saved_state:
                prev_key = saved_state['prev_key']
                if bidx is not None:
                    prev_key = prev_key[bidx]
                prev_key = prev_key.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
                k = torch.cat((prev_key, k), dim=1)
            if 'prev_value' in saved_state:
                prev_value = saved_state['prev_value']
                if bidx is not None:
                    prev_value = prev_value[bidx]
                prev_value = prev_value.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
                v = torch.cat((prev_value, v), dim=1)
            saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
            saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
            self._set_input_buffer(incremental_state, saved_state)	
        
        src_len = k.size(1)
        # k,v: bsz*heads x src_len x dim
        # q: bsz*heads x tgt_len x dim 

        attn_weights = torch.bmm(q, k.transpose(1, 2))
        assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]

        if attn_mask is not None:
            attn_weights.masked_fill_(
                attn_mask.unsqueeze(0),
                float('-inf')
            )

        if key_padding_mask is not None:
            # don't attend to padding symbols
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights.masked_fill_(
                key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2),
                float('-inf')
            )
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        
        attn_weights = F.softmax(attn_weights, dim=-1)
        
        if self.weights_dropout:
            attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        attn = torch.bmm(attn_weights, v)
        if not self.weights_dropout:
            attn = F.dropout(attn, p=self.dropout, training=self.training)

        assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]

        attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
        attn = self.out_proj(attn)
        if need_weights:
            # maximum attention weight over heads 
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            
            attn_weights, _ = attn_weights.max(dim=1)
            attn_weights = attn_weights.transpose(0, 1)
        else:
            attn_weights = None

        return attn, attn_weights

    def in_proj_qkv(self, query):
        return self._in_proj(query).chunk(3, dim=-1)

    def in_proj_kv(self, key):
        return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)

    def in_proj_q(self, query):
        return self._in_proj(query, end=self.embed_dim)

    def in_proj_k(self, key):
        return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)

    def in_proj_v(self, value):
        return self._in_proj(value, start=2 * self.embed_dim)

    def _in_proj(self, input, start=0, end=None):
        weight = self.in_proj_weight
        bias = self.in_proj_bias
        weight = weight[start:end, :]
        if bias is not None:
            bias = bias[start:end]
        return F.linear(input, weight, bias)

    def _get_input_buffer(self, incremental_state):
       return get_incremental_state(
                self,
                incremental_state,
                'attn_state',
                ) or {}

    def _set_input_buffer(self, incremental_state, buffer):
        set_incremental_state(
                self,
                incremental_state,
                'attn_state',
                buffer,)

    def _get_bidx(self, incremental_state):
        if "bidx" in incremental_state:
            return incremental_state["bidx"]
        else:
            return None

def Embedding(num_embeddings, embedding_dim, padding_idx):
    m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
    nn.init.normal_(m.weight, std=0.02)
    nn.init.constant_(m.weight[padding_idx], 0)
    return m

class SelfAttentionMask(nn.Module):
    def __init__(self, init_size = 100, device = 0):
        super(SelfAttentionMask, self).__init__()
        self.weights = SelfAttentionMask.get_mask(init_size)
        self.device = device
    
    @staticmethod
    def get_mask(size):
        weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1)
        return weights

    def forward(self, size):
        if self.weights is None or size > self.weights.size(0):
            self.weights = SelfAttentionMask.get_mask(size)
        res = self.weights[:size,:size].cuda(self.device).detach()
        return res

class LearnedPositionalEmbedding(nn.Module):
    """This module produces LearnedPositionalEmbedding.
    """
    def __init__(self, embedding_dim, init_size=1024, device=0):
        super(LearnedPositionalEmbedding, self).__init__()
        self.weights = nn.Embedding(init_size, embedding_dim)
        self.device= device
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.normal_(self.weights.weight, std=0.02)

    def forward(self, input, offset=0):
        """Input is expected to be of size [seq_len x bsz]."""
        seq_len, bsz = input.size()
        positions = (offset + torch.arange(seq_len)).cuda(self.device)
        res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1)
        return res

class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.
    """
    def __init__(self, embedding_dim, init_size=1024, device=0):
        super(SinusoidalPositionalEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size,
            embedding_dim
        )
        self.device= device

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim):
        """Build sinusoidal embeddings.
        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        return emb

    def forward(self, input, offset=0):
        """Input is expected to be of size [seq_len x bsz]."""
        seq_len, bsz = input.size()
        mx_position = seq_len + offset
        if self.weights is None or mx_position > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                mx_position,
                self.embedding_dim,
            )

        positions = offset + torch.arange(seq_len)
        res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach()
        return res


================================================
FILE: utils.py
================================================
import torch
from torch import nn
from torch.nn import Parameter
from collections import defaultdict

import math

def gelu(x):
    cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
    return cdf*x

class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(hidden_size))
        self.bias = nn.Parameter(torch.Tensor(hidden_size))
        self.eps = eps
        self.reset_parameters()
    def reset_parameters(self):
        nn.init.constant_(self.weight, 1.)
        nn.init.constant_(self.bias, 0.)

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        return self.weight * x + self.bias


INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)

def _get_full_incremental_state_key(module_instance, key):
    module_name = module_instance.__class__.__name__

    # assign a unique ID to each module instance, so that incremental state is
    # not shared across module instances
    if not hasattr(module_instance, '_guyu_instance_id'):
        INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
        module_instance._guyu_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]

    return '{}.{}.{}'.format(module_name, module_instance._guyu_instance_id, key)

def get_incremental_state(module, incremental_state, key):
    """Helper for getting incremental state for an nn.Module."""
    full_key = _get_full_incremental_state_key(module, key)
    if incremental_state is None or full_key not in incremental_state:
        return None
    return incremental_state[full_key]

def set_incremental_state(module, incremental_state, key, value):
    """Helper for setting incremental state for an nn.Module."""
    if incremental_state is not None:
        full_key = _get_full_incremental_state_key(module, key)
        incremental_state[full_key] = value


Download .txt
gitextract_oqyy3n7z/

├── .gitignore
├── LICENSE
├── README.md
├── adam.py
├── biglm.py
├── data.py
├── eval.py
├── eval.sh
├── label_smoothing.py
├── metrics.py
├── optim.py
├── polish.py
├── polish.sh
├── prepare_data.py
├── test.py
├── test.sh
├── train.py
├── train.sh
├── transformer.py
└── utils.py
Download .txt
SYMBOL INDEX (103 symbols across 12 files)

FILE: adam.py
  class AdamWeightDecayOptimizer (line 5) | class AdamWeightDecayOptimizer(Optimizer):
    method __init__ (line 9) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
    method __setstate__ (line 23) | def __setstate__(self, state):
    method step (line 28) | def step(self, closure=None):

FILE: biglm.py
  class BIGLM (line 9) | class BIGLM(nn.Module):
    method __init__ (line 10) | def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_hea...
    method reset_parameters (line 35) | def reset_parameters(self):
    method label_smotthing_loss (line 41) | def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):
    method nll_loss (line 51) | def nll_loss(self, y_pred, y, y_mask, avg=True):
    method work_incremental (line 64) | def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_s...
    method work (line 92) | def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):
    method encode (line 114) | def encode(self, xs_tpl, xs_seg, xs_pos):
    method ppl (line 120) | def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg...
    method forward (line 141) | def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys...

FILE: data.py
  function ListsToTensor (line 17) | def ListsToTensor(xs, vocab=None):
  function _back_to_text_for_check (line 28) | def _back_to_text_for_check(x, vocab):
  function batchify (line 33) | def batchify(data, vocab):
  function s2t (line 61) | def s2t(strs, vocab):
  function s2xy (line 71) | def s2xy(lines, vocab, max_len, min_len):
  function parse_line (line 80) | def parse_line(line, max_len, min_len):
  function s2xy_polish (line 140) | def s2xy_polish(lines, vocab, max_len, min_len):
  function parse_line_polish (line 147) | def parse_line_polish(line, max_len, min_len):
  class DataLoader (line 204) | class DataLoader(object):
    method __init__ (line 205) | def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):
    method __iter__ (line 214) | def __iter__(self):
  class Vocab (line 238) | class Vocab(object):
    method __init__ (line 239) | def __init__(self, filename, min_occur_cnt, specials = None):
    method size (line 255) | def size(self):
    method unk_idx (line 259) | def unk_idx(self):
    method padding_idx (line 263) | def padding_idx(self):
    method random_token (line 266) | def random_token(self):
    method idx2token (line 269) | def idx2token(self, x):
    method token2idx (line 274) | def token2idx(self, x):

FILE: eval.py
  function init_model (line 14) | def init_model(m_path, device, vocab):

FILE: label_smoothing.py
  class LabelSmoothing (line 5) | class LabelSmoothing(nn.Module):
    method __init__ (line 7) | def __init__(self, device, size, padding_idx, label_smoothing=0.0):
    method forward (line 20) | def forward(self, output, target):

FILE: metrics.py
  function eval_tpl (line 29) | def eval_tpl(sents1, sents2):
  function rhythm_labellig (line 53) | def rhythm_labellig(sents):
  function eval_rhythm (line 86) | def eval_rhythm(sents1, sents2):
  function eval_endings (line 109) | def eval_endings(sents1, sents2):
  function eval (line 126) | def eval(res_file, fid):

FILE: optim.py
  class Optim (line 3) | class Optim:
    method __init__ (line 5) | def __init__(self, model_size, factor, warmup, optimizer):
    method step (line 13) | def step(self):
    method rate (line 22) | def rate(self, step = None):
    method state_dict (line 28) | def state_dict(self):
    method load_state_dict (line 31) | def load_state_dict(self, m):

FILE: polish.py
  function init_model (line 13) | def init_model(m_path, device, vocab):
  function top_k_inc (line 29) | def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos,...
  function top_k (line 78) | def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):

FILE: test.py
  function init_seeds (line 14) | def init_seeds():
  function init_model (line 23) | def init_model(m_path, device, vocab):
  function top_k_inc (line 38) | def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos,...
  function top_k (line 87) | def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
  function greedy (line 129) | def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
  function beam_decode (line 175) | def beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp...
  function beam_search (line 293) | def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):

FILE: train.py
  function parse_config (line 15) | def parse_config():
  function update_lr (line 48) | def update_lr(optimizer, lr):
  function average_gradients (line 52) | def average_gradients(model):
  function eval_epoch (line 65) | def eval_epoch(lm_args, model, lm_vocab, local_rank, label):
  function run (line 104) | def run(args, local_rank):
  function init_processes (line 177) | def init_processes(args, local_rank, fn, backend='nccl'):

FILE: transformer.py
  class TransformerLayer (line 9) | class TransformerLayer(nn.Module):
    method __init__ (line 11) | def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_e...
    method reset_parameters (line 25) | def reset_parameters(self):
    method forward (line 31) | def forward(self, x, kv = None,
    method work_incremental (line 62) | def work_incremental(self, x, self_padding_mask = None, self_attn_mask...
  class MultiheadAttention (line 82) | class MultiheadAttention(nn.Module):
    method __init__ (line 84) | def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=T...
    method reset_parameters (line 100) | def reset_parameters(self):
    method forward (line 106) | def forward(self, query, key, value, key_padding_mask=None, attn_mask=...
    method in_proj_qkv (line 206) | def in_proj_qkv(self, query):
    method in_proj_kv (line 209) | def in_proj_kv(self, key):
    method in_proj_q (line 212) | def in_proj_q(self, query):
    method in_proj_k (line 215) | def in_proj_k(self, key):
    method in_proj_v (line 218) | def in_proj_v(self, value):
    method _in_proj (line 221) | def _in_proj(self, input, start=0, end=None):
    method _get_input_buffer (line 229) | def _get_input_buffer(self, incremental_state):
    method _set_input_buffer (line 236) | def _set_input_buffer(self, incremental_state, buffer):
    method _get_bidx (line 243) | def _get_bidx(self, incremental_state):
  function Embedding (line 249) | def Embedding(num_embeddings, embedding_dim, padding_idx):
  class SelfAttentionMask (line 255) | class SelfAttentionMask(nn.Module):
    method __init__ (line 256) | def __init__(self, init_size = 100, device = 0):
    method get_mask (line 262) | def get_mask(size):
    method forward (line 266) | def forward(self, size):
  class LearnedPositionalEmbedding (line 272) | class LearnedPositionalEmbedding(nn.Module):
    method __init__ (line 275) | def __init__(self, embedding_dim, init_size=1024, device=0):
    method reset_parameters (line 281) | def reset_parameters(self):
    method forward (line 284) | def forward(self, input, offset=0):
  class SinusoidalPositionalEmbedding (line 291) | class SinusoidalPositionalEmbedding(nn.Module):
    method __init__ (line 294) | def __init__(self, embedding_dim, init_size=1024, device=0):
    method get_embedding (line 304) | def get_embedding(num_embeddings, embedding_dim):
    method forward (line 318) | def forward(self, input, offset=0):

FILE: utils.py
  function gelu (line 8) | def gelu(x):
  class LayerNorm (line 12) | class LayerNorm(nn.Module):
    method __init__ (line 13) | def __init__(self, hidden_size, eps=1e-12):
    method reset_parameters (line 19) | def reset_parameters(self):
    method forward (line 23) | def forward(self, x):
  function _get_full_incremental_state_key (line 32) | def _get_full_incremental_state_key(module_instance, key):
  function get_incremental_state (line 43) | def get_incremental_state(module, incremental_state, key):
  function set_incremental_state (line 50) | def set_incremental_state(module, incremental_state, key, value):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (80K chars).
[
  {
    "path": ".gitignore",
    "chars": 56,
    "preview": "*.pyc\n*.log\nckpt\n/data*/*\n/model*/*\n/ckpt*/*\n/result*/*\n"
  },
  {
    "path": "LICENSE",
    "chars": 1064,
    "preview": "MIT License\n\nCopyright (c) 2021 Piji Li\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof"
  },
  {
    "path": "README.md",
    "chars": 1107,
    "preview": "# SongNet\nSongNet: SongCi + Song (Lyrics) + Sonnet + etc.\n\n```\n@inproceedings{li-etal-2020-rigid,\n    title = \"Rigid For"
  },
  {
    "path": "adam.py",
    "chars": 4134,
    "preview": "# coding=utf-8\nimport torch\nfrom torch.optim import Optimizer\n\nclass AdamWeightDecayOptimizer(Optimizer):\n    \"\"\"A basic"
  },
  {
    "path": "biglm.py",
    "chars": 7558,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom utils import gelu, LayerNorm\nfrom transformer im"
  },
  {
    "path": "data.py",
    "chars": 8186,
    "preview": "import random\nimport torch\nimport numpy as np\n\nPAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'\nBOC, EOC = '<boc>"
  },
  {
    "path": "eval.py",
    "chars": 2078,
    "preview": "import sys\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport cop"
  },
  {
    "path": "eval.sh",
    "chars": 149,
    "preview": "#!/bin/bash\npath=./ckpt/\nFILES=$path/*\nfor f in $FILES; do\n    echo \"==========================\" ${f##*/}\n    python -u "
  },
  {
    "path": "label_smoothing.py",
    "chars": 1291,
    "preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nclass LabelSmoothing(nn.Module):\r\n    \"Implement"
  },
  {
    "path": "metrics.py",
    "chars": 7545,
    "preview": "import os\nimport sys\nimport numpy as np\nfrom pypinyin import Style, lazy_pinyin\n\nfrom data import PUNCS\n\nyunjiaos = {\n  "
  },
  {
    "path": "optim.py",
    "chars": 936,
    "preview": "# -*- coding: utf-8 -*-\n\nclass Optim:\n    \"Optim wrapper that implements rate.\"\n    def __init__(self, model_size, facto"
  },
  {
    "path": "polish.py",
    "chars": 5139,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport t"
  },
  {
    "path": "polish.sh",
    "chars": 21,
    "preview": "python3 -u polish.py\n"
  },
  {
    "path": "prepare_data.py",
    "chars": 2492,
    "preview": "import sys, re\nfrom collections import Counter\nimport random\ncnt = Counter()\n\nf_ci = \"./data/ci.txt\"\nf_cipai = \"./data/c"
  },
  {
    "path": "test.py",
    "chars": 10828,
    "preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport t"
  },
  {
    "path": "test.sh",
    "chars": 19,
    "preview": "python3 -u test.py\n"
  },
  {
    "path": "train.py",
    "chars": 7681,
    "preview": "# coding=utf-8\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimpor"
  },
  {
    "path": "train.sh",
    "chars": 1052,
    "preview": "CUDA_VISIBLE_DEVICES=1 \\\npython3 -u train.py --embed_dim 768 \\\n                      --ff_embed_dim 3072 \\\n             "
  },
  {
    "path": "transformer.py",
    "chars": 13554,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom utils import gelu"
  },
  {
    "path": "utils.py",
    "chars": 1985,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nfrom collections import defaultdict\n\nimport math\n\ndef g"
  }
]

About this extraction

This page contains the full source code of the lipiji/SongNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (75.1 KB), approximately 21.1k tokens, and a symbol index with 103 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!