master 553edbb5a3c4
10 files
115.6 KB
29.9k tokens
Repository: LeeSureman/Batch_Parallel_LatticeLSTM
Branch: master
Commit: 553edbb5a3c4
Files: 10
Total size: 115.6 KB

Directory structure:
gitextract_lw5g_jo_/

├── README.md
├── load_data.py
├── main.py
├── main_without_fitlog.py
├── models.py
├── modules.py
├── pathes.py
├── small.py
├── utils.py
└── utils_.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
[中文](#支持批并行的LatticeLSTM)

[English](#Batch-Parallel-LatticeLSTM)
# 支持批并行的LatticeLSTM
+ 原论文:https://arxiv.org/abs/1805.02023
+ 在batch=10时,计算速度已明显超过[原版代码](https://github.com/jiesutd/LatticeLSTM)。
+ 在main.py中添加三个embedding的文件路径以及对应数据集的路径即可运行(原文中用的embedding文件下载路径请见https://github.com/jiesutd/LatticeLSTM)
+ 此代码集合已加入fastNLP

## 运行环境:
+ python >= 3.7.3
+ fastNLP >= dev.0.5.0
+ pytorch >= 1.1.0
+ numpy >= 1.16.4
+ fitlog >= 0.2.0
## 支持的数据集:
+ Resume,可以从[这里](https://github.com/jiesutd/LatticeLSTM)下载
+ Ontonote
+ [Weibo](https://github.com/hltcoe/golden-horse)

未包含的数据集可以通过提供增加类似 load_data.py 中 load_ontonotes4ner 这个输出格式的函数来增加对其的支持
## 性能:
|数据集| 目前达到的F1分数(test)|原文中的F1分数(test)|
|:----:|:----:|:----:|
|Weibo|58.66(可能有误)|58.79|
|Resume|95.18|94.46|
|Ontonote|73.62|73.88|

备注:Weibo数据集我用的是V2版本,也就是更新过的版本,根据杨杰博士Github上LatticeLSTM仓库里的某个issue,应该是一致的。

## 如有任何疑问请联系:
+ lixiaonan_xdu@outlook.com

---

# Batch Parallel LatticeLSTM
+ paper:https://arxiv.org/abs/1805.02023
+ when batch is 10,the computation efficiency exceeds that of [original code](https://github.com/jiesutd/LatticeLSTM)。
+ set the path of embeddings and corpus before you run main.py. You can get 3 embeddings in https://github.com/jiesutd/LatticeLSTM
+ this code set has been added to fastNLP

## Environment:
+ python >= 3.7.3
+ fastNLP >= dev.0.5.0
+ pytorch >= 1.1.0
+ numpy >= 1.16.4
+ fitlog >= 0.2.0

## Dataset:
+ Resume,downloaded from [here](https://github.com/jiesutd/LatticeLSTM)
+ Ontonote
+ [Weibo](https://github.com/hltcoe/golden-horse)

to those unincluded dataset, you can write the interface function whose output form is like *load_ontonotes4ner* in load_data.py

## Performance:
|Dataset|F1 of my code(test)|F1 in paper(test)|
|:----:|:----:|:----:|
|Weibo|58.66(maybe wrong)|58.79|
|Resume|95.18|94.46|
|Ontonote|73.62|73.88|

PS:The Weibo dataset I use is V2, namely revised version.
## If any confusion, please contact:
+ lixiaonan_xdu@outlook.com


================================================
FILE: load_data.py
================================================
from fastNLP.io import CSVLoader
from fastNLP import Vocabulary
from fastNLP import Const
import numpy as np
import fitlog
import pickle
import os
from fastNLP.embeddings import StaticEmbedding
from fastNLP import cache_results


@cache_results(_cache_fp='mtl16', _refresh=False)
def load_16_task(dict_path):
    '''

    :param dict_path: /remote-home/txsun/fnlp/MTL-LT/data
    :return:
    '''
    task_path = os.path.join(dict_path,'data.pkl')
    embedding_path = os.path.join(dict_path,'word_embedding.npy')

    embedding = np.load(embedding_path).astype(np.float32)

    task_list = pickle.load(open(task_path, 'rb'))['task_lst']

    for t in task_list:
        t.train_set.rename_field('words_idx', 'words')
        t.dev_set.rename_field('words_idx', 'words')
        t.test_set.rename_field('words_idx', 'words')

        t.train_set.rename_field('label', 'target')
        t.dev_set.rename_field('label', 'target')
        t.test_set.rename_field('label', 'target')

        t.train_set.add_seq_len('words')
        t.dev_set.add_seq_len('words')
        t.test_set.add_seq_len('words')

        t.train_set.set_input(Const.INPUT, Const.INPUT_LEN)
        t.dev_set.set_input(Const.INPUT, Const.INPUT_LEN)
        t.test_set.set_input(Const.INPUT, Const.INPUT_LEN)

    return task_list,embedding


@cache_results(_cache_fp='SST2', _refresh=False)
def load_sst2(dict_path,embedding_path=None):
    '''

    :param dict_path: /remote-home/xnli/data/corpus/text_classification/SST-2/
    :param embedding_path: glove 300d txt
    :return:
    '''
    train_path = os.path.join(dict_path,'train.tsv')
    dev_path = os.path.join(dict_path,'dev.tsv')

    loader = CSVLoader(headers=('words', 'target'), sep='\t')
    train_data = loader.load(train_path).datasets['train']
    dev_data = loader.load(dev_path).datasets['train']

    train_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words')
    dev_data.apply_field(lambda x: x.split(), field_name='words', new_field_name='words')

    train_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len')
    dev_data.apply_field(lambda x: len(x), field_name='words', new_field_name='seq_len')

    vocab = Vocabulary(min_freq=2)
    vocab.from_dataset(train_data, field_name='words')
    vocab.from_dataset(dev_data, field_name='words')

    # pretrained_embedding = load_word_emb(embedding_path, 300, vocab)

    label_vocab = Vocabulary(padding=None, unknown=None).from_dataset(train_data, field_name='target')

    label_vocab.index_dataset(train_data, field_name='target')
    label_vocab.index_dataset(dev_data, field_name='target')

    vocab.index_dataset(train_data, field_name='words', new_field_name='words')
    vocab.index_dataset(dev_data, field_name='words', new_field_name='words')

    train_data.set_input(Const.INPUT, Const.INPUT_LEN)
    train_data.set_target(Const.TARGET)

    dev_data.set_input(Const.INPUT, Const.INPUT_LEN)
    dev_data.set_target(Const.TARGET)

    if embedding_path is not None:
        pretrained_embedding = load_word_emb(embedding_path, 300, vocab)
        return (train_data,dev_data),(vocab,label_vocab),pretrained_embedding

    else:
        return (train_data,dev_data),(vocab,label_vocab)

@cache_results(_cache_fp='OntonotesPOS', _refresh=False)
def load_conllized_ontonote_POS(path,embedding_path=None):
    from fastNLP.io.loader import ConllLoader
    header2index = {'words':3,'POS':4,'NER':10}
    headers = ['words','POS']

    if 'NER' in headers:
        print('警告!通过 load_conllized_ontonote 函数读出来的NER标签不是BIOS,是纯粹的conll格式,是错误的!')
    indexes = list(map(lambda x:header2index[x],headers))

    loader = ConllLoader(headers,indexes)

    bundle = loader.load(path)

    # print(bundle.datasets)

    train_set = bundle.datasets['train']
    dev_set = bundle.datasets['dev']
    test_set = bundle.datasets['test']




    # train_set = loader.load(os.path.join(path,'train.txt'))
    # dev_set = loader.load(os.path.join(path, 'dev.txt'))
    # test_set = loader.load(os.path.join(path, 'test.txt'))

    # print(len(train_set))

    train_set.add_seq_len('words','seq_len')
    dev_set.add_seq_len('words','seq_len')
    test_set.add_seq_len('words','seq_len')



    # print(dataset['POS'])

    vocab = Vocabulary(min_freq=1)
    vocab.from_dataset(train_set,field_name='words')
    vocab.from_dataset(dev_set, field_name='words')
    vocab.from_dataset(test_set, field_name='words')

    vocab.index_dataset(train_set,field_name='words')
    vocab.index_dataset(dev_set, field_name='words')
    vocab.index_dataset(test_set, field_name='words')




    label_vocab_dict = {}

    for i,h in enumerate(headers):
        if h == 'words':
            continue
        label_vocab_dict[h] = Vocabulary(min_freq=1,padding=None,unknown=None)
        label_vocab_dict[h].from_dataset(train_set,field_name=h)

        label_vocab_dict[h].index_dataset(train_set,field_name=h)
        label_vocab_dict[h].index_dataset(dev_set,field_name=h)
        label_vocab_dict[h].index_dataset(test_set,field_name=h)

    train_set.set_input(Const.INPUT, Const.INPUT_LEN)
    train_set.set_target(headers[1])

    dev_set.set_input(Const.INPUT, Const.INPUT_LEN)
    dev_set.set_target(headers[1])

    test_set.set_input(Const.INPUT, Const.INPUT_LEN)
    test_set.set_target(headers[1])

    if len(headers) > 2:
        print('警告:由于任务数量大于1,所以需要每次手动设置target!')


    print('train:',len(train_set),'dev:',len(dev_set),'test:',len(test_set))

    if embedding_path is not None:
        pretrained_embedding = load_word_emb(embedding_path, 300, vocab)
        return (train_set,dev_set,test_set),(vocab,label_vocab_dict),pretrained_embedding
    else:
        return (train_set, dev_set, test_set), (vocab, label_vocab_dict)


@cache_results(_cache_fp='OntonotesNER', _refresh=False)
def load_conllized_ontonote_NER(path,embedding_path=None):
    from fastNLP.io.pipe.conll import OntoNotesNERPipe
    ontoNotesNERPipe = OntoNotesNERPipe(lower=True,target_pad_val=-100)
    bundle_NER = ontoNotesNERPipe.process_from_file(path)

    train_set_NER = bundle_NER.datasets['train']
    dev_set_NER = bundle_NER.datasets['dev']
    test_set_NER = bundle_NER.datasets['test']

    train_set_NER.add_seq_len('words','seq_len')
    dev_set_NER.add_seq_len('words','seq_len')
    test_set_NER.add_seq_len('words','seq_len')


    NER_vocab = bundle_NER.get_vocab('target')
    word_vocab = bundle_NER.get_vocab('words')

    if embedding_path is not None:

        embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
                                dropout=0.5,lower=True)


        # pretrained_embedding = load_word_emb(embedding_path, 300, word_vocab)
        return (train_set_NER,dev_set_NER,test_set_NER),\
               (word_vocab,NER_vocab),embed
    else:
        return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab)

@cache_results(_cache_fp='OntonotesPOSNER', _refresh=False)

def load_conllized_ontonote_NER_POS(path,embedding_path=None):
    from fastNLP.io.pipe.conll import OntoNotesNERPipe
    ontoNotesNERPipe = OntoNotesNERPipe(lower=True)
    bundle_NER = ontoNotesNERPipe.process_from_file(path)

    train_set_NER = bundle_NER.datasets['train']
    dev_set_NER = bundle_NER.datasets['dev']
    test_set_NER = bundle_NER.datasets['test']

    NER_vocab = bundle_NER.get_vocab('target')
    word_vocab = bundle_NER.get_vocab('words')

    (train_set_POS,dev_set_POS,test_set_POS),(_,POS_vocab) = load_conllized_ontonote_POS(path)
    POS_vocab = POS_vocab['POS']

    train_set_NER.add_field('pos',train_set_POS['POS'],is_target=True)
    dev_set_NER.add_field('pos', dev_set_POS['POS'], is_target=True)
    test_set_NER.add_field('pos', test_set_POS['POS'], is_target=True)

    if train_set_NER.has_field('target'):
        train_set_NER.rename_field('target','ner')

    if dev_set_NER.has_field('target'):
        dev_set_NER.rename_field('target','ner')

    if test_set_NER.has_field('target'):
        test_set_NER.rename_field('target','ner')



    if train_set_NER.has_field('pos'):
        train_set_NER.rename_field('pos','posid')
    if dev_set_NER.has_field('pos'):
        dev_set_NER.rename_field('pos','posid')
    if test_set_NER.has_field('pos'):
        test_set_NER.rename_field('pos','posid')

    if train_set_NER.has_field('ner'):
        train_set_NER.rename_field('ner','nerid')
    if dev_set_NER.has_field('ner'):
        dev_set_NER.rename_field('ner','nerid')
    if test_set_NER.has_field('ner'):
        test_set_NER.rename_field('ner','nerid')

    if embedding_path is not None:

        embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
                                dropout=0.5,lower=True)

        return (train_set_NER,dev_set_NER,test_set_NER),\
               (word_vocab,POS_vocab,NER_vocab),embed
    else:
        return (train_set_NER, dev_set_NER, test_set_NER), (NER_vocab, word_vocab)

@cache_results(_cache_fp='Ontonotes3', _refresh=True)
def load_conllized_ontonote_pkl(path,embedding_path=None):

    data_bundle = pickle.load(open(path,'rb'))
    train_set = data_bundle.datasets['train']
    dev_set = data_bundle.datasets['dev']
    test_set = data_bundle.datasets['test']

    train_set.rename_field('pos','posid')
    train_set.rename_field('ner','nerid')
    train_set.rename_field('chunk','chunkid')

    dev_set.rename_field('pos','posid')
    dev_set.rename_field('ner','nerid')
    dev_set.rename_field('chunk','chunkid')

    test_set.rename_field('pos','posid')
    test_set.rename_field('ner','nerid')
    test_set.rename_field('chunk','chunkid')


    word_vocab = data_bundle.vocabs['words']
    pos_vocab = data_bundle.vocabs['pos']
    ner_vocab = data_bundle.vocabs['ner']
    chunk_vocab = data_bundle.vocabs['chunk']


    if embedding_path is not None:

        embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
                                dropout=0.5,lower=True)

        return (train_set,dev_set,test_set),\
               (word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed
    else:
        return (train_set, dev_set, test_set), (word_vocab,ner_vocab)
    # print(data_bundle)










# @cache_results(_cache_fp='Conll2003', _refresh=False)
# def load_conll_2003(path,embedding_path=None):
#     f = open(path, 'rb')
#     data_pkl = pickle.load(f)
#
#     task_lst = data_pkl['task_lst']
#     vocabs = data_pkl['vocabs']
#     # word_vocab = vocabs['words']
#     # pos_vocab = vocabs['pos']
#     # chunk_vocab = vocabs['chunk']
#     # ner_vocab = vocabs['ner']
#
#     if embedding_path is not None:
#         embed = StaticEmbedding(vocab=vocabs['words'], model_dir_or_name=embedding_path, word_dropout=0.01,
#                                 dropout=0.5)
#         return task_lst,vocabs,embed
#     else:
#         return task_lst,vocabs

# @cache_results(_cache_fp='Conll2003_mine', _refresh=False)
@cache_results(_cache_fp='Conll2003_mine_embed_100', _refresh=True)
def load_conll_2003_mine(path,embedding_path=None,pad_val=-100):
    f = open(path, 'rb')

    data_pkl = pickle.load(f)
    # print(data_pkl)
    # print(data_pkl)
    train_set = data_pkl[0]['train']
    dev_set = data_pkl[0]['dev']
    test_set = data_pkl[0]['test']

    train_set.set_pad_val('posid',pad_val)
    train_set.set_pad_val('nerid', pad_val)
    train_set.set_pad_val('chunkid', pad_val)

    dev_set.set_pad_val('posid',pad_val)
    dev_set.set_pad_val('nerid', pad_val)
    dev_set.set_pad_val('chunkid', pad_val)

    test_set.set_pad_val('posid',pad_val)
    test_set.set_pad_val('nerid', pad_val)
    test_set.set_pad_val('chunkid', pad_val)

    if train_set.has_field('task_id'):

        train_set.delete_field('task_id')

    if dev_set.has_field('task_id'):
        dev_set.delete_field('task_id')

    if test_set.has_field('task_id'):
        test_set.delete_field('task_id')

    if train_set.has_field('words_idx'):
        train_set.rename_field('words_idx','words')

    if dev_set.has_field('words_idx'):
        dev_set.rename_field('words_idx','words')

    if test_set.has_field('words_idx'):
        test_set.rename_field('words_idx','words')



    word_vocab = data_pkl[1]['words']
    pos_vocab = data_pkl[1]['pos']
    ner_vocab = data_pkl[1]['ner']
    chunk_vocab = data_pkl[1]['chunk']

    if embedding_path is not None:
        embed = StaticEmbedding(vocab=word_vocab, model_dir_or_name=embedding_path, word_dropout=0.01,
                                dropout=0.5,lower=True)
        return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab),embed
    else:
        return (train_set,dev_set,test_set),(word_vocab,pos_vocab,ner_vocab,chunk_vocab)


def load_conllized_ontonote_pkl_yf(path):
    def init_task(task):
        task_name = task.task_name
        for ds in [task.train_set, task.dev_set, task.test_set]:
            if ds.has_field('words'):
                ds.rename_field('words', 'x')
            else:
                ds.rename_field('words_idx', 'x')
            if ds.has_field('label'):
                ds.rename_field('label', 'y')
            else:
                ds.rename_field(task_name, 'y')
            ds.set_input('x', 'y', 'task_id')
            ds.set_target('y')

            if task_name in ['ner', 'chunk'] or 'pos' in task_name:
                ds.set_input('seq_len')
                ds.set_target('seq_len')
        return task
    #/remote-home/yfshao/workdir/datasets/conll03/data.pkl
    def pload(fn):
        with open(fn, 'rb') as f:
            return pickle.load(f)

    DB = pload(path)
    task_lst = DB['task_lst']
    vocabs = DB['vocabs']
    task_lst = [init_task(task) for task in task_lst]

    return task_lst, vocabs


@cache_results(_cache_fp='weiboNER old uni+bi', _refresh=False)
def load_weibo_ner_old(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True,
                  normlize={'char':True,'bigram':True,'word':False}):
    from fastNLP.io.loader import ConllLoader
    from utils import get_bigrams

    loader = ConllLoader(['chars','target'])
    # from fastNLP.io.file_reader import _read_conll
    # from fastNLP.core import Instance,DataSet
    # def _load(path):
    #     ds = DataSet()
    #     for idx, data in _read_conll(path, indexes=loader.indexes, dropna=loader.dropna,
    #                                 encoding='ISO-8859-1'):
    #         ins = {h: data[i] for i, h in enumerate(loader.headers)}
    #         ds.append(Instance(**ins))
    #     return ds
    # from fastNLP.io.utils import check_loader_paths
    # paths = check_loader_paths(path)
    # datasets = {name: _load(path) for name, path in paths.items()}
    datasets = {}
    train_path = os.path.join(path,'train.all.bmes')
    dev_path = os.path.join(path,'dev.all.bmes')
    test_path = os.path.join(path,'test.all.bmes')
    datasets['train'] = loader.load(train_path).datasets['train']
    datasets['dev'] = loader.load(dev_path).datasets['train']
    datasets['test'] = loader.load(test_path).datasets['train']

    for k,v in datasets.items():
        print('{}:{}'.format(k,len(v)))

    vocabs = {}
    word_vocab = Vocabulary()
    bigram_vocab = Vocabulary()
    label_vocab = Vocabulary(padding=None,unknown=None)

    for k,v in datasets.items():
        # ignore the word segmentation tag
        v.apply_field(lambda x: [w[0] for w in x],'chars','chars')
        v.apply_field(get_bigrams,'chars','bigrams')


    word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']])
    label_vocab.from_dataset(datasets['train'],field_name='target')
    print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word))


    for k,v in datasets.items():
        # v.set_pad_val('target',-100)
        v.add_seq_len('chars',new_field_name='seq_len')


    vocabs['char'] = word_vocab
    vocabs['label'] = label_vocab


    bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']])
    if index_token:
        word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words')
        bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams')
        label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target')

    # for k,v in datasets.items():
    #     v.set_input('chars','bigrams','seq_len','target')
    #     v.set_target('target','seq_len')

    vocabs['bigram'] = bigram_vocab

    embeddings = {}

    if unigram_embedding_path is not None:
        unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path,
                                            word_dropout=0.01,normalize=normlize['char'])
        embeddings['char'] = unigram_embedding

    if bigram_embedding_path is not None:
        bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path,
                                           word_dropout=0.01,normalize=normlize['bigram'])
        embeddings['bigram'] = bigram_embedding

    return datasets, vocabs, embeddings


@cache_results(_cache_fp='weiboNER uni+bi', _refresh=False)
def load_weibo_ner(path,unigram_embedding_path=None,bigram_embedding_path=None,index_token=True,
                   normlize={'char':True,'bigram':True,'word':False}):
    from fastNLP.io.loader import ConllLoader
    from utils import get_bigrams

    loader = ConllLoader(['chars','target'])
    bundle = loader.load(path)

    datasets = bundle.datasets
    for k,v in datasets.items():
        print('{}:{}'.format(k,len(v)))
    # print(*list(datasets.keys()))
    vocabs = {}
    word_vocab = Vocabulary()
    bigram_vocab = Vocabulary()
    label_vocab = Vocabulary(padding=None,unknown=None)

    for k,v in datasets.items():
        # ignore the word segmentation tag
        v.apply_field(lambda x: [w[0] for w in x],'chars','chars')
        v.apply_field(get_bigrams,'chars','bigrams')


    word_vocab.from_dataset(datasets['train'],field_name='chars',no_create_entry_dataset=[datasets['dev'],datasets['test']])
    label_vocab.from_dataset(datasets['train'],field_name='target')
    print('label_vocab:{}\n{}'.format(len(label_vocab),label_vocab.idx2word))


    for k,v in datasets.items():
        # v.set_pad_val('target',-100)
        v.add_seq_len('chars',new_field_name='seq_len')


    vocabs['char'] = word_vocab
    vocabs['label'] = label_vocab


    bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',no_create_entry_dataset=[datasets['dev'],datasets['test']])
    if index_token:
        word_vocab.index_dataset(*list(datasets.values()), field_name='raw_words', new_field_name='words')
        bigram_vocab.index_dataset(*list(datasets.values()),field_name='raw_bigrams',new_field_name='bigrams')
        label_vocab.index_dataset(*list(datasets.values()), field_name='raw_target', new_field_name='target')

    # for k,v in datasets.items():
    #     v.set_input('chars','bigrams','seq_len','target')
    #     v.set_target('target','seq_len')

    vocabs['bigram'] = bigram_vocab

    embeddings = {}

    if unigram_embedding_path is not None:
        unigram_embedding = StaticEmbedding(word_vocab, model_dir_or_name=unigram_embedding_path,
                                            word_dropout=0.01,normalize=normlize['char'])
        embeddings['char'] = unigram_embedding

    if bigram_embedding_path is not None:
        bigram_embedding = StaticEmbedding(bigram_vocab, model_dir_or_name=bigram_embedding_path,
                                           word_dropout=0.01,normalize=normlize['bigram'])
        embeddings['bigram'] = bigram_embedding

    return datasets, vocabs, embeddings



# datasets,vocabs = load_weibo_ner('/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo')
#
# print(datasets['train'][:5])
# print(vocabs['word'].idx2word)
# print(vocabs['target'].idx2word)


@cache_results(_cache_fp='cache/ontonotes4ner',_refresh=False)
def load_ontonotes4ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True,
                       normalize={'char':True,'bigram':True,'word':False}):
    from fastNLP.io.loader import ConllLoader
    from utils import get_bigrams

    train_path = os.path.join(path,'train.char.bmes')
    dev_path = os.path.join(path,'dev.char.bmes')
    test_path = os.path.join(path,'test.char.bmes')

    loader = ConllLoader(['chars','target'])
    train_bundle = loader.load(train_path)
    dev_bundle = loader.load(dev_path)
    test_bundle = loader.load(test_path)


    datasets = dict()
    datasets['train'] = train_bundle.datasets['train']
    datasets['dev'] = dev_bundle.datasets['train']
    datasets['test'] = test_bundle.datasets['train']


    datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams')
    datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')
    datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')

    datasets['train'].add_seq_len('chars')
    datasets['dev'].add_seq_len('chars')
    datasets['test'].add_seq_len('chars')



    char_vocab = Vocabulary()
    bigram_vocab = Vocabulary()
    label_vocab = Vocabulary(padding=None,unknown=None)
    print(datasets.keys())
    print(len(datasets['dev']))
    print(len(datasets['test']))
    print(len(datasets['train']))
    char_vocab.from_dataset(datasets['train'],field_name='chars',
                            no_create_entry_dataset=[datasets['dev'],datasets['test']] )
    bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',
                              no_create_entry_dataset=[datasets['dev'],datasets['test']])
    label_vocab.from_dataset(datasets['train'],field_name='target')
    if index_token:
        char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='chars',new_field_name='chars')
        bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='bigrams',new_field_name='bigrams')
        label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='target',new_field_name='target')

    vocabs = {}
    vocabs['char'] = char_vocab
    vocabs['label'] = label_vocab
    vocabs['bigram'] = bigram_vocab
    vocabs['label'] = label_vocab

    embeddings = {}
    if char_embedding_path is not None:
        char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,
                                         normalize=normalize['char'])
        embeddings['char'] = char_embedding

    if bigram_embedding_path is not None:
        bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,
                                           normalize=normalize['bigram'])
        embeddings['bigram'] = bigram_embedding

    return datasets,vocabs,embeddings



