[
  {
    "path": "Fastformer-Keras.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"from nltk.tokenize import wordpunct_tokenize\\n\",\n    \"dataset = load_dataset('ag_news')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text=[]\\n\",\n    \"label=[]\\n\",\n    \"for row in dataset['train']['text']+dataset['test']['text']:\\n\",\n    \"    text.append(wordpunct_tokenize(row.lower()))\\n\",\n    \"for row in dataset['train']['label']+dataset['test']['label']:\\n\",\n    \"    label.append(row)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"word_dict={'PADDING':0}\\n\",\n    \"for sent in text:    \\n\",\n    \"    for token in sent:        \\n\",\n    \"        if token not in word_dict:\\n\",\n    \"            word_dict[token]=len(word_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"MAX_SENT_LENGTH=256\\n\",\n    \"\\n\",\n    \"news_words = []\\n\",\n    \"for sent in text:       \\n\",\n    \"    sample=[]\\n\",\n    \"    for token in sent:     \\n\",\n    \"        sample.append(word_dict[token])\\n\",\n    \"    sample = sample[:MAX_SENT_LENGTH]\\n\",\n    \"    news_words.append(sample+[0]*(MAX_SENT_LENGTH-len(sample)))\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 7,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"news_words=np.array(news_words,dtype='int32') \\n\",\n    \"label=np.array(label,dtype='int32') \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 6,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"index=np.arange(len(label))\\n\",\n    \"train_index=index[:120000]\\n\",\n    \"np.random.shuffle(train_index)\\n\",\n    \"test_index=index[120000:]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 3,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"Using TensorFlow backend.\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  _np_qint8 = np.dtype([(\\\"qint8\\\", np.int8, 1)])\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  _np_quint8 = np.dtype([(\\\"quint8\\\", np.uint8, 1)])\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  _np_qint16 = np.dtype([(\\\"qint16\\\", np.int16, 1)])\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  _np_quint16 = np.dtype([(\\\"quint16\\\", np.uint16, 1)])\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  _np_qint32 = np.dtype([(\\\"qint32\\\", np.int32, 1)])\\n\",\n      \"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\\n\",\n      \"  np_resource = np.dtype([(\\\"resource\\\", np.ubyte, 1)])\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"import os\\n\",\n    \"os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"2\\\"\\n\",\n    \"\\n\",\n    \"from keras.utils.np_utils import to_categorical\\n\",\n    \"from keras.layers import *\\n\",\n    \"from keras.models import Model, load_model\\n\",\n    \"from keras import backend as K\\n\",\n    \"from sklearn.metrics import *\\n\",\n    \"from keras.optimizers import *\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"news_words=np.array(news_words,dtype='int32') \\n\",\n    \"label=np.array(label,dtype='int32')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 9,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import random\\n\",\n    \"index=np.arange(len(label))\\n\",\n    \"train_index=index[:120000]\\n\",\n    \"test_index=index[120000:]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 29,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"\\n\",\n    \"class Fastformer(Layer):\\n\",\n    \"\\n\",\n    \"    def __init__(self, nb_head, size_per_head, **kwargs):\\n\",\n    \"        self.nb_head = nb_head\\n\",\n    \"        self.size_per_head = size_per_head\\n\",\n    \"        self.output_dim = nb_head*size_per_head\\n\",\n    \"        self.now_input_shape=None\\n\",\n    \"        super(Fastformer, self).__init__(**kwargs)\\n\",\n    \"\\n\",\n    \"    def build(self, input_shape):\\n\",\n    \"        self.now_input_shape=input_shape\\n\",\n    \"        self.WQ = self.add_weight(name='WQ', \\n\",\n    \"                                  shape=(input_shape[0][-1], self.output_dim),\\n\",\n    \"                                  initializer='glorot_uniform',\\n\",\n    \"                                  trainable=True)\\n\",\n    \"        self.WK = self.add_weight(name='WK', \\n\",\n    \"                                  shape=(input_shape[1][-1], self.output_dim),\\n\",\n    \"                                  initializer='glorot_uniform',\\n\",\n    \"                                  trainable=True) \\n\",\n    \"        self.Wq = self.add_weight(name='Wq', \\n\",\n    \"                                  shape=(self.output_dim,self.nb_head),\\n\",\n    \"                                  initializer='glorot_uniform',\\n\",\n    \"                                  trainable=True)\\n\",\n    \"        self.Wk = self.add_weight(name='Wk', \\n\",\n    \"                                  shape=(self.output_dim,self.nb_head),\\n\",\n    \"                                  initializer='glorot_uniform',\\n\",\n    \"                                  trainable=True)\\n\",\n    \"        \\n\",\n    \"        self.WP = self.add_weight(name='WP', \\n\",\n    \"                                  shape=(self.output_dim,self.output_dim),\\n\",\n    \"                                  initializer='glorot_uniform',\\n\",\n    \"                                  trainable=True)\\n\",\n    \"        \\n\",\n    \"        \\n\",\n    \"        super(Fastformer, self).build(input_shape)\\n\",\n    \"        \\n\",\n    \"    def call(self, x):\\n\",\n    \"        if len(x) == 2:\\n\",\n    \"            Q_seq,K_seq = x\\n\",\n    \"        elif len(x) == 4:\\n\",\n    \"            Q_seq,K_seq,Q_mask,K_mask = x #different mask lengths, reserved for cross attention\\n\",\n    \"\\n\",\n    \"        Q_seq = K.dot(Q_seq, self.WQ)        \\n\",\n    \"        Q_seq_reshape = K.reshape(Q_seq, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))\\n\",\n    \"\\n\",\n    \"        Q_att=  K.permute_dimensions(K.dot(Q_seq_reshape, self.Wq),(0,2,1))/ self.size_per_head**0.5\\n\",\n    \"\\n\",\n    \"        if len(x)  == 4:\\n\",\n    \"            Q_att = Q_att-(1-K.expand_dims(Q_mask,axis=1))*1e8\\n\",\n    \"\\n\",\n    \"        Q_att = K.softmax(Q_att)\\n\",\n    \"        Q_seq = K.reshape(Q_seq, (-1,self.now_input_shape[0][1], self.nb_head, self.size_per_head))\\n\",\n    \"        Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))\\n\",\n    \"        \\n\",\n    \"        K_seq = K.dot(K_seq, self.WK)\\n\",\n    \"        K_seq = K.reshape(K_seq, (-1,self.now_input_shape[1][1], self.nb_head, self.size_per_head))\\n\",\n    \"        K_seq = K.permute_dimensions(K_seq, (0,2,1,3))\\n\",\n    \"\\n\",\n    \"        Q_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(Q_att)\\n\",\n    \"        global_q = K.sum(multiply([Q_att, Q_seq]),axis=2)\\n\",\n    \"        \\n\",\n    \"        global_q_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[1][1],axis=2))(global_q)\\n\",\n    \"\\n\",\n    \"        QK_interaction = multiply([K_seq, global_q_repeat])\\n\",\n    \"        QK_interaction_reshape = K.reshape(QK_interaction, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))\\n\",\n    \"        K_att = K.permute_dimensions(K.dot(QK_interaction_reshape, self.Wk),(0,2,1))/ self.size_per_head**0.5\\n\",\n    \"        \\n\",\n    \"        if len(x)  == 4:\\n\",\n    \"            K_att = K_att-(1-K.expand_dims(K_mask,axis=1))*1e8\\n\",\n    \"            \\n\",\n    \"        K_att = K.softmax(K_att)\\n\",\n    \"\\n\",\n    \"        K_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(K_att)\\n\",\n    \"\\n\",\n    \"        global_k = K.sum(multiply([K_att, QK_interaction]),axis=2)\\n\",\n    \"     \\n\",\n    \"        global_k_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[0][1],axis=2))(global_k)\\n\",\n    \"        #Q=V\\n\",\n    \"        QKQ_interaction = multiply([global_k_repeat, Q_seq])\\n\",\n    \"        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\\n\",\n    \"        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))\\n\",\n    \"        QKQ_interaction = K.dot(QKQ_interaction, self.WP)\\n\",\n    \"        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head,self.size_per_head))\\n\",\n    \"        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\\n\",\n    \"        QKQ_interaction = QKQ_interaction+Q_seq\\n\",\n    \"        QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\\n\",\n    \"        QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))\\n\",\n    \"\\n\",\n    \"        #many operations can be optimized if higher versions are used. \\n\",\n    \"        \\n\",\n    \"        return QKQ_interaction\\n\",\n    \"        \\n\",\n    \"    def compute_output_shape(self, input_shape):\\n\",\n    \"        return (input_shape[0][0], input_shape[0][1], self.output_dim)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"keras.backend.clear_session() \\n\",\n    \"\\n\",\n    \"text_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')\\n\",\n    \"qmask=Lambda(lambda x:  K.cast(K.cast(x,'bool'),'float32'))(text_input)\\n\",\n    \"word_emb = Embedding(len(word_dict),256, trainable=True)(text_input)\\n\",\n    \"\\n\",\n    \"#pos_emb = Embedding(MAX_SENT_LENGTH, 256, trainable=True)(Lambda(lambda x:K.zeros_like(x,dtype='int32')+K.arange(x.shape[1]))(text_input))\\n\",\n    \"#word_emb  =add([word_emb ,pos_emb])\\n\",\n    \"#We find that position embedding is not important on this dataset and we removed it for simplicity. If needed, please uncomment the two lines above\\n\",\n    \"\\n\",\n    \"word_emb=Dropout(0.2)(word_emb)\\n\",\n    \"\\n\",\n    \"hidden_word_emb = Fastformer(16,16)([word_emb,word_emb,qmask,qmask])\\n\",\n    \"hidden_word_emb = Dropout(0.2)(hidden_word_emb)\\n\",\n    \"hidden_word_emb = LayerNormalization()(add([word_emb,hidden_word_emb])) \\n\",\n    \"#if there is no layer norm in old version, please import an external layernorm class from a higher version.\\n\",\n    \"\\n\",\n    \"hidden_word_emb_layer2 = Fastformer(16,16)([hidden_word_emb,hidden_word_emb,qmask,qmask])\\n\",\n    \"hidden_word_emb_layer2 = Dropout(0.2)(hidden_word_emb_layer2)\\n\",\n    \"hidden_word_emb_layer2 = LayerNormalization()(add([hidden_word_emb,hidden_word_emb_layer2]))\\n\",\n    \"\\n\",\n    \"#without FFNN for simplicity\\n\",\n    \"\\n\",\n    \"word_att = Flatten()(Dense(1)(hidden_word_emb_layer2))\\n\",\n    \"word_att = Activation('softmax')(word_att)\\n\",\n    \"text_emb = Dot((1, 1))([hidden_word_emb_layer2 , word_att])\\n\",\n    \"classifier = Dense(4, activation='softmax')(text_emb)\\n\",\n    \"                                      \\n\",\n    \"model = Model([text_input], [classifier])\\n\",\n    \"model.compile(loss=['categorical_crossentropy'],optimizer=Adam(lr=0.001), metrics=['acc'])\\n\",\n    \"\\n\",\n    \"for i in range(1):\\n\",\n    \"    model.fit(news_words[train_index],to_categorical(label)[train_index],shuffle=True,batch_size=64, epochs=1,verbose=1)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"    y_pred = model.predict([news_words[test_index] ], batch_size=128, verbose=1)\\n\",\n    \"    y_pred = np.argmax(y_pred, axis=1)\\n\",\n    \"    y_true = label[test_index]\\n\",\n    \"    acc = accuracy_score(y_true, y_pred)\\n\",\n    \"    report = f1_score(y_true, y_pred, average='macro')  \\n\",\n    \"    print(acc)\\n\",\n    \"    print(report)\\n\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.8\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "Fastformer.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from datasets import load_dataset\\n\",\n    \"from nltk.tokenize import wordpunct_tokenize\\n\",\n    \"dataset = load_dataset('ag_news')\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 119,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"text=[]\\n\",\n    \"label=[]\\n\",\n    \"for row in dataset['train']['text']+dataset['test']['text']:\\n\",\n    \"    text.append(wordpunct_tokenize(row.lower()))\\n\",\n    \"for row in dataset['train']['label']+dataset['test']['label']:\\n\",\n    \"    label.append(row)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 120,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"word_dict={'PADDING':0}\\n\",\n    \"for sent in text:    \\n\",\n    \"    for token in sent:        \\n\",\n    \"        if token not in word_dict:\\n\",\n    \"            word_dict[token]=len(word_dict)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"news_words = []\\n\",\n    \"for sent in text:       \\n\",\n    \"    sample=[]\\n\",\n    \"    for token in sent:     \\n\",\n    \"        sample.append(word_dict[token])\\n\",\n    \"    sample = sample[:256]\\n\",\n    \"    news_words.append(sample+[0]*(256-len(sample)))\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"news_words=np.array(news_words,dtype='int32') \\n\",\n    \"label=np.array(label,dtype='int32') \"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 83,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"index=np.arange(len(label))\\n\",\n    \"train_index=index[:120000]\\n\",\n    \"np.random.shuffle(train_index)\\n\",\n    \"test_index=index[120000:]\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 8,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import os\\n\",\n    \"import logging\\n\",\n    \"os.environ[\\\"CUDA_VISIBLE_DEVICES\\\"] = \\\"1\\\"\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 99,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from transformers import BertConfig\\n\",\n    \"from transformers.models.bert.modeling_bert import BertSelfOutput, BertIntermediate, BertOutput\\n\",\n    \"config=BertConfig.from_json_file('fastformer.json')\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 100,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"import torch\\n\",\n    \"import torch.nn as nn\\n\",\n    \"\\n\",\n    \"class AttentionPooling(nn.Module):\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        self.config = config\\n\",\n    \"        super(AttentionPooling, self).__init__()\\n\",\n    \"        self.att_fc1 = nn.Linear(config.hidden_size, config.hidden_size)\\n\",\n    \"        self.att_fc2 = nn.Linear(config.hidden_size, 1)\\n\",\n    \"        self.apply(self.init_weights)\\n\",\n    \"        \\n\",\n    \"    def init_weights(self, module):\\n\",\n    \"        if isinstance(module, nn.Linear):\\n\",\n    \"            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\\n\",\n    \"        if isinstance(module, nn.Linear) and module.bias is not None:\\n\",\n    \"            module.bias.data.zero_()\\n\",\n    \"            \\n\",\n    \"                \\n\",\n    \"    def forward(self, x, attn_mask=None):\\n\",\n    \"        bz = x.shape[0]\\n\",\n    \"        e = self.att_fc1(x)\\n\",\n    \"        e = nn.Tanh()(e)\\n\",\n    \"        alpha = self.att_fc2(e)\\n\",\n    \"        alpha = torch.exp(alpha)\\n\",\n    \"        if attn_mask is not None:\\n\",\n    \"            alpha = alpha * attn_mask.unsqueeze(2)\\n\",\n    \"        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)\\n\",\n    \"        x = torch.bmm(x.permute(0, 2, 1), alpha)\\n\",\n    \"        x = torch.reshape(x, (bz, -1))  \\n\",\n    \"        return x\\n\",\n    \"\\n\",\n    \"class FastSelfAttention(nn.Module):\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super(FastSelfAttention, self).__init__()\\n\",\n    \"        self.config = config\\n\",\n    \"        if config.hidden_size % config.num_attention_heads != 0:\\n\",\n    \"            raise ValueError(\\n\",\n    \"                \\\"The hidden size (%d) is not a multiple of the number of attention \\\"\\n\",\n    \"                \\\"heads (%d)\\\" %\\n\",\n    \"                (config.hidden_size, config.num_attention_heads))\\n\",\n    \"        self.attention_head_size = int(config.hidden_size /config.num_attention_heads)\\n\",\n    \"        self.num_attention_heads = config.num_attention_heads\\n\",\n    \"        self.all_head_size = self.num_attention_heads * self.attention_head_size\\n\",\n    \"        self.input_dim= config.hidden_size\\n\",\n    \"        \\n\",\n    \"        self.query = nn.Linear(self.input_dim, self.all_head_size)\\n\",\n    \"        self.query_att = nn.Linear(self.all_head_size, self.num_attention_heads)\\n\",\n    \"        self.key = nn.Linear(self.input_dim, self.all_head_size)\\n\",\n    \"        self.key_att = nn.Linear(self.all_head_size, self.num_attention_heads)\\n\",\n    \"        self.transform = nn.Linear(self.all_head_size, self.all_head_size)\\n\",\n    \"\\n\",\n    \"        self.softmax = nn.Softmax(dim=-1)\\n\",\n    \"        \\n\",\n    \"        self.apply(self.init_weights)\\n\",\n    \"\\n\",\n    \"    def init_weights(self, module):\\n\",\n    \"        if isinstance(module, nn.Linear):\\n\",\n    \"            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\\n\",\n    \"        if isinstance(module, nn.Linear) and module.bias is not None:\\n\",\n    \"            module.bias.data.zero_()\\n\",\n    \"                \\n\",\n    \"    def transpose_for_scores(self, x):\\n\",\n    \"        new_x_shape = x.size()[:-1] + (self.num_attention_heads,\\n\",\n    \"                                       self.attention_head_size)\\n\",\n    \"        x = x.view(*new_x_shape)\\n\",\n    \"        return x.permute(0, 2, 1, 3)\\n\",\n    \"    \\n\",\n    \"    def forward(self, hidden_states, attention_mask):\\n\",\n    \"        # batch_size, seq_len, num_head * head_dim, batch_size, seq_len\\n\",\n    \"        batch_size, seq_len, _ = hidden_states.shape\\n\",\n    \"        mixed_query_layer = self.query(hidden_states)\\n\",\n    \"        mixed_key_layer = self.key(hidden_states)\\n\",\n    \"        # batch_size, num_head, seq_len\\n\",\n    \"        query_for_score = self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5\\n\",\n    \"        # add attention mask\\n\",\n    \"        query_for_score += attention_mask\\n\",\n    \"\\n\",\n    \"        # batch_size, num_head, 1, seq_len\\n\",\n    \"        query_weight = self.softmax(query_for_score).unsqueeze(2)\\n\",\n    \"\\n\",\n    \"        # batch_size, num_head, seq_len, head_dim\\n\",\n    \"        query_layer = self.transpose_for_scores(mixed_query_layer)\\n\",\n    \"\\n\",\n    \"        # batch_size, num_head, head_dim, 1\\n\",\n    \"        pooled_query = torch.matmul(query_weight, query_layer).transpose(1, 2).view(-1,1,self.num_attention_heads*self.attention_head_size)\\n\",\n    \"        pooled_query_repeat= pooled_query.repeat(1, seq_len,1)\\n\",\n    \"        # batch_size, num_head, seq_len, head_dim\\n\",\n    \"\\n\",\n    \"        # batch_size, num_head, seq_len\\n\",\n    \"        mixed_query_key_layer=mixed_key_layer* pooled_query_repeat\\n\",\n    \"        \\n\",\n    \"        query_key_score=(self.key_att(mixed_query_key_layer)/ self.attention_head_size**0.5).transpose(1, 2)\\n\",\n    \"        \\n\",\n    \"        # add attention mask\\n\",\n    \"        query_key_score +=attention_mask\\n\",\n    \"\\n\",\n    \"        # batch_size, num_head, 1, seq_len\\n\",\n    \"        query_key_weight = self.softmax(query_key_score).unsqueeze(2)\\n\",\n    \"\\n\",\n    \"        key_layer = self.transpose_for_scores(mixed_query_key_layer)\\n\",\n    \"        pooled_key = torch.matmul(query_key_weight, key_layer)\\n\",\n    \"\\n\",\n    \"        #query = value\\n\",\n    \"        weighted_value =(pooled_key * query_layer).transpose(1, 2)\\n\",\n    \"        weighted_value = weighted_value.reshape(\\n\",\n    \"            weighted_value.size()[:-2] + (self.num_attention_heads * self.attention_head_size,))\\n\",\n    \"        weighted_value = self.transform(weighted_value) + mixed_query_layer\\n\",\n    \"      \\n\",\n    \"        return weighted_value\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 101,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"\\n\",\n    \"class FastAttention(nn.Module):\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super(FastAttention, self).__init__()\\n\",\n    \"        self.self = FastSelfAttention(config)\\n\",\n    \"        self.output = BertSelfOutput(config)\\n\",\n    \"\\n\",\n    \"    def forward(self, input_tensor, attention_mask):\\n\",\n    \"        self_output = self.self(input_tensor, attention_mask)\\n\",\n    \"        attention_output = self.output(self_output, input_tensor)\\n\",\n    \"        return attention_output\\n\",\n    \"\\n\",\n    \"class FastformerLayer(nn.Module):\\n\",\n    \"    def __init__(self, config):\\n\",\n    \"        super(FastformerLayer, self).__init__()\\n\",\n    \"        self.attention = FastAttention(config)\\n\",\n    \"        self.intermediate = BertIntermediate(config)\\n\",\n    \"        self.output = BertOutput(config)\\n\",\n    \"\\n\",\n    \"    def forward(self, hidden_states, attention_mask):\\n\",\n    \"        attention_output = self.attention(hidden_states, attention_mask)\\n\",\n    \"        intermediate_output = self.intermediate(attention_output)\\n\",\n    \"        layer_output = self.output(intermediate_output, attention_output)\\n\",\n    \"        return layer_output\\n\",\n    \"    \\n\",\n    \"class FastformerEncoder(nn.Module):\\n\",\n    \"    def __init__(self, config, pooler_count=1):\\n\",\n    \"        super(FastformerEncoder, self).__init__()\\n\",\n    \"        self.config = config\\n\",\n    \"        self.encoders = nn.ModuleList([FastformerLayer(config) for _ in range(config.num_hidden_layers)])\\n\",\n    \"        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)\\n\",\n    \"        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\\n\",\n    \"        self.dropout = nn.Dropout(config.hidden_dropout_prob)\\n\",\n    \"\\n\",\n    \"        # support multiple different poolers with shared bert encoder.\\n\",\n    \"        self.poolers = nn.ModuleList()\\n\",\n    \"        if config.pooler_type == 'weightpooler':\\n\",\n    \"            for _ in range(pooler_count):\\n\",\n    \"                self.poolers.append(AttentionPooling(config))\\n\",\n    \"        logging.info(f\\\"This model has {len(self.poolers)} poolers.\\\")\\n\",\n    \"\\n\",\n    \"        self.apply(self.init_weights)\\n\",\n    \"\\n\",\n    \"    def init_weights(self, module):\\n\",\n    \"        if isinstance(module, (nn.Linear, nn.Embedding)):\\n\",\n    \"            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\\n\",\n    \"            if isinstance(module, (nn.Embedding)) and module.padding_idx is not None:\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    module.weight[module.padding_idx].fill_(0)\\n\",\n    \"        elif isinstance(module, nn.LayerNorm):\\n\",\n    \"            module.bias.data.zero_()\\n\",\n    \"            module.weight.data.fill_(1.0)\\n\",\n    \"        if isinstance(module, nn.Linear) and module.bias is not None:\\n\",\n    \"            module.bias.data.zero_()\\n\",\n    \"\\n\",\n    \"    def forward(self, \\n\",\n    \"                input_embs, \\n\",\n    \"                attention_mask, \\n\",\n    \"                pooler_index=0):\\n\",\n    \"        #input_embs: batch_size, seq_len, emb_dim\\n\",\n    \"        #attention_mask: batch_size, seq_len, emb_dim\\n\",\n    \"\\n\",\n    \"        extended_attention_mask = attention_mask.unsqueeze(1)\\n\",\n    \"        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility\\n\",\n    \"        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\\n\",\n    \"\\n\",\n    \"        batch_size, seq_length, emb_dim = input_embs.shape\\n\",\n    \"        position_ids = torch.arange(seq_length, dtype=torch.long, device=input_embs.device)\\n\",\n    \"        position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)\\n\",\n    \"        position_embeddings = self.position_embeddings(position_ids)\\n\",\n    \"\\n\",\n    \"        embeddings = input_embs + position_embeddings\\n\",\n    \"        \\n\",\n    \"        embeddings = self.LayerNorm(embeddings)\\n\",\n    \"        embeddings = self.dropout(embeddings)\\n\",\n    \"        #print(embeddings.size())\\n\",\n    \"        all_hidden_states = [embeddings]\\n\",\n    \"\\n\",\n    \"        for i, layer_module in enumerate(self.encoders):\\n\",\n    \"            layer_outputs = layer_module(all_hidden_states[-1], extended_attention_mask)\\n\",\n    \"            all_hidden_states.append(layer_outputs)\\n\",\n    \"        assert len(self.poolers) > pooler_index\\n\",\n    \"        output = self.poolers[pooler_index](all_hidden_states[-1], attention_mask)\\n\",\n    \"\\n\",\n    \"        return output \\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 106,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"    \\n\",\n    \"class Model(torch.nn.Module):\\n\",\n    \"\\n\",\n    \"    def __init__(self,config):\\n\",\n    \"        super(Model, self).__init__()\\n\",\n    \"        self.config = config\\n\",\n    \"        self.dense_linear = nn.Linear(config.hidden_size,4)\\n\",\n    \"        self.word_embedding = nn.Embedding(len(word_dict),256,padding_idx=0)\\n\",\n    \"        self.fastformer_model = FastformerEncoder(config)\\n\",\n    \"        self.criterion = nn.CrossEntropyLoss() \\n\",\n    \"        self.apply(self.init_weights)\\n\",\n    \"        \\n\",\n    \"    def init_weights(self, module):\\n\",\n    \"        if isinstance(module, (nn.Linear, nn.Embedding)):\\n\",\n    \"            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\\n\",\n    \"            if isinstance(module, (nn.Embedding)) and module.padding_idx is not None:\\n\",\n    \"                with torch.no_grad():\\n\",\n    \"                    module.weight[module.padding_idx].fill_(0)\\n\",\n    \"        if isinstance(module, nn.Linear) and module.bias is not None:\\n\",\n    \"            module.bias.data.zero_()\\n\",\n    \"    \\n\",\n    \"    def forward(self,input_ids,targets):\\n\",\n    \"        mask=input_ids.bool().float()\\n\",\n    \"        embds=self.word_embedding(input_ids)\\n\",\n    \"        text_vec = self.fastformer_model(embds,mask)\\n\",\n    \"        score = self.dense_linear(text_vec)\\n\",\n    \"        loss = self.criterion(score, targets) \\n\",\n    \"        return loss, score\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 107,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"def acc(y_true, y_hat):\\n\",\n    \"    y_hat = torch.argmax(y_hat, dim=-1)\\n\",\n    \"    tot = y_true.shape[0]\\n\",\n    \"    hit = torch.sum(y_true == y_hat)\\n\",\n    \"    return hit.data.float() * 1.0 / tot\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 111,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"Model(\\n\",\n       \"  (dense_linear): Linear(in_features=256, out_features=4, bias=True)\\n\",\n       \"  (word_embedding): Embedding(66818, 256, padding_idx=0)\\n\",\n       \"  (fastformer_model): FastformerEncoder(\\n\",\n       \"    (encoders): ModuleList(\\n\",\n       \"      (0): FastformerLayer(\\n\",\n       \"        (attention): FastAttention(\\n\",\n       \"          (self): FastSelfAttention(\\n\",\n       \"            (query): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (query_att): Linear(in_features=256, out_features=16, bias=True)\\n\",\n       \"            (key): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (key_att): Linear(in_features=256, out_features=16, bias=True)\\n\",\n       \"            (transform): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (softmax): Softmax(dim=-1)\\n\",\n       \"          )\\n\",\n       \"          (output): BertSelfOutput(\\n\",\n       \"            (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\\n\",\n       \"            (dropout): Dropout(p=0.2, inplace=False)\\n\",\n       \"          )\\n\",\n       \"        )\\n\",\n       \"        (intermediate): BertIntermediate(\\n\",\n       \"          (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"        )\\n\",\n       \"        (output): BertOutput(\\n\",\n       \"          (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"          (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\\n\",\n       \"          (dropout): Dropout(p=0.2, inplace=False)\\n\",\n       \"        )\\n\",\n       \"      )\\n\",\n       \"      (1): FastformerLayer(\\n\",\n       \"        (attention): FastAttention(\\n\",\n       \"          (self): FastSelfAttention(\\n\",\n       \"            (query): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (query_att): Linear(in_features=256, out_features=16, bias=True)\\n\",\n       \"            (key): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (key_att): Linear(in_features=256, out_features=16, bias=True)\\n\",\n       \"            (transform): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (softmax): Softmax(dim=-1)\\n\",\n       \"          )\\n\",\n       \"          (output): BertSelfOutput(\\n\",\n       \"            (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"            (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\\n\",\n       \"            (dropout): Dropout(p=0.2, inplace=False)\\n\",\n       \"          )\\n\",\n       \"        )\\n\",\n       \"        (intermediate): BertIntermediate(\\n\",\n       \"          (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"        )\\n\",\n       \"        (output): BertOutput(\\n\",\n       \"          (dense): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"          (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\\n\",\n       \"          (dropout): Dropout(p=0.2, inplace=False)\\n\",\n       \"        )\\n\",\n       \"      )\\n\",\n       \"    )\\n\",\n       \"    (position_embeddings): Embedding(256, 256)\\n\",\n       \"    (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\\n\",\n       \"    (dropout): Dropout(p=0.2, inplace=False)\\n\",\n       \"    (poolers): ModuleList(\\n\",\n       \"      (0): AttentionPooling(\\n\",\n       \"        (att_fc1): Linear(in_features=256, out_features=256, bias=True)\\n\",\n       \"        (att_fc2): Linear(in_features=256, out_features=1, bias=True)\\n\",\n       \"      )\\n\",\n       \"    )\\n\",\n       \"  )\\n\",\n       \"  (criterion): CrossEntropyLoss()\\n\",\n       \")\"\n      ]\n     },\n     \"execution_count\": 111,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"model = Model(config)\\n\",\n    \"import torch.optim as optim\\n\",\n    \"optimizer = optim.Adam([ {'params': model.parameters(), 'lr': 1e-3}])\\n\",\n    \"model.cuda()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 112,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \" Ed: 0, train_loss: 1.38936, acc: 0.20312\\n\",\n      \" Ed: 6400, train_loss: 0.59932, acc: 0.77135\\n\",\n      \" Ed: 12800, train_loss: 0.47238, acc: 0.82774\\n\",\n      \" Ed: 19200, train_loss: 0.42074, acc: 0.84988\\n\",\n      \" Ed: 25600, train_loss: 0.39129, acc: 0.86234\\n\",\n      \" Ed: 32000, train_loss: 0.37110, acc: 0.87004\\n\",\n      \" Ed: 38400, train_loss: 0.35466, acc: 0.87656\\n\",\n      \" Ed: 44800, train_loss: 0.34215, acc: 0.88137\\n\",\n      \" Ed: 51200, train_loss: 0.33404, acc: 0.88442\\n\",\n      \" Ed: 57600, train_loss: 0.32441, acc: 0.88820\\n\",\n      \" Ed: 64000, train_loss: 0.31654, acc: 0.89111\\n\",\n      \" Ed: 70400, train_loss: 0.31033, acc: 0.89336\\n\",\n      \" Ed: 76800, train_loss: 0.30431, acc: 0.89562\\n\",\n      \" Ed: 83200, train_loss: 0.29853, acc: 0.89778\\n\",\n      \" Ed: 89600, train_loss: 0.29410, acc: 0.89949\\n\",\n      \" Ed: 96000, train_loss: 0.29110, acc: 0.90074\\n\",\n      \" Ed: 102400, train_loss: 0.28680, acc: 0.90223\\n\",\n      \" Ed: 108800, train_loss: 0.28471, acc: 0.90309\\n\",\n      \" Ed: 115200, train_loss: 0.28077, acc: 0.90427\\n\",\n      \"0.9235526315789474\\n\",\n      \" Ed: 0, train_loss: 0.16773, acc: 0.95312\\n\",\n      \" Ed: 6400, train_loss: 0.18398, acc: 0.93673\\n\",\n      \" Ed: 12800, train_loss: 0.17103, acc: 0.94038\\n\",\n      \" Ed: 19200, train_loss: 0.16540, acc: 0.94285\\n\",\n      \" Ed: 25600, train_loss: 0.16188, acc: 0.94389\\n\",\n      \" Ed: 32000, train_loss: 0.15936, acc: 0.94414\\n\",\n      \" Ed: 38400, train_loss: 0.15465, acc: 0.94572\\n\",\n      \" Ed: 44800, train_loss: 0.15194, acc: 0.94655\\n\",\n      \" Ed: 51200, train_loss: 0.15035, acc: 0.94737\\n\",\n      \" Ed: 57600, train_loss: 0.14845, acc: 0.94827\\n\",\n      \" Ed: 64000, train_loss: 0.14617, acc: 0.94897\\n\",\n      \" Ed: 70400, train_loss: 0.14536, acc: 0.94941\\n\",\n      \" Ed: 76800, train_loss: 0.14417, acc: 0.94992\\n\",\n      \" Ed: 83200, train_loss: 0.14298, acc: 0.95030\\n\",\n      \" Ed: 89600, train_loss: 0.14216, acc: 0.95063\\n\",\n      \" Ed: 96000, train_loss: 0.14231, acc: 0.95071\\n\",\n      \" Ed: 102400, train_loss: 0.14160, acc: 0.95107\\n\",\n      \" Ed: 108800, train_loss: 0.14129, acc: 0.95128\\n\",\n      \" Ed: 115200, train_loss: 0.13996, acc: 0.95181\\n\",\n      \"0.9271052631578948\\n\"\n     ]\n    }\n   ],\n   \"source\": [\n    \"for epoch in range(2):\\n\",\n    \"    loss = 0.0\\n\",\n    \"    accuary = 0.0\\n\",\n    \"    for cnt in range(len(train_index)//64):\\n\",\n    \"\\n\",\n    \"        log_ids=news_words[train_index][cnt*64:cnt*64+64,:256]\\n\",\n    \"        targets= label[train_index][cnt*64:cnt*64+64]\\n\",\n    \"\\n\",\n    \"        log_ids = torch.LongTensor(log_ids).cuda(non_blocking=True)\\n\",\n    \"        targets = torch.LongTensor(targets).cuda(non_blocking=True)\\n\",\n    \"        bz_loss, y_hat = model(log_ids, targets)\\n\",\n    \"        loss += bz_loss.data.float()\\n\",\n    \"        accuary += acc(targets, y_hat)\\n\",\n    \"        unified_loss=bz_loss\\n\",\n    \"        optimizer.zero_grad()\\n\",\n    \"        unified_loss.backward()\\n\",\n    \"        optimizer.step()\\n\",\n    \"\\n\",\n    \"        if cnt % 100== 0:\\n\",\n    \"            print( ' Ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * 64, loss.data / (cnt+1), accuary / (cnt+1)))\\n\",\n    \"    model.eval()\\n\",\n    \"    allpred=[]\\n\",\n    \"    for cnt in range(len(test_index)//64+1):\\n\",\n    \"    \\n\",\n    \"        log_ids=news_words[test_index][cnt*64:cnt*64+64,:256]\\n\",\n    \"        targets= label[test_index][cnt*64:cnt*64+64]\\n\",\n    \"        log_ids = torch.LongTensor(log_ids).cuda(non_blocking=True)\\n\",\n    \"        targets = torch.LongTensor(targets).cuda(non_blocking=True)\\n\",\n    \"    \\n\",\n    \"        bz_loss2, y_hat2 = model(log_ids, targets)\\n\",\n    \"        allpred+=y_hat2.to('cpu').detach().numpy().tolist()\\n\",\n    \"        \\n\",\n    \"    y_pred=np.argmax(allpred,axis=-1)\\n\",\n    \"    y_true=label[test_index]\\n\",\n    \"    from sklearn.metrics import *\\n\",\n    \"    print(accuracy_score(y_true, y_pred))\\n\",\n    \"    model.train()\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.6.9\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "README.md",
    "content": "# Fastformer\n\n## Notes from the authors\n\nPytorch/Keras implementation of Fastformer. The keras version only includes the core fastformer attention part. The pytorch version is written in a huggingface transformers style. The jupyter notebooks contain the quickstart codes for text classification on AG's News (without pretrained word embeddings for simplicity), which can be directly run.  We noticed that in our experiments, NOT all tasks need FFNN, residual connection, layer normalization and even position embedding. For example, we find that in news recommendation, it is better to directly use Fastformer without layer normalization and position embedding. However, in Ad CVR prediction, both position embedding and layer normalization are needed.\n\nKeras version: 2.2.4 (may not be compatible with higher versions)\n\nTF version: from 1.12 to 1.15 (may be compatible with lower versions)\n\nPytorch version: 1.6.0 (may be compatible with higher/lower versions)\n\n## Citation\n```\n@article{wu2021fastformer,\n  title={Fastformer: Additive Attention Can Be All You Need},\n  author={Wu, Chuhan and Wu, Fangzhao and Qi, Tao and Huang, Yongfeng},\n  journal={arXiv preprint arXiv:2108.09084},\n  year={2021}\n}\n```\n"
  },
  {
    "path": "fastformer.json",
    "content": "{\n    \"hidden_size\":256,\n    \"hidden_dropout_prob\":0.2,\n    \"num_hidden_layers\":2,\n    \"hidden_act\":\"gelu\",\n    \"num_attention_heads\":16,\n    \"intermediate_size\":256,\n    \"max_position_embeddings\":256,\n    \"type_vocab_size\":2,\n    \"vocab_size\":100000,\n    \"layer_norm_eps\": 1e-12,\n    \"initializer_range\": 0.02,\n    \"pooler_type\": \"weightpooler\",\n    \"enable_fp16\": \"False\"\n}"
  }
]