Repository: lipiji/SongNet
Branch: master
Commit: f8c7064c21c7
Files: 20
Total size: 75.1 KB
Directory structure:
gitextract_oqyy3n7z/
├── .gitignore
├── LICENSE
├── README.md
├── adam.py
├── biglm.py
├── data.py
├── eval.py
├── eval.sh
├── label_smoothing.py
├── metrics.py
├── optim.py
├── polish.py
├── polish.sh
├── prepare_data.py
├── test.py
├── test.sh
├── train.py
├── train.sh
├── transformer.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
*.log
ckpt
/data*/*
/model*/*
/ckpt*/*
/result*/*
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2021 Piji Li
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# SongNet
SongNet: SongCi + Song (Lyrics) + Sonnet + etc.
```
@inproceedings{li-etal-2020-rigid,
title = "Rigid Formats Controlled Text Generation",
author = "Li, Piji and Zhang, Haisong and Liu, Xiaojiang and Shi, Shuming",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-main.68",
doi = "10.18653/v1/2020.acl-main.68",
pages = "742--751"
}
```
### Run
- python prepare_data.py
- ./train.sh
### Evaluation
- Modify test.py: m_path = the best dev model
- ./test.sh
- python metrics.py
### Polish
- ./polish.sh
### Download
- The pretrained Chinese Language Model: https://drive.google.com/file/d/1g2tGyUwPe86vPn2nub1vkQva5lwtZ6Rd/view
- The finetuned SongCi model: https://drive.google.com/file/d/16A2AzuU7slf7xj2QdLcBAorUCCaCk650/view
#### Reference
- Guyu: https://github.com/lipiji/Guyu
- Pretraining:https://github.com/lipiji/big_tpl_zh_10_base
================================================
FILE: adam.py
================================================
# coding=utf-8
import torch
from torch.optim import Optimizer
class AdamWeightDecayOptimizer(Optimizer):
"""A basic Adam optimizer that includes "correct" L2 weight decay.
https://github.com/google-research/bert/blob/master/optimization.py
https://raw.githubusercontent.com/pytorch/pytorch/v1.0.0/torch/optim/adam.py"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0, amsgrad=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad)
super(AdamWeightDecayOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(AdamWeightDecayOptimizer, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
amsgrad = group['amsgrad']
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
if amsgrad:
max_exp_avg_sq = state['max_exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
# Use the max. for normalizing running avg. of gradient
denom = max_exp_avg_sq.sqrt().add_(group['eps'])
else:
denom = exp_avg_sq.sqrt().add_(group['eps'])
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want ot decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
update = (exp_avg/denom).add_(group['weight_decay'], p.data)
p.data.add_(-group['lr'], update)
return loss
================================================
FILE: biglm.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
from utils import gelu, LayerNorm
from transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, SelfAttentionMask
from label_smoothing import LabelSmoothing
class BIGLM(nn.Module):
def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx=None):
super(BIGLM, self).__init__()
self.vocab = vocab
self.embed_dim = embed_dim
self.tok_embed = Embedding(self.vocab.size, embed_dim, self.vocab.padding_idx)
self.pos_embed = LearnedPositionalEmbedding(embed_dim, device=local_rank)
self.layers = nn.ModuleList()
for i in range(layers):
self.layers.append(TransformerLayer(embed_dim, ff_embed_dim, num_heads, dropout, with_external=True))
self.emb_layer_norm = LayerNorm(embed_dim)
self.one_more = nn.Linear(embed_dim, embed_dim)
self.one_more_layer_norm = LayerNorm(embed_dim)
self.out_proj = nn.Linear(embed_dim, self.vocab.size)
self.attn_mask = SelfAttentionMask(device=local_rank)
self.smoothing = LabelSmoothing(local_rank, self.vocab.size, self.vocab.padding_idx, smoothing_factor)
self.dropout = dropout
self.device = local_rank
self.approx = approx
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.one_more.bias, 0.)
nn.init.normal_(self.one_more.weight, std=0.02)
nn.init.constant_(self.out_proj.bias, 0.)
nn.init.normal_(self.out_proj.weight, std=0.02)
def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):
seq_len, bsz = y.size()
y_pred = torch.log(y_pred.clamp(min=1e-8))
loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1))
if avg:
return loss / torch.sum(y_mask)
else:
return loss / bsz
def nll_loss(self, y_pred, y, y_mask, avg=True):
cost = -torch.log(torch.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1)))
cost = cost.view(y.shape)
y_mask = y_mask.view(y.shape)
if avg:
cost = torch.sum(cost * y_mask, 0) / torch.sum(y_mask, 0)
else:
cost = torch.sum(cost * y_mask, 0)
cost = cost.view((y.size(1), -1))
ppl = 2 ** cost
return cost.sum().item(), ppl.sum().item()
def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos, incremental_state=None):
seq_len, bsz = ys_inp.size()
x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
x = self.emb_layer_norm(x)
padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
if not padding_mask.any():
padding_mask = None
if incremental_state is None:
self_attn_mask = self.attn_mask(seq_len)
incremental_state = {}
else:
x = x[-1, :, :].unsqueeze(0)
self_attn_mask = None
for layer in self.layers:
x, _ ,_ = layer.work_incremental(x, self_padding_mask=padding_mask, \
self_attn_mask=self_attn_mask, \
external_memories = enc, \
external_padding_mask = src_padding_mask, \
incremental_state = incremental_state)
x = self.one_more_layer_norm(gelu(self.one_more(x)))
probs = torch.softmax(self.out_proj(x), -1)
_, pred_y = probs.max(-1)
return probs, pred_y, incremental_state
def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):
seq_len, bsz = ys_inp.size()
self_attn_mask = self.attn_mask(seq_len)
x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
x = self.emb_layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
padding_mask = torch.eq(ys_inp, self.vocab.padding_idx)
if not padding_mask.any():
padding_mask = None
for layer in self.layers:
x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
self_attn_mask = self_attn_mask, \
external_memories = enc, \
external_padding_mask = src_padding_mask,)
x = self.one_more_layer_norm(gelu(self.one_more(x)))
probs = torch.softmax(self.out_proj(x), -1)
_, pred_y = probs.max(-1)
return probs, pred_y
def encode(self, xs_tpl, xs_seg, xs_pos):
padding_mask = torch.eq(xs_tpl, self.vocab.padding_idx)
x = self.tok_embed(xs_tpl) + self.tok_embed(xs_seg) + self.tok_embed(xs_pos)
x = self.emb_layer_norm(x)
return x, padding_mask
def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):
enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
seq_len, bsz = ys_inp.size()
self_attn_mask = self.attn_mask(seq_len)
x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
x = self.emb_layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
if not padding_mask.any():
padding_mask = None
for layer in self.layers:
x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
self_attn_mask = self_attn_mask, \
external_memories = enc, \
external_padding_mask = src_padding_mask,)
x = self.one_more_layer_norm(gelu(self.one_more(x)))
pred = torch.softmax(self.out_proj(x), -1)
nll, ppl = self.nll_loss(pred, ys_truth, msk)
return nll, ppl, bsz
def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk):
enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos)
seq_len, bsz = ys_inp.size()
self_attn_mask = self.attn_mask(seq_len)
x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos)
x = self.emb_layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
padding_mask = torch.eq(ys_truth, self.vocab.padding_idx)
if not padding_mask.any():
padding_mask = None
for layer in self.layers:
x, _ ,_ = layer(x, self_padding_mask=padding_mask, \
self_attn_mask = self_attn_mask, \
external_memories = enc, \
external_padding_mask = src_padding_mask,)
x = self.one_more_layer_norm(gelu(self.one_more(x)))
pred = torch.softmax(self.out_proj(x), -1)
loss = self.label_smotthing_loss(pred, ys_truth, msk)
_, pred_y = pred.max(-1)
tot_tokens = msk.float().sum().item()
acc = (torch.eq(pred_y, ys_truth).float() * msk).sum().item()
nll, ppl = self.nll_loss(pred, ys_truth, msk)
return (pred_y, ys_truth), loss, acc, nll, ppl, tot_tokens, bsz
================================================
FILE: data.py
================================================
import random
import torch
import numpy as np
PAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'
BOC, EOC = '<boc>', '<eoc>'
LS, RS, SP = '<s>', '</s>', ' '
CS = ['<c-1>'] + ['<c' + str(i) + '>' for i in range(32)] # content
SS = ['<s-1>'] + ['<s' + str(i) + '>' for i in range(512)] # segment
PS = ['<p-1>'] + ['<p' + str(i) + '>' for i in range(512)] # position
TS = ['<t-1>'] + ['<t' + str(i) + '>' for i in range(32)] # other types
PUNCS = set([",", ".", "?", "!", ":", ",", "。", "?", "!", ":"])
BUFSIZE = 4096000
def ListsToTensor(xs, vocab=None):
max_len = max(len(x) for x in xs)
ys = []
for x in xs:
if vocab is not None:
y = vocab.token2idx(x) + [vocab.padding_idx]*(max_len - len(x))
else:
y = x + [0]*(max_len - len(x))
ys.append(y)
return ys
def _back_to_text_for_check(x, vocab):
w = x.t().tolist()
for sent in vocab.idx2token(w):
print (' '.join(sent))
def batchify(data, vocab):
xs_tpl, xs_seg, xs_pos, \
ys_truth, ys_inp, \
ys_tpl, ys_seg, ys_pos, msk = [], [], [], [], [], [], [], [], []
for xs_tpl_i, xs_seg_i, xs_pos_i, ys_i, ys_tpl_i, ys_seg_i, ys_pos_i in data:
xs_tpl.append(xs_tpl_i)
xs_seg.append(xs_seg_i)
xs_pos.append(xs_pos_i)
ys_truth.append(ys_i)
ys_inp.append([BOS] + ys_i[:-1])
ys_tpl.append(ys_tpl_i)
ys_seg.append(ys_seg_i)
ys_pos.append(ys_pos_i)
msk.append([1 for i in range(len(ys_i))])
xs_tpl = torch.LongTensor(ListsToTensor(xs_tpl, vocab)).t_().contiguous()
xs_seg = torch.LongTensor(ListsToTensor(xs_seg, vocab)).t_().contiguous()
xs_pos = torch.LongTensor(ListsToTensor(xs_pos, vocab)).t_().contiguous()
ys_truth = torch.LongTensor(ListsToTensor(ys_truth, vocab)).t_().contiguous()
ys_inp = torch.LongTensor(ListsToTensor(ys_inp, vocab)).t_().contiguous()
ys_tpl = torch.LongTensor(ListsToTensor(ys_tpl, vocab)).t_().contiguous()
ys_seg = torch.LongTensor(ListsToTensor(ys_seg, vocab)).t_().contiguous()
ys_pos = torch.LongTensor(ListsToTensor(ys_pos, vocab)).t_().contiguous()
msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()
return xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk
def s2t(strs, vocab):
inp, msk = [], []
for x in strs:
inp.append(x)
msk.append([1 for i in range(len(x))])
inp = torch.LongTensor(ListsToTensor(inp, vocab)).t_().contiguous()
msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous()
return inp, msk
def s2xy(lines, vocab, max_len, min_len):
data = []
for line in lines:
res = parse_line(line, max_len, min_len)
if not res:
continue
data.append(res)
return batchify(data, vocab)
def parse_line(line, max_len, min_len):
line = line.strip()
if not line:
return None
fs = line.split("<s2>")
author, cipai = fs[0].split("<s1>")
sents = fs[1].strip()
if len(sents) > max_len:
sents = sents[:max_len]
if len(sents) < min_len:
return None
sents = sents.split("</s>")
ys = []
xs_tpl = []
xs_seg = []
xs_pos = []
ctx = cipai
ws = [w for w in ctx]
xs_tpl = ws + [EOC]
xs_seg = [SS[0] for w in ws] + [EOC]
xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]
ys_tpl = []
ys_seg = []
ys_pos = []
for si, sent in enumerate(sents):
ws = []
sent = sent.strip()
if not sent:
continue
for w in sent:
ws.append(w)
if w.strip() and w not in PUNCS:
ys_tpl.append(CS[2])
else:
ys_tpl.append(CS[1])
ys += ws + [RS]
if ws[-1] in PUNCS:
ys_tpl[-2] = CS[3]
else:
ys_tpl[-1] = CS[3]
ys_tpl += [RS]
ys_seg += [SS[si + 1] for w in ws] + [RS]
ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]
ys += [EOS]
ys_tpl += [EOS]
ys_seg += [EOS]
ys_pos += [EOS]
xs_tpl += ys_tpl
xs_seg += ys_seg
xs_pos += ys_pos
if len(ys) < min_len:
return None
return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos
def s2xy_polish(lines, vocab, max_len, min_len):
data = []
for line in lines:
res = parse_line_polish(line, max_len, min_len)
data.append(res)
return batchify(data, vocab)
def parse_line_polish(line, max_len, min_len):
line = line.strip()
if not line:
return None
fs = line.split("<s2>")
author, cipai = fs[0].split("<s1>")
sents = fs[1].strip()
if len(sents) > max_len:
sents = sents[:max_len]
if len(sents) < min_len:
return None
sents = sents.split("</s>")
ys = []
xs_tpl = []
xs_seg = []
xs_pos = []
ctx = cipai
ws = [w for w in ctx]
xs_tpl = ws + [EOC]
xs_seg = [SS[0] for w in ws] + [EOC]
xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC]
ys_tpl = []
ys_seg = []
ys_pos = []
for si, sent in enumerate(sents):
ws = []
sent = sent.strip()
if not sent:
continue
for w in sent:
ws.append(w)
if w == "_":
ys_tpl.append(CS[2])
else:
ys_tpl.append(w)
ys += ws + [RS]
ys_tpl += [RS]
ys_seg += [SS[si + 1] for w in ws] + [RS]
ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS]
ys += [EOS]
ys_tpl += [EOS]
ys_seg += [EOS]
ys_pos += [EOS]
xs_tpl += ys_tpl
xs_seg += ys_seg
xs_pos += ys_pos
if len(ys) < min_len:
return None
return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos
class DataLoader(object):
def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):
self.batch_size = batch_size
self.vocab = vocab
self.max_len_y = max_len_y
self.min_len_y = min_len_y
self.filename = filename
self.stream = open(self.filename, encoding='utf8')
self.epoch_id = 0
def __iter__(self):
lines = self.stream.readlines(BUFSIZE)
if not lines:
self.epoch_id += 1
self.stream.close()
self.stream = open(self.filename, encoding='utf8')
lines = self.stream.readlines(BUFSIZE)
data = []
for line in lines[:-1]: # the last sent may be imcomplete
res = parse_line(line, self.max_len_y, self.min_len_y)
if not res:
continue
data.append(res)
random.shuffle(data)
idx = 0
while idx < len(data):
yield batchify(data[idx:idx+self.batch_size], self.vocab)
idx += self.batch_size
class Vocab(object):
def __init__(self, filename, min_occur_cnt, specials = None):
idx2token = [PAD, UNK, BOS, EOS] + [BOC, EOC, LS, RS, SP] + CS + SS + PS + TS \
+ (specials if specials is not None else [])
for line in open(filename, encoding='utf8').readlines():
try:
token, cnt = line.strip().split()
except:
continue
if int(cnt) >= min_occur_cnt:
idx2token.append(token)
self._token2idx = dict(zip(idx2token, range(len(idx2token))))
self._idx2token = idx2token
self._padding_idx = self._token2idx[PAD]
self._unk_idx = self._token2idx[UNK]
@property
def size(self):
return len(self._idx2token)
@property
def unk_idx(self):
return self._unk_idx
@property
def padding_idx(self):
return self._padding_idx
def random_token(self):
return self.idx2token(1 + np.random.randint(self.size-1))
def idx2token(self, x):
if isinstance(x, list):
return [self.idx2token(i) for i in x]
return self._idx2token[x]
def token2idx(self, x):
if isinstance(x, list):
return [self.token2idx(i) for i in x]
return self._token2idx.get(x, self.unk_idx)
================================================
FILE: eval.py
================================================
import sys
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy
import time
from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy
gpu = int(sys.argv[2]) if len(sys.argv) > 1 else 0
def init_model(m_path, device, vocab):
ckpt= torch.load(m_path, map_location='cpu')
lm_args = ckpt['args']
lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1, lm_args.approx)
lm_model.load_state_dict(ckpt['model'])
lm_model = lm_model.cuda(device)
lm_model.eval()
return lm_model, lm_vocab, lm_args
#m_path = "./ckpt_d101_6/epoch5_batch_139999"
m_path = sys.argv[1] if len(sys.argv) > 1 else None
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./data/vocab.txt")
ds = []
with open("./data/dev.txt", "r") as f:
for line in f:
line = line.strip()
if line:
ds.append(line)
print(len(ds))
local_rank = gpu
batch_size = 10
batches = round(len(ds) / batch_size)
idx = 0
avg_nll = 0.
avg_ppl = 0.
count = 0.
while idx < len(ds):
cplb = ds[idx:idx + batch_size]
xs_tpl, xs_seg, xs_pos, \
ys_truth, ys_inp, \
ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)
xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_truth = ys_truth.cuda(local_rank)
ys_inp = ys_inp.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
msk = msk.cuda(local_rank)
nll, ppl, bsz = lm_model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
avg_nll += nll
avg_ppl += ppl
count += bsz
idx += batch_size
if count % 200 == 0:
print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count)
print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count)
================================================
FILE: eval.sh
================================================
#!/bin/bash
path=./ckpt/
FILES=$path/*
for f in $FILES; do
echo "==========================" ${f##*/}
python -u eval.py $path${f##*/} 1
done
================================================
FILE: label_smoothing.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self, device, size, padding_idx, label_smoothing=0.0):
super(LabelSmoothing, self).__init__()
assert 0.0 < label_smoothing <= 1.0
self.padding_idx = padding_idx
self.size = size
self.device = device
self.smoothing_value = label_smoothing / (size - 2)
self.one_hot = torch.full((1, size), self.smoothing_value).to(device)
self.one_hot[0, self.padding_idx] = 0
self.confidence = 1.0 - label_smoothing
def forward(self, output, target):
real_size = output.size(1)
if real_size > self.size:
real_size -= self.size
else:
real_size = 0
model_prob = self.one_hot.repeat(target.size(0), 1)
if real_size > 0:
ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device)
model_prob = torch.cat((model_prob, ext_zeros), -1)
model_prob.scatter_(1, target, self.confidence)
model_prob.masked_fill_((target == self.padding_idx), 0.)
return F.kl_div(output, model_prob, reduction='sum')
================================================
FILE: metrics.py
================================================
import os
import sys
import numpy as np
from pypinyin import Style, lazy_pinyin
from data import PUNCS
yunjiaos = {
"0":["a", "ia", "ua", "va", "üa"],
"1":["e", "o", "uo", "ie", "ue", "üe", "ve"],
"2":["u"],
"3":["i", "ü", "v"],
"4":["ai", "uai"],
"5":["ao", "iao"],
"6":["ou", "iu", "iou"],
"7":["an", "ian", "uan", "üan", "van"],
"8":["en", "in", "un", "ün", "vn"],
"9":["ang", "iang", "uang"],
"10":["eng", "ing", "ueng", "ong", "iong"],
"11":["er"],
"12":["ei", "ui", "uei", "vei"],
}
yun2id = {}
for yid, yws in yunjiaos.items():
for w in yws:
yun2id[w] = yid
def eval_tpl(sents1, sents2):
n = 0.
if len(sents1) > len(sents2):
sents1 = sents1[:len(sents2)]
for i, x in enumerate(sents1):
y = sents2[i]
if len(x) != len(y):
continue
px, py = [], []
for w in x:
if w in PUNCS:
px.append(w)
for w in y:
if w in PUNCS:
py.append(w)
if px == py:
n += 1
p = n / len(sents2)
r = n / len(sents1)
f = 2 * p * r / (p + r + 1e-16)
return p, r, f, n, len(sents1), len(sents2)
def rhythm_labellig(sents):
rhys = []
for sent in sents:
w = sent[-1]
if w in PUNCS and len(sent) > 1:
w = sent[-2]
yunmu = lazy_pinyin(w, style=Style.FINALS)
rhys.append(yunmu[0])
assert len(rhys) == len(sents)
rhy_map = {}
for i, r in enumerate(rhys):
if r in yun2id:
rid = yun2id[r]
if rid in rhy_map:
rhy_map[rid] += [i]
else:
rhy_map[rid] = [i]
else:
pass
max_len_yuns = -1
max_rid = ""
for rid, yuns in rhy_map.items():
if len(yuns) > max_len_yuns:
max_len_yuns = len(yuns)
max_rid = rid
res = []
for i in range(len(sents)):
if max_rid in rhy_map and i in rhy_map[max_rid]:
res.append(1)
else:
res.append(-1)
return res
def eval_rhythm(sents1, sents2):
n = 0.
if len(sents1) > len(sents2):
sents1 = sents1[:len(sents2)]
rhys1 = rhythm_labellig(sents1)
rhys2 = rhythm_labellig(sents2)
n1, n2 = 0., 0.
for v in rhys1:
if v == 1:
n1 += 1
for v in rhys2:
if v == 1:
n2 += 1
for i, v1 in enumerate(rhys1):
v2 = rhys2[i]
if v1 == 1 and v1 == v2:
n += 1
p = n / (n2 + 1e-16)
r = n / (n1 + 1e-16)
f1 = 2 * p * r / (p + r + 1e-16)
return p, r, f1, n, n1, n2
def eval_endings(sents1, sents2):
n = 0.
if len(sents1) > len(sents2):
sents1 = sents1[:len(sents2)]
sents0 = []
for si, sent1 in enumerate(sents1):
sent2 = sents2[si]
if len(sent2) <= len(sent1):
sents0.append(sent2)
else:
sents0.append(sent2[:len(sent1) - 1] + sent1[-1])
sent = "</s>".join(sents0)
return sent
def eval(res_file, fid):
docs = []
with open(res_file) as f:
for line in f:
line = line.strip()
if not line:
continue
fs = line.split("\t")
if len(fs) != 2:
print("error", line)
continue
x, y = fs
docs.append((x, y))
print(len(docs))
ugrams_ = []
bigrams_ = []
p_, r_, f1_ = 0., 0., 0.
n0_, n1_, n2_ = 0., 0., 0.
p__, r__, f1__ = 0., 0., 0.
n0__, n1__, n2__ = 0., 0., 0.
d1_, d2_ = 0., 0.
d4ends = []
for x, y in docs:
topic, content = x.split("<s2>")
author, topic = topic.split("<s1>")
sents1 = content.split("</s>")
y = y.replace("<bos>", "")
sents2 = y.split("</s>")
sents1_ = []
for sent in sents1:
sent = sent.strip()
if sent:
sents1_.append(sent)
sents1 = sents1_
sents2_ = []
for sent in sents2:
sent = sent.strip()
if sent:
sents2_.append(sent)
sents2 = sents2_
p, r, f1, n0, n1, n2 = eval_tpl(sents1, sents2)
p_ += p
r_ += r
f1_ += f1
n0_ += n0
n1_ += n1
n2_ += n2
ugrams = [w for w in ''.join(sents2)]
bigrams = []
for bi in range(len(ugrams) - 1):
bigrams.append(ugrams[bi] + ugrams[bi+1])
d1_ += len(set(ugrams)) / float(len(ugrams))
d2_ += len(set(bigrams)) / float(len(bigrams))
ugrams_ += ugrams
bigrams_ += bigrams
p, r, f1, n0, n1, n2 = eval_rhythm(sents1, sents2)
p__ += p
r__ += r
f1__ += f1
n0__ += n0
n1__ += n1
n2__ += n2
d4end = eval_endings(sents1, sents2)
d4ends.append(author + "<s1>" + topic + "<s2>" + d4end)
tpl_macro_p = p_ / len(docs)
tpl_macro_r = r_ / len(docs)
tpl_macro_f1 = 2 * tpl_macro_p * tpl_macro_r / (tpl_macro_p + tpl_macro_r)
tpl_micro_p = n0_ / n2_
tpl_micro_r = n0_ / n1_
tpl_micro_f1 = 2 * tpl_micro_p * tpl_micro_r / (tpl_micro_p + tpl_micro_r)
rhy_macro_p = p__ / len(docs)
rhy_macro_r = r__ / len(docs)
rhy_macro_f1 = 2 * rhy_macro_p * rhy_macro_r / (rhy_macro_p + rhy_macro_r)
rhy_micro_p = n0__ / n2__
rhy_micro_r = n0__ / n1__
rhy_micro_f1 = 2 * rhy_micro_p * rhy_micro_r / (rhy_micro_p + rhy_micro_r)
macro_dist1 = d1_ / len(docs)
macro_dist2 = d2_ / len(docs)
micro_dist1 = len(set(ugrams_)) / float(len(ugrams_))
micro_dist2 = len(set(bigrams_)) / float(len(bigrams_))
with open("./results_4ending/res4end" + str(fid) + ".txt", "w") as fo:
for line in d4ends:
fo.write(line + "\n")
return tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2
tpl_macro_f1_, tpl_micro_f1_, rhy_macro_f1_, rhy_micro_f1_, \
macro_dist1_, micro_dist1_, macro_dist2_, micro_dist2_ = [], [], [], [], [], [], [], []
abalation = "top-32"
for i in range(5):
f_name = "./results/"+abalation+"/out" +str(i+1)+".txt"
if not os.path.exists(f_name):
continue
tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2 = eval(f_name, i + 1)
print(tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2)
tpl_macro_f1_.append(tpl_macro_f1)
tpl_micro_f1_.append(tpl_micro_f1)
rhy_macro_f1_.append(rhy_macro_f1)
rhy_micro_f1_.append(rhy_micro_f1)
macro_dist1_.append(macro_dist1)
micro_dist1_.append(micro_dist1)
macro_dist2_.append(macro_dist2)
micro_dist2_.append(micro_dist2)
print()
print("tpl_macro_f1", np.mean(tpl_macro_f1_), np.std(tpl_macro_f1_, ddof=1))
print("tpl_micro_f1", np.mean(tpl_micro_f1_), np.std(tpl_micro_f1_, ddof=1))
print("rhy_macro_f1", np.mean(rhy_macro_f1_), np.std(rhy_macro_f1_, ddof=1))
print("rhy_micro_f1", np.mean(rhy_micro_f1_), np.std(rhy_micro_f1_, ddof=1))
print("macro_dist1", np.mean(macro_dist1_), np.std(macro_dist1_, ddof=1))
print("micro_dist1", np.mean(micro_dist1_), np.std(micro_dist1_, ddof=1))
print("macro_dist2", np.mean(macro_dist2_), np.std(macro_dist2_, ddof=1))
print("micro_dist2", np.mean(micro_dist2_), np.std(micro_dist2_, ddof=1))
================================================
FILE: optim.py
================================================
# -*- coding: utf-8 -*-
class Optim:
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5)))
def state_dict(self):
return self.optimizer.state_dict()
def load_state_dict(self, m):
self.optimizer.load_state_dict(m)
================================================
FILE: polish.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy
import time
from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy_polish
gpu = 0
def init_model(m_path, device, vocab):
ckpt= torch.load(m_path, map_location='cpu')
lm_args = ckpt['args']
lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)
lm_model.load_state_dict(ckpt['model'])
lm_model = lm_model.cuda(device)
lm_model.eval()
return lm_model, lm_vocab, lm_args
m_path = "./model/songci.ckpt"
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")
MAX_LEN = 300
k = 32
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
start = time.time()
incremental_state = None
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
res = []
for l in range(inp_ys_tpl.size(0)):
probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
incremental_state)
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
if l == 0:
logits = probs[len(s[i]) - 1, i]
else:
logits = probs[0, i]
ps, idx = torch.topk(logits, k=k)
ps = ps / torch.sum(ps)
sampled = torch.multinomial(ps, num_samples = 1)
sampled_idx = idx[sampled]
next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
s_ = []
bidx = [1] * len(s)
for idx, (sent, t) in enumerate(zip(s, next_tk)):
if t == "<eos>":
res.append(sent)
bidx[idx] = 0
else:
s_.append(sent + [t])
if not s_:
break
s = s_
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
bidx = torch.BoolTensor(bidx).cuda(gpu)
incremental_state["bidx"] = bidx
res += s_
#for i in res:
# print(''.join(i))
print(time.time()-start)
return res
def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
start = time.time()
res = []
for l in range(inp_ys_tpl.size(0)):
probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
logits = probs[len(s[i]) - 1, i]
ps, idx = torch.topk(logits, k=k)
ps = ps / torch.sum(ps)
sampled = torch.multinomial(ps, num_samples = 1)
sampled_idx = idx[sampled]
next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
s_ = []
for sent, t in zip(s, next_tk):
if t == "<eos>":
res.append(sent)
else:
s_.append(sent + [t])
if not s_:
break
s = s_
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
res += s_
#for i in res:
# print(''.join(i))
#print(time.time()-start)
return res
ds = []
with open("./data/polish_tpl.txt", "r") as f:
for line in f:
line = line.strip()
if line:
ds.append(line)
print(len(ds))
local_rank = gpu
batch_size = 1
cp_size = 1
batches = round(len(ds) / batch_size)
for i in range(5):
fo = open("./results/out"+str(i+1)+".txt", "w")
idx = 0
while idx < len(ds):
lb = ds[idx:idx + batch_size]
cplb = []
for line in lb:
cplb += [line for i in range(cp_size)]
print(cplb)
xs_tpl, xs_seg, xs_pos, \
ys_truth, ys_inp, \
ys_tpl, ys_seg, ys_pos, msk = s2xy_polish(cplb, lm_vocab, lm_args.max_len,2)
xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)
s = [['<bos>']] * batch_size * cp_size
res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)
for i, line in enumerate(cplb):
r = ''.join(res[i])
print(line)
print(r)
fo.write(line + "\t" + r + "\n")
idx += batch_size
fo.close()
================================================
FILE: polish.sh
================================================
python3 -u polish.py
================================================
FILE: prepare_data.py
================================================
import sys, re
from collections import Counter
import random
cnt = Counter()
f_ci = "./data/ci.txt"
f_cipai = "./data/cipai.txt"
cipai = Counter()
with open(f_cipai) as f:
for line in f:
line = line.strip()
fs = line.split()
cipai.update(fs)
cipai = cipai.keys()
docs = {}
with open(f_ci) as f:
for line in f:
line = line.strip()
fs = line.split("<s1>")
author = fs[0]
topic, content = fs[1].split("<s2>")
if "・" in topic:
t1, t2 = topic.split("・")
if t1 == t2:
topic = t1
else:
if t1 in cipai:
topic = t1
elif t2 in cipai:
topic = t2
else:
topic = t1
content = content.replace("、", ",")
sents = content.split("</s>")
ws = [w for w in author + topic + ''.join(sents)]
cnt.update(ws)
if topic not in docs:
docs[topic] = []
docs[topic].append(author + "<s1>" + topic + "<s2>" + '</s>'.join(sents))
topics = list(docs.keys())
print(len(topics))
random.shuffle(topics)
topics_train = topics[:len(topics)-50]
topics_dev_test = topics[-50:]
topics_dev = topics_dev_test[:25]
topics_test = topics_dev_test[-25:]
docs_train = []
docs_dev = []
docs_test = []
for t in topics_train:
docs_train.extend(docs[t])
for t in topics_dev:
docs_dev.extend(docs[t])
for t in topics_test:
docs_test.extend(docs[t])
random.shuffle(docs_train)
random.shuffle(docs_dev)
random.shuffle(docs_test)
print(len(docs_train), len(docs_dev), len(docs_test))
train_cps = []
dev_cps = []
test_cps = []
with open('./data/train.txt', 'w', encoding ='utf8') as f:
for x in docs_train:
s = x.split("<s2>")[0]
train_cps.append(s.split("<s1>")[1])
f.write(x + '\n')
print(len(set(train_cps)))
with open('./data/dev.txt', 'w', encoding ='utf8') as f:
for x in docs_dev:
s = x.split("<s2>")[0]
dev_cps.append(s.split("<s1>")[1])
f.write(x + '\n')
print(len(set(dev_cps)))
with open('./data/test.txt', 'w', encoding ='utf8') as f:
for x in docs_test:
s = x.split("<s2>")[0]
test_cps.append(s.split("<s1>")[1])
f.write(x + '\n')
print(len(set(test_cps)))
print("vocab")
with open('./data/vocab.txt', 'w', encoding ='utf8') as f:
for x, y in cnt.most_common():
f.write(x + '\t' + str(y) + '\n')
print("done")
================================================
FILE: test.py
================================================
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import copy
import time
from biglm import BIGLM
from data import Vocab, DataLoader, s2t, s2xy
def init_seeds():
random.seed(123)
torch.manual_seed(123)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(123)
#init_seeds()
gpu = 1
def init_model(m_path, device, vocab):
ckpt= torch.load(m_path, map_location='cpu')
lm_args = ckpt['args']
lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[])
lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1)
lm_model.load_state_dict(ckpt['model'])
lm_model = lm_model.cuda(device)
lm_model.eval()
return lm_model, lm_vocab, lm_args
m_path = "./model/songci.ckpt"
lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt")
k = 32
def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
start = time.time()
incremental_state = None
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
res = []
for l in range(inp_ys_tpl.size(0)):
probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
incremental_state)
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
if l == 0:
logits = probs[len(s[i]) - 1, i]
else:
logits = probs[0, i]
ps, idx = torch.topk(logits, k=k)
ps = ps / torch.sum(ps)
sampled = torch.multinomial(ps, num_samples = 1)
sampled_idx = idx[sampled]
next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
s_ = []
bidx = [1] * len(s)
for idx, (sent, t) in enumerate(zip(s, next_tk)):
if t == "<eos>":
res.append(sent)
bidx[idx] = 0
else:
s_.append(sent + [t])
if not s_:
break
s = s_
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
bidx = torch.BoolTensor(bidx).cuda(gpu)
incremental_state["bidx"] = bidx
res += s_
#for i in res:
# print(''.join(i))
print(time.time()-start)
return res
def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
start = time.time()
res = []
for l in range(inp_ys_tpl.size(0)):
probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:])
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
logits = probs[len(s[i]) - 1, i]
ps, idx = torch.topk(logits, k=k)
ps = ps / torch.sum(ps)
sampled = torch.multinomial(ps, num_samples = 1)
sampled_idx = idx[sampled]
next_tk.append(lm_vocab.idx2token(sampled_idx.item()))
s_ = []
for sent, t in zip(s, next_tk):
if t == "<eos>":
res.append(sent)
else:
s_.append(sent + [t])
if not s_:
break
s = s_
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
res += s_
#for i in res:
# print(''.join(i))
#print(time.time()-start)
return res
def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
start = time.time()
incremental_state = None
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
res = []
for l in range(inp_ys_tpl.size(0)):
probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \
inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\
incremental_state)
next_tk = []
for i in range(len(s)):
ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item())
if ctk != "<c1>" and ctk != "<c2>" and ctk != "<c0>":
next_tk.append(ctk)
continue
if l == 0:
pred = pred[len(s[i]) - 1, i]
else:
pred = pred[0, i]
next_tk.append(lm_vocab.idx2token(pred.item()))
s_ = []
bidx = [1] * len(s)
for idx, (sent, t) in enumerate(zip(s, next_tk)):
if t == "<eos>":
res.append(sent)
bidx[idx] = 0
else:
s_.append(sent + [t])
if not s_:
break
s = s_
inp_y, m = s2t(s, lm_vocab)
inp_y = inp_y.cuda(gpu)
bidx = torch.BoolTensor(bidx).cuda(gpu)
incremental_state["bidx"] = bidx
res += s_
#for i in res:
# print(''.join(i))
print(time.time()-start)
return res
def beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos):
beam_size = 5
num_live = 1
num_dead = 0
samples = []
sample_scores = np.zeros(beam_size)
last_traces = [[]]
last_scores = torch.FloatTensor(np.zeros(1)).to(gpu)
x = x.to(gpu)
ys = x
for l in range(inp_ys_tpl.size(0)):
seq_len, bsz = ys.size()
enc_ = enc.repeat(1, bsz, 1)
src_padding_mask_ = src_padding_mask.repeat(1, bsz)
inp_ys_tpl_ = inp_ys_tpl.repeat(1, bsz)
inp_ys_seg_ = inp_ys_seg.repeat(1, bsz)
inp_ys_pos_ = inp_ys_pos.repeat(1, bsz)
y_pred, _ = lm_model.work(enc_, src_padding_mask_, ys, inp_ys_tpl_[0:l+1,:], inp_ys_seg_[0:l+1,:], inp_ys_pos_[0:l+1,:])
dict_size = y_pred.shape[-1]
y_pred = y_pred[-1, :, :]
cand_y_scores = last_scores + torch.log(y_pred) # larger is better
cand_scores = cand_y_scores.flatten()
idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1]
'''
ps, idx_top_joint_scores = torch.topk(cand_scores, 100)
ps = F.softmax(ps)
sampled = torch.multinomial(ps, num_samples = beam_size - num_dead)
idx_top_joint_scores = idx_top_joint_scores[sampled]
'''
idx_last_traces = idx_top_joint_scores / dict_size
idx_word_now = idx_top_joint_scores % dict_size
top_joint_scores = cand_scores[idx_top_joint_scores]
traces_now = []
scores_now = np.zeros((beam_size - num_dead))
ys_now = []
for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)):
traces_now.append(last_traces[j] + [k])
scores_now[i] = copy.copy(top_joint_scores[i])
ys_now.append(copy.copy(ys[:,j]))
num_live = 0
last_traces = []
last_scores = []
ys = []
for i in range(len(traces_now)):
w = lm_vocab.idx2token(traces_now[i][-1].item())
if w == "<eos>":
samples.append([str(e.item()) for e in traces_now[i][:-1]])
sample_scores[num_dead] = scores_now[i]
num_dead += 1
else:
last_traces.append(traces_now[i])
last_scores.append(scores_now[i])
ys.append(ys_now[i])
num_live += 1
if num_live == 0 or num_dead >= beam_size:
break
ys = torch.stack(ys, dim = 1)
last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(gpu)
next_y = []
for e in last_traces:
eid = e[-1].item()
next_y.append(eid)
next_y = np.array(next_y).reshape((1, num_live))
next_y = torch.LongTensor(next_y).to(gpu)
ys = torch.cat([ys, next_y], dim=0)
assert num_live + num_dead == beam_size
# end for loop
if num_live > 0:
for i in range(num_live):
samples.append([str(e.item()) for e in last_traces[i]])
sample_scores[num_dead] = last_scores[i]
num_dead += 1
idx_sorted_scores = np.argsort(sample_scores) # ascending order
sorted_samples = []
sorted_scores = []
filter_idx = []
for e in idx_sorted_scores:
if len(samples[e]) > 0:
filter_idx.append(e)
if len(filter_idx) == 0:
filter_idx = idx_sorted_scores
for e in filter_idx:
sorted_samples.append(samples[e])
sorted_scores.append(sample_scores[e])
res = []
dec_words = []
for sample in sorted_samples[::-1]:
for e in sample:
e = int(e)
dec_words.append(lm_vocab.idx2token(e))
#r = ''.join(dec_words)
#print(r)
res.append(dec_words)
dec_words = []
return res
def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):
x, m = s2t(s, lm_vocab)
return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos)
ds = []
with open("./data/test.txt", "r") as f:
for line in f:
line = line.strip()
if line:
ds.append(line)
print(len(ds))
local_rank = gpu
batch_size = 1
cp_size = 1
batches = round(len(ds) / batch_size)
for i in range(5):
idx = 0
fo = open("./results/top-"+str(k)+"/out"+str(i+1)+".txt", "w")
while idx < len(ds):
lb = ds[idx:idx + batch_size]
cplb = []
for line in lb:
cplb += [line for i in range(cp_size)]
print(cplb)
xs_tpl, xs_seg, xs_pos, \
ys_truth, ys_inp, \
ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2)
xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos)
s = [['<bos>']] * batch_size * cp_size
res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s)
for i, line in enumerate(cplb):
r = ''.join(res[i])
print(line)
print(r)
fo.write(line + "\t" + r + "\n")
idx += batch_size
fo.close()
================================================
FILE: test.sh
================================================
python3 -u test.py
================================================
FILE: train.py
================================================
# coding=utf-8
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
from biglm import BIGLM
from data import Vocab, DataLoader, s2xy
from optim import Optim
import argparse, os
import random
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('--embed_dim', type=int)
parser.add_argument('--ff_embed_dim', type=int)
parser.add_argument('--num_heads', type=int)
parser.add_argument('--layers', type=int)
parser.add_argument('--dropout', type=float)
parser.add_argument('--train_data', type=str)
parser.add_argument('--dev_data', type=str)
parser.add_argument('--vocab', type=str)
parser.add_argument('--min_occur_cnt', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--warmup_steps', type=int)
parser.add_argument('--lr', type=float)
parser.add_argument('--smoothing', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--max_len', type=int)
parser.add_argument('--min_len', type=int)
parser.add_argument('--print_every', type=int)
parser.add_argument('--save_every', type=int)
parser.add_argument('--start_from', type=str, default=None)
parser.add_argument('--save_dir', type=str)
parser.add_argument('--world_size', type=int)
parser.add_argument('--gpus', type=int)
parser.add_argument('--MASTER_ADDR', type=str)
parser.add_argument('--MASTER_PORT', type=str)
parser.add_argument('--start_rank', type=int)
parser.add_argument('--backend', type=str)
return parser.parse_args()
def update_lr(optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def average_gradients(model):
""" Gradient averaging. """
normal = True
size = float(dist.get_world_size())
for param in model.parameters():
if param.grad is not None:
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
else:
normal = False
break
return normal
def eval_epoch(lm_args, model, lm_vocab, local_rank, label):
print("validating...", flush=True)
ds = []
with open(lm_args.dev_data, "r") as f:
for line in f:
line = line.strip()
if line:
ds.append(line)
batch_size = 10
batches = round(len(ds) / batch_size)
idx = 0
avg_nll = 0.
avg_ppl = 0.
count = 0.
while idx < len(ds):
cplb = ds[idx:idx + batch_size]
xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, lm_args.min_len)
xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_truth = ys_truth.cuda(local_rank)
ys_inp = ys_inp.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
msk = msk.cuda(local_rank)
nll, ppl, bsz = model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
avg_nll += nll
avg_ppl += ppl
count += bsz
idx += batch_size
print(label, "nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count, flush=True)
def run(args, local_rank):
""" Distributed Synchronous """
torch.manual_seed(1234)
vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[])
if (args.world_size == 1 or dist.get_rank() == 0):
print ("vocab.size = " + str(vocab.size), flush=True)
model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\
args.num_heads, args.dropout, args.layers, args.smoothing)
if args.start_from is not None:
ckpt = torch.load(args.start_from, map_location='cpu')
model.load_state_dict(ckpt['model'])
model = model.cuda(local_rank)
optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9))
if args.start_from is not None:
optimizer.load_state_dict(ckpt['optimizer'])
train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len)
batch_acm = 0
acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0.
while True:
model.train()
if train_data.epoch_id > 30:
break
for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data:
batch_acm += 1
xs_tpl = xs_tpl.cuda(local_rank)
xs_seg = xs_seg.cuda(local_rank)
xs_pos = xs_pos.cuda(local_rank)
ys_truth = ys_truth.cuda(local_rank)
ys_inp = ys_inp.cuda(local_rank)
ys_tpl = ys_tpl.cuda(local_rank)
ys_seg = ys_seg.cuda(local_rank)
ys_pos = ys_pos.cuda(local_rank)
msk = msk.cuda(local_rank)
model.zero_grad()
res, loss, acc, nll, ppl, ntokens, npairs = model(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk)
loss_acm += loss.item()
acc_acm += acc
nll_acm += nll
ppl_acm += ppl
ntokens_acm += ntokens
npairs_acm += npairs
nxs += npairs
loss.backward()
if args.world_size > 1:
is_normal = average_gradients(model)
else:
is_normal = True
if is_normal:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
else:
print("gradient: none, gpu: " + str(local_rank), flush=True)
continue
if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every:
print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\
%(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \
nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True)
acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0.
if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every:
if not os.path.exists(args.save_dir):
os.mkdir(args.save_dir)
model.eval()
eval_epoch(args, model, vocab, local_rank, "epoch-" + str(train_data.epoch_id) + "-acm-" + str(batch_acm))
model.train()
torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm))
def init_processes(args, local_rank, fn, backend='nccl'):
""" Initialize the distributed environment. """
os.environ['MASTER_ADDR'] = args.MASTER_ADDR
os.environ['MASTER_PORT'] = args.MASTER_PORT
dist.init_process_group(backend, rank=args.start_rank + local_rank, world_size=args.world_size)
fn(args, local_rank)
if __name__ == "__main__":
mp.set_start_method('spawn')
args = parse_config()
if args.world_size == 1:
run(args, 0)
exit(0)
processes = []
for rank in range(args.gpus):
p = mp.Process(target=init_processes, args=(args, rank, run, args.backend))
p.start()
processes.append(p)
for p in processes:
p.join()
================================================
FILE: train.sh
================================================
CUDA_VISIBLE_DEVICES=1 \
python3 -u train.py --embed_dim 768 \
--ff_embed_dim 3072 \
--num_heads 12 \
--layers 12 \
--dropout 0.2 \
--train_data ./data/train.txt \
--dev_data ./data/dev.txt \
--vocab ./data/vocab.txt \
--min_occur_cnt 1 \
--batch_size 32 \
--warmup_steps 8000 \
--lr 0.5 \
--weight_decay 0 \
--smoothing 0.1 \
--max_len 300 \
--min_len 10 \
--world_size 1 \
--gpus 1 \
--start_rank 0 \
--MASTER_ADDR localhost \
--MASTER_PORT 28512 \
--print_every 100 \
--save_every 1000 \
--save_dir ckpt \
--backend nccl
================================================
FILE: transformer.py
================================================
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F
from utils import gelu, LayerNorm, get_incremental_state, set_incremental_state
import math
class TransformerLayer(nn.Module):
def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True):
super(TransformerLayer, self).__init__()
self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
self.fc1 = nn.Linear(embed_dim, ff_embed_dim)
self.fc2 = nn.Linear(ff_embed_dim, embed_dim)
self.attn_layer_norm = LayerNorm(embed_dim)
self.ff_layer_norm = LayerNorm(embed_dim)
self.with_external = with_external
self.dropout = dropout
if self.with_external:
self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout)
self.external_layer_norm = LayerNorm(embed_dim)
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.fc1.weight, std=0.02)
nn.init.normal_(self.fc2.weight, std=0.02)
nn.init.constant_(self.fc1.bias, 0.)
nn.init.constant_(self.fc2.bias, 0.)
def forward(self, x, kv = None,
self_padding_mask = None, self_attn_mask = None,
external_memories = None, external_padding_mask=None,
need_weights = False):
# x: seq_len x bsz x embed_dim
residual = x
if kv is None:
x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
else:
x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.attn_layer_norm(residual + x)
if self.with_external:
residual = x
x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.external_layer_norm(residual + x)
else:
external_attn = None
residual = x
x = gelu(self.fc1(x))
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.ff_layer_norm(residual + x)
return x, self_attn, external_attn
def work_incremental(self, x, self_padding_mask = None, self_attn_mask = None,
external_memories = None, external_padding_mask = None, incremental_state = None):
# x: seq_len x bsz x embed_dim
residual = x
x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state)
x = self.attn_layer_norm(residual + x)
if self.with_external:
residual = x
x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask)
x = self.external_layer_norm(residual + x)
else:
external_attn = None
residual = x
x = gelu(self.fc1(x))
x = self.fc2(x)
x = self.ff_layer_norm(residual + x)
return x, self_attn, external_attn
class MultiheadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True):
super(MultiheadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
self.weights_dropout = weights_dropout
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.in_proj_weight, std=0.02)
nn.init.normal_(self.out_proj.weight, std=0.02)
nn.init.constant_(self.in_proj_bias, 0.)
nn.init.constant_(self.out_proj.bias, 0.)
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, incremental_state = None):
""" Input shape: Time x Batch x Channel
key_padding_mask: Time x batch
attn_mask: tgt_len x src_len
"""
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
bidx = self._get_bidx(incremental_state)
else:
saved_state = None
bidx = None
qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr()
kv_same = key.data_ptr() == value.data_ptr()
tgt_len, bsz, embed_dim = query.size()
assert key.size() == value.size()
if qkv_same:
# self-attention
q, k, v = self.in_proj_qkv(query)
elif kv_same:
# encoder-decoder attention
q = self.in_proj_q(query)
k, v = self.in_proj_kv(key)
else:
q = self.in_proj_q(query)
k = self.in_proj_k(key)
v = self.in_proj_v(value)
q = q * self.scaling
q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if 'prev_key' in saved_state:
prev_key = saved_state['prev_key']
if bidx is not None:
prev_key = prev_key[bidx]
prev_key = prev_key.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
k = torch.cat((prev_key, k), dim=1)
if 'prev_value' in saved_state:
prev_value = saved_state['prev_value']
if bidx is not None:
prev_value = prev_value[bidx]
prev_value = prev_value.contiguous().view(bsz * self.num_heads, -1, self.head_dim)
v = torch.cat((prev_value, v), dim=1)
saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim)
self._set_input_buffer(incremental_state, saved_state)
src_len = k.size(1)
# k,v: bsz*heads x src_len x dim
# q: bsz*heads x tgt_len x dim
attn_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_weights.masked_fill_(
attn_mask.unsqueeze(0),
float('-inf')
)
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights.masked_fill_(
key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2),
float('-inf')
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
if self.weights_dropout:
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
if not self.weights_dropout:
attn = F.dropout(attn, p=self.dropout, training=self.training)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
if need_weights:
# maximum attention weight over heads
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights, _ = attn_weights.max(dim=1)
attn_weights = attn_weights.transpose(0, 1)
else:
attn_weights = None
return attn, attn_weights
def in_proj_qkv(self, query):
return self._in_proj(query).chunk(3, dim=-1)
def in_proj_kv(self, key):
return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1)
def in_proj_q(self, query):
return self._in_proj(query, end=self.embed_dim)
def in_proj_k(self, key):
return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim)
def in_proj_v(self, value):
return self._in_proj(value, start=2 * self.embed_dim)
def _in_proj(self, input, start=0, end=None):
weight = self.in_proj_weight
bias = self.in_proj_bias
weight = weight[start:end, :]
if bias is not None:
bias = bias[start:end]
return F.linear(input, weight, bias)
def _get_input_buffer(self, incremental_state):
return get_incremental_state(
self,
incremental_state,
'attn_state',
) or {}
def _set_input_buffer(self, incremental_state, buffer):
set_incremental_state(
self,
incremental_state,
'attn_state',
buffer,)
def _get_bidx(self, incremental_state):
if "bidx" in incremental_state:
return incremental_state["bidx"]
else:
return None
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, std=0.02)
nn.init.constant_(m.weight[padding_idx], 0)
return m
class SelfAttentionMask(nn.Module):
def __init__(self, init_size = 100, device = 0):
super(SelfAttentionMask, self).__init__()
self.weights = SelfAttentionMask.get_mask(init_size)
self.device = device
@staticmethod
def get_mask(size):
weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1)
return weights
def forward(self, size):
if self.weights is None or size > self.weights.size(0):
self.weights = SelfAttentionMask.get_mask(size)
res = self.weights[:size,:size].cuda(self.device).detach()
return res
class LearnedPositionalEmbedding(nn.Module):
"""This module produces LearnedPositionalEmbedding.
"""
def __init__(self, embedding_dim, init_size=1024, device=0):
super(LearnedPositionalEmbedding, self).__init__()
self.weights = nn.Embedding(init_size, embedding_dim)
self.device= device
self.reset_parameters()
def reset_parameters(self):
nn.init.normal_(self.weights.weight, std=0.02)
def forward(self, input, offset=0):
"""Input is expected to be of size [seq_len x bsz]."""
seq_len, bsz = input.size()
positions = (offset + torch.arange(seq_len)).cuda(self.device)
res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1)
return res
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
"""
def __init__(self, embedding_dim, init_size=1024, device=0):
super(SinusoidalPositionalEmbedding, self).__init__()
self.embedding_dim = embedding_dim
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim
)
self.device= device
@staticmethod
def get_embedding(num_embeddings, embedding_dim):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
if embedding_dim % 2 == 1:
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
return emb
def forward(self, input, offset=0):
"""Input is expected to be of size [seq_len x bsz]."""
seq_len, bsz = input.size()
mx_position = seq_len + offset
if self.weights is None or mx_position > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
mx_position,
self.embedding_dim,
)
positions = offset + torch.arange(seq_len)
res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach()
return res
================================================
FILE: utils.py
================================================
import torch
from torch import nn
from torch.nn import Parameter
from collections import defaultdict
import math
def gelu(x):
cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
return cdf*x
class LayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(LayerNorm, self).__init__()
self.weight = nn.Parameter(torch.Tensor(hidden_size))
self.bias = nn.Parameter(torch.Tensor(hidden_size))
self.eps = eps
self.reset_parameters()
def reset_parameters(self):
nn.init.constant_(self.weight, 1.)
nn.init.constant_(self.bias, 0.)
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
return self.weight * x + self.bias
INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0)
def _get_full_incremental_state_key(module_instance, key):
module_name = module_instance.__class__.__name__
# assign a unique ID to each module instance, so that incremental state is
# not shared across module instances
if not hasattr(module_instance, '_guyu_instance_id'):
INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1
module_instance._guyu_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name]
return '{}.{}.{}'.format(module_name, module_instance._guyu_instance_id, key)
def get_incremental_state(module, incremental_state, key):
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(module, incremental_state, key, value):
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
gitextract_oqyy3n7z/ ├── .gitignore ├── LICENSE ├── README.md ├── adam.py ├── biglm.py ├── data.py ├── eval.py ├── eval.sh ├── label_smoothing.py ├── metrics.py ├── optim.py ├── polish.py ├── polish.sh ├── prepare_data.py ├── test.py ├── test.sh ├── train.py ├── train.sh ├── transformer.py └── utils.py
SYMBOL INDEX (103 symbols across 12 files)
FILE: adam.py
class AdamWeightDecayOptimizer (line 5) | class AdamWeightDecayOptimizer(Optimizer):
method __init__ (line 9) | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
method __setstate__ (line 23) | def __setstate__(self, state):
method step (line 28) | def step(self, closure=None):
FILE: biglm.py
class BIGLM (line 9) | class BIGLM(nn.Module):
method __init__ (line 10) | def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_hea...
method reset_parameters (line 35) | def reset_parameters(self):
method label_smotthing_loss (line 41) | def label_smotthing_loss(self, y_pred, y, y_mask, avg=True):
method nll_loss (line 51) | def nll_loss(self, y_pred, y, y_mask, avg=True):
method work_incremental (line 64) | def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_s...
method work (line 92) | def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos):
method encode (line 114) | def encode(self, xs_tpl, xs_seg, xs_pos):
method ppl (line 120) | def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg...
method forward (line 141) | def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys...
FILE: data.py
function ListsToTensor (line 17) | def ListsToTensor(xs, vocab=None):
function _back_to_text_for_check (line 28) | def _back_to_text_for_check(x, vocab):
function batchify (line 33) | def batchify(data, vocab):
function s2t (line 61) | def s2t(strs, vocab):
function s2xy (line 71) | def s2xy(lines, vocab, max_len, min_len):
function parse_line (line 80) | def parse_line(line, max_len, min_len):
function s2xy_polish (line 140) | def s2xy_polish(lines, vocab, max_len, min_len):
function parse_line_polish (line 147) | def parse_line_polish(line, max_len, min_len):
class DataLoader (line 204) | class DataLoader(object):
method __init__ (line 205) | def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y):
method __iter__ (line 214) | def __iter__(self):
class Vocab (line 238) | class Vocab(object):
method __init__ (line 239) | def __init__(self, filename, min_occur_cnt, specials = None):
method size (line 255) | def size(self):
method unk_idx (line 259) | def unk_idx(self):
method padding_idx (line 263) | def padding_idx(self):
method random_token (line 266) | def random_token(self):
method idx2token (line 269) | def idx2token(self, x):
method token2idx (line 274) | def token2idx(self, x):
FILE: eval.py
function init_model (line 14) | def init_model(m_path, device, vocab):
FILE: label_smoothing.py
class LabelSmoothing (line 5) | class LabelSmoothing(nn.Module):
method __init__ (line 7) | def __init__(self, device, size, padding_idx, label_smoothing=0.0):
method forward (line 20) | def forward(self, output, target):
FILE: metrics.py
function eval_tpl (line 29) | def eval_tpl(sents1, sents2):
function rhythm_labellig (line 53) | def rhythm_labellig(sents):
function eval_rhythm (line 86) | def eval_rhythm(sents1, sents2):
function eval_endings (line 109) | def eval_endings(sents1, sents2):
function eval (line 126) | def eval(res_file, fid):
FILE: optim.py
class Optim (line 3) | class Optim:
method __init__ (line 5) | def __init__(self, model_size, factor, warmup, optimizer):
method step (line 13) | def step(self):
method rate (line 22) | def rate(self, step = None):
method state_dict (line 28) | def state_dict(self):
method load_state_dict (line 31) | def load_state_dict(self, m):
FILE: polish.py
function init_model (line 13) | def init_model(m_path, device, vocab):
function top_k_inc (line 29) | def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos,...
function top_k (line 78) | def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
FILE: test.py
function init_seeds (line 14) | def init_seeds():
function init_model (line 23) | def init_model(m_path, device, vocab):
function top_k_inc (line 38) | def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos,...
function top_k (line 87) | def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
function greedy (line 129) | def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s):
function beam_decode (line 175) | def beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp...
function beam_search (line 293) | def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s):
FILE: train.py
function parse_config (line 15) | def parse_config():
function update_lr (line 48) | def update_lr(optimizer, lr):
function average_gradients (line 52) | def average_gradients(model):
function eval_epoch (line 65) | def eval_epoch(lm_args, model, lm_vocab, local_rank, label):
function run (line 104) | def run(args, local_rank):
function init_processes (line 177) | def init_processes(args, local_rank, fn, backend='nccl'):
FILE: transformer.py
class TransformerLayer (line 9) | class TransformerLayer(nn.Module):
method __init__ (line 11) | def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_e...
method reset_parameters (line 25) | def reset_parameters(self):
method forward (line 31) | def forward(self, x, kv = None,
method work_incremental (line 62) | def work_incremental(self, x, self_padding_mask = None, self_attn_mask...
class MultiheadAttention (line 82) | class MultiheadAttention(nn.Module):
method __init__ (line 84) | def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=T...
method reset_parameters (line 100) | def reset_parameters(self):
method forward (line 106) | def forward(self, query, key, value, key_padding_mask=None, attn_mask=...
method in_proj_qkv (line 206) | def in_proj_qkv(self, query):
method in_proj_kv (line 209) | def in_proj_kv(self, key):
method in_proj_q (line 212) | def in_proj_q(self, query):
method in_proj_k (line 215) | def in_proj_k(self, key):
method in_proj_v (line 218) | def in_proj_v(self, value):
method _in_proj (line 221) | def _in_proj(self, input, start=0, end=None):
method _get_input_buffer (line 229) | def _get_input_buffer(self, incremental_state):
method _set_input_buffer (line 236) | def _set_input_buffer(self, incremental_state, buffer):
method _get_bidx (line 243) | def _get_bidx(self, incremental_state):
function Embedding (line 249) | def Embedding(num_embeddings, embedding_dim, padding_idx):
class SelfAttentionMask (line 255) | class SelfAttentionMask(nn.Module):
method __init__ (line 256) | def __init__(self, init_size = 100, device = 0):
method get_mask (line 262) | def get_mask(size):
method forward (line 266) | def forward(self, size):
class LearnedPositionalEmbedding (line 272) | class LearnedPositionalEmbedding(nn.Module):
method __init__ (line 275) | def __init__(self, embedding_dim, init_size=1024, device=0):
method reset_parameters (line 281) | def reset_parameters(self):
method forward (line 284) | def forward(self, input, offset=0):
class SinusoidalPositionalEmbedding (line 291) | class SinusoidalPositionalEmbedding(nn.Module):
method __init__ (line 294) | def __init__(self, embedding_dim, init_size=1024, device=0):
method get_embedding (line 304) | def get_embedding(num_embeddings, embedding_dim):
method forward (line 318) | def forward(self, input, offset=0):
FILE: utils.py
function gelu (line 8) | def gelu(x):
class LayerNorm (line 12) | class LayerNorm(nn.Module):
method __init__ (line 13) | def __init__(self, hidden_size, eps=1e-12):
method reset_parameters (line 19) | def reset_parameters(self):
method forward (line 23) | def forward(self, x):
function _get_full_incremental_state_key (line 32) | def _get_full_incremental_state_key(module_instance, key):
function get_incremental_state (line 43) | def get_incremental_state(module, incremental_state, key):
function set_incremental_state (line 50) | def set_incremental_state(module, incremental_state, key, value):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (80K chars).
[
{
"path": ".gitignore",
"chars": 56,
"preview": "*.pyc\n*.log\nckpt\n/data*/*\n/model*/*\n/ckpt*/*\n/result*/*\n"
},
{
"path": "LICENSE",
"chars": 1064,
"preview": "MIT License\n\nCopyright (c) 2021 Piji Li\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof"
},
{
"path": "README.md",
"chars": 1107,
"preview": "# SongNet\nSongNet: SongCi + Song (Lyrics) + Sonnet + etc.\n\n```\n@inproceedings{li-etal-2020-rigid,\n title = \"Rigid For"
},
{
"path": "adam.py",
"chars": 4134,
"preview": "# coding=utf-8\nimport torch\nfrom torch.optim import Optimizer\n\nclass AdamWeightDecayOptimizer(Optimizer):\n \"\"\"A basic"
},
{
"path": "biglm.py",
"chars": 7558,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\n\nfrom utils import gelu, LayerNorm\nfrom transformer im"
},
{
"path": "data.py",
"chars": 8186,
"preview": "import random\nimport torch\nimport numpy as np\n\nPAD, UNK, BOS, EOS = '<pad>', '<unk>', '<bos>', '<eos>'\nBOC, EOC = '<boc>"
},
{
"path": "eval.py",
"chars": 2078,
"preview": "import sys\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport cop"
},
{
"path": "eval.sh",
"chars": 149,
"preview": "#!/bin/bash\npath=./ckpt/\nFILES=$path/*\nfor f in $FILES; do\n echo \"==========================\" ${f##*/}\n python -u "
},
{
"path": "label_smoothing.py",
"chars": 1291,
"preview": "import torch\r\nimport torch.nn as nn\r\nimport torch.nn.functional as F\r\n\r\nclass LabelSmoothing(nn.Module):\r\n \"Implement"
},
{
"path": "metrics.py",
"chars": 7545,
"preview": "import os\nimport sys\nimport numpy as np\nfrom pypinyin import Style, lazy_pinyin\n\nfrom data import PUNCS\n\nyunjiaos = {\n "
},
{
"path": "optim.py",
"chars": 936,
"preview": "# -*- coding: utf-8 -*-\n\nclass Optim:\n \"Optim wrapper that implements rate.\"\n def __init__(self, model_size, facto"
},
{
"path": "polish.py",
"chars": 5139,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport t"
},
{
"path": "polish.sh",
"chars": 21,
"preview": "python3 -u polish.py\n"
},
{
"path": "prepare_data.py",
"chars": 2492,
"preview": "import sys, re\nfrom collections import Counter\nimport random\ncnt = Counter()\n\nf_ci = \"./data/ci.txt\"\nf_cipai = \"./data/c"
},
{
"path": "test.py",
"chars": 10828,
"preview": "import torch\nfrom torch import nn\nimport torch.nn.functional as F\nimport random\nimport numpy as np\nimport copy \nimport t"
},
{
"path": "test.sh",
"chars": 19,
"preview": "python3 -u test.py\n"
},
{
"path": "train.py",
"chars": 7681,
"preview": "# coding=utf-8\nimport torch\nimport torch.distributed as dist\nimport torch.nn as nn\nimport torch.nn.functional as F\nimpor"
},
{
"path": "train.sh",
"chars": 1052,
"preview": "CUDA_VISIBLE_DEVICES=1 \\\npython3 -u train.py --embed_dim 768 \\\n --ff_embed_dim 3072 \\\n "
},
{
"path": "transformer.py",
"chars": 13554,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nimport torch.nn.functional as F\n\nfrom utils import gelu"
},
{
"path": "utils.py",
"chars": 1985,
"preview": "import torch\nfrom torch import nn\nfrom torch.nn import Parameter\nfrom collections import defaultdict\n\nimport math\n\ndef g"
}
]
About this extraction
This page contains the full source code of the lipiji/SongNet GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (75.1 KB), approximately 21.1k tokens, and a symbol index with 103 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.