@cache_results(_cache_fp='cache/resume_ner',_refresh=False)
def load_resume_ner(path,char_embedding_path=None,bigram_embedding_path=None,index_token=True,
                    normalize={'char':True,'bigram':True,'word':False}):
    from fastNLP.io.loader import ConllLoader
    from utils import get_bigrams

    train_path = os.path.join(path,'train.char.bmes')
    dev_path = os.path.join(path,'dev.char.bmes')
    test_path = os.path.join(path,'test.char.bmes')

    loader = ConllLoader(['chars','target'])
    train_bundle = loader.load(train_path)
    dev_bundle = loader.load(dev_path)
    test_bundle = loader.load(test_path)


    datasets = dict()
    datasets['train'] = train_bundle.datasets['train']
    datasets['dev'] = dev_bundle.datasets['train']
    datasets['test'] = test_bundle.datasets['train']


    datasets['train'].apply_field(get_bigrams,field_name='chars',new_field_name='bigrams')
    datasets['dev'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')
    datasets['test'].apply_field(get_bigrams, field_name='chars', new_field_name='bigrams')

    datasets['train'].add_seq_len('chars')
    datasets['dev'].add_seq_len('chars')
    datasets['test'].add_seq_len('chars')



    char_vocab = Vocabulary()
    bigram_vocab = Vocabulary()
    label_vocab = Vocabulary(padding=None,unknown=None)
    print(datasets.keys())
    print(len(datasets['dev']))
    print(len(datasets['test']))
    print(len(datasets['train']))
    char_vocab.from_dataset(datasets['train'],field_name='chars',
                            no_create_entry_dataset=[datasets['dev'],datasets['test']] )
    bigram_vocab.from_dataset(datasets['train'],field_name='bigrams',
                              no_create_entry_dataset=[datasets['dev'],datasets['test']])
    label_vocab.from_dataset(datasets['train'],field_name='target')
    if index_token:
        char_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='chars',new_field_name='chars')
        bigram_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='bigrams',new_field_name='bigrams')
        label_vocab.index_dataset(datasets['train'],datasets['dev'],datasets['test'],
                                 field_name='target',new_field_name='target')

    vocabs = {}
    vocabs['char'] = char_vocab
    vocabs['label'] = label_vocab
    vocabs['bigram'] = bigram_vocab

    embeddings = {}
    if char_embedding_path is not None:
        char_embedding = StaticEmbedding(char_vocab,char_embedding_path,word_dropout=0.01,normalize=normalize['char'])
        embeddings['char'] = char_embedding

    if bigram_embedding_path is not None:
        bigram_embedding = StaticEmbedding(bigram_vocab,bigram_embedding_path,word_dropout=0.01,normalize=normalize['bigram'])
        embeddings['bigram'] = bigram_embedding

    return datasets,vocabs,embeddings


