Repository: lipiji/SongNet Branch: master Commit: f8c7064c21c7 Files: 20 Total size: 75.1 KB Directory structure: gitextract_oqyy3n7z/ ├── .gitignore ├── LICENSE ├── README.md ├── adam.py ├── biglm.py ├── data.py ├── eval.py ├── eval.sh ├── label_smoothing.py ├── metrics.py ├── optim.py ├── polish.py ├── polish.sh ├── prepare_data.py ├── test.py ├── test.sh ├── train.py ├── train.sh ├── transformer.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc *.log ckpt /data*/* /model*/* /ckpt*/* /result*/* ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2021 Piji Li Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # SongNet SongNet: SongCi + Song (Lyrics) + Sonnet + etc. ``` @inproceedings{li-etal-2020-rigid, title = "Rigid Formats Controlled Text Generation", author = "Li, Piji and Zhang, Haisong and Liu, Xiaojiang and Shi, Shuming", booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", month = jul, year = "2020", address = "Online", publisher = "Association for Computational Linguistics", url = "https://www.aclweb.org/anthology/2020.acl-main.68", doi = "10.18653/v1/2020.acl-main.68", pages = "742--751" } ``` ### Run - python prepare_data.py - ./train.sh ### Evaluation - Modify test.py: m_path = the best dev model - ./test.sh - python metrics.py ### Polish - ./polish.sh ### Download - The pretrained Chinese Language Model: https://drive.google.com/file/d/1g2tGyUwPe86vPn2nub1vkQva5lwtZ6Rd/view - The finetuned SongCi model: https://drive.google.com/file/d/16A2AzuU7slf7xj2QdLcBAorUCCaCk650/view #### Reference - Guyu: https://github.com/lipiji/Guyu - Pretraining:https://github.com/lipiji/big_tpl_zh_10_base ================================================ FILE: adam.py ================================================ # coding=utf-8 import torch from torch.optim import Optimizer class AdamWeightDecayOptimizer(Optimizer): """A basic Adam optimizer that includes "correct" L2 weight decay. https://github.com/google-research/bert/blob/master/optimization.py https://raw.githubusercontent.com/pytorch/pytorch/v1.0.0/torch/optim/adam.py""" def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(AdamWeightDecayOptimizer, self).__init__(params, defaults) def __setstate__(self, state): super(AdamWeightDecayOptimizer, self).__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') amsgrad = group['amsgrad'] state = self.state[p] # State initialization if len(state) == 0: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) if amsgrad: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] if amsgrad: max_exp_avg_sq = state['max_exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(1 - beta1, grad) exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # Use the max. for normalizing running avg. of gradient denom = max_exp_avg_sq.sqrt().add_(group['eps']) else: denom = exp_avg_sq.sqrt().add_(group['eps']) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. update = (exp_avg/denom).add_(group['weight_decay'], p.data) p.data.add_(-group['lr'], update) return loss ================================================ FILE: biglm.py ================================================ import torch from torch import nn import torch.nn.functional as F from utils import gelu, LayerNorm from transformer import TransformerLayer, Embedding, LearnedPositionalEmbedding, SelfAttentionMask from label_smoothing import LabelSmoothing class BIGLM(nn.Module): def __init__(self, local_rank, vocab, embed_dim, ff_embed_dim, num_heads, dropout, layers, smoothing_factor, approx=None): super(BIGLM, self).__init__() self.vocab = vocab self.embed_dim = embed_dim self.tok_embed = Embedding(self.vocab.size, embed_dim, self.vocab.padding_idx) self.pos_embed = LearnedPositionalEmbedding(embed_dim, device=local_rank) self.layers = nn.ModuleList() for i in range(layers): self.layers.append(TransformerLayer(embed_dim, ff_embed_dim, num_heads, dropout, with_external=True)) self.emb_layer_norm = LayerNorm(embed_dim) self.one_more = nn.Linear(embed_dim, embed_dim) self.one_more_layer_norm = LayerNorm(embed_dim) self.out_proj = nn.Linear(embed_dim, self.vocab.size) self.attn_mask = SelfAttentionMask(device=local_rank) self.smoothing = LabelSmoothing(local_rank, self.vocab.size, self.vocab.padding_idx, smoothing_factor) self.dropout = dropout self.device = local_rank self.approx = approx self.reset_parameters() def reset_parameters(self): nn.init.constant_(self.one_more.bias, 0.) nn.init.normal_(self.one_more.weight, std=0.02) nn.init.constant_(self.out_proj.bias, 0.) nn.init.normal_(self.out_proj.weight, std=0.02) def label_smotthing_loss(self, y_pred, y, y_mask, avg=True): seq_len, bsz = y.size() y_pred = torch.log(y_pred.clamp(min=1e-8)) loss = self.smoothing(y_pred.view(seq_len * bsz, -1), y.view(seq_len * bsz, -1)) if avg: return loss / torch.sum(y_mask) else: return loss / bsz def nll_loss(self, y_pred, y, y_mask, avg=True): cost = -torch.log(torch.gather(y_pred, 2, y.view(y.size(0), y.size(1), 1))) cost = cost.view(y.shape) y_mask = y_mask.view(y.shape) if avg: cost = torch.sum(cost * y_mask, 0) / torch.sum(y_mask, 0) else: cost = torch.sum(cost * y_mask, 0) cost = cost.view((y.size(1), -1)) ppl = 2 ** cost return cost.sum().item(), ppl.sum().item() def work_incremental(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos, incremental_state=None): seq_len, bsz = ys_inp.size() x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos) x = self.emb_layer_norm(x) padding_mask = torch.eq(ys_inp, self.vocab.padding_idx) if not padding_mask.any(): padding_mask = None if incremental_state is None: self_attn_mask = self.attn_mask(seq_len) incremental_state = {} else: x = x[-1, :, :].unsqueeze(0) self_attn_mask = None for layer in self.layers: x, _ ,_ = layer.work_incremental(x, self_padding_mask=padding_mask, \ self_attn_mask=self_attn_mask, \ external_memories = enc, \ external_padding_mask = src_padding_mask, \ incremental_state = incremental_state) x = self.one_more_layer_norm(gelu(self.one_more(x))) probs = torch.softmax(self.out_proj(x), -1) _, pred_y = probs.max(-1) return probs, pred_y, incremental_state def work(self, enc, src_padding_mask, ys_inp, ys_tpl, ys_seg, ys_pos): seq_len, bsz = ys_inp.size() self_attn_mask = self.attn_mask(seq_len) x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos) x = self.emb_layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) padding_mask = torch.eq(ys_inp, self.vocab.padding_idx) if not padding_mask.any(): padding_mask = None for layer in self.layers: x, _ ,_ = layer(x, self_padding_mask=padding_mask, \ self_attn_mask = self_attn_mask, \ external_memories = enc, \ external_padding_mask = src_padding_mask,) x = self.one_more_layer_norm(gelu(self.one_more(x))) probs = torch.softmax(self.out_proj(x), -1) _, pred_y = probs.max(-1) return probs, pred_y def encode(self, xs_tpl, xs_seg, xs_pos): padding_mask = torch.eq(xs_tpl, self.vocab.padding_idx) x = self.tok_embed(xs_tpl) + self.tok_embed(xs_seg) + self.tok_embed(xs_pos) x = self.emb_layer_norm(x) return x, padding_mask def ppl(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk): enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos) seq_len, bsz = ys_inp.size() self_attn_mask = self.attn_mask(seq_len) x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos) x = self.emb_layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) padding_mask = torch.eq(ys_truth, self.vocab.padding_idx) if not padding_mask.any(): padding_mask = None for layer in self.layers: x, _ ,_ = layer(x, self_padding_mask=padding_mask, \ self_attn_mask = self_attn_mask, \ external_memories = enc, \ external_padding_mask = src_padding_mask,) x = self.one_more_layer_norm(gelu(self.one_more(x))) pred = torch.softmax(self.out_proj(x), -1) nll, ppl = self.nll_loss(pred, ys_truth, msk) return nll, ppl, bsz def forward(self, xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk): enc, src_padding_mask = self.encode(xs_tpl, xs_seg, xs_pos) seq_len, bsz = ys_inp.size() self_attn_mask = self.attn_mask(seq_len) x = self.tok_embed(ys_inp) + self.pos_embed(ys_inp) + self.tok_embed(ys_tpl) + self.tok_embed(ys_seg) + self.tok_embed(ys_pos) x = self.emb_layer_norm(x) x = F.dropout(x, p=self.dropout, training=self.training) padding_mask = torch.eq(ys_truth, self.vocab.padding_idx) if not padding_mask.any(): padding_mask = None for layer in self.layers: x, _ ,_ = layer(x, self_padding_mask=padding_mask, \ self_attn_mask = self_attn_mask, \ external_memories = enc, \ external_padding_mask = src_padding_mask,) x = self.one_more_layer_norm(gelu(self.one_more(x))) pred = torch.softmax(self.out_proj(x), -1) loss = self.label_smotthing_loss(pred, ys_truth, msk) _, pred_y = pred.max(-1) tot_tokens = msk.float().sum().item() acc = (torch.eq(pred_y, ys_truth).float() * msk).sum().item() nll, ppl = self.nll_loss(pred, ys_truth, msk) return (pred_y, ys_truth), loss, acc, nll, ppl, tot_tokens, bsz ================================================ FILE: data.py ================================================ import random import torch import numpy as np PAD, UNK, BOS, EOS = '', '', '', '' BOC, EOC = '', '' LS, RS, SP = '', '', ' ' CS = [''] + ['' for i in range(32)] # content SS = [''] + ['' for i in range(512)] # segment PS = [''] + ['' for i in range(512)] # position TS = [''] + ['' for i in range(32)] # other types PUNCS = set([",", ".", "?", "!", ":", ",", "。", "?", "!", ":"]) BUFSIZE = 4096000 def ListsToTensor(xs, vocab=None): max_len = max(len(x) for x in xs) ys = [] for x in xs: if vocab is not None: y = vocab.token2idx(x) + [vocab.padding_idx]*(max_len - len(x)) else: y = x + [0]*(max_len - len(x)) ys.append(y) return ys def _back_to_text_for_check(x, vocab): w = x.t().tolist() for sent in vocab.idx2token(w): print (' '.join(sent)) def batchify(data, vocab): xs_tpl, xs_seg, xs_pos, \ ys_truth, ys_inp, \ ys_tpl, ys_seg, ys_pos, msk = [], [], [], [], [], [], [], [], [] for xs_tpl_i, xs_seg_i, xs_pos_i, ys_i, ys_tpl_i, ys_seg_i, ys_pos_i in data: xs_tpl.append(xs_tpl_i) xs_seg.append(xs_seg_i) xs_pos.append(xs_pos_i) ys_truth.append(ys_i) ys_inp.append([BOS] + ys_i[:-1]) ys_tpl.append(ys_tpl_i) ys_seg.append(ys_seg_i) ys_pos.append(ys_pos_i) msk.append([1 for i in range(len(ys_i))]) xs_tpl = torch.LongTensor(ListsToTensor(xs_tpl, vocab)).t_().contiguous() xs_seg = torch.LongTensor(ListsToTensor(xs_seg, vocab)).t_().contiguous() xs_pos = torch.LongTensor(ListsToTensor(xs_pos, vocab)).t_().contiguous() ys_truth = torch.LongTensor(ListsToTensor(ys_truth, vocab)).t_().contiguous() ys_inp = torch.LongTensor(ListsToTensor(ys_inp, vocab)).t_().contiguous() ys_tpl = torch.LongTensor(ListsToTensor(ys_tpl, vocab)).t_().contiguous() ys_seg = torch.LongTensor(ListsToTensor(ys_seg, vocab)).t_().contiguous() ys_pos = torch.LongTensor(ListsToTensor(ys_pos, vocab)).t_().contiguous() msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous() return xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk def s2t(strs, vocab): inp, msk = [], [] for x in strs: inp.append(x) msk.append([1 for i in range(len(x))]) inp = torch.LongTensor(ListsToTensor(inp, vocab)).t_().contiguous() msk = torch.FloatTensor(ListsToTensor(msk)).t_().contiguous() return inp, msk def s2xy(lines, vocab, max_len, min_len): data = [] for line in lines: res = parse_line(line, max_len, min_len) if not res: continue data.append(res) return batchify(data, vocab) def parse_line(line, max_len, min_len): line = line.strip() if not line: return None fs = line.split("") author, cipai = fs[0].split("") sents = fs[1].strip() if len(sents) > max_len: sents = sents[:max_len] if len(sents) < min_len: return None sents = sents.split("") ys = [] xs_tpl = [] xs_seg = [] xs_pos = [] ctx = cipai ws = [w for w in ctx] xs_tpl = ws + [EOC] xs_seg = [SS[0] for w in ws] + [EOC] xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC] ys_tpl = [] ys_seg = [] ys_pos = [] for si, sent in enumerate(sents): ws = [] sent = sent.strip() if not sent: continue for w in sent: ws.append(w) if w.strip() and w not in PUNCS: ys_tpl.append(CS[2]) else: ys_tpl.append(CS[1]) ys += ws + [RS] if ws[-1] in PUNCS: ys_tpl[-2] = CS[3] else: ys_tpl[-1] = CS[3] ys_tpl += [RS] ys_seg += [SS[si + 1] for w in ws] + [RS] ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS] ys += [EOS] ys_tpl += [EOS] ys_seg += [EOS] ys_pos += [EOS] xs_tpl += ys_tpl xs_seg += ys_seg xs_pos += ys_pos if len(ys) < min_len: return None return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos def s2xy_polish(lines, vocab, max_len, min_len): data = [] for line in lines: res = parse_line_polish(line, max_len, min_len) data.append(res) return batchify(data, vocab) def parse_line_polish(line, max_len, min_len): line = line.strip() if not line: return None fs = line.split("") author, cipai = fs[0].split("") sents = fs[1].strip() if len(sents) > max_len: sents = sents[:max_len] if len(sents) < min_len: return None sents = sents.split("") ys = [] xs_tpl = [] xs_seg = [] xs_pos = [] ctx = cipai ws = [w for w in ctx] xs_tpl = ws + [EOC] xs_seg = [SS[0] for w in ws] + [EOC] xs_pos = [SS[i+300] for i in range(len(ws))] + [EOC] ys_tpl = [] ys_seg = [] ys_pos = [] for si, sent in enumerate(sents): ws = [] sent = sent.strip() if not sent: continue for w in sent: ws.append(w) if w == "_": ys_tpl.append(CS[2]) else: ys_tpl.append(w) ys += ws + [RS] ys_tpl += [RS] ys_seg += [SS[si + 1] for w in ws] + [RS] ys_pos += [PS[len(ws) - i] for i in range(len(ws))] + [RS] ys += [EOS] ys_tpl += [EOS] ys_seg += [EOS] ys_pos += [EOS] xs_tpl += ys_tpl xs_seg += ys_seg xs_pos += ys_pos if len(ys) < min_len: return None return xs_tpl, xs_seg, xs_pos, ys, ys_tpl, ys_seg, ys_pos class DataLoader(object): def __init__(self, vocab, filename, batch_size, max_len_y, min_len_y): self.batch_size = batch_size self.vocab = vocab self.max_len_y = max_len_y self.min_len_y = min_len_y self.filename = filename self.stream = open(self.filename, encoding='utf8') self.epoch_id = 0 def __iter__(self): lines = self.stream.readlines(BUFSIZE) if not lines: self.epoch_id += 1 self.stream.close() self.stream = open(self.filename, encoding='utf8') lines = self.stream.readlines(BUFSIZE) data = [] for line in lines[:-1]: # the last sent may be imcomplete res = parse_line(line, self.max_len_y, self.min_len_y) if not res: continue data.append(res) random.shuffle(data) idx = 0 while idx < len(data): yield batchify(data[idx:idx+self.batch_size], self.vocab) idx += self.batch_size class Vocab(object): def __init__(self, filename, min_occur_cnt, specials = None): idx2token = [PAD, UNK, BOS, EOS] + [BOC, EOC, LS, RS, SP] + CS + SS + PS + TS \ + (specials if specials is not None else []) for line in open(filename, encoding='utf8').readlines(): try: token, cnt = line.strip().split() except: continue if int(cnt) >= min_occur_cnt: idx2token.append(token) self._token2idx = dict(zip(idx2token, range(len(idx2token)))) self._idx2token = idx2token self._padding_idx = self._token2idx[PAD] self._unk_idx = self._token2idx[UNK] @property def size(self): return len(self._idx2token) @property def unk_idx(self): return self._unk_idx @property def padding_idx(self): return self._padding_idx def random_token(self): return self.idx2token(1 + np.random.randint(self.size-1)) def idx2token(self, x): if isinstance(x, list): return [self.idx2token(i) for i in x] return self._idx2token[x] def token2idx(self, x): if isinstance(x, list): return [self.token2idx(i) for i in x] return self._token2idx.get(x, self.unk_idx) ================================================ FILE: eval.py ================================================ import sys import torch from torch import nn import torch.nn.functional as F import random import numpy as np import copy import time from biglm import BIGLM from data import Vocab, DataLoader, s2t, s2xy gpu = int(sys.argv[2]) if len(sys.argv) > 1 else 0 def init_model(m_path, device, vocab): ckpt= torch.load(m_path, map_location='cpu') lm_args = ckpt['args'] lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[]) lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1, lm_args.approx) lm_model.load_state_dict(ckpt['model']) lm_model = lm_model.cuda(device) lm_model.eval() return lm_model, lm_vocab, lm_args #m_path = "./ckpt_d101_6/epoch5_batch_139999" m_path = sys.argv[1] if len(sys.argv) > 1 else None lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./data/vocab.txt") ds = [] with open("./data/dev.txt", "r") as f: for line in f: line = line.strip() if line: ds.append(line) print(len(ds)) local_rank = gpu batch_size = 10 batches = round(len(ds) / batch_size) idx = 0 avg_nll = 0. avg_ppl = 0. count = 0. while idx < len(ds): cplb = ds[idx:idx + batch_size] xs_tpl, xs_seg, xs_pos, \ ys_truth, ys_inp, \ ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2) xs_tpl = xs_tpl.cuda(local_rank) xs_seg = xs_seg.cuda(local_rank) xs_pos = xs_pos.cuda(local_rank) ys_truth = ys_truth.cuda(local_rank) ys_inp = ys_inp.cuda(local_rank) ys_tpl = ys_tpl.cuda(local_rank) ys_seg = ys_seg.cuda(local_rank) ys_pos = ys_pos.cuda(local_rank) msk = msk.cuda(local_rank) nll, ppl, bsz = lm_model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk) avg_nll += nll avg_ppl += ppl count += bsz idx += batch_size if count % 200 == 0: print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count) print("nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count) ================================================ FILE: eval.sh ================================================ #!/bin/bash path=./ckpt/ FILES=$path/* for f in $FILES; do echo "==========================" ${f##*/} python -u eval.py $path${f##*/} 1 done ================================================ FILE: label_smoothing.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F class LabelSmoothing(nn.Module): "Implement label smoothing." def __init__(self, device, size, padding_idx, label_smoothing=0.0): super(LabelSmoothing, self).__init__() assert 0.0 < label_smoothing <= 1.0 self.padding_idx = padding_idx self.size = size self.device = device self.smoothing_value = label_smoothing / (size - 2) self.one_hot = torch.full((1, size), self.smoothing_value).to(device) self.one_hot[0, self.padding_idx] = 0 self.confidence = 1.0 - label_smoothing def forward(self, output, target): real_size = output.size(1) if real_size > self.size: real_size -= self.size else: real_size = 0 model_prob = self.one_hot.repeat(target.size(0), 1) if real_size > 0: ext_zeros = torch.full((model_prob.size(0), real_size), self.smoothing_value).to(self.device) model_prob = torch.cat((model_prob, ext_zeros), -1) model_prob.scatter_(1, target, self.confidence) model_prob.masked_fill_((target == self.padding_idx), 0.) return F.kl_div(output, model_prob, reduction='sum') ================================================ FILE: metrics.py ================================================ import os import sys import numpy as np from pypinyin import Style, lazy_pinyin from data import PUNCS yunjiaos = { "0":["a", "ia", "ua", "va", "üa"], "1":["e", "o", "uo", "ie", "ue", "üe", "ve"], "2":["u"], "3":["i", "ü", "v"], "4":["ai", "uai"], "5":["ao", "iao"], "6":["ou", "iu", "iou"], "7":["an", "ian", "uan", "üan", "van"], "8":["en", "in", "un", "ün", "vn"], "9":["ang", "iang", "uang"], "10":["eng", "ing", "ueng", "ong", "iong"], "11":["er"], "12":["ei", "ui", "uei", "vei"], } yun2id = {} for yid, yws in yunjiaos.items(): for w in yws: yun2id[w] = yid def eval_tpl(sents1, sents2): n = 0. if len(sents1) > len(sents2): sents1 = sents1[:len(sents2)] for i, x in enumerate(sents1): y = sents2[i] if len(x) != len(y): continue px, py = [], [] for w in x: if w in PUNCS: px.append(w) for w in y: if w in PUNCS: py.append(w) if px == py: n += 1 p = n / len(sents2) r = n / len(sents1) f = 2 * p * r / (p + r + 1e-16) return p, r, f, n, len(sents1), len(sents2) def rhythm_labellig(sents): rhys = [] for sent in sents: w = sent[-1] if w in PUNCS and len(sent) > 1: w = sent[-2] yunmu = lazy_pinyin(w, style=Style.FINALS) rhys.append(yunmu[0]) assert len(rhys) == len(sents) rhy_map = {} for i, r in enumerate(rhys): if r in yun2id: rid = yun2id[r] if rid in rhy_map: rhy_map[rid] += [i] else: rhy_map[rid] = [i] else: pass max_len_yuns = -1 max_rid = "" for rid, yuns in rhy_map.items(): if len(yuns) > max_len_yuns: max_len_yuns = len(yuns) max_rid = rid res = [] for i in range(len(sents)): if max_rid in rhy_map and i in rhy_map[max_rid]: res.append(1) else: res.append(-1) return res def eval_rhythm(sents1, sents2): n = 0. if len(sents1) > len(sents2): sents1 = sents1[:len(sents2)] rhys1 = rhythm_labellig(sents1) rhys2 = rhythm_labellig(sents2) n1, n2 = 0., 0. for v in rhys1: if v == 1: n1 += 1 for v in rhys2: if v == 1: n2 += 1 for i, v1 in enumerate(rhys1): v2 = rhys2[i] if v1 == 1 and v1 == v2: n += 1 p = n / (n2 + 1e-16) r = n / (n1 + 1e-16) f1 = 2 * p * r / (p + r + 1e-16) return p, r, f1, n, n1, n2 def eval_endings(sents1, sents2): n = 0. if len(sents1) > len(sents2): sents1 = sents1[:len(sents2)] sents0 = [] for si, sent1 in enumerate(sents1): sent2 = sents2[si] if len(sent2) <= len(sent1): sents0.append(sent2) else: sents0.append(sent2[:len(sent1) - 1] + sent1[-1]) sent = "".join(sents0) return sent def eval(res_file, fid): docs = [] with open(res_file) as f: for line in f: line = line.strip() if not line: continue fs = line.split("\t") if len(fs) != 2: print("error", line) continue x, y = fs docs.append((x, y)) print(len(docs)) ugrams_ = [] bigrams_ = [] p_, r_, f1_ = 0., 0., 0. n0_, n1_, n2_ = 0., 0., 0. p__, r__, f1__ = 0., 0., 0. n0__, n1__, n2__ = 0., 0., 0. d1_, d2_ = 0., 0. d4ends = [] for x, y in docs: topic, content = x.split("") author, topic = topic.split("") sents1 = content.split("") y = y.replace("", "") sents2 = y.split("") sents1_ = [] for sent in sents1: sent = sent.strip() if sent: sents1_.append(sent) sents1 = sents1_ sents2_ = [] for sent in sents2: sent = sent.strip() if sent: sents2_.append(sent) sents2 = sents2_ p, r, f1, n0, n1, n2 = eval_tpl(sents1, sents2) p_ += p r_ += r f1_ += f1 n0_ += n0 n1_ += n1 n2_ += n2 ugrams = [w for w in ''.join(sents2)] bigrams = [] for bi in range(len(ugrams) - 1): bigrams.append(ugrams[bi] + ugrams[bi+1]) d1_ += len(set(ugrams)) / float(len(ugrams)) d2_ += len(set(bigrams)) / float(len(bigrams)) ugrams_ += ugrams bigrams_ += bigrams p, r, f1, n0, n1, n2 = eval_rhythm(sents1, sents2) p__ += p r__ += r f1__ += f1 n0__ += n0 n1__ += n1 n2__ += n2 d4end = eval_endings(sents1, sents2) d4ends.append(author + "" + topic + "" + d4end) tpl_macro_p = p_ / len(docs) tpl_macro_r = r_ / len(docs) tpl_macro_f1 = 2 * tpl_macro_p * tpl_macro_r / (tpl_macro_p + tpl_macro_r) tpl_micro_p = n0_ / n2_ tpl_micro_r = n0_ / n1_ tpl_micro_f1 = 2 * tpl_micro_p * tpl_micro_r / (tpl_micro_p + tpl_micro_r) rhy_macro_p = p__ / len(docs) rhy_macro_r = r__ / len(docs) rhy_macro_f1 = 2 * rhy_macro_p * rhy_macro_r / (rhy_macro_p + rhy_macro_r) rhy_micro_p = n0__ / n2__ rhy_micro_r = n0__ / n1__ rhy_micro_f1 = 2 * rhy_micro_p * rhy_micro_r / (rhy_micro_p + rhy_micro_r) macro_dist1 = d1_ / len(docs) macro_dist2 = d2_ / len(docs) micro_dist1 = len(set(ugrams_)) / float(len(ugrams_)) micro_dist2 = len(set(bigrams_)) / float(len(bigrams_)) with open("./results_4ending/res4end" + str(fid) + ".txt", "w") as fo: for line in d4ends: fo.write(line + "\n") return tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2 tpl_macro_f1_, tpl_micro_f1_, rhy_macro_f1_, rhy_micro_f1_, \ macro_dist1_, micro_dist1_, macro_dist2_, micro_dist2_ = [], [], [], [], [], [], [], [] abalation = "top-32" for i in range(5): f_name = "./results/"+abalation+"/out" +str(i+1)+".txt" if not os.path.exists(f_name): continue tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2 = eval(f_name, i + 1) print(tpl_macro_f1, tpl_micro_f1, rhy_macro_f1, rhy_micro_f1, macro_dist1, micro_dist1, macro_dist2, micro_dist2) tpl_macro_f1_.append(tpl_macro_f1) tpl_micro_f1_.append(tpl_micro_f1) rhy_macro_f1_.append(rhy_macro_f1) rhy_micro_f1_.append(rhy_micro_f1) macro_dist1_.append(macro_dist1) micro_dist1_.append(micro_dist1) macro_dist2_.append(macro_dist2) micro_dist2_.append(micro_dist2) print() print("tpl_macro_f1", np.mean(tpl_macro_f1_), np.std(tpl_macro_f1_, ddof=1)) print("tpl_micro_f1", np.mean(tpl_micro_f1_), np.std(tpl_micro_f1_, ddof=1)) print("rhy_macro_f1", np.mean(rhy_macro_f1_), np.std(rhy_macro_f1_, ddof=1)) print("rhy_micro_f1", np.mean(rhy_micro_f1_), np.std(rhy_micro_f1_, ddof=1)) print("macro_dist1", np.mean(macro_dist1_), np.std(macro_dist1_, ddof=1)) print("micro_dist1", np.mean(micro_dist1_), np.std(micro_dist1_, ddof=1)) print("macro_dist2", np.mean(macro_dist2_), np.std(macro_dist2_, ddof=1)) print("micro_dist2", np.mean(micro_dist2_), np.std(micro_dist2_, ddof=1)) ================================================ FILE: optim.py ================================================ # -*- coding: utf-8 -*- class Optim: "Optim wrapper that implements rate." def __init__(self, model_size, factor, warmup, optimizer): self.optimizer = optimizer self._step = 0 self.warmup = warmup self.factor = factor self.model_size = model_size self._rate = 0 def step(self): "Update parameters and rate" self._step += 1 rate = self.rate() for p in self.optimizer.param_groups: p['lr'] = rate self._rate = rate self.optimizer.step() def rate(self, step = None): "Implement `lrate` above" if step is None: step = self._step return self.factor * (self.model_size ** (-0.5) * min(step ** (-0.5), step * self.warmup ** (-1.5))) def state_dict(self): return self.optimizer.state_dict() def load_state_dict(self, m): self.optimizer.load_state_dict(m) ================================================ FILE: polish.py ================================================ import torch from torch import nn import torch.nn.functional as F import random import numpy as np import copy import time from biglm import BIGLM from data import Vocab, DataLoader, s2t, s2xy_polish gpu = 0 def init_model(m_path, device, vocab): ckpt= torch.load(m_path, map_location='cpu') lm_args = ckpt['args'] lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[]) lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1) lm_model.load_state_dict(ckpt['model']) lm_model = lm_model.cuda(device) lm_model.eval() return lm_model, lm_vocab, lm_args m_path = "./model/songci.ckpt" lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt") MAX_LEN = 300 k = 32 def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): start = time.time() incremental_state = None inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res = [] for l in range(inp_ys_tpl.size(0)): probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \ inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\ incremental_state) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "" and ctk != "" and ctk != "": next_tk.append(ctk) continue if l == 0: logits = probs[len(s[i]) - 1, i] else: logits = probs[0, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples = 1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) bidx = torch.BoolTensor(bidx).cuda(gpu) incremental_state["bidx"] = bidx res += s_ #for i in res: # print(''.join(i)) print(time.time()-start) return res def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) start = time.time() res = [] for l in range(inp_ys_tpl.size(0)): probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:]) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "" and ctk != "" and ctk != "": next_tk.append(ctk) continue logits = probs[len(s[i]) - 1, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples = 1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] for sent, t in zip(s, next_tk): if t == "": res.append(sent) else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res += s_ #for i in res: # print(''.join(i)) #print(time.time()-start) return res ds = [] with open("./data/polish_tpl.txt", "r") as f: for line in f: line = line.strip() if line: ds.append(line) print(len(ds)) local_rank = gpu batch_size = 1 cp_size = 1 batches = round(len(ds) / batch_size) for i in range(5): fo = open("./results/out"+str(i+1)+".txt", "w") idx = 0 while idx < len(ds): lb = ds[idx:idx + batch_size] cplb = [] for line in lb: cplb += [line for i in range(cp_size)] print(cplb) xs_tpl, xs_seg, xs_pos, \ ys_truth, ys_inp, \ ys_tpl, ys_seg, ys_pos, msk = s2xy_polish(cplb, lm_vocab, lm_args.max_len,2) xs_tpl = xs_tpl.cuda(local_rank) xs_seg = xs_seg.cuda(local_rank) xs_pos = xs_pos.cuda(local_rank) ys_tpl = ys_tpl.cuda(local_rank) ys_seg = ys_seg.cuda(local_rank) ys_pos = ys_pos.cuda(local_rank) enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos) s = [['']] * batch_size * cp_size res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s) for i, line in enumerate(cplb): r = ''.join(res[i]) print(line) print(r) fo.write(line + "\t" + r + "\n") idx += batch_size fo.close() ================================================ FILE: polish.sh ================================================ python3 -u polish.py ================================================ FILE: prepare_data.py ================================================ import sys, re from collections import Counter import random cnt = Counter() f_ci = "./data/ci.txt" f_cipai = "./data/cipai.txt" cipai = Counter() with open(f_cipai) as f: for line in f: line = line.strip() fs = line.split() cipai.update(fs) cipai = cipai.keys() docs = {} with open(f_ci) as f: for line in f: line = line.strip() fs = line.split("") author = fs[0] topic, content = fs[1].split("") if "・" in topic: t1, t2 = topic.split("・") if t1 == t2: topic = t1 else: if t1 in cipai: topic = t1 elif t2 in cipai: topic = t2 else: topic = t1 content = content.replace("、", ",") sents = content.split("") ws = [w for w in author + topic + ''.join(sents)] cnt.update(ws) if topic not in docs: docs[topic] = [] docs[topic].append(author + "" + topic + "" + ''.join(sents)) topics = list(docs.keys()) print(len(topics)) random.shuffle(topics) topics_train = topics[:len(topics)-50] topics_dev_test = topics[-50:] topics_dev = topics_dev_test[:25] topics_test = topics_dev_test[-25:] docs_train = [] docs_dev = [] docs_test = [] for t in topics_train: docs_train.extend(docs[t]) for t in topics_dev: docs_dev.extend(docs[t]) for t in topics_test: docs_test.extend(docs[t]) random.shuffle(docs_train) random.shuffle(docs_dev) random.shuffle(docs_test) print(len(docs_train), len(docs_dev), len(docs_test)) train_cps = [] dev_cps = [] test_cps = [] with open('./data/train.txt', 'w', encoding ='utf8') as f: for x in docs_train: s = x.split("")[0] train_cps.append(s.split("")[1]) f.write(x + '\n') print(len(set(train_cps))) with open('./data/dev.txt', 'w', encoding ='utf8') as f: for x in docs_dev: s = x.split("")[0] dev_cps.append(s.split("")[1]) f.write(x + '\n') print(len(set(dev_cps))) with open('./data/test.txt', 'w', encoding ='utf8') as f: for x in docs_test: s = x.split("")[0] test_cps.append(s.split("")[1]) f.write(x + '\n') print(len(set(test_cps))) print("vocab") with open('./data/vocab.txt', 'w', encoding ='utf8') as f: for x, y in cnt.most_common(): f.write(x + '\t' + str(y) + '\n') print("done") ================================================ FILE: test.py ================================================ import torch from torch import nn import torch.nn.functional as F import random import numpy as np import copy import time from biglm import BIGLM from data import Vocab, DataLoader, s2t, s2xy def init_seeds(): random.seed(123) torch.manual_seed(123) if torch.cuda.is_available(): torch.cuda.manual_seed_all(123) #init_seeds() gpu = 1 def init_model(m_path, device, vocab): ckpt= torch.load(m_path, map_location='cpu') lm_args = ckpt['args'] lm_vocab = Vocab(vocab, min_occur_cnt=lm_args.min_occur_cnt, specials=[]) lm_model = BIGLM(device, lm_vocab, lm_args.embed_dim, lm_args.ff_embed_dim, lm_args.num_heads, lm_args.dropout, lm_args.layers, 0.1) lm_model.load_state_dict(ckpt['model']) lm_model = lm_model.cuda(device) lm_model.eval() return lm_model, lm_vocab, lm_args m_path = "./model/songci.ckpt" lm_model, lm_vocab, lm_args = init_model(m_path, gpu, "./model/vocab.txt") k = 32 def top_k_inc(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): start = time.time() incremental_state = None inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res = [] for l in range(inp_ys_tpl.size(0)): probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \ inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\ incremental_state) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "" and ctk != "" and ctk != "": next_tk.append(ctk) continue if l == 0: logits = probs[len(s[i]) - 1, i] else: logits = probs[0, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples = 1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) bidx = torch.BoolTensor(bidx).cuda(gpu) incremental_state["bidx"] = bidx res += s_ #for i in res: # print(''.join(i)) print(time.time()-start) return res def top_k(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) start = time.time() res = [] for l in range(inp_ys_tpl.size(0)): probs, pred = lm_model.work(enc, src_padding_mask, inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:]) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "" and ctk != "" and ctk != "": next_tk.append(ctk) continue logits = probs[len(s[i]) - 1, i] ps, idx = torch.topk(logits, k=k) ps = ps / torch.sum(ps) sampled = torch.multinomial(ps, num_samples = 1) sampled_idx = idx[sampled] next_tk.append(lm_vocab.idx2token(sampled_idx.item())) s_ = [] for sent, t in zip(s, next_tk): if t == "": res.append(sent) else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res += s_ #for i in res: # print(''.join(i)) #print(time.time()-start) return res def greedy(enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos, s): start = time.time() incremental_state = None inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) res = [] for l in range(inp_ys_tpl.size(0)): probs, pred, incremental_state = lm_model.work_incremental(enc, src_padding_mask, \ inp_y, inp_ys_tpl[0:l+1,:], inp_ys_seg[0:l+1,:], inp_ys_pos[0:l+1,:],\ incremental_state) next_tk = [] for i in range(len(s)): ctk = lm_vocab.idx2token(inp_ys_tpl[l,i].item()) if ctk != "" and ctk != "" and ctk != "": next_tk.append(ctk) continue if l == 0: pred = pred[len(s[i]) - 1, i] else: pred = pred[0, i] next_tk.append(lm_vocab.idx2token(pred.item())) s_ = [] bidx = [1] * len(s) for idx, (sent, t) in enumerate(zip(s, next_tk)): if t == "": res.append(sent) bidx[idx] = 0 else: s_.append(sent + [t]) if not s_: break s = s_ inp_y, m = s2t(s, lm_vocab) inp_y = inp_y.cuda(gpu) bidx = torch.BoolTensor(bidx).cuda(gpu) incremental_state["bidx"] = bidx res += s_ #for i in res: # print(''.join(i)) print(time.time()-start) return res def beam_decode(s, x, enc, src_padding_mask, inp_ys_tpl, inp_ys_seg, inp_ys_pos): beam_size = 5 num_live = 1 num_dead = 0 samples = [] sample_scores = np.zeros(beam_size) last_traces = [[]] last_scores = torch.FloatTensor(np.zeros(1)).to(gpu) x = x.to(gpu) ys = x for l in range(inp_ys_tpl.size(0)): seq_len, bsz = ys.size() enc_ = enc.repeat(1, bsz, 1) src_padding_mask_ = src_padding_mask.repeat(1, bsz) inp_ys_tpl_ = inp_ys_tpl.repeat(1, bsz) inp_ys_seg_ = inp_ys_seg.repeat(1, bsz) inp_ys_pos_ = inp_ys_pos.repeat(1, bsz) y_pred, _ = lm_model.work(enc_, src_padding_mask_, ys, inp_ys_tpl_[0:l+1,:], inp_ys_seg_[0:l+1,:], inp_ys_pos_[0:l+1,:]) dict_size = y_pred.shape[-1] y_pred = y_pred[-1, :, :] cand_y_scores = last_scores + torch.log(y_pred) # larger is better cand_scores = cand_y_scores.flatten() idx_top_joint_scores = torch.topk(cand_scores, beam_size - num_dead)[1] ''' ps, idx_top_joint_scores = torch.topk(cand_scores, 100) ps = F.softmax(ps) sampled = torch.multinomial(ps, num_samples = beam_size - num_dead) idx_top_joint_scores = idx_top_joint_scores[sampled] ''' idx_last_traces = idx_top_joint_scores / dict_size idx_word_now = idx_top_joint_scores % dict_size top_joint_scores = cand_scores[idx_top_joint_scores] traces_now = [] scores_now = np.zeros((beam_size - num_dead)) ys_now = [] for i, [j, k] in enumerate(zip(idx_last_traces, idx_word_now)): traces_now.append(last_traces[j] + [k]) scores_now[i] = copy.copy(top_joint_scores[i]) ys_now.append(copy.copy(ys[:,j])) num_live = 0 last_traces = [] last_scores = [] ys = [] for i in range(len(traces_now)): w = lm_vocab.idx2token(traces_now[i][-1].item()) if w == "": samples.append([str(e.item()) for e in traces_now[i][:-1]]) sample_scores[num_dead] = scores_now[i] num_dead += 1 else: last_traces.append(traces_now[i]) last_scores.append(scores_now[i]) ys.append(ys_now[i]) num_live += 1 if num_live == 0 or num_dead >= beam_size: break ys = torch.stack(ys, dim = 1) last_scores = torch.FloatTensor(np.array(last_scores).reshape((num_live, 1))).to(gpu) next_y = [] for e in last_traces: eid = e[-1].item() next_y.append(eid) next_y = np.array(next_y).reshape((1, num_live)) next_y = torch.LongTensor(next_y).to(gpu) ys = torch.cat([ys, next_y], dim=0) assert num_live + num_dead == beam_size # end for loop if num_live > 0: for i in range(num_live): samples.append([str(e.item()) for e in last_traces[i]]) sample_scores[num_dead] = last_scores[i] num_dead += 1 idx_sorted_scores = np.argsort(sample_scores) # ascending order sorted_samples = [] sorted_scores = [] filter_idx = [] for e in idx_sorted_scores: if len(samples[e]) > 0: filter_idx.append(e) if len(filter_idx) == 0: filter_idx = idx_sorted_scores for e in filter_idx: sorted_samples.append(samples[e]) sorted_scores.append(sample_scores[e]) res = [] dec_words = [] for sample in sorted_samples[::-1]: for e in sample: e = int(e) dec_words.append(lm_vocab.idx2token(e)) #r = ''.join(dec_words) #print(r) res.append(dec_words) dec_words = [] return res def beam_search(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s): x, m = s2t(s, lm_vocab) return beam_decode(s[0], x, enc, src_padding_mask, ys_tpl, ys_seg, ys_pos) ds = [] with open("./data/test.txt", "r") as f: for line in f: line = line.strip() if line: ds.append(line) print(len(ds)) local_rank = gpu batch_size = 1 cp_size = 1 batches = round(len(ds) / batch_size) for i in range(5): idx = 0 fo = open("./results/top-"+str(k)+"/out"+str(i+1)+".txt", "w") while idx < len(ds): lb = ds[idx:idx + batch_size] cplb = [] for line in lb: cplb += [line for i in range(cp_size)] print(cplb) xs_tpl, xs_seg, xs_pos, \ ys_truth, ys_inp, \ ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, 2) xs_tpl = xs_tpl.cuda(local_rank) xs_seg = xs_seg.cuda(local_rank) xs_pos = xs_pos.cuda(local_rank) ys_tpl = ys_tpl.cuda(local_rank) ys_seg = ys_seg.cuda(local_rank) ys_pos = ys_pos.cuda(local_rank) enc, src_padding_mask = lm_model.encode(xs_tpl, xs_seg, xs_pos) s = [['']] * batch_size * cp_size res = top_k_inc(enc, src_padding_mask, ys_tpl, ys_seg, ys_pos, s) for i, line in enumerate(cplb): r = ''.join(res[i]) print(line) print(r) fo.write(line + "\t" + r + "\n") idx += batch_size fo.close() ================================================ FILE: test.sh ================================================ python3 -u test.py ================================================ FILE: train.py ================================================ # coding=utf-8 import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F import torch.multiprocessing as mp from biglm import BIGLM from data import Vocab, DataLoader, s2xy from optim import Optim import argparse, os import random def parse_config(): parser = argparse.ArgumentParser() parser.add_argument('--embed_dim', type=int) parser.add_argument('--ff_embed_dim', type=int) parser.add_argument('--num_heads', type=int) parser.add_argument('--layers', type=int) parser.add_argument('--dropout', type=float) parser.add_argument('--train_data', type=str) parser.add_argument('--dev_data', type=str) parser.add_argument('--vocab', type=str) parser.add_argument('--min_occur_cnt', type=int) parser.add_argument('--batch_size', type=int) parser.add_argument('--warmup_steps', type=int) parser.add_argument('--lr', type=float) parser.add_argument('--smoothing', type=float) parser.add_argument('--weight_decay', type=float) parser.add_argument('--max_len', type=int) parser.add_argument('--min_len', type=int) parser.add_argument('--print_every', type=int) parser.add_argument('--save_every', type=int) parser.add_argument('--start_from', type=str, default=None) parser.add_argument('--save_dir', type=str) parser.add_argument('--world_size', type=int) parser.add_argument('--gpus', type=int) parser.add_argument('--MASTER_ADDR', type=str) parser.add_argument('--MASTER_PORT', type=str) parser.add_argument('--start_rank', type=int) parser.add_argument('--backend', type=str) return parser.parse_args() def update_lr(optimizer, lr): for param_group in optimizer.param_groups: param_group['lr'] = lr def average_gradients(model): """ Gradient averaging. """ normal = True size = float(dist.get_world_size()) for param in model.parameters(): if param.grad is not None: dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) param.grad.data /= size else: normal = False break return normal def eval_epoch(lm_args, model, lm_vocab, local_rank, label): print("validating...", flush=True) ds = [] with open(lm_args.dev_data, "r") as f: for line in f: line = line.strip() if line: ds.append(line) batch_size = 10 batches = round(len(ds) / batch_size) idx = 0 avg_nll = 0. avg_ppl = 0. count = 0. while idx < len(ds): cplb = ds[idx:idx + batch_size] xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk = s2xy(cplb, lm_vocab, lm_args.max_len, lm_args.min_len) xs_tpl = xs_tpl.cuda(local_rank) xs_seg = xs_seg.cuda(local_rank) xs_pos = xs_pos.cuda(local_rank) ys_truth = ys_truth.cuda(local_rank) ys_inp = ys_inp.cuda(local_rank) ys_tpl = ys_tpl.cuda(local_rank) ys_seg = ys_seg.cuda(local_rank) ys_pos = ys_pos.cuda(local_rank) msk = msk.cuda(local_rank) nll, ppl, bsz = model.ppl(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk) avg_nll += nll avg_ppl += ppl count += bsz idx += batch_size print(label, "nll=", avg_nll/count, "ppl=", avg_ppl/count, "count=", count, flush=True) def run(args, local_rank): """ Distributed Synchronous """ torch.manual_seed(1234) vocab = Vocab(args.vocab, min_occur_cnt=args.min_occur_cnt, specials=[]) if (args.world_size == 1 or dist.get_rank() == 0): print ("vocab.size = " + str(vocab.size), flush=True) model = BIGLM(local_rank, vocab, args.embed_dim, args.ff_embed_dim,\ args.num_heads, args.dropout, args.layers, args.smoothing) if args.start_from is not None: ckpt = torch.load(args.start_from, map_location='cpu') model.load_state_dict(ckpt['model']) model = model.cuda(local_rank) optimizer = Optim(model.embed_dim, args.lr, args.warmup_steps, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.998), eps=1e-9)) if args.start_from is not None: optimizer.load_state_dict(ckpt['optimizer']) train_data = DataLoader(vocab, args.train_data, args.batch_size, args.max_len, args.min_len) batch_acm = 0 acc_acm, nll_acm, ppl_acm, ntokens_acm, nxs, npairs_acm, loss_acm = 0., 0., 0., 0., 0., 0., 0. while True: model.train() if train_data.epoch_id > 30: break for xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk in train_data: batch_acm += 1 xs_tpl = xs_tpl.cuda(local_rank) xs_seg = xs_seg.cuda(local_rank) xs_pos = xs_pos.cuda(local_rank) ys_truth = ys_truth.cuda(local_rank) ys_inp = ys_inp.cuda(local_rank) ys_tpl = ys_tpl.cuda(local_rank) ys_seg = ys_seg.cuda(local_rank) ys_pos = ys_pos.cuda(local_rank) msk = msk.cuda(local_rank) model.zero_grad() res, loss, acc, nll, ppl, ntokens, npairs = model(xs_tpl, xs_seg, xs_pos, ys_truth, ys_inp, ys_tpl, ys_seg, ys_pos, msk) loss_acm += loss.item() acc_acm += acc nll_acm += nll ppl_acm += ppl ntokens_acm += ntokens npairs_acm += npairs nxs += npairs loss.backward() if args.world_size > 1: is_normal = average_gradients(model) else: is_normal = True if is_normal: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() else: print("gradient: none, gpu: " + str(local_rank), flush=True) continue if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.print_every == -1%args.print_every: print ('batch_acm %d, loss %.3f, acc %.3f, nll %.3f, ppl %.3f, x_acm %d, lr %.6f'\ %(batch_acm, loss_acm/args.print_every, acc_acm/ntokens_acm, \ nll_acm/nxs, ppl_acm/nxs, npairs_acm, optimizer._rate), flush=True) acc_acm, nll_acm, ppl_acm, ntokens_acm, loss_acm, nxs = 0., 0., 0., 0., 0., 0. if (args.world_size==1 or dist.get_rank() ==0) and batch_acm%args.save_every == -1%args.save_every: if not os.path.exists(args.save_dir): os.mkdir(args.save_dir) model.eval() eval_epoch(args, model, vocab, local_rank, "epoch-" + str(train_data.epoch_id) + "-acm-" + str(batch_acm)) model.train() torch.save({'args':args, 'model':model.state_dict(), 'optimizer':optimizer.state_dict()}, '%s/epoch%d_batch_%d'%(args.save_dir, train_data.epoch_id, batch_acm)) def init_processes(args, local_rank, fn, backend='nccl'): """ Initialize the distributed environment. """ os.environ['MASTER_ADDR'] = args.MASTER_ADDR os.environ['MASTER_PORT'] = args.MASTER_PORT dist.init_process_group(backend, rank=args.start_rank + local_rank, world_size=args.world_size) fn(args, local_rank) if __name__ == "__main__": mp.set_start_method('spawn') args = parse_config() if args.world_size == 1: run(args, 0) exit(0) processes = [] for rank in range(args.gpus): p = mp.Process(target=init_processes, args=(args, rank, run, args.backend)) p.start() processes.append(p) for p in processes: p.join() ================================================ FILE: train.sh ================================================ CUDA_VISIBLE_DEVICES=1 \ python3 -u train.py --embed_dim 768 \ --ff_embed_dim 3072 \ --num_heads 12 \ --layers 12 \ --dropout 0.2 \ --train_data ./data/train.txt \ --dev_data ./data/dev.txt \ --vocab ./data/vocab.txt \ --min_occur_cnt 1 \ --batch_size 32 \ --warmup_steps 8000 \ --lr 0.5 \ --weight_decay 0 \ --smoothing 0.1 \ --max_len 300 \ --min_len 10 \ --world_size 1 \ --gpus 1 \ --start_rank 0 \ --MASTER_ADDR localhost \ --MASTER_PORT 28512 \ --print_every 100 \ --save_every 1000 \ --save_dir ckpt \ --backend nccl ================================================ FILE: transformer.py ================================================ import torch from torch import nn from torch.nn import Parameter import torch.nn.functional as F from utils import gelu, LayerNorm, get_incremental_state, set_incremental_state import math class TransformerLayer(nn.Module): def __init__(self, embed_dim, ff_embed_dim, num_heads, dropout, with_external=False, weights_dropout = True): super(TransformerLayer, self).__init__() self.self_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout) self.fc1 = nn.Linear(embed_dim, ff_embed_dim) self.fc2 = nn.Linear(ff_embed_dim, embed_dim) self.attn_layer_norm = LayerNorm(embed_dim) self.ff_layer_norm = LayerNorm(embed_dim) self.with_external = with_external self.dropout = dropout if self.with_external: self.external_attn = MultiheadAttention(embed_dim, num_heads, dropout, weights_dropout) self.external_layer_norm = LayerNorm(embed_dim) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.fc1.weight, std=0.02) nn.init.normal_(self.fc2.weight, std=0.02) nn.init.constant_(self.fc1.bias, 0.) nn.init.constant_(self.fc2.bias, 0.) def forward(self, x, kv = None, self_padding_mask = None, self_attn_mask = None, external_memories = None, external_padding_mask=None, need_weights = False): # x: seq_len x bsz x embed_dim residual = x if kv is None: x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights) else: x, self_attn = self.self_attn(query=x, key=kv, value=kv, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, need_weights = need_weights) x = F.dropout(x, p=self.dropout, training=self.training) x = self.attn_layer_norm(residual + x) if self.with_external: residual = x x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask, need_weights = need_weights) x = F.dropout(x, p=self.dropout, training=self.training) x = self.external_layer_norm(residual + x) else: external_attn = None residual = x x = gelu(self.fc1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.fc2(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.ff_layer_norm(residual + x) return x, self_attn, external_attn def work_incremental(self, x, self_padding_mask = None, self_attn_mask = None, external_memories = None, external_padding_mask = None, incremental_state = None): # x: seq_len x bsz x embed_dim residual = x x, self_attn = self.self_attn(query=x, key=x, value=x, key_padding_mask=self_padding_mask, attn_mask=self_attn_mask, incremental_state=incremental_state) x = self.attn_layer_norm(residual + x) if self.with_external: residual = x x, external_attn = self.external_attn(query=x, key=external_memories, value=external_memories, key_padding_mask=external_padding_mask) x = self.external_layer_norm(residual + x) else: external_attn = None residual = x x = gelu(self.fc1(x)) x = self.fc2(x) x = self.ff_layer_norm(residual + x) return x, self_attn, external_attn class MultiheadAttention(nn.Module): def __init__(self, embed_dim, num_heads, dropout=0., weights_dropout=True): super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" self.scaling = self.head_dim ** -0.5 self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) self.weights_dropout = weights_dropout self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.in_proj_weight, std=0.02) nn.init.normal_(self.out_proj.weight, std=0.02) nn.init.constant_(self.in_proj_bias, 0.) nn.init.constant_(self.out_proj.bias, 0.) def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, need_weights=False, incremental_state = None): """ Input shape: Time x Batch x Channel key_padding_mask: Time x batch attn_mask: tgt_len x src_len """ if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) bidx = self._get_bidx(incremental_state) else: saved_state = None bidx = None qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() kv_same = key.data_ptr() == value.data_ptr() tgt_len, bsz, embed_dim = query.size() assert key.size() == value.size() if qkv_same: # self-attention q, k, v = self.in_proj_qkv(query) elif kv_same: # encoder-decoder attention q = self.in_proj_q(query) k, v = self.in_proj_kv(key) else: q = self.in_proj_q(query) k = self.in_proj_k(key) v = self.in_proj_v(value) q = q * self.scaling q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'] if bidx is not None: prev_key = prev_key[bidx] prev_key = prev_key.contiguous().view(bsz * self.num_heads, -1, self.head_dim) k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'] if bidx is not None: prev_value = prev_value[bidx] prev_value = prev_value.contiguous().view(bsz * self.num_heads, -1, self.head_dim) v = torch.cat((prev_value, v), dim=1) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) # k,v: bsz*heads x src_len x dim # q: bsz*heads x tgt_len x dim attn_weights = torch.bmm(q, k.transpose(1, 2)) assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] if attn_mask is not None: attn_weights.masked_fill_( attn_mask.unsqueeze(0), float('-inf') ) if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights.masked_fill_( key_padding_mask.transpose(0, 1).unsqueeze(1).unsqueeze(2), float('-inf') ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = F.softmax(attn_weights, dim=-1) if self.weights_dropout: attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) attn = torch.bmm(attn_weights, v) if not self.weights_dropout: attn = F.dropout(attn, p=self.dropout, training=self.training) assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: # maximum attention weight over heads attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) attn_weights, _ = attn_weights.max(dim=1) attn_weights = attn_weights.transpose(0, 1) else: attn_weights = None return attn, attn_weights def in_proj_qkv(self, query): return self._in_proj(query).chunk(3, dim=-1) def in_proj_kv(self, key): return self._in_proj(key, start=self.embed_dim).chunk(2, dim=-1) def in_proj_q(self, query): return self._in_proj(query, end=self.embed_dim) def in_proj_k(self, key): return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) def in_proj_v(self, value): return self._in_proj(value, start=2 * self.embed_dim) def _in_proj(self, input, start=0, end=None): weight = self.in_proj_weight bias = self.in_proj_bias weight = weight[start:end, :] if bias is not None: bias = bias[start:end] return F.linear(input, weight, bias) def _get_input_buffer(self, incremental_state): return get_incremental_state( self, incremental_state, 'attn_state', ) or {} def _set_input_buffer(self, incremental_state, buffer): set_incremental_state( self, incremental_state, 'attn_state', buffer,) def _get_bidx(self, incremental_state): if "bidx" in incremental_state: return incremental_state["bidx"] else: return None def Embedding(num_embeddings, embedding_dim, padding_idx): m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) nn.init.normal_(m.weight, std=0.02) nn.init.constant_(m.weight[padding_idx], 0) return m class SelfAttentionMask(nn.Module): def __init__(self, init_size = 100, device = 0): super(SelfAttentionMask, self).__init__() self.weights = SelfAttentionMask.get_mask(init_size) self.device = device @staticmethod def get_mask(size): weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1) return weights def forward(self, size): if self.weights is None or size > self.weights.size(0): self.weights = SelfAttentionMask.get_mask(size) res = self.weights[:size,:size].cuda(self.device).detach() return res class LearnedPositionalEmbedding(nn.Module): """This module produces LearnedPositionalEmbedding. """ def __init__(self, embedding_dim, init_size=1024, device=0): super(LearnedPositionalEmbedding, self).__init__() self.weights = nn.Embedding(init_size, embedding_dim) self.device= device self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weights.weight, std=0.02) def forward(self, input, offset=0): """Input is expected to be of size [seq_len x bsz].""" seq_len, bsz = input.size() positions = (offset + torch.arange(seq_len)).cuda(self.device) res = self.weights(positions).unsqueeze(1).expand(-1, bsz, -1) return res class SinusoidalPositionalEmbedding(nn.Module): """This module produces sinusoidal positional embeddings of any length. """ def __init__(self, embedding_dim, init_size=1024, device=0): super(SinusoidalPositionalEmbedding, self).__init__() self.embedding_dim = embedding_dim self.weights = SinusoidalPositionalEmbedding.get_embedding( init_size, embedding_dim ) self.device= device @staticmethod def get_embedding(num_embeddings, embedding_dim): """Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) if embedding_dim % 2 == 1: emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) return emb def forward(self, input, offset=0): """Input is expected to be of size [seq_len x bsz].""" seq_len, bsz = input.size() mx_position = seq_len + offset if self.weights is None or mx_position > self.weights.size(0): # recompute/expand embeddings if needed self.weights = SinusoidalPositionalEmbedding.get_embedding( mx_position, self.embedding_dim, ) positions = offset + torch.arange(seq_len) res = self.weights.index_select(0, positions).unsqueeze(1).expand(-1, bsz, -1).cuda(self.device).detach() return res ================================================ FILE: utils.py ================================================ import torch from torch import nn from torch.nn import Parameter from collections import defaultdict import math def gelu(x): cdf = 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) return cdf*x class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.Tensor(hidden_size)) self.bias = nn.Parameter(torch.Tensor(hidden_size)) self.eps = eps self.reset_parameters() def reset_parameters(self): nn.init.constant_(self.weight, 1.) nn.init.constant_(self.bias, 0.) def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) return self.weight * x + self.bias INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) def _get_full_incremental_state_key(module_instance, key): module_name = module_instance.__class__.__name__ # assign a unique ID to each module instance, so that incremental state is # not shared across module instances if not hasattr(module_instance, '_guyu_instance_id'): INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 module_instance._guyu_instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] return '{}.{}.{}'.format(module_name, module_instance._guyu_instance_id, key) def get_incremental_state(module, incremental_state, key): """Helper for getting incremental state for an nn.Module.""" full_key = _get_full_incremental_state_key(module, key) if incremental_state is None or full_key not in incremental_state: return None return incremental_state[full_key] def set_incremental_state(module, incremental_state, key, value): """Helper for setting incremental state for an nn.Module.""" if incremental_state is not None: full_key = _get_full_incremental_state_key(module, key) incremental_state[full_key] = value