Repository: wuch15/Fastformer
Branch: main
Commit: 84cc859f6dfb
Files: 4
Total size: 37.8 KB
Directory structure:
gitextract_awcr4gr3/
├── Fastformer-Keras.ipynb
├── Fastformer.ipynb
├── README.md
└── fastformer.json
================================================
FILE CONTENTS
================================================
================================================
FILE: Fastformer-Keras.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from nltk.tokenize import wordpunct_tokenize\n",
"dataset = load_dataset('ag_news')"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"text=[]\n",
"label=[]\n",
"for row in dataset['train']['text']+dataset['test']['text']:\n",
" text.append(wordpunct_tokenize(row.lower()))\n",
"for row in dataset['train']['label']+dataset['test']['label']:\n",
" label.append(row)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"word_dict={'PADDING':0}\n",
"for sent in text: \n",
" for token in sent: \n",
" if token not in word_dict:\n",
" word_dict[token]=len(word_dict)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"MAX_SENT_LENGTH=256\n",
"\n",
"news_words = []\n",
"for sent in text: \n",
" sample=[]\n",
" for token in sent: \n",
" sample.append(word_dict[token])\n",
" sample = sample[:MAX_SENT_LENGTH]\n",
" news_words.append(sample+[0]*(MAX_SENT_LENGTH-len(sample)))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"news_words=np.array(news_words,dtype='int32') \n",
"label=np.array(label,dtype='int32') "
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"index=np.arange(len(label))\n",
"train_index=index[:120000]\n",
"np.random.shuffle(train_index)\n",
"test_index=index[120000:]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:524: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
"/home/user/anaconda3/envs/wuch15/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:532: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
" np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n"
]
}
],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n",
"\n",
"from keras.utils.np_utils import to_categorical\n",
"from keras.layers import *\n",
"from keras.models import Model, load_model\n",
"from keras import backend as K\n",
"from sklearn.metrics import *\n",
"from keras.optimizers import *\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"news_words=np.array(news_words,dtype='int32') \n",
"label=np.array(label,dtype='int32')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"index=np.arange(len(label))\n",
"train_index=index[:120000]\n",
"test_index=index[120000:]"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class Fastformer(Layer):\n",
"\n",
" def __init__(self, nb_head, size_per_head, **kwargs):\n",
" self.nb_head = nb_head\n",
" self.size_per_head = size_per_head\n",
" self.output_dim = nb_head*size_per_head\n",
" self.now_input_shape=None\n",
" super(Fastformer, self).__init__(**kwargs)\n",
"\n",
" def build(self, input_shape):\n",
" self.now_input_shape=input_shape\n",
" self.WQ = self.add_weight(name='WQ', \n",
" shape=(input_shape[0][-1], self.output_dim),\n",
" initializer='glorot_uniform',\n",
" trainable=True)\n",
" self.WK = self.add_weight(name='WK', \n",
" shape=(input_shape[1][-1], self.output_dim),\n",
" initializer='glorot_uniform',\n",
" trainable=True) \n",
" self.Wq = self.add_weight(name='Wq', \n",
" shape=(self.output_dim,self.nb_head),\n",
" initializer='glorot_uniform',\n",
" trainable=True)\n",
" self.Wk = self.add_weight(name='Wk', \n",
" shape=(self.output_dim,self.nb_head),\n",
" initializer='glorot_uniform',\n",
" trainable=True)\n",
" \n",
" self.WP = self.add_weight(name='WP', \n",
" shape=(self.output_dim,self.output_dim),\n",
" initializer='glorot_uniform',\n",
" trainable=True)\n",
" \n",
" \n",
" super(Fastformer, self).build(input_shape)\n",
" \n",
" def call(self, x):\n",
" if len(x) == 2:\n",
" Q_seq,K_seq = x\n",
" elif len(x) == 4:\n",
" Q_seq,K_seq,Q_mask,K_mask = x #different mask lengths, reserved for cross attention\n",
"\n",
" Q_seq = K.dot(Q_seq, self.WQ) \n",
" Q_seq_reshape = K.reshape(Q_seq, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))\n",
"\n",
" Q_att= K.permute_dimensions(K.dot(Q_seq_reshape, self.Wq),(0,2,1))/ self.size_per_head**0.5\n",
"\n",
" if len(x) == 4:\n",
" Q_att = Q_att-(1-K.expand_dims(Q_mask,axis=1))*1e8\n",
"\n",
" Q_att = K.softmax(Q_att)\n",
" Q_seq = K.reshape(Q_seq, (-1,self.now_input_shape[0][1], self.nb_head, self.size_per_head))\n",
" Q_seq = K.permute_dimensions(Q_seq, (0,2,1,3))\n",
" \n",
" K_seq = K.dot(K_seq, self.WK)\n",
" K_seq = K.reshape(K_seq, (-1,self.now_input_shape[1][1], self.nb_head, self.size_per_head))\n",
" K_seq = K.permute_dimensions(K_seq, (0,2,1,3))\n",
"\n",
" Q_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(Q_att)\n",
" global_q = K.sum(multiply([Q_att, Q_seq]),axis=2)\n",
" \n",
" global_q_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[1][1],axis=2))(global_q)\n",
"\n",
" QK_interaction = multiply([K_seq, global_q_repeat])\n",
" QK_interaction_reshape = K.reshape(QK_interaction, (-1, self.now_input_shape[0][1], self.nb_head*self.size_per_head))\n",
" K_att = K.permute_dimensions(K.dot(QK_interaction_reshape, self.Wk),(0,2,1))/ self.size_per_head**0.5\n",
" \n",
" if len(x) == 4:\n",
" K_att = K_att-(1-K.expand_dims(K_mask,axis=1))*1e8\n",
" \n",
" K_att = K.softmax(K_att)\n",
"\n",
" K_att = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=3),self.size_per_head,axis=3))(K_att)\n",
"\n",
" global_k = K.sum(multiply([K_att, QK_interaction]),axis=2)\n",
" \n",
" global_k_repeat = Lambda(lambda x: K.repeat_elements(K.expand_dims(x,axis=2), self.now_input_shape[0][1],axis=2))(global_k)\n",
" #Q=V\n",
" QKQ_interaction = multiply([global_k_repeat, Q_seq])\n",
" QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\n",
" QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))\n",
" QKQ_interaction = K.dot(QKQ_interaction, self.WP)\n",
" QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head,self.size_per_head))\n",
" QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\n",
" QKQ_interaction = QKQ_interaction+Q_seq\n",
" QKQ_interaction = K.permute_dimensions(QKQ_interaction, (0,2,1,3))\n",
" QKQ_interaction = K.reshape(QKQ_interaction, (-1,self.now_input_shape[0][1], self.nb_head*self.size_per_head))\n",
"\n",
" #many operations can be optimized if higher versions are used. \n",
" \n",
" return QKQ_interaction\n",
" \n",
" def compute_output_shape(self, input_shape):\n",
" return (input_shape[0][0], input_shape[0][1], self.output_dim)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session() \n",
"\n",
"text_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')\n",
"qmask=Lambda(lambda x: K.cast(K.cast(x,'bool'),'float32'))(text_input)\n",
"word_emb = Embedding(len(word_dict),256, trainable=True)(text_input)\n",
"\n",
"#pos_emb = Embedding(MAX_SENT_LENGTH, 256, trainable=True)(Lambda(lambda x:K.zeros_like(x,dtype='int32')+K.arange(x.shape[1]))(text_input))\n",
"#word_emb =add([word_emb ,pos_emb])\n",
"#We find that position embedding is not important on this dataset and we removed it for simplicity. If needed, please uncomment the two lines above\n",
"\n",
"word_emb=Dropout(0.2)(word_emb)\n",
"\n",
"hidden_word_emb = Fastformer(16,16)([word_emb,word_emb,qmask,qmask])\n",
"hidden_word_emb = Dropout(0.2)(hidden_word_emb)\n",
"hidden_word_emb = LayerNormalization()(add([word_emb,hidden_word_emb])) \n",
"#if there is no layer norm in old version, please import an external layernorm class from a higher version.\n",
"\n",
"hidden_word_emb_layer2 = Fastformer(16,16)([hidden_word_emb,hidden_word_emb,qmask,qmask])\n",
"hidden_word_emb_layer2 = Dropout(0.2)(hidden_word_emb_layer2)\n",
"hidden_word_emb_layer2 = LayerNormalization()(add([hidden_word_emb,hidden_word_emb_layer2]))\n",
"\n",
"#without FFNN for simplicity\n",
"\n",
"word_att = Flatten()(Dense(1)(hidden_word_emb_layer2))\n",
"word_att = Activation('softmax')(word_att)\n",
"text_emb = Dot((1, 1))([hidden_word_emb_layer2 , word_att])\n",
"classifier = Dense(4, activation='softmax')(text_emb)\n",
" \n",
"model = Model([text_input], [classifier])\n",
"model.compile(loss=['categorical_crossentropy'],optimizer=Adam(lr=0.001), metrics=['acc'])\n",
"\n",
"for i in range(1):\n",
" model.fit(news_words[train_index],to_categorical(label)[train_index],shuffle=True,batch_size=64, epochs=1,verbose=1)\n",
"\n",
"\n",
" y_pred = model.predict([news_words[test_index] ], batch_size=128, verbose=1)\n",
" y_pred = np.argmax(y_pred, axis=1)\n",
" y_true = label[test_index]\n",
" acc = accuracy_score(y_true, y_pred)\n",
" report = f1_score(y_true, y_pred, average='macro') \n",
" print(acc)\n",
" print(report)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: Fastformer.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from nltk.tokenize import wordpunct_tokenize\n",
"dataset = load_dataset('ag_news')"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"text=[]\n",
"label=[]\n",
"for row in dataset['train']['text']+dataset['test']['text']:\n",
" text.append(wordpunct_tokenize(row.lower()))\n",
"for row in dataset['train']['label']+dataset['test']['label']:\n",
" label.append(row)"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"word_dict={'PADDING':0}\n",
"for sent in text: \n",
" for token in sent: \n",
" if token not in word_dict:\n",
" word_dict[token]=len(word_dict)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"news_words = []\n",
"for sent in text: \n",
" sample=[]\n",
" for token in sent: \n",
" sample.append(word_dict[token])\n",
" sample = sample[:256]\n",
" news_words.append(sample+[0]*(256-len(sample)))\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"news_words=np.array(news_words,dtype='int32') \n",
"label=np.array(label,dtype='int32') "
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
"index=np.arange(len(label))\n",
"train_index=index[:120000]\n",
"np.random.shuffle(train_index)\n",
"test_index=index[120000:]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import logging\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n"
]
},
{
"cell_type": "code",
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
"from transformers import BertConfig\n",
"from transformers.models.bert.modeling_bert import BertSelfOutput, BertIntermediate, BertOutput\n",
"config=BertConfig.from_json_file('fastformer.json')\n"
]
},
{
"cell_type": "code",
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class AttentionPooling(nn.Module):\n",
" def __init__(self, config):\n",
" self.config = config\n",
" super(AttentionPooling, self).__init__()\n",
" self.att_fc1 = nn.Linear(config.hidden_size, config.hidden_size)\n",
" self.att_fc2 = nn.Linear(config.hidden_size, 1)\n",
" self.apply(self.init_weights)\n",
" \n",
" def init_weights(self, module):\n",
" if isinstance(module, nn.Linear):\n",
" module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
" if isinstance(module, nn.Linear) and module.bias is not None:\n",
" module.bias.data.zero_()\n",
" \n",
" \n",
" def forward(self, x, attn_mask=None):\n",
" bz = x.shape[0]\n",
" e = self.att_fc1(x)\n",
" e = nn.Tanh()(e)\n",
" alpha = self.att_fc2(e)\n",
" alpha = torch.exp(alpha)\n",
" if attn_mask is not None:\n",
" alpha = alpha * attn_mask.unsqueeze(2)\n",
" alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)\n",
" x = torch.bmm(x.permute(0, 2, 1), alpha)\n",
" x = torch.reshape(x, (bz, -1)) \n",
" return x\n",
"\n",
"class FastSelfAttention(nn.Module):\n",
" def __init__(self, config):\n",
" super(FastSelfAttention, self).__init__()\n",
" self.config = config\n",
" if config.hidden_size % config.num_attention_heads != 0:\n",
" raise ValueError(\n",
" \"The hidden size (%d) is not a multiple of the number of attention \"\n",
" \"heads (%d)\" %\n",
" (config.hidden_size, config.num_attention_heads))\n",
" self.attention_head_size = int(config.hidden_size /config.num_attention_heads)\n",
" self.num_attention_heads = config.num_attention_heads\n",
" self.all_head_size = self.num_attention_heads * self.attention_head_size\n",
" self.input_dim= config.hidden_size\n",
" \n",
" self.query = nn.Linear(self.input_dim, self.all_head_size)\n",
" self.query_att = nn.Linear(self.all_head_size, self.num_attention_heads)\n",
" self.key = nn.Linear(self.input_dim, self.all_head_size)\n",
" self.key_att = nn.Linear(self.all_head_size, self.num_attention_heads)\n",
" self.transform = nn.Linear(self.all_head_size, self.all_head_size)\n",
"\n",
" self.softmax = nn.Softmax(dim=-1)\n",
" \n",
" self.apply(self.init_weights)\n",
"\n",
" def init_weights(self, module):\n",
" if isinstance(module, nn.Linear):\n",
" module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
" if isinstance(module, nn.Linear) and module.bias is not None:\n",
" module.bias.data.zero_()\n",
" \n",
" def transpose_for_scores(self, x):\n",
" new_x_shape = x.size()[:-1] + (self.num_attention_heads,\n",
" self.attention_head_size)\n",
" x = x.view(*new_x_shape)\n",
" return x.permute(0, 2, 1, 3)\n",
" \n",
" def forward(self, hidden_states, attention_mask):\n",
" # batch_size, seq_len, num_head * head_dim, batch_size, seq_len\n",
" batch_size, seq_len, _ = hidden_states.shape\n",
" mixed_query_layer = self.query(hidden_states)\n",
" mixed_key_layer = self.key(hidden_states)\n",
" # batch_size, num_head, seq_len\n",
" query_for_score = self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5\n",
" # add attention mask\n",
" query_for_score += attention_mask\n",
"\n",
" # batch_size, num_head, 1, seq_len\n",
" query_weight = self.softmax(query_for_score).unsqueeze(2)\n",
"\n",
" # batch_size, num_head, seq_len, head_dim\n",
" query_layer = self.transpose_for_scores(mixed_query_layer)\n",
"\n",
" # batch_size, num_head, head_dim, 1\n",
" pooled_query = torch.matmul(query_weight, query_layer).transpose(1, 2).view(-1,1,self.num_attention_heads*self.attention_head_size)\n",
" pooled_query_repeat= pooled_query.repeat(1, seq_len,1)\n",
" # batch_size, num_head, seq_len, head_dim\n",
"\n",
" # batch_size, num_head, seq_len\n",
" mixed_query_key_layer=mixed_key_layer* pooled_query_repeat\n",
" \n",
" query_key_score=(self.key_att(mixed_query_key_layer)/ self.attention_head_size**0.5).transpose(1, 2)\n",
" \n",
" # add attention mask\n",
" query_key_score +=attention_mask\n",
"\n",
" # batch_size, num_head, 1, seq_len\n",
" query_key_weight = self.softmax(query_key_score).unsqueeze(2)\n",
"\n",
" key_layer = self.transpose_for_scores(mixed_query_key_layer)\n",
" pooled_key = torch.matmul(query_key_weight, key_layer)\n",
"\n",
" #query = value\n",
" weighted_value =(pooled_key * query_layer).transpose(1, 2)\n",
" weighted_value = weighted_value.reshape(\n",
" weighted_value.size()[:-2] + (self.num_attention_heads * self.attention_head_size,))\n",
" weighted_value = self.transform(weighted_value) + mixed_query_layer\n",
" \n",
" return weighted_value"
]
},
{
"cell_type": "code",
"execution_count": 101,
"metadata": {},
"outputs": [],
"source": [
"\n",
"class FastAttention(nn.Module):\n",
" def __init__(self, config):\n",
" super(FastAttention, self).__init__()\n",
" self.self = FastSelfAttention(config)\n",
" self.output = BertSelfOutput(config)\n",
"\n",
" def forward(self, input_tensor, attention_mask):\n",
" self_output = self.self(input_tensor, attention_mask)\n",
" attention_output = self.output(self_output, input_tensor)\n",
" return attention_output\n",
"\n",
"class FastformerLayer(nn.Module):\n",
" def __init__(self, config):\n",
" super(FastformerLayer, self).__init__()\n",
" self.attention = FastAttention(config)\n",
" self.intermediate = BertIntermediate(config)\n",
" self.output = BertOutput(config)\n",
"\n",
" def forward(self, hidden_states, attention_mask):\n",
" attention_output = self.attention(hidden_states, attention_mask)\n",
" intermediate_output = self.intermediate(attention_output)\n",
" layer_output = self.output(intermediate_output, attention_output)\n",
" return layer_output\n",
" \n",
"class FastformerEncoder(nn.Module):\n",
" def __init__(self, config, pooler_count=1):\n",
" super(FastformerEncoder, self).__init__()\n",
" self.config = config\n",
" self.encoders = nn.ModuleList([FastformerLayer(config) for _ in range(config.num_hidden_layers)])\n",
" self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)\n",
" self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)\n",
" self.dropout = nn.Dropout(config.hidden_dropout_prob)\n",
"\n",
" # support multiple different poolers with shared bert encoder.\n",
" self.poolers = nn.ModuleList()\n",
" if config.pooler_type == 'weightpooler':\n",
" for _ in range(pooler_count):\n",
" self.poolers.append(AttentionPooling(config))\n",
" logging.info(f\"This model has {len(self.poolers)} poolers.\")\n",
"\n",
" self.apply(self.init_weights)\n",
"\n",
" def init_weights(self, module):\n",
" if isinstance(module, (nn.Linear, nn.Embedding)):\n",
" module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
" if isinstance(module, (nn.Embedding)) and module.padding_idx is not None:\n",
" with torch.no_grad():\n",
" module.weight[module.padding_idx].fill_(0)\n",
" elif isinstance(module, nn.LayerNorm):\n",
" module.bias.data.zero_()\n",
" module.weight.data.fill_(1.0)\n",
" if isinstance(module, nn.Linear) and module.bias is not None:\n",
" module.bias.data.zero_()\n",
"\n",
" def forward(self, \n",
" input_embs, \n",
" attention_mask, \n",
" pooler_index=0):\n",
" #input_embs: batch_size, seq_len, emb_dim\n",
" #attention_mask: batch_size, seq_len, emb_dim\n",
"\n",
" extended_attention_mask = attention_mask.unsqueeze(1)\n",
" extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility\n",
" extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0\n",
"\n",
" batch_size, seq_length, emb_dim = input_embs.shape\n",
" position_ids = torch.arange(seq_length, dtype=torch.long, device=input_embs.device)\n",
" position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)\n",
" position_embeddings = self.position_embeddings(position_ids)\n",
"\n",
" embeddings = input_embs + position_embeddings\n",
" \n",
" embeddings = self.LayerNorm(embeddings)\n",
" embeddings = self.dropout(embeddings)\n",
" #print(embeddings.size())\n",
" all_hidden_states = [embeddings]\n",
"\n",
" for i, layer_module in enumerate(self.encoders):\n",
" layer_outputs = layer_module(all_hidden_states[-1], extended_attention_mask)\n",
" all_hidden_states.append(layer_outputs)\n",
" assert len(self.poolers) > pooler_index\n",
" output = self.poolers[pooler_index](all_hidden_states[-1], attention_mask)\n",
"\n",
" return output \n"
]
},
{
"cell_type": "code",
"execution_count": 106,
"metadata": {},
"outputs": [],
"source": [
" \n",
"class Model(torch.nn.Module):\n",
"\n",
" def __init__(self,config):\n",
" super(Model, self).__init__()\n",
" self.config = config\n",
" self.dense_linear = nn.Linear(config.hidden_size,4)\n",
" self.word_embedding = nn.Embedding(len(word_dict),256,padding_idx=0)\n",
" self.fastformer_model = FastformerEncoder(config)\n",
" self.criterion = nn.CrossEntropyLoss() \n",
" self.apply(self.init_weights)\n",
" \n",
" def init_weights(self, module):\n",
" if isinstance(module, (nn.Linear, nn.Embedding)):\n",
" module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)\n",
" if isinstance(module, (nn.Embedding)) and module.padding_idx is not None:\n",
" with torch.no_grad():\n",
" module.weight[module.padding_idx].fill_(0)\n",
" if isinstance(module, nn.Linear) and module.bias is not None:\n",
" module.bias.data.zero_()\n",
" \n",
" def forward(self,input_ids,targets):\n",
" mask=input_ids.bool().float()\n",
" embds=self.word_embedding(input_ids)\n",
" text_vec = self.fastformer_model(embds,mask)\n",
" score = self.dense_linear(text_vec)\n",
" loss = self.criterion(score, targets) \n",
" return loss, score"
]
},
{
"cell_type": "code",
"execution_count": 107,
"metadata": {},
"outputs": [],
"source": [
"def acc(y_true, y_hat):\n",
" y_hat = torch.argmax(y_hat, dim=-1)\n",
" tot = y_true.shape[0]\n",
" hit = torch.sum(y_true == y_hat)\n",
" return hit.data.float() * 1.0 / tot"
]
},
{
"cell_type": "code",
"execution_count": 111,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Model(\n",
" (dense_linear): Linear(in_features=256, out_features=4, bias=True)\n",
" (word_embedding): Embedding(66818, 256, padding_idx=0)\n",
" (fastformer_model): FastformerEncoder(\n",
" (encoders): ModuleList(\n",
" (0): FastformerLayer(\n",
" (attention): FastAttention(\n",
" (self): FastSelfAttention(\n",
" (query): Linear(in_features=256, out_features=256, bias=True)\n",
" (query_att): Linear(in_features=256, out_features=16, bias=True)\n",
" (key): Linear(in_features=256, out_features=256, bias=True)\n",
" (key_att): Linear(in_features=256, out_features=16, bias=True)\n",
" (transform): Linear(in_features=256, out_features=256, bias=True)\n",
" (softmax): Softmax(dim=-1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" )\n",
" )\n",
" (1): FastformerLayer(\n",
" (attention): FastAttention(\n",
" (self): FastSelfAttention(\n",
" (query): Linear(in_features=256, out_features=256, bias=True)\n",
" (query_att): Linear(in_features=256, out_features=16, bias=True)\n",
" (key): Linear(in_features=256, out_features=256, bias=True)\n",
" (key_att): Linear(in_features=256, out_features=16, bias=True)\n",
" (transform): Linear(in_features=256, out_features=256, bias=True)\n",
" (softmax): Softmax(dim=-1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=256, out_features=256, bias=True)\n",
" (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" (position_embeddings): Embedding(256, 256)\n",
" (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.2, inplace=False)\n",
" (poolers): ModuleList(\n",
" (0): AttentionPooling(\n",
" (att_fc1): Linear(in_features=256, out_features=256, bias=True)\n",
" (att_fc2): Linear(in_features=256, out_features=1, bias=True)\n",
" )\n",
" )\n",
" )\n",
" (criterion): CrossEntropyLoss()\n",
")"
]
},
"execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = Model(config)\n",
"import torch.optim as optim\n",
"optimizer = optim.Adam([ {'params': model.parameters(), 'lr': 1e-3}])\n",
"model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 112,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Ed: 0, train_loss: 1.38936, acc: 0.20312\n",
" Ed: 6400, train_loss: 0.59932, acc: 0.77135\n",
" Ed: 12800, train_loss: 0.47238, acc: 0.82774\n",
" Ed: 19200, train_loss: 0.42074, acc: 0.84988\n",
" Ed: 25600, train_loss: 0.39129, acc: 0.86234\n",
" Ed: 32000, train_loss: 0.37110, acc: 0.87004\n",
" Ed: 38400, train_loss: 0.35466, acc: 0.87656\n",
" Ed: 44800, train_loss: 0.34215, acc: 0.88137\n",
" Ed: 51200, train_loss: 0.33404, acc: 0.88442\n",
" Ed: 57600, train_loss: 0.32441, acc: 0.88820\n",
" Ed: 64000, train_loss: 0.31654, acc: 0.89111\n",
" Ed: 70400, train_loss: 0.31033, acc: 0.89336\n",
" Ed: 76800, train_loss: 0.30431, acc: 0.89562\n",
" Ed: 83200, train_loss: 0.29853, acc: 0.89778\n",
" Ed: 89600, train_loss: 0.29410, acc: 0.89949\n",
" Ed: 96000, train_loss: 0.29110, acc: 0.90074\n",
" Ed: 102400, train_loss: 0.28680, acc: 0.90223\n",
" Ed: 108800, train_loss: 0.28471, acc: 0.90309\n",
" Ed: 115200, train_loss: 0.28077, acc: 0.90427\n",
"0.9235526315789474\n",
" Ed: 0, train_loss: 0.16773, acc: 0.95312\n",
" Ed: 6400, train_loss: 0.18398, acc: 0.93673\n",
" Ed: 12800, train_loss: 0.17103, acc: 0.94038\n",
" Ed: 19200, train_loss: 0.16540, acc: 0.94285\n",
" Ed: 25600, train_loss: 0.16188, acc: 0.94389\n",
" Ed: 32000, train_loss: 0.15936, acc: 0.94414\n",
" Ed: 38400, train_loss: 0.15465, acc: 0.94572\n",
" Ed: 44800, train_loss: 0.15194, acc: 0.94655\n",
" Ed: 51200, train_loss: 0.15035, acc: 0.94737\n",
" Ed: 57600, train_loss: 0.14845, acc: 0.94827\n",
" Ed: 64000, train_loss: 0.14617, acc: 0.94897\n",
" Ed: 70400, train_loss: 0.14536, acc: 0.94941\n",
" Ed: 76800, train_loss: 0.14417, acc: 0.94992\n",
" Ed: 83200, train_loss: 0.14298, acc: 0.95030\n",
" Ed: 89600, train_loss: 0.14216, acc: 0.95063\n",
" Ed: 96000, train_loss: 0.14231, acc: 0.95071\n",
" Ed: 102400, train_loss: 0.14160, acc: 0.95107\n",
" Ed: 108800, train_loss: 0.14129, acc: 0.95128\n",
" Ed: 115200, train_loss: 0.13996, acc: 0.95181\n",
"0.9271052631578948\n"
]
}
],
"source": [
"for epoch in range(2):\n",
" loss = 0.0\n",
" accuary = 0.0\n",
" for cnt in range(len(train_index)//64):\n",
"\n",
" log_ids=news_words[train_index][cnt*64:cnt*64+64,:256]\n",
" targets= label[train_index][cnt*64:cnt*64+64]\n",
"\n",
" log_ids = torch.LongTensor(log_ids).cuda(non_blocking=True)\n",
" targets = torch.LongTensor(targets).cuda(non_blocking=True)\n",
" bz_loss, y_hat = model(log_ids, targets)\n",
" loss += bz_loss.data.float()\n",
" accuary += acc(targets, y_hat)\n",
" unified_loss=bz_loss\n",
" optimizer.zero_grad()\n",
" unified_loss.backward()\n",
" optimizer.step()\n",
"\n",
" if cnt % 100== 0:\n",
" print( ' Ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(cnt * 64, loss.data / (cnt+1), accuary / (cnt+1)))\n",
" model.eval()\n",
" allpred=[]\n",
" for cnt in range(len(test_index)//64+1):\n",
" \n",
" log_ids=news_words[test_index][cnt*64:cnt*64+64,:256]\n",
" targets= label[test_index][cnt*64:cnt*64+64]\n",
" log_ids = torch.LongTensor(log_ids).cuda(non_blocking=True)\n",
" targets = torch.LongTensor(targets).cuda(non_blocking=True)\n",
" \n",
" bz_loss2, y_hat2 = model(log_ids, targets)\n",
" allpred+=y_hat2.to('cpu').detach().numpy().tolist()\n",
" \n",
" y_pred=np.argmax(allpred,axis=-1)\n",
" y_true=label[test_index]\n",
" from sklearn.metrics import *\n",
" print(accuracy_score(y_true, y_pred))\n",
" model.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: README.md
================================================
# Fastformer
## Notes from the authors
Pytorch/Keras implementation of Fastformer. The keras version only includes the core fastformer attention part. The pytorch version is written in a huggingface transformers style. The jupyter notebooks contain the quickstart codes for text classification on AG's News (without pretrained word embeddings for simplicity), which can be directly run. We noticed that in our experiments, NOT all tasks need FFNN, residual connection, layer normalization and even position embedding. For example, we find that in news recommendation, it is better to directly use Fastformer without layer normalization and position embedding. However, in Ad CVR prediction, both position embedding and layer normalization are needed.
Keras version: 2.2.4 (may not be compatible with higher versions)
TF version: from 1.12 to 1.15 (may be compatible with lower versions)
Pytorch version: 1.6.0 (may be compatible with higher/lower versions)
## Citation
```
@article{wu2021fastformer,
title={Fastformer: Additive Attention Can Be All You Need},
author={Wu, Chuhan and Wu, Fangzhao and Qi, Tao and Huang, Yongfeng},
journal={arXiv preprint arXiv:2108.09084},
year={2021}
}
```
================================================
FILE: fastformer.json
================================================
{
"hidden_size":256,
"hidden_dropout_prob":0.2,
"num_hidden_layers":2,
"hidden_act":"gelu",
"num_attention_heads":16,
"intermediate_size":256,
"max_position_embeddings":256,
"type_vocab_size":2,
"vocab_size":100000,
"layer_norm_eps": 1e-12,
"initializer_range": 0.02,
"pooler_type": "weightpooler",
"enable_fp16": "False"
}
gitextract_awcr4gr3/ ├── Fastformer-Keras.ipynb ├── Fastformer.ipynb ├── README.md └── fastformer.json
Condensed preview — 4 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (42K chars).
[
{
"path": "Fastformer-Keras.ipynb",
"chars": 13526,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": [\n "
},
{
"path": "Fastformer.ipynb",
"chars": 23554,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": null,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": "
},
{
"path": "README.md",
"chars": 1206,
"preview": "# Fastformer\n\n## Notes from the authors\n\nPytorch/Keras implementation of Fastformer. The keras version only includes the"
},
{
"path": "fastformer.json",
"chars": 375,
"preview": "{\n \"hidden_size\":256,\n \"hidden_dropout_prob\":0.2,\n \"num_hidden_layers\":2,\n \"hidden_act\":\"gelu\",\n \"num_att"
}
]
About this extraction
This page contains the full source code of the wuch15/Fastformer GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 4 files (37.8 KB), approximately 11.5k tokens. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.