@cache_results(_cache_fp='need_to_defined_fp',_refresh=False)
def equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,word_embedding_path=None,
                                normalize={'char':True,'bigram':True,'word':False}):
    from utils_ import Trie,get_skip_path
    from functools import partial
    w_trie = Trie()
    for w in w_list:
        w_trie.insert(w)

    # for k,v in datasets.items():
    #     v.apply_field(partial(get_skip_path,w_trie=w_trie),'chars','skips')

    def skips2skips_l2r(chars,w_trie):
        '''

        :param lexicons: list[[int,int,str]]
        :return: skips_l2r
        '''
        # print(lexicons)
        # print('******')

        lexicons = get_skip_path(chars,w_trie=w_trie)


        # max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0

        result = [[] for _ in range(len(chars))]

        for lex in lexicons:
            s = lex[0]
            e = lex[1]
            w = lex[2]

            result[e].append([s,w])

        return result

    def skips2skips_r2l(chars,w_trie):
        '''

        :param lexicons: list[[int,int,str]]
        :return: skips_l2r
        '''
        # print(lexicons)
        # print('******')

        lexicons = get_skip_path(chars,w_trie=w_trie)


        # max_len = max(list(map(lambda x:max(x[:2]),lexicons)))+1 if len(lexicons) != 0 else 0

        result = [[] for _ in range(len(chars))]

        for lex in lexicons:
            s = lex[0]
            e = lex[1]
            w = lex[2]

            result[s].append([e,w])

        return result

    for k,v in datasets.items():
        v.apply_field(partial(skips2skips_l2r,w_trie=w_trie),'chars','skips_l2r')

    for k,v in datasets.items():
        v.apply_field(partial(skips2skips_r2l,w_trie=w_trie),'chars','skips_r2l')

    # print(v['skips_l2r'][0])
    word_vocab = Vocabulary()
    word_vocab.add_word_lst(w_list)
    vocabs['word'] = word_vocab
    for k,v in datasets.items():
        v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_l2r','skips_l2r_source')
        v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_l2r', 'skips_l2r_word')

    for k,v in datasets.items():
        v.apply_field(lambda x:[ list(map(lambda x:x[0],p)) for p in x],'skips_r2l','skips_r2l_source')
        v.apply_field(lambda x:[ list(map(lambda x:x[1],p)) for p in x], 'skips_r2l', 'skips_r2l_word')

    for k,v in datasets.items():
        v.apply_field(lambda x:list(map(len,x)), 'skips_l2r_word', 'lexicon_count')
        v.apply_field(lambda x:
                      list(map(lambda y:
                               list(map(lambda z:word_vocab.to_index(z),y)),x)),
                      'skips_l2r_word',new_field_name='skips_l2r_word')

        v.apply_field(lambda x:list(map(len,x)), 'skips_r2l_word', 'lexicon_count_back')

        v.apply_field(lambda x:
                      list(map(lambda y:
                               list(map(lambda z:word_vocab.to_index(z),y)),x)),
                      'skips_r2l_word',new_field_name='skips_r2l_word')





    if word_embedding_path is not None:
        word_embedding = StaticEmbedding(word_vocab,word_embedding_path,word_dropout=0,normalize=normalize['word'])
        embeddings['word'] = word_embedding

    vocabs['char'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
                             field_name='chars', new_field_name='chars')
    vocabs['bigram'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
                               field_name='bigrams', new_field_name='bigrams')
    vocabs['label'].index_dataset(datasets['train'], datasets['dev'], datasets['test'],
                              field_name='target', new_field_name='target')

    return datasets,vocabs,embeddings



@cache_results(_cache_fp='cache/load_yangjie_rich_pretrain_word_list',_refresh=False)
def load_yangjie_rich_pretrain_word_list(embedding_path,drop_characters=True):
    f = open(embedding_path,'r')
    lines = f.readlines()
    w_list = []
    for line in lines:
        splited = line.strip().split(' ')
        w = splited[0]
        w_list.append(w)

    if drop_characters:
        w_list = list(filter(lambda x:len(x) != 1, w_list))

    return w_list



# from pathes import *
#
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,
#                                                 yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path)
# print(datasets.keys())
# print(vocabs.keys())
# print(embeddings)
# yangjie_rich_pretrain_word_path
# datasets['train'].set_pad_val

================================================
FILE: main.py
================================================
import torch.nn as nn
# print(1111111111)
# from pathes import *
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,\
    load_resume_ner,load_weibo_ner,load_weibo_ner_old
from fastNLP.embeddings import StaticEmbedding
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward
import torch.optim as optim
import argparse
import torch
import sys
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ
from fastNLP import Tester
import fitlog
from fastNLP.core.callback import FitlogCallback
from utils import set_seed
import os
from fastNLP import LRScheduler
from torch.optim.lr_scheduler import LambdaLR

parser = argparse.ArgumentParser()
parser.add_argument('--device',default='cuda:1')
parser.add_argument('--debug',default=False)

parser.add_argument('--norm_embed',default=False)
parser.add_argument('--batch',default=1)
parser.add_argument('--test_batch',default=1024)
parser.add_argument('--optim',default='sgd',help='adam|sgd')
parser.add_argument('--lr',default=0.045)
parser.add_argument('--model',default='lattice',help='lattice|lstm')
parser.add_argument('--skip_before_head',default=False)#in paper it's false
parser.add_argument('--hidden',default=113)
parser.add_argument('--momentum',default=0)
parser.add_argument('--bi',default=True)
parser.add_argument('--dataset',default='resume',help='resume|ontonote|weibo|msra')
parser.add_argument('--use_bigram',default=True)

parser.add_argument('--embed_dropout',default=0.5)
parser.add_argument('--gaz_dropout',default=-1)
parser.add_argument('--output_dropout',default=0.5)
parser.add_argument('--epoch',default=100)
parser.add_argument('--seed',default=100)

args = parser.parse_args()

set_seed(args.seed)

fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)]
if args.model == 'lattice':
    fit_msg_list.append(str(args.skip_before_head))
fit_msg = ' '.join(fit_msg_list)
fitlog.commit(__file__,fit_msg=fit_msg)

device = torch.device(args.device)
for k,v in args.__dict__.items():
    print(k,v)

refresh_data = False


from pathes import *
# ontonote4ner_cn_path = 0
# yangjie_rich_pretrain_unigram_path = 0
# yangjie_rich_pretrain_bigram_path = 0
# resume_ner_path = 0
# weibo_ner_path = 0

if args.dataset == 'ontonote':
    datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
elif args.dataset == 'resume':
    datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
elif args.dataset == 'weibo':
    datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )

elif args.dataset == 'weibo_old':
    datasets,vocabs,embeddings = load_weibo_ner_old(weibo_ner_old_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
if args.dataset == 'ontonote':
    args.batch = 10
    args.lr = 0.045
elif args.dataset == 'resume':
    args.batch = 1
    args.lr = 0.015
elif args.dataset == 'weibo':
    args.batch = 10
    args.gaz_dropout = 0.1
    args.embed_dropout = 0.1
    args.output_dropout = 0.1
elif args.dataset == 'weibo_old':
    args.embed_dropout = 0.1
    args.output_dropout = 0.1

if args.gaz_dropout < 0:
    args.gaz_dropout = args.embed_dropout

fitlog.add_hyper(args)
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
                                              _refresh=refresh_data)

cache_name = os.path.join('cache',args.dataset+'_lattice')
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
                                                         _refresh=refresh_data,_cache_fp=cache_name)

print(datasets['train'][0])
print('vocab info:')
for k,v in vocabs.items():
    print('{}:{}'.format(k,len(v)))

for k,v in datasets.items():
    if args.model == 'lattice':
        v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source')
        if args.skip_before_head:
            v.set_padder('skips_l2r_word',LatticeLexiconPadder())
            v.set_padder('skips_l2r_source',LatticeLexiconPadder())
            v.set_padder('skips_r2l_word',LatticeLexiconPadder())
            v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True))
        else:
            v.set_padder('skips_l2r_word',LatticeLexiconPadder())
            v.set_padder('skips_r2l_word', LatticeLexiconPadder())
            v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1))
            v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1))
        if args.bi:
            v.set_input('chars','bigrams','seq_len',
                        'skips_l2r_word','skips_l2r_source','lexicon_count',
                        'skips_r2l_word', 'skips_r2l_source','lexicon_count_back',
                        'target',
                        use_1st_ins_infer_dim_type=True)
        else:
            v.set_input('chars','bigrams','seq_len',
                        'skips_l2r_word','skips_l2r_source','lexicon_count',
                        'target',
                        use_1st_ins_infer_dim_type=True)
        v.set_target('target','seq_len')

        v['target'].set_pad_val(0)
    elif args.model == 'lstm':
        v.set_ignore_type('skips_l2r_word','skips_l2r_source')
        v.set_padder('skips_l2r_word',LatticeLexiconPadder())
        v.set_padder('skips_l2r_source',LatticeLexiconPadder())
        v.set_input('chars','bigrams','seq_len','target',
                    use_1st_ins_infer_dim_type=True)
        v.set_target('target','seq_len')

        v['target'].set_pad_val(0)

print(datasets['dev']['skips_l2r_word'][100])


if args.model =='lattice':
    model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'],
                        hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
                                 embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
                                 skip_batch_first=True,bidirectional=args.bi,debug=args.debug,
                                 skip_before_head=args.skip_before_head,use_bigram=args.use_bigram,
                                    gaz_dropout=args.gaz_dropout
                                 )
elif args.model == 'lstm':
    model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'],
                        hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
                          bidirectional=args.bi,
                          embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
                          use_bigram=args.use_bigram)


loss = LossInForward()
encoding_type = 'bmeso'
if args.dataset == 'weibo':
    encoding_type = 'bio'
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type=encoding_type)
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len')
metrics = [f1_metric,acc_metric]

if args.optim == 'adam':
    optimizer = optim.Adam(model.parameters(),lr=args.lr)
elif args.optim == 'sgd':
    optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum)




callbacks = [
    FitlogCallback({'test':datasets['test'],'train':datasets['train']}),
    LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep))
]
print('label_vocab:{}\n{}'.format(len(vocabs['label']),vocabs['label'].idx2word))
trainer = Trainer(datasets['train'],model,
                  optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  dev_data=datasets['dev'],
                  device=device,
                  batch_size=args.batch,
                  n_epochs=args.epoch,
                  dev_batch_size=args.test_batch,
                  callbacks=callbacks)

trainer.train()

================================================
FILE: main_without_fitlog.py
================================================
import torch.nn as nn
# print(1111111111)
# from pathes import *
from load_data import load_ontonotes4ner,equip_chinese_ner_with_skip,load_yangjie_rich_pretrain_word_list,\
    load_resume_ner,load_weibo_ner,load_weibo_ner_old
from fastNLP.embeddings import StaticEmbedding
from models import LatticeLSTM_SeqLabel,LSTM_SeqLabel,LatticeLSTM_SeqLabel_V1
from fastNLP import CrossEntropyLoss,SpanFPreRecMetric,Trainer,AccuracyMetric,LossInForward
import torch.optim as optim
import argparse
import torch
import sys
from utils_ import LatticeLexiconPadder,SpanFPreRecMetric_YJ
from fastNLP import Tester
# import fitlog
# from fastNLP.core.callback import FitlogCallback
from utils import set_seed
import os
from fastNLP import LRScheduler
from torch.optim.lr_scheduler import LambdaLR

parser = argparse.ArgumentParser()
parser.add_argument('--device',default='cuda:1')
parser.add_argument('--debug',default=False)

parser.add_argument('--norm_embed',default=False)
parser.add_argument('--batch',default=1)
parser.add_argument('--test_batch',default=1024)
parser.add_argument('--optim',default='sgd',help='adam|sgd')
parser.add_argument('--lr',default=0.045)
parser.add_argument('--model',default='lattice',help='lattice|lstm')
parser.add_argument('--skip_before_head',default=False)#in paper it's false
parser.add_argument('--hidden',default=113)
parser.add_argument('--momentum',default=0)
parser.add_argument('--bi',default=True)
parser.add_argument('--dataset',default='weibo',help='resume|ontonote|weibo|msra')
parser.add_argument('--use_bigram',default=True)

parser.add_argument('--embed_dropout',default=0.5)
parser.add_argument('--gaz_dropout',default=-1)
parser.add_argument('--output_dropout',default=0.5)
parser.add_argument('--epoch',default=100)
parser.add_argument('--seed',default=100)

args = parser.parse_args()

set_seed(args.seed)

fit_msg_list = [args.model,'bi' if args.bi else 'uni',str(args.batch)]
if args.model == 'lattice':
    fit_msg_list.append(str(args.skip_before_head))
fit_msg = ' '.join(fit_msg_list)
# fitlog.commit(__file__,fit_msg=fit_msg)

device = torch.device(args.device)
for k,v in args.__dict__.items():
    print(k,v)

refresh_data = False


from pathes import *
# ontonote4ner_cn_path = 0
# yangjie_rich_pretrain_unigram_path = 0
# yangjie_rich_pretrain_bigram_path = 0
# resume_ner_path = 0
# weibo_ner_path = 0

if args.dataset == 'ontonote':
    datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
elif args.dataset == 'resume':
    datasets,vocabs,embeddings = load_resume_ner(resume_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
elif args.dataset == 'weibo':
    datasets,vocabs,embeddings = load_weibo_ner(weibo_ner_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )

elif args.dataset == 'weibo_old':
    datasets,vocabs,embeddings = load_weibo_ner_old(weibo_ner_old_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
                                                    _refresh=refresh_data,index_token=False,
                                                    )
if args.dataset == 'ontonote':
    args.batch = 10
    args.lr = 0.045
elif args.dataset == 'resume':
    args.batch = 1
    args.lr = 0.015
elif args.dataset == 'weibo':
    args.batch = 10
    args.gaz_dropout = 0.1
    args.embed_dropout = 0.1
    args.output_dropout = 0.1
elif args.dataset == 'weibo_old':
    args.embed_dropout = 0.1
    args.output_dropout = 0.1

if args.gaz_dropout < 0:
    args.gaz_dropout = args.embed_dropout

# fitlog.add_hyper(args)
w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
                                              _refresh=refresh_data)

cache_name = os.path.join('cache',args.dataset+'_lattice')
datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
                                                         _refresh=refresh_data,_cache_fp=cache_name)

print(datasets['train'][0])
print('vocab info:')
for k,v in vocabs.items():
    print('{}:{}'.format(k,len(v)))

for k,v in datasets.items():
    if args.model == 'lattice':
        v.set_ignore_type('skips_l2r_word','skips_l2r_source','skips_r2l_word', 'skips_r2l_source')
        if args.skip_before_head:
            v.set_padder('skips_l2r_word',LatticeLexiconPadder())
            v.set_padder('skips_l2r_source',LatticeLexiconPadder())
            v.set_padder('skips_r2l_word',LatticeLexiconPadder())
            v.set_padder('skips_r2l_source',LatticeLexiconPadder(pad_val_dynamic=True))
        else:
            v.set_padder('skips_l2r_word',LatticeLexiconPadder())
            v.set_padder('skips_r2l_word', LatticeLexiconPadder())
            v.set_padder('skips_l2r_source', LatticeLexiconPadder(-1))
            v.set_padder('skips_r2l_source', LatticeLexiconPadder(pad_val_dynamic=True,dynamic_offset=1))
        if args.bi:
            v.set_input('chars','bigrams','seq_len',
                        'skips_l2r_word','skips_l2r_source','lexicon_count',
                        'skips_r2l_word', 'skips_r2l_source','lexicon_count_back',
                        'target',
                        use_1st_ins_infer_dim_type=True)
        else:
            v.set_input('chars','bigrams','seq_len',
                        'skips_l2r_word','skips_l2r_source','lexicon_count',
                        'target',
                        use_1st_ins_infer_dim_type=True)
        v.set_target('target','seq_len')

        v['target'].set_pad_val(0)
    elif args.model == 'lstm':
        v.set_ignore_type('skips_l2r_word','skips_l2r_source')
        v.set_padder('skips_l2r_word',LatticeLexiconPadder())
        v.set_padder('skips_l2r_source',LatticeLexiconPadder())
        v.set_input('chars','bigrams','seq_len','target',
                    use_1st_ins_infer_dim_type=True)
        v.set_target('target','seq_len')

        v['target'].set_pad_val(0)

print(datasets['dev']['skips_l2r_word'][100])


if args.model =='lattice':
    model = LatticeLSTM_SeqLabel_V1(embeddings['char'],embeddings['bigram'],embeddings['word'],
                        hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
                                 embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
                                 skip_batch_first=True,bidirectional=args.bi,debug=args.debug,
                                 skip_before_head=args.skip_before_head,use_bigram=args.use_bigram,
                                    gaz_dropout=args.gaz_dropout
                                 )
elif args.model == 'lstm':
    model = LSTM_SeqLabel(embeddings['char'],embeddings['bigram'],embeddings['word'],
                        hidden_size=args.hidden,label_size=len(vocabs['label']),device=args.device,
                          bidirectional=args.bi,
                          embed_dropout=args.embed_dropout,output_dropout=args.output_dropout,
                          use_bigram=args.use_bigram)


loss = LossInForward()
encoding_type = 'bmeso'
if args.dataset == 'weibo':
    encoding_type = 'bio'
f1_metric = SpanFPreRecMetric(vocabs['label'],pred='pred',target='target',seq_len='seq_len',encoding_type=encoding_type)
acc_metric = AccuracyMetric(pred='pred',target='target',seq_len='seq_len')
metrics = [f1_metric,acc_metric]

if args.optim == 'adam':
    optimizer = optim.Adam(model.parameters(),lr=args.lr)
elif args.optim == 'sgd':
    optimizer = optim.SGD(model.parameters(),lr=args.lr,momentum=args.momentum)




callbacks = [
    # FitlogCallback({'test':datasets['test'],'train':datasets['train']}),
    LRScheduler(lr_scheduler=LambdaLR(optimizer, lambda ep: 1 / (1 + 0.03)**ep))
]
print('label_vocab:{}\n{}'.format(len(vocabs['label']),vocabs['label'].idx2word))
trainer = Trainer(datasets['train'],model,
                  optimizer=optimizer,
                  loss=loss,
                  metrics=metrics,
                  dev_data=datasets['dev'],
                  device=device,
                  batch_size=args.batch,
                  n_epochs=args.epoch,
                  dev_batch_size=args.test_batch,
                  callbacks=callbacks)

trainer.train()

================================================
FILE: models.py
================================================
import torch.nn as nn
from fastNLP.embeddings import StaticEmbedding
from fastNLP.modules import LSTM, ConditionalRandomField
import torch
from fastNLP import seq_len_to_mask
from utils import better_init_rnn,print_info


class LatticeLSTM_SeqLabel(nn.Module):
    def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False,
                 device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False,
                 skip_before_head=False,use_bigram=True,vocabs=None):
        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device(device)
        from modules import LatticeLSTMLayer_sup_back_V0
        super().__init__()
        self.debug = debug
        self.skip_batch_first = skip_batch_first
        self.char_embed_size = char_embed.embedding.weight.size(1)
        self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
        self.word_embed_size = word_embed.embedding.weight.size(1)
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.bidirectional = bidirectional
        self.use_bigram = use_bigram
        self.vocabs = vocabs

        if self.use_bigram:
            self.input_size = self.char_embed_size + self.bigram_embed_size
        else:
            self.input_size = self.char_embed_size

        self.char_embed = char_embed
        self.bigram_embed = bigram_embed
        self.word_embed = word_embed
        self.encoder = LatticeLSTMLayer_sup_back_V0(self.input_size,self.word_embed_size,
                                                 self.hidden_size,
                                                 left2right=True,
                                                 bias=bias,
                                                 device=self.device,
                                                 debug=self.debug,
                                                 skip_before_head=skip_before_head)
        if self.bidirectional:
            self.encoder_back = LatticeLSTMLayer_sup_back_V0(self.input_size,
                                                          self.word_embed_size, self.hidden_size,
                                                          left2right=False,
                                                          bias=bias,
                                                          device=self.device,
                                                          debug=self.debug,
                                                          skip_before_head=skip_before_head)

        self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)
        self.crf = ConditionalRandomField(label_size, True)

        self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True))
        if self.crf.include_start_end_trans:
            self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))
            self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))

        self.loss_func = nn.CrossEntropyLoss()
        self.embed_dropout = nn.Dropout(embed_dropout)
        self.output_dropout = nn.Dropout(output_dropout)

    def forward(self, chars, bigrams, seq_len, target,
                skips_l2r_source, skips_l2r_word, lexicon_count,
                skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None):
        # print('skips_l2r_word_id:{}'.format(skips_l2r_word.size()))
        batch = chars.size(0)
        max_seq_len = chars.size(1)
        # max_lexicon_count = skips_l2r_word.size(2)


        embed_char = self.char_embed(chars)
        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char


        embed_nonword = self.embed_dropout(embedding)

        # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1])
        embed_word = self.word_embed(skips_l2r_word)
        embed_word = self.embed_dropout(embed_word)
        # embed_word = torch.reshape(embed_word,shape=[batch,max_seq_len,max_lexicon_count,-1])


        encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count)

        if self.bidirectional:
            embed_word_back = self.word_embed(skips_r2l_word)
            embed_word_back = self.embed_dropout(embed_word_back)
            encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source,
                                                               embed_word_back, lexicon_count_back)
            encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}

        # batch_size, sent_len = pred.shape[0], pred.shape[1]
        # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
        # return {'pred':pred,'loss':loss}

class LatticeLSTM_SeqLabel_V1(nn.Module):
    def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True, bidirectional=False,
                 device=None, embed_dropout=0, output_dropout=0, skip_batch_first=True,debug=False,
                 skip_before_head=False,use_bigram=True,vocabs=None,gaz_dropout=0):
        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device(device)
        from modules import LatticeLSTMLayer_sup_back_V1
        super().__init__()
        self.count = 0
        self.debug = debug
        self.skip_batch_first = skip_batch_first
        self.char_embed_size = char_embed.embedding.weight.size(1)
        self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
        self.word_embed_size = word_embed.embedding.weight.size(1)
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.bidirectional = bidirectional
        self.use_bigram = use_bigram
        self.vocabs = vocabs

        if self.use_bigram:
            self.input_size = self.char_embed_size + self.bigram_embed_size
        else:
            self.input_size = self.char_embed_size

        self.char_embed = char_embed
        self.bigram_embed = bigram_embed
        self.word_embed = word_embed
        self.encoder = LatticeLSTMLayer_sup_back_V1(self.input_size,self.word_embed_size,
                                                 self.hidden_size,
                                                 left2right=True,
                                                 bias=bias,
                                                 device=self.device,
                                                 debug=self.debug,
                                                 skip_before_head=skip_before_head)
        if self.bidirectional:
            self.encoder_back = LatticeLSTMLayer_sup_back_V1(self.input_size,
                                                          self.word_embed_size, self.hidden_size,
                                                          left2right=False,
                                                          bias=bias,
                                                          device=self.device,
                                                          debug=self.debug,
                                                          skip_before_head=skip_before_head)

        self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)
        self.crf = ConditionalRandomField(label_size, True)

        self.crf.trans_m = nn.Parameter(torch.zeros(size=[label_size, label_size],requires_grad=True))
        if self.crf.include_start_end_trans:
            self.crf.start_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))
            self.crf.end_scores = nn.Parameter(torch.zeros(size=[label_size],requires_grad=True))

        self.loss_func = nn.CrossEntropyLoss()
        self.embed_dropout = nn.Dropout(embed_dropout)
        self.gaz_dropout = nn.Dropout(gaz_dropout)
        self.output_dropout = nn.Dropout(output_dropout)

    def forward(self, chars, bigrams, seq_len, target,
                skips_l2r_source, skips_l2r_word, lexicon_count,
                skips_r2l_source=None, skips_r2l_word=None, lexicon_count_back=None):

        batch = chars.size(0)
        max_seq_len = chars.size(1)



        embed_char = self.char_embed(chars)
        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char


        embed_nonword = self.embed_dropout(embedding)

        # skips_l2r_word = torch.reshape(skips_l2r_word,shape=[batch,-1])
        embed_word = self.word_embed(skips_l2r_word)
        embed_word = self.embed_dropout(embed_word)



        encoded_h, encoded_c = self.encoder(embed_nonword, seq_len, skips_l2r_source, embed_word, lexicon_count)

        if self.bidirectional:
            embed_word_back = self.word_embed(skips_r2l_word)
            embed_word_back = self.embed_dropout(embed_word_back)
            encoded_h_back, encoded_c_back = self.encoder_back(embed_nonword, seq_len, skips_r2l_source,
                                                               embed_word_back, lexicon_count_back)
            encoded_h = torch.cat([encoded_h, encoded_h_back], dim=-1)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}


class LSTM_SeqLabel(nn.Module):
    def __init__(self, char_embed, bigram_embed, word_embed, hidden_size, label_size, bias=True,
                 bidirectional=False, device=None, embed_dropout=0, output_dropout=0,use_bigram=True):

        if device is None:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device(device)
        super().__init__()
        self.char_embed_size = char_embed.embedding.weight.size(1)
        self.bigram_embed_size = bigram_embed.embedding.weight.size(1)
        self.word_embed_size = word_embed.embedding.weight.size(1)
        self.hidden_size = hidden_size
        self.label_size = label_size
        self.bidirectional = bidirectional
        self.use_bigram = use_bigram

        self.char_embed = char_embed
        self.bigram_embed = bigram_embed
        self.word_embed = word_embed

        if self.use_bigram:
            self.input_size = self.char_embed_size + self.bigram_embed_size
        else:
            self.input_size = self.char_embed_size

        self.encoder = LSTM(self.input_size, self.hidden_size,
                            bidirectional=self.bidirectional)

        better_init_rnn(self.encoder.lstm)


        self.output = nn.Linear(self.hidden_size * (2 if self.bidirectional else 1), self.label_size)

        self.debug = True
        self.loss_func = nn.CrossEntropyLoss()
        self.embed_dropout = nn.Dropout(embed_dropout)
        self.output_dropout = nn.Dropout(output_dropout)
        self.crf = ConditionalRandomField(label_size, True)

    def forward(self, chars, bigrams, seq_len, target):
        if self.debug:

            print_info('chars:{}'.format(chars.size()))
            print_info('bigrams:{}'.format(bigrams.size()))
            print_info('seq_len:{}'.format(seq_len.size()))
            print_info('target:{}'.format(target.size()))
        embed_char = self.char_embed(chars)

        if self.use_bigram:

            embed_bigram = self.bigram_embed(bigrams)

            embedding = torch.cat([embed_char, embed_bigram], dim=-1)
        else:

            embedding = embed_char

        embedding = self.embed_dropout(embedding)

        encoded_h, encoded_c = self.encoder(embedding, seq_len)

        encoded_h = self.output_dropout(encoded_h)

        pred = self.output(encoded_h)

        mask = seq_len_to_mask(seq_len)

        # pred = self.crf(pred)

        # batch_size, sent_len = pred.shape[0], pred.shape[1]
        # loss = self.loss_func(pred.reshape(batch_size * sent_len, -1), target.reshape(batch_size * sent_len))
        if self.debug:
            print('debug mode:finish')
            exit(1208)
        if self.training:
            loss = self.crf(pred, target, mask)
            return {'loss': loss}
        else:
            pred, path = self.crf.viterbi_decode(pred, mask)
            return {'pred': pred}


================================================
FILE: modules.py
================================================
import torch.nn as nn
import torch
from fastNLP.core.utils import seq_len_to_mask
from utils import better_init_rnn
import numpy as np


class WordLSTMCell_yangjie(nn.Module):

    """A basic LSTM cell."""

    def __init__(self, input_size, hidden_size, use_bias=True,debug=False, left2right=True):
        """
        Most parts are copied from torch.nn.LSTMCell.
        """

        super().__init__()
        self.left2right = left2right
        self.debug = debug
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias
        self.weight_ih = nn.Parameter(
            torch.FloatTensor(input_size, 3 * hidden_size))
        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, 3 * hidden_size))
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """
        nn.init.orthogonal(self.weight_ih.data)
        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 3)
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)
        # The bias is just set to zero vectors.
        if self.use_bias:
            nn.init.constant(self.bias.data, val=0)

    def forward(self, input_, hx):
        """
        Args:
            input_: A (batch, input_size) tensor containing input
                features.
            hx: A tuple (h_0, c_0), which contains the initial hidden
                and cell state, where the size of both states is
                (batch, hidden_size).
        Returns:
            h_1, c_1: Tensors containing the next hidden and cell state.
        """

        h_0, c_0 = hx



        batch_size = h_0.size(0)
        bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))
        wh_b = torch.addmm(bias_batch, h_0, self.weight_hh)
        wi = torch.mm(input_, self.weight_ih)
        f, i, g = torch.split(wh_b + wi, split_size_or_sections=self.hidden_size, dim=1)
        c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)

        return c_1

    def __repr__(self):
        s = '{name}({input_size}, {hidden_size})'
        return s.format(name=self.__class__.__name__, **self.__dict__)


class MultiInputLSTMCell_V0(nn.Module):
    def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
        super().__init__()
        self.char_input_size = char_input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias

        self.weight_ih = nn.Parameter(
            torch.FloatTensor(char_input_size, 3 * hidden_size)
        )

        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, 3 * hidden_size)
        )

        self.alpha_weight_ih = nn.Parameter(
            torch.FloatTensor(char_input_size, hidden_size)
        )

        self.alpha_weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, hidden_size)
        )

        if self.use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
            self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            self.register_parameter('alpha_bias', None)

        self.debug = debug
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """
        nn.init.orthogonal(self.weight_ih.data)
        nn.init.orthogonal(self.alpha_weight_ih.data)

        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 3)
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)

        alpha_weight_hh_data = torch.eye(self.hidden_size)
        alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
        with torch.no_grad():
            self.alpha_weight_hh.set_(alpha_weight_hh_data)

        # The bias is just set to zero vectors.
        if self.use_bias:
            nn.init.constant_(self.bias.data, val=0)
            nn.init.constant_(self.alpha_bias.data, val=0)

    def forward(self, inp, skip_c, skip_count, hx):
        '''

        :param inp: chars B * hidden
        :param skip_c: 由跳边得到的c, B * X * hidden
        :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
        :param hx:
        :return:
        '''
        max_skip_count = torch.max(skip_count).item()



        if True:
            h_0, c_0 = hx
            batch_size = h_0.size(0)

            bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

            wi = torch.matmul(inp, self.weight_ih)
            wh = torch.matmul(h_0, self.weight_hh)



            i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

            i = torch.sigmoid(i).unsqueeze(1)
            o = torch.sigmoid(o).unsqueeze(1)
            g = torch.tanh(g).unsqueeze(1)



            alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
            alpha_wi.unsqueeze_(1)

            # alpha_wi = alpha_wi.expand(1,skip_count,self.hidden_size)
            alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)

            alpha_bias_batch = self.alpha_bias.unsqueeze(0)

            alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)

            skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1])

            skip_mask = 1 - skip_mask


            skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)

            skip_mask = (skip_mask).float()*1e20

            alpha = alpha - skip_mask

            alpha = torch.exp(torch.cat([i, alpha], dim=1))



            alpha_sum = torch.sum(alpha, dim=1, keepdim=True)

            alpha = torch.div(alpha, alpha_sum)

            merge_i_c = torch.cat([g, skip_c], dim=1)

            c_1 = merge_i_c * alpha

            c_1 = c_1.sum(1, keepdim=True)
            # h_1 = o * c_1
            h_1 = o * torch.tanh(c_1)

            return h_1.squeeze(1), c_1.squeeze(1)

        else:

            h_0, c_0 = hx
            batch_size = h_0.size(0)

            bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

            wi = torch.matmul(inp, self.weight_ih)
            wh = torch.matmul(h_0, self.weight_hh)

            i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

            i = torch.sigmoid(i).unsqueeze(1)
            o = torch.sigmoid(o).unsqueeze(1)
            g = torch.tanh(g).unsqueeze(1)

            c_1 = g
            h_1 = o * c_1

            return h_1,c_1

class MultiInputLSTMCell_V1(nn.Module):
    def __init__(self, char_input_size, hidden_size, use_bias=True,debug=False):
        super().__init__()
        self.char_input_size = char_input_size
        self.hidden_size = hidden_size
        self.use_bias = use_bias

        self.weight_ih = nn.Parameter(
            torch.FloatTensor(char_input_size, 3 * hidden_size)
        )

        self.weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, 3 * hidden_size)
        )

        self.alpha_weight_ih = nn.Parameter(
            torch.FloatTensor(char_input_size, hidden_size)
        )

        self.alpha_weight_hh = nn.Parameter(
            torch.FloatTensor(hidden_size, hidden_size)
        )

        if self.use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
            self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
        else:
            self.register_parameter('bias', None)
            self.register_parameter('alpha_bias', None)

        self.debug = debug
        self.reset_parameters()

    def reset_parameters(self):
        """
        Initialize parameters following the way proposed in the paper.
        """
        nn.init.orthogonal(self.weight_ih.data)
        nn.init.orthogonal(self.alpha_weight_ih.data)

        weight_hh_data = torch.eye(self.hidden_size)
        weight_hh_data = weight_hh_data.repeat(1, 3)
        with torch.no_grad():
            self.weight_hh.set_(weight_hh_data)

        alpha_weight_hh_data = torch.eye(self.hidden_size)
        alpha_weight_hh_data = alpha_weight_hh_data.repeat(1, 1)
        with torch.no_grad():
            self.alpha_weight_hh.set_(alpha_weight_hh_data)

        # The bias is just set to zero vectors.
        if self.use_bias:
            nn.init.constant_(self.bias.data, val=0)
            nn.init.constant_(self.alpha_bias.data, val=0)

    def forward(self, inp, skip_c, skip_count, hx):
        '''

        :param inp: chars B * hidden
        :param skip_c: 由跳边得到的c, B * X * hidden
        :param skip_count: 这个batch中每个example中当前位置的跳边的数量,用于mask
        :param hx:
        :return:
        '''
        max_skip_count = torch.max(skip_count).item()



        if True:
            h_0, c_0 = hx
            batch_size = h_0.size(0)

            bias_batch = (self.bias.unsqueeze(0).expand(batch_size, *self.bias.size()))

            wi = torch.matmul(inp, self.weight_ih)
            wh = torch.matmul(h_0, self.weight_hh)


            i, o, g = torch.split(wh + wi + bias_batch, split_size_or_sections=self.hidden_size, dim=1)

            i = torch.sigmoid(i).unsqueeze(1)
            o = torch.sigmoid(o).unsqueeze(1)
            g = torch.tanh(g).unsqueeze(1)



            ##basic lstm start

            f = 1 - i
            c_1_basic = f*c_0.unsqueeze(1) + i*g
            c_1_basic = c_1_basic.squeeze(1)





            alpha_wi = torch.matmul(inp, self.alpha_weight_ih)
            alpha_wi.unsqueeze_(1)


            alpha_wh = torch.matmul(skip_c, self.alpha_weight_hh)

            alpha_bias_batch = self.alpha_bias.unsqueeze(0)

            alpha = torch.sigmoid(alpha_wi + alpha_wh + alpha_bias_batch)

            skip_mask = seq_len_to_mask(skip_count,max_len=skip_c.size()[1]).float()

            skip_mask = 1 - skip_mask


            skip_mask = skip_mask.unsqueeze(-1).expand(*skip_mask.size(), self.hidden_size)

            skip_mask = (skip_mask).float()*1e20

            alpha = alpha - skip_mask

            alpha = torch.exp(torch.cat([i, alpha], dim=1))



            alpha_sum = torch.sum(alpha, dim=1, keepdim=True)

            alpha = torch.div(alpha, alpha_sum)

            merge_i_c = torch.cat([g, skip_c], dim=1)

            c_1 = merge_i_c * alpha

            c_1 = c_1.sum(1, keepdim=True)
            # h_1 = o * c_1
            c_1 = c_1.squeeze(1)
            count_select = (skip_count != 0).float().unsqueeze(-1)




            c_1 = c_1*count_select + c_1_basic*(1-count_select)


            o = o.squeeze(1)
            h_1 = o * torch.tanh(c_1)

            return h_1, c_1

class LatticeLSTMLayer_sup_back_V0(nn.Module):
    def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
                 bias=True,device=None,debug=False,skip_before_head=False):
        super().__init__()

        self.skip_before_head = skip_before_head

        self.hidden_size = hidden_size

        self.char_cell = MultiInputLSTMCell_V0(char_input_size, hidden_size, bias,debug)

        self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)

        self.word_input_size = word_input_size
        self.left2right = left2right
        self.bias = bias
        self.device = device
        self.debug = debug

    def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
        '''

        :param inp: batch * seq_len * embedding, chars
        :param seq_len: batch, length of chars
        :param skip_sources: batch * seq_len * X, 跳边的起点
        :param skip_words: batch * seq_len * X * embedding, 跳边的词
        :param lexicon_count: batch * seq_len, count of lexicon per example per position
        :param init_state: the hx of rnn
        :return:
        '''


        if self.left2right:

            max_seq_len = max(seq_len)
            batch_size = inp.size(0)
            c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
            h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

            for i in range(max_seq_len):
                max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
                h_0, c_0 = h_[:, i, :], c_[:, i, :]

                skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

                skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
                skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


                index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
                index_1 = skip_source_flat

                if not self.skip_before_head:
                    c_x = c_[[index_0, index_1+1]]
                    h_x = h_[[index_0, index_1+1]]
                else:
                    c_x = c_[[index_0,index_1]]
                    h_x = h_[[index_0,index_1]]

                c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
                h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




                c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

                c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

                h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


                h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
                c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)

            return h_[:,1:],c_[:,1:]
        else:
            mask_for_seq_len = seq_len_to_mask(seq_len)

            max_seq_len = max(seq_len)
            batch_size = inp.size(0)
            c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
            h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

            for i in reversed(range(max_seq_len)):
                max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)



                h_0, c_0 = h_[:, 0, :], c_[:, 0, :]

                skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

                skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
                skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


                index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
                index_1 = skip_source_flat-i

                if not self.skip_before_head:
                    c_x = c_[[index_0, index_1-1]]
                    h_x = h_[[index_0, index_1-1]]
                else:
                    c_x = c_[[index_0,index_1]]
                    h_x = h_[[index_0,index_1]]

                c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
                h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




                c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

                c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

                h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


                h_1_mask = h_1.masked_fill(1-mask_for_seq_len[:,i].unsqueeze(-1),0)
                c_1_mask = c_1.masked_fill(1 - mask_for_seq_len[:, i].unsqueeze(-1), 0)


                h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
                c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)

            return h_[:,:-1],c_[:,:-1]

class LatticeLSTMLayer_sup_back_V1(nn.Module):
    # V1与V0的不同在于,V1在当前位置完全无lexicon匹配时,会采用普通的lstm计算公式,
    # 普通的lstm计算公式与杨杰实现的lattice lstm在lexicon数量为0时不同
    def __init__(self, char_input_size, word_input_size, hidden_size, left2right,
                 bias=True,device=None,debug=False,skip_before_head=False):
        super().__init__()

        self.debug = debug

        self.skip_before_head = skip_before_head

        self.hidden_size = hidden_size

        self.char_cell = MultiInputLSTMCell_V1(char_input_size, hidden_size, bias,debug)

        self.word_cell = WordLSTMCell_yangjie(word_input_size,hidden_size,bias,debug=self.debug)

        self.word_input_size = word_input_size
        self.left2right = left2right
        self.bias = bias
        self.device = device

    def forward(self, inp, seq_len, skip_sources, skip_words, skip_count, init_state=None):
        '''

        :param inp: batch * seq_len * embedding, chars
        :param seq_len: batch, length of chars
        :param skip_sources: batch * seq_len * X, 跳边的起点
        :param skip_words: batch * seq_len * X * embedding_size, 跳边的词
        :param lexicon_count: batch * seq_len,
        lexicon_count[i,j]为第i个例子以第j个位子为结尾匹配到的词的数量
        :param init_state: the hx of rnn
        :return:
        '''


        if self.left2right:

            max_seq_len = max(seq_len)
            batch_size = inp.size(0)
            c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
            h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

            for i in range(max_seq_len):
                max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)
                h_0, c_0 = h_[:, i, :], c_[:, i, :]

                #为了使rnn能够计算B*lexicon_count*embedding_size的张量,需要将其reshape成二维张量
                #为了匹配pytorch的[]取址方式,需要将reshape成二维张量

                skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

                skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
                skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


                index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
                index_1 = skip_source_flat


                if not self.skip_before_head:
                    c_x = c_[[index_0, index_1+1]]
                    h_x = h_[[index_0, index_1+1]]
                else:
                    c_x = c_[[index_0,index_1]]
                    h_x = h_[[index_0,index_1]]

                c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
                h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)



                c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))

                c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

                h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


                h_ = torch.cat([h_,h_1.unsqueeze(1)],dim=1)
                c_ = torch.cat([c_, c_1.unsqueeze(1)], dim=1)

            return h_[:,1:],c_[:,1:]
        else:
            mask_for_seq_len = seq_len_to_mask(seq_len)

            max_seq_len = max(seq_len)
            batch_size = inp.size(0)
            c_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)
            h_ = torch.zeros(size=[batch_size, 1, self.hidden_size], requires_grad=True).to(self.device)

            for i in reversed(range(max_seq_len)):
                max_lexicon_count = max(torch.max(skip_count[:, i]).item(), 1)


                h_0, c_0 = h_[:, 0, :], c_[:, 0, :]

                skip_word_flat = skip_words[:, i, :max_lexicon_count].contiguous()

                skip_word_flat = skip_word_flat.view(batch_size*max_lexicon_count,self.word_input_size)
                skip_source_flat = skip_sources[:, i, :max_lexicon_count].contiguous().view(batch_size, max_lexicon_count)


                index_0 = torch.tensor(range(batch_size)).unsqueeze(1).expand(batch_size,max_lexicon_count)
                index_1 = skip_source_flat-i

                if not self.skip_before_head:
                    c_x = c_[[index_0, index_1-1]]
                    h_x = h_[[index_0, index_1-1]]
                else:
                    c_x = c_[[index_0,index_1]]
                    h_x = h_[[index_0,index_1]]

                c_x_flat = c_x.view(batch_size*max_lexicon_count,self.hidden_size)
                h_x_flat = h_x.view(batch_size*max_lexicon_count,self.hidden_size)




                c_1_flat = self.word_cell(skip_word_flat,(h_x_flat,c_x_flat))



                c_1_skip = c_1_flat.view(batch_size,max_lexicon_count,self.hidden_size)

                h_1,c_1 = self.char_cell(inp[:,i,:],c_1_skip,skip_count[:,i],(h_0,c_0))


                h_1_mask = h_1.masked_fill(~ mask_for_seq_len[:,i].unsqueeze(-1),0)
                c_1_mask = c_1.masked_fill(~ mask_for_seq_len[:, i].unsqueeze(-1), 0)


                h_ = torch.cat([h_1_mask.unsqueeze(1),h_],dim=1)
                c_ = torch.cat([c_1_mask.unsqueeze(1),c_], dim=1)



            return h_[:,:-1],c_[:,:-1]






================================================
FILE: pathes.py
================================================


glove_100_path = 'en-glove-6b-100d'
glove_50_path = 'en-glove-6b-50d'
glove_200_path = ''
glove_300_path = 'en-glove-840b-300'
fasttext_path = 'en-fasttext' #300
tencent_chinese_word_path = 'cn' # tencent 200
fasttext_cn_path = 'cn-fasttext' # 300
yangjie_rich_pretrain_unigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.uni.ite50.vec'
yangjie_rich_pretrain_bigram_path = '/remote-home/xnli/data/pretrain/chinese/gigaword_chn.all.a2b.bi.ite50.vec'
yangjie_rich_pretrain_word_path = '/remote-home/xnli/data/pretrain/chinese/ctb.50d.vec'


conll_2003_path = '/remote-home/xnli/data/corpus/multi_task/conll_2013/data_mine.pkl'
conllized_ontonote_path = '/remote-home/txsun/data/OntoNotes-5.0-NER-master/v12/english'
conllized_ontonote_pkl_path = '/remote-home/txsun/data/ontonotes5.pkl'
sst2_path = '/remote-home/xnli/data/corpus/text_classification/SST-2/'
# weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/ner_weibo'
ontonote4ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/OntoNote4NER'
msra_ner_cn_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/MSRANER'
resume_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/ResumeNER'
weibo_ner_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER'
weibo_ner_old_path = '/remote-home/xnli/data/corpus/sequence_labelling/chinese_ner/WeiboNER_old'

================================================
FILE: small.py
================================================
from utils_ import get_skip_path_trivial, Trie, get_skip_path
from load_data import load_yangjie_rich_pretrain_word_list, load_ontonotes4ner, equip_chinese_ner_with_skip
from pathes import *
from functools import partial
from fastNLP import cache_results
from fastNLP.embeddings.static_embedding import StaticEmbedding
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastNLP.core.metrics import _bmes_tag_to_spans,_bmeso_tag_to_spans
from load_data import load_resume_ner


# embed = StaticEmbedding(None,embedding_dim=2)
# datasets,vocabs,embeddings = load_ontonotes4ner(ontonote4ner_cn_path,yangjie_rich_pretrain_unigram_path,yangjie_rich_pretrain_bigram_path,
#                                                 _refresh=True,index_token=False)
#
# w_list = load_yangjie_rich_pretrain_word_list(yangjie_rich_pretrain_word_path,
#                                               _refresh=False)
#
# datasets,vocabs,embeddings = equip_chinese_ner_with_skip(datasets,vocabs,embeddings,w_list,yangjie_rich_pretrain_word_path,
#                                                          _refresh=True)
#

def reverse_style(input_string):
    target_position = input_string.index('[')
    input_len = len(input_string)
    output_string = input_string[target_position:input_len] + input_string[0:target_position]
    # print('in:{}.out:{}'.format(input_string, output_string))
    return output_string





def get_yangjie_bmeso(label_list):
    def get_ner_BMESO_yj(label_list):
        # list_len = len(word_list)
        # assert(list_len == len(label_list)), "word list size unmatch with label list"
        list_len = len(label_list)
        begin_label = 'b-'
        end_label = 'e-'
        single_label = 's-'
        whole_tag = ''
        index_tag = ''
        tag_list = []
        stand_matrix = []
        for i in range(0, list_len):
            # wordlabel = word_list[i]
            current_label = label_list[i].lower()
            if begin_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i - 1))
                whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i)
                index_tag = current_label.replace(begin_label, "", 1)

            elif single_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i - 1))
                whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i)
                tag_list.append(whole_tag)
                whole_tag = ""
                index_tag = ""
            elif end_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i))
                whole_tag = ''
                index_tag = ''
            else:
                continue
        if (whole_tag != '') & (index_tag != ''):
            tag_list.append(whole_tag)
        tag_list_len = len(tag_list)

        for i in range(0, tag_list_len):
            if len(tag_list[i]) > 0:
                tag_list[i] = tag_list[i] + ']'
                insert_list = reverse_style(tag_list[i])
                stand_matrix.append(insert_list)
        # print stand_matrix
        return stand_matrix

    def transform_YJ_to_fastNLP(span):
        span = span[1:]
        span_split = span.split(']')
        # print('span_list:{}'.format(span_split))
        span_type = span_split[1]
        # print('span_split[0].split(','):{}'.format(span_split[0].split(',')))
        if ',' in span_split[0]:
            b, e = span_split[0].split(',')
        else:
            b = span_split[0]
            e = b

        b = int(b)
        e = int(e)

        e += 1

        return (span_type, (b, e))
    yj_form = get_ner_BMESO_yj(label_list)
    # print('label_list:{}'.format(label_list))
    # print('yj_from:{}'.format(yj_form))
    fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form))
    return fastNLP_form


# tag_list = ['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']
# span_list = get_ner_BMES(tag_list)
# print(span_list)
# yangjie_label_list = ['B-NAME', 'E-NAME', 'O', 'B-CONT', 'M-CONT', 'E-CONT', 'B-RACE', 'E-RACE', 'B-TITLE', 'M-TITLE', 'E-TITLE', 'B-EDU', 'M-EDU', 'E-EDU', 'B-ORG', 'M-ORG', 'E-ORG', 'M-NAME', 'B-PRO', 'M-PRO', 'E-PRO', 'S-RACE', 'S-NAME', 'B-LOC', 'M-LOC', 'E-LOC', 'M-RACE', 'S-ORG']
# my_label_list = ['O', 'M-ORG', 'M-TITLE', 'B-TITLE', 'E-TITLE', 'B-ORG', 'E-ORG', 'M-EDU', 'B-NAME', 'E-NAME', 'B-EDU', 'E-EDU', 'M-NAME', 'M-PRO', 'M-CONT', 'B-PRO', 'E-PRO', 'B-CONT', 'E-CONT', 'M-LOC', 'B-RACE', 'E-RACE', 'S-NAME', 'B-LOC', 'E-LOC', 'M-RACE', 'S-RACE', 'S-ORG']
# yangjie_label = set(yangjie_label_list)
# my_label = set(my_label_list)

a = torch.tensor([0,2,0,3])
b = (a==0)
print(b)
print(b.float())
from fastNLP import RandomSampler

# f = open('/remote-home/xnli/weight_debug/lattice_yangjie.pkl','rb')
# weight_dict = torch.load(f)
# print(weight_dict.keys())
# for k,v in weight_dict.items():
#     print("{}:{}".format(k,v.size()))

================================================
FILE: utils.py
================================================
import torch.nn.functional as F
import torch
import random
import numpy as np
from fastNLP import Const
from fastNLP import CrossEntropyLoss
from fastNLP import AccuracyMetric
from fastNLP import Tester
import os
from fastNLP import logger
def should_mask(name, t=''):
    if 'bias' in name:
        return False
    if 'embedding' in name:
        splited = name.split('.')
        if splited[-1]!='weight':
            return False
        if 'embedding' in splited[-2]:
            return False
    if 'c0' in name:
        return False
    if 'h0' in name:
        return False

    if 'output' in name and t not in name:
        return False

    return True
def get_init_mask(model):
    init_masks = {}
    for name, param in model.named_parameters():
        if should_mask(name):
            init_masks[name+'.mask'] = torch.ones_like(param)
            # logger.info(init_masks[name+'.mask'].requires_grad)

    return init_masks

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed+100)
    torch.manual_seed(seed+200)
    torch.cuda.manual_seed_all(seed+300)

def get_parameters_size(model):
    result = {}
    for name,p in model.state_dict().items():
        result[name] = p.size()

    return result

def prune_by_proportion_model(model,proportion,task):
    # print('this time prune to ',proportion*100,'%')
    for name, p in model.named_parameters():
        # print(name)
        if not should_mask(name,task):
            continue

        tensor = p.data.cpu().numpy()
        index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
        # print(name,'alive count',len(index[0]))
        alive = tensor[index]
        # print('p and mask size:',p.size(),print(model.mask[task][name+'.mask'].size()))
        percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)
        # tensor = p
        # index = torch.nonzero(model.mask[task][name+'.mask'])
        # # print('nonzero len',index)
        # alive = tensor[index]
        # print('alive size:',alive.shape)
        # prune_by_proportion_model()

        # percentile_value = torch.topk(abs(alive), int((1-proportion)*len(index[0]))).values
        # print('the',(1-proportion)*len(index[0]),'th big')
        # print('threshold:',percentile_value)

        prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)
        # for

def prune_by_proportion_model_global(model,proportion,task):
    # print('this time prune to ',proportion*100,'%')
    alive = None
    for name, p in model.named_parameters():
        # print(name)
        if not should_mask(name,task):
            continue

        tensor = p.data.cpu().numpy()
        index = np.nonzero(model.mask[task][name+'.mask'].data.cpu().numpy())
        # print(name,'alive count',len(index[0]))
        if alive is None:
            alive = tensor[index]
        else:
            alive = np.concatenate([alive,tensor[index]],axis=0)

    percentile_value = np.percentile(abs(alive), (1 - proportion) * 100)

    for name, p in model.named_parameters():
        if should_mask(name,task):
            prune_by_threshold_parameter(p, model.mask[task][name+'.mask'],percentile_value)


def prune_by_threshold_parameter(p, mask, threshold):
    p_abs = torch.abs(p)

    new_mask = (p_abs > threshold).float()
    # print(mask)
    mask[:]*=new_mask


def one_time_train_and_prune_single_task(trainer,PRUNE_PER,
                                         optimizer_init_state_dict=None,
                                         model_init_state_dict=None,
                                         is_global=None,
                                         ):


    from fastNLP import Trainer


    trainer.optimizer.load_state_dict(optimizer_init_state_dict)
    trainer.model.load_state_dict(model_init_state_dict)
    # print('metrics:',metrics.__dict__)
    # print('loss:',loss.__dict__)
    # print('trainer input:',task.train_set.get_input_name())
    # trainer = Trainer(model=model, train_data=task.train_set, dev_data=task.dev_set, loss=loss, metrics=metrics,
    #                   optimizer=optimizer, n_epochs=EPOCH, batch_size=BATCH, device=device,callbacks=callbacks)


    trainer.train(load_best_model=True)
    # tester = Tester(task.train_set, model, metrics, BATCH, device=device, verbose=1,use_tqdm=False)
    # print('FOR DEBUG: test train_set:',tester.test())
    # print('**'*20)
    # if task.test_set:
    #     tester = Tester(task.test_set, model, metrics, BATCH, device=device, verbose=1)
    #     tester.test()
    if is_global:

        prune_by_proportion_model_global(trainer.model, PRUNE_PER, trainer.model.now_task)

    else:
        prune_by_proportion_model(trainer.model, PRUNE_PER, trainer.model.now_task)



# def iterative_train_and_prune_single_task(get_trainer,ITER,PRUNE,is_global=False,save_path=None):
def iterative_train_and_prune_single_task(get_trainer,args,model,train_set,dev_set,test_set,device,save_path=None):

    '''

    :param trainer:
    :param ITER:
    :param PRUNE:
    :param is_global:
    :param save_path: should be a dictionary which will be filled with mask and state dict
    :return:
    '''



    from fastNLP import Trainer
    import torch
    import math
    import copy
    PRUNE = args.prune
    ITER = args.iter
    trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
    optimizer_init_state_dict = copy.deepcopy(trainer.optimizer.state_dict())
    model_init_state_dict = copy.deepcopy(trainer.model.state_dict())
    if save_path is not None:
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # if not os.path.exists(os.path.join(save_path, 'model_init.pkl')):
        #     f = open(os.path.join(save_path, 'model_init.pkl'), 'wb')
        #     torch.save(trainer.model.state_dict(),f)


    mask_count = 0
    model = trainer.model
    task = trainer.model.now_task
    for name, p in model.mask[task].items():
        mask_count += torch.sum(p).item()
    init_mask_count = mask_count
    logger.info('init mask count:{}'.format(mask_count))
    # logger.info('{}th traning mask count: {} / {} = {}%'.format(i, mask_count, init_mask_count,
    #                                                             mask_count / init_mask_count * 100))

    prune_per_iter = math.pow(PRUNE, 1 / ITER)


    for i in range(ITER):
        trainer = get_trainer(args,model,train_set,dev_set,test_set,device)
        one_time_train_and_prune_single_task(trainer,prune_per_iter,optimizer_init_state_dict,model_init_state_dict)
        if save_path is not None:
            f = open(os.path.join(save_path,task+'_mask_'+str(i)+'.pkl'),'wb')
            torch.save(model.mask[task],f)

        mask_count = 0
        for name, p in model.mask[task].items():
            mask_count += torch.sum(p).item()
        logger.info('{}th traning mask count: {} / {} = {}%'.format(i,mask_count,init_mask_count,mask_count/init_mask_count*100))


def get_appropriate_cuda(task_scale='s'):
    if task_scale not in {'s','m','l'}:
        logger.info('task scale wrong!')
        exit(2)
    import pynvml
    pynvml.nvmlInit()
    total_cuda_num = pynvml.nvmlDeviceGetCount()
    for i in range(total_cuda_num):
        logger.info(i)
        handle = pynvml.nvmlDeviceGetHandleByIndex(i)  # 这里的0是GPU id
        memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
        logger.info(i, 'mem:', memInfo.used / memInfo.total, 'util:',utilizationInfo.gpu)
        if memInfo.used / memInfo.total < 0.15 and utilizationInfo.gpu <0.2:
            logger.info(i,memInfo.used / memInfo.total)
            return 'cuda:'+str(i)

    if task_scale=='s':
        max_memory=2000
    elif task_scale=='m':
        max_memory=6000
    else:
        max_memory = 9000

    max_id = -1
    for i in range(total_cuda_num):
        handle = pynvml.nvmlDeviceGetHandleByIndex(0)  # 这里的0是GPU id
        memInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
        utilizationInfo = pynvml.nvmlDeviceGetUtilizationRates(handle)
        if max_memory < memInfo.free:
            max_memory = memInfo.free
            max_id = i

    if id == -1:
        logger.info('no appropriate gpu, wait!')
        exit(2)

    return 'cuda:'+str(max_id)

        # if memInfo.used / memInfo.total < 0.5:
        #     return

def print_mask(mask_dict):
    def seq_mul(*X):
        res = 1
        for x in X:
            res*=x
        return res

    for name,p in mask_dict.items():
        total_size = seq_mul(*p.size())
        unmasked_size = len(np.nonzero(p))

        print(name,':',unmasked_size,'/',total_size,'=',unmasked_size/total_size*100,'%')


    print()


def check_words_same(dataset_1,dataset_2,field_1,field_2):
    if len(dataset_1[field_1]) != len(dataset_2[field_2]):
        logger.info('CHECK: example num not same!')
        return False

    for i, words in enumerate(dataset_1[field_1]):
        if len(dataset_1[field_1][i]) != len(dataset_2[field_2][i]):
            logger.info('CHECK {} th example length not same'.format(i))
            logger.info('1:{}'.format(dataset_1[field_1][i]))
            logger.info('2:'.format(dataset_2[field_2][i]))
            return False

        # for j,w in enumerate(words):
        #     if dataset_1[field_1][i][j] != dataset_2[field_2][i][j]:
        #         print('CHECK', i, 'th example has words different!')
        #         print('1:',dataset_1[field_1][i])
        #         print('2:',dataset_2[field_2][i])
        #         return False

    logger.info('CHECK: totally same!')

    return True

def get_now_time():
    import time
    from datetime import datetime, timezone, timedelta
    dt = datetime.utcnow()
    # print(dt)
    tzutc_8 = timezone(timedelta(hours=8))
    local_dt = dt.astimezone(tzutc_8)
    result = ("_{}_{}_{}__{}_{}_{}".format(local_dt.year, local_dt.month, local_dt.day, local_dt.hour, local_dt.minute,
                                      local_dt.second))

    return result


def get_bigrams(words):
    result = []
    for i,w in enumerate(words):
        if i!=len(words)-1:
            result.append(words[i]+words[i+1])
        else:
            result.append(words[i]+'<end>')

    return result

def print_info(*inp,islog=False,sep=' '):
    from fastNLP import logger
    if islog:
        print(*inp,sep=sep)
    else:
        inp = sep.join(map(str,inp))
        logger.info(inp)

def better_init_rnn(rnn,coupled=False):
    import torch.nn as nn
    if coupled:
        repeat_size = 3
    else:
        repeat_size = 4
    # print(list(rnn.named_parameters()))
    if hasattr(rnn,'num_layers'):
        for i in range(rnn.num_layers):
            nn.init.orthogonal(getattr(rnn,'weight_ih_l'+str(i)).data)
            weight_hh_data = torch.eye(rnn.hidden_size)
            weight_hh_data = weight_hh_data.repeat(1, repeat_size)
            with torch.no_grad():
                getattr(rnn,'weight_hh_l'+str(i)).set_(weight_hh_data)
            nn.init.constant(getattr(rnn,'bias_ih_l'+str(i)).data, val=0)
            nn.init.constant(getattr(rnn,'bias_hh_l'+str(i)).data, val=0)

        if rnn.bidirectional:
            for i in range(rnn.num_layers):
                nn.init.orthogonal(getattr(rnn, 'weight_ih_l' + str(i)+'_reverse').data)
                weight_hh_data = torch.eye(rnn.hidden_size)
                weight_hh_data = weight_hh_data.repeat(1, repeat_size)
                with torch.no_grad():
                    getattr(rnn, 'weight_hh_l' + str(i)+'_reverse').set_(weight_hh_data)
                nn.init.constant(getattr(rnn, 'bias_ih_l' + str(i)+'_reverse').data, val=0)
                nn.init.constant(getattr(rnn, 'bias_hh_l' + str(i)+'_reverse').data, val=0)


    else:
        nn.init.orthogonal(rnn.weight_ih.data)
        weight_hh_data = torch.eye(rnn.hidden_size)
        weight_hh_data = weight_hh_data.repeat(repeat_size,1)
        with torch.no_grad():
            rnn.weight_hh.set_(weight_hh_data)
        # The bias is just set to zero vectors.
        print('rnn param size:{},{}'.format(rnn.weight_hh.size(),type(rnn)))
        if rnn.bias:
            nn.init.constant(rnn.bias_ih.data, val=0)
            nn.init.constant(rnn.bias_hh.data, val=0)

    # print(list(rnn.named_parameters()))








================================================
FILE: utils_.py
================================================
import collections
from fastNLP import cache_results
def get_skip_path(chars,w_trie):
    sentence = ''.join(chars)
    result = w_trie.get_lexicon(sentence)

    return result

# @cache_results(_cache_fp='cache/get_skip_path_trivial',_refresh=True)
def get_skip_path_trivial(chars,w_list):
    chars = ''.join(chars)
    w_set = set(w_list)
    result = []
    # for i in range(len(chars)):
    #     result.append([])
    for i in range(len(chars)-1):
        for j in range(i+2,len(chars)+1):
            if chars[i:j] in w_set:
                result.append([i,j-1,chars[i:j]])

    return result


class TrieNode:
    def __init__(self):
        self.children = collections.defaultdict(TrieNode)
        self.is_w = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self,w):

        current = self.root
        for c in w:
            current = current.children[c]

        current.is_w = True

    def search(self,w):
        '''

        :param w:
        :return:
        -1:not w route
        0:subroute but not word
        1:subroute and word
        '''
        current = self.root

        for c in w:
            current = current.children.get(c)

            if current is None:
                return -1

        if current.is_w:
            return 1
        else:
            return 0

    def get_lexicon(self,sentence):
        result = []
        for i in range(len(sentence)):
            current = self.root
            for j in range(i, len(sentence)):
                current = current.children.get(sentence[j])
                if current is None:
                    break

                if current.is_w:
                    result.append([i,j,sentence[i:j+1]])

        return result

from fastNLP.core.field import Padder
import numpy as np
import torch
from collections import defaultdict
class LatticeLexiconPadder(Padder):

    def __init__(self, pad_val=0, pad_val_dynamic=False,dynamic_offset=0, **kwargs):
        '''

        :param pad_val:
        :param pad_val_dynamic: if True, pad_val is the seq_len
        :param kwargs:
        '''
        self.pad_val = pad_val
        self.pad_val_dynamic = pad_val_dynamic
        self.dynamic_offset = dynamic_offset

    def __call__(self, contents, field_name, field_ele_dtype, dim: int):
        # 与autoPadder中 dim=2 的情况一样
        max_len = max(map(len, contents))

        max_len = max(max_len,1)#avoid 0 size dim which causes cuda wrong

        max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
                            content_i in contents])

        max_word_len = max(max_word_len,1)
        if self.pad_val_dynamic:
            # print('pad_val_dynamic:{}'.format(max_len-1))

            array = np.full((len(contents), max_len, max_word_len), max_len-1+self.dynamic_offset,
                            dtype=field_ele_dtype)

        else:
            array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype)
        for i, content_i in enumerate(contents):
            for j, content_ii in enumerate(content_i):
                array[i, j, :len(content_ii)] = content_ii
        array = torch.tensor(array)

        return array

from fastNLP.core.metrics import MetricBase

def get_yangjie_bmeso(label_list,ignore_labels=None):
    def get_ner_BMESO_yj(label_list):
        def reverse_style(input_string):
            target_position = input_string.index('[')
            input_len = len(input_string)
            output_string = input_string[target_position:input_len] + input_string[0:target_position]
            # print('in:{}.out:{}'.format(input_string, output_string))
            return output_string

        # list_len = len(word_list)
        # assert(list_len == len(label_list)), "word list size unmatch with label list"
        list_len = len(label_list)
        begin_label = 'b-'
        end_label = 'e-'
        single_label = 's-'
        whole_tag = ''
        index_tag = ''
        tag_list = []
        stand_matrix = []
        for i in range(0, list_len):
            # wordlabel = word_list[i]
            current_label = label_list[i].lower()
            if begin_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i - 1))
                whole_tag = current_label.replace(begin_label, "", 1) + '[' + str(i)
                index_tag = current_label.replace(begin_label, "", 1)

            elif single_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i - 1))
                whole_tag = current_label.replace(single_label, "", 1) + '[' + str(i)
                tag_list.append(whole_tag)
                whole_tag = ""
                index_tag = ""
            elif end_label in current_label:
                if index_tag != '':
                    tag_list.append(whole_tag + ',' + str(i))
                whole_tag = ''
                index_tag = ''
            else:
                continue
        if (whole_tag != '') & (index_tag != ''):
            tag_list.append(whole_tag)
        tag_list_len = len(tag_list)

        for i in range(0, tag_list_len):
            if len(tag_list[i]) > 0:
                tag_list[i] = tag_list[i] + ']'
                insert_list = reverse_style(tag_list[i])
                stand_matrix.append(insert_list)
        # print stand_matrix
        return stand_matrix

    def transform_YJ_to_fastNLP(span):
        span = span[1:]
        span_split = span.split(']')
        # print('span_list:{}'.format(span_split))
        span_type = span_split[1]
        # print('span_split[0].split(','):{}'.format(span_split[0].split(',')))
        if ',' in span_split[0]:
            b, e = span_split[0].split(',')
        else:
            b = span_split[0]
            e = b

        b = int(b)
        e = int(e)

        e += 1

        return (span_type, (b, e))
    yj_form = get_ner_BMESO_yj(label_list)
    # print('label_list:{}'.format(label_list))
    # print('yj_from:{}'.format(yj_form))
    fastNLP_form = list(map(transform_YJ_to_fastNLP,yj_form))
    return fastNLP_form
class SpanFPreRecMetric_YJ(MetricBase):
    r"""
    别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`

    在序列标注问题中,以span的方式计算F, pre, rec.
    比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例)
    ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
    最后得到的metric结果为::

        {
            'f': xxx, # 这里使用f考虑以后可以计算f_beta值
            'pre': xxx,
            'rec':xxx
        }

    若only_gross=False, 即还会返回各个label的metric统计值::

        {
            'f': xxx,
            'pre': xxx,
            'rec':xxx,
            'f-label': xxx,
            'pre-label': xxx,
            'rec-label':xxx,
            ...
        }

    :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
        在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
    :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
    :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
    :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
    :param str encoding_type: 目前支持bio, bmes, bmeso, bioes
    :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
        个label
    :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
        label的f1, pre, rec
    :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` :
        分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
    :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` .
        常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
    """
    def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
                 only_gross=True, f_type='micro', beta=1):
        from fastNLP.core import Vocabulary
        from fastNLP.core.metrics import _bmes_tag_to_spans,_bio_tag_to_spans,\
            _bioes_tag_to_spans,_bmeso_tag_to_spans
        from collections import defaultdict

        encoding_type = encoding_type.lower()

        if not isinstance(tag_vocab, Vocabulary):
            raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
        if f_type not in ('micro', 'macro'):
            raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))

        self.encoding_type = encoding_type
        # print('encoding_type:{}'self.encoding_type)
        if self.encoding_type == 'bmes':
            self.tag_to_span_func = _bmes_tag_to_spans
        elif self.encoding_type == 'bio':
            self.tag_to_span_func = _bio_tag_to_spans
        elif self.encoding_type == 'bmeso':
            self.tag_to_span_func = _bmeso_tag_to_spans
        elif self.encoding_type == 'bioes':
            self.tag_to_span_func = _bioes_tag_to_spans
        elif self.encoding_type == 'bmesoyj':
            self.tag_to_span_func = get_yangjie_bmeso
            # self.tag_to_span_func =
        else:
            raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")

        self.ignore_labels = ignore_labels
        self.f_type = f_type
        self.beta = beta
        self.beta_square = self.beta ** 2
        self.only_gross = only_gross

        super().__init__()
        self._init_param_map(pred=pred, target=target, seq_len=seq_len)

        self.tag_vocab = tag_vocab

        self._true_positives = defaultdict(int)
        self._false_positives = defaultdict(int)
        self._false_negatives = defaultdict(int)

    def evaluate(self, pred, target, seq_len):
        from fastNLP.core.utils import _get_func_signature
        """evaluate函数将针对一个批次的预测结果做评价指标的累计

        :param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果
        :param target: [batch, seq_len], 真实值
        :param seq_len: [batch] 文本长度标记
        :return:
        """
        if not isinstance(pred, torch.Tensor):
            raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(pred)}.")
        if not isinstance(target, torch.Tensor):
            raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(target)}.")

        if not isinstance(seq_len, torch.Tensor):
            raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
                            f"got {type(seq_len)}.")

        if pred.size() == target.size() and len(target.size()) == 2:
            pass
        elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
            num_classes = pred.size(-1)
            pred = pred.argmax(dim=-1)
            if (target >= num_classes).any():
                raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
                                 "id >= {}, the number of classes.".format(num_classes))
        else:
            raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
                               f"size:{pred.size()}, target should have size: {pred.size()} or "
                               f"{pred.size()[:-1]}, got {target.size()}.")

        batch_size = pred.size(0)
        pred = pred.tolist()
        target = target.tolist()
        for i in range(batch_size):
            pred_tags = pred[i][:int(seq_len[i])]
            gold_tags = target[i][:int(seq_len[i])]

            pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
            gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]

            pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
            gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)

            for span in pred_spans:
                if span in gold_spans:
                    self._true_positives[span[0]] += 1
                    gold_spans.remove(span)
                else:
                    self._false_positives[span[0]] += 1
            for span in gold_spans:
                self._false_negatives[span[0]] += 1

    def get_metric(self, reset=True):
        """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
        evaluate_result = {}
        if not self.only_gross or self.f_type == 'macro':
            tags = set(self._false_negatives.keys())
            tags.update(set(self._false_positives.keys()))
            tags.update(set(self._true_positives.keys()))
            f_sum = 0
            pre_sum = 0
            rec_sum = 0
            for tag in tags:
                tp = self._true_positives[tag]
                fn = self._false_negatives[tag]
                fp = self._false_positives[tag]
                f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
                f_sum += f
                pre_sum += pre
                rec_sum += rec
                if not self.only_gross and tag != '':  # tag!=''防止无tag的情况
                    f_key = 'f-{}'.format(tag)
                    pre_key = 'pre-{}'.format(tag)
                    rec_key = 'rec-{}'.format(tag)
                    evaluate_result[f_key] = f
                    evaluate_result[pre_key] = pre
                    evaluate_result[rec_key] = rec

            if self.f_type == 'macro':
                evaluate_result['f'] = f_sum / len(tags)
                evaluate_result['pre'] = pre_sum / len(tags)
                evaluate_result['rec'] = rec_sum / len(tags)

        if self.f_type == 'micro':
            f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
                                                  sum(self._false_negatives.values()),
                                                  sum(self._false_positives.values()))
            evaluate_result['f'] = f
            evaluate_result['pre'] = pre
            evaluate_result['rec'] = rec

        if reset:
            self._true_positives = defaultdict(int)
            self._false_positives = defaultdict(int)
            self._false_negatives = defaultdict(int)

        for key, value in evaluate_result.items():
            evaluate_result[key] = round(value, 6)

        return evaluate_result

    def _compute_f_pre_rec(self, tp, fn, fp):
        """

        :param tp: int, true positive
        :param fn: int, false negative
        :param fp: int, false positive
        :return: (f, pre, rec)
        """
        pre = tp / (fp + tp + 1e-13)
        rec = tp / (fn + tp + 1e-13)
        f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)

        return f, pre, rec




gitextract_lw5g_jo_/

├── README.md
├── load_data.py
├── main.py
├── main_without_fitlog.py
├── models.py
├── modules.py
├── pathes.py
├── small.py
├── utils.py
└── utils_.py
Condensed preview — 10 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (122K chars).
[
  {
    "path": "README.md",
    "chars": 1928,
    "preview": "[中文](#支持批并行的LatticeLSTM)\n\n[English](#Batch-Parallel-LatticeLSTM)\n# 支持批并行的LatticeLSTM\n+ 原论文:https://arxiv.org/abs/1805.02..."
  },
  {
    "path": "load_data.py",
    "chars": 30812,
    "preview": "from fastNLP.io import CSVLoader\nfrom fastNLP import Vocabulary\nfrom fastNLP import Const\nimport numpy as np\nimport fitl..."
  },
  {
    "path": "main.py",
    "chars": 8745,
    "preview": "import torch.nn as nn\n# print(1111111111)\n# from pathes import *\nfrom load_data import load_ontonotes4ner,equip_chinese_..."
  },
  {
    "path": "main_without_fitlog.py",
    "chars": 8754,
    "preview": "import torch.nn as nn\n# print(1111111111)\n# from pathes import *\nfrom load_data import load_ontonotes4ner,equip_chinese_..."
  },
  {
    "path": "models.py",
    "chars": 13097,
    "preview": "import torch.nn as nn\nfrom fastNLP.embeddings import StaticEmbedding\nfrom fastNLP.modules import LSTM, ConditionalRandom..."
  },
  {
    "path": "modules.py",
    "chars": 21276,
    "preview": "import torch.nn as nn\nimport torch\nfrom fastNLP.core.utils import seq_len_to_mask\nfrom utils import better_init_rnn\nimpo..."
  },
  {
    "path": "pathes.py",
    "chars": 1428,
    "preview": "\n\nglove_100_path = 'en-glove-6b-100d'\nglove_50_path = 'en-glove-6b-50d'\nglove_200_path = ''\nglove_300_path = 'en-glove-8..."
  },
  {
    "path": "small.py",
    "chars": 5050,
    "preview": "from utils_ import get_skip_path_trivial, Trie, get_skip_path\nfrom load_data import load_yangjie_rich_pretrain_word_list..."
  },
  {
    "path": "utils.py",
    "chars": 12317,
    "preview": "import torch.nn.functional as F\nimport torch\nimport random\nimport numpy as np\nfrom fastNLP import Const\nfrom fastNLP imp..."
  },
  {
    "path": "utils_.py",
    "chars": 15003,
    "preview": "import collections\nfrom fastNLP import cache_results\ndef get_skip_path(chars,w_trie):\n    sentence = ''.join(chars)..."
  }
]

About this extraction

This page contains the full source code of the LeeSureman/Batch_Parallel_LatticeLSTM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 10 files (115.6 KB), approximately 29.9k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!