[
  {
    "path": "README.md",
    "content": "# STAN: Spatial-Temporal-Attention-Network-for-Next-Location-Recommendation\n[[Paper](https://arxiv.org/abs/2102.04095)]. [[Oral Youtube](https://www.youtube.com/watch?v=ajNzESvOvzs)] or [[Oral Bilibili](https://www.bilibili.com/video/BV1WL411E7Qm?from=search&seid=7472683569881802215)]. [[Implementation through LibCity](https://github.com/LibCity/Bigscity-LibCity)].\n\nThank you for your interest in our work! Thank you for reporting possible bugs and please make sure you are forking the latest repo to avoid eariler bugs. Before asking questions regarding the codes or the paper, I strongly recommend you to read the FAQ first. You can also use the [LibCity](https://github.com/LibCity/Bigscity-LibCity) version.\n\n## Description\nBecause of the huge memory of the location matrix, the running speed of STAN is extremely low. You can refer to the implementation of masked attention [[here](https://github.com/yingtaoluo/PyHealth/blob/master/pyhealth/models/sequence/dipole.py)] if you wish to rewrite your own codes. \n\nDivide the dataset into different proportions of users to test the performance and then average. \n\nRun \"load.py\" first and then \"train.py\". You should see on the screen the result of the first proportion:   \n100%|██████████| 100/100 [14:32<00:00,  8.72s/it]  \nepoch:27, time:23587.941201210022, valid_acc:[0.18 0.49 0.56 0.67]  \nepoch:27, time:23587.941201210022, test_acc:[0.15 0.46 0.59 0.67]\n\n## FAQs\nQ1: Can you provide a dataset?  \nA1: Our datasets are collected from the following links. Please feel free to do your own data processing on your model while comparing STAN as baseline.\nhttp://snap.stanford.edu/data/loc-gowalla.html;  \nhttps://personal.ntu.edu.sg/gaocong/data/poidata.zip;\nhttp://www-public.imtbs-tsp.eu/~zhang_da/pub/dataset_tsmc2014.zip  \n  \nQ2.1: What does it mean \"The number of the training set is 𝑚 − 3, with the first 𝑚′ ∈ [1,𝑚 − 3] check-ins as input sequence and the [2,𝑚 − 2]-nd visited location as the label\"?  \nA2.1: We use [1] as input to predict [2], use [1,2] as input to predict [3], and ..., until we use [1,...,m-3] to predict [m-2]. Basically we do not use the last few steps and reserve them as a simulation of \"future visits\" to test the model since these last steps are not fed into the model during training.  \n  \nQ2.2: Can you please explain your trajectory encoding process? Do you create the location embeddings using skip-gram-like approaches?  \nA2.2: Pre-training of embedding is an effective approach and can further improve the performance for sure. Unfortunately, the focus and contribution of this paper are not on embedding pre-training but on spatio-temporal linear embedding, and pretraining is not used in baselines, so we do not use it in our paper.\n\nQ2.3: Would it be better to construct edges based on spatial distances instead of using distances?  \nA2.3: If the edges can truly reflect the relations between each loaction and each user, then yes. Ideal 0-1 edge relation is a stronger representation. However, constructing edges merely based on spatial distances can raise problems. Consider that a 30-kilometer metro takes less time than a 5-kilometer walk. From the data, we only know distances.  \n\nQ2.4: What do you mean by setting a unit spatiotemporal embedding?  \nA2.4: ![image](https://github.com/yingtaoluo/Spatial-Temporal-Attention-Network-for-POI-Recommendation/blob/master/unit_embedding.png)\n\nQ2.5: What does each column/row in NYC.npy mean?  \nA2.5: Each row: [user id, check-in location id, time in minutes].  \n\nQ2.6: Can we try a different division of train/dev/test datasets?  \nA2.6: Our goal here is to generalize for the future visits of each user we have known (we do not want to test the model performance on biased past behavior), instead of generalizing to other users whose user-id embeddings are not known to the model. \n\nQ2.7: How is the value of the recall rate calculated in your paper? For example, the top5 probability of the NYC data set is 0.xx but in the paper it is 0.xxxx.  \nA2.7: It is common practice to run under different seeds and get the average value. We averaged the ten times results and all of them are accepted by the statistical test of p=0.01. \n\nQ3: What is the environment to run the code? And version?  \nA3: We use python 3.7.2, CUDA 10.1 and PyTorch 1.7.1. Make sure to install all libs that we import.  \n"
  },
  {
    "path": "layers.py",
    "content": "from load import *\r\nimport torch\r\nfrom torch import nn\r\nfrom torch.nn import functional as F\r\n\r\nseed = 0\r\nglobal_seed = 0\r\nhours = 24*7\r\ntorch.manual_seed(seed)\r\ndevice = 'cuda'\r\n\r\n\r\ndef to_npy(x):\r\n    return x.cpu().data.numpy() if device == 'cuda' else x.detach().numpy()\r\n\r\n\r\nclass Attn(nn.Module):\r\n    def __init__(self, emb_loc, loc_max, dropout=0.1):\r\n        super(Attn, self).__init__()\r\n        self.value = nn.Linear(max_len, 1, bias=False)\r\n        self.emb_loc = emb_loc\r\n        self.loc_max = loc_max\r\n\r\n    def forward(self, self_attn, self_delta, traj_len):\r\n        # self_attn (N, M, emb), candidate (N, L, emb), self_delta (N, M, L, emb), len [N]\r\n        self_delta = torch.sum(self_delta, -1).transpose(-1, -2)  # squeeze the embed dimension\r\n        [N, L, M] = self_delta.shape\r\n        candidates = torch.linspace(1, int(self.loc_max), int(self.loc_max)).long()  # (L)\r\n        candidates = candidates.unsqueeze(0).expand(N, -1).to(device)  # (N, L)\r\n        emb_candidates = self.emb_loc(candidates)  # (N, L, emb)\r\n        attn = torch.mul(torch.bmm(emb_candidates, self_attn.transpose(-1, -2)), self_delta)  # (N, L, M)\r\n        # pdb.set_trace()\r\n        attn_out = self.value(attn).view(N, L)  # (N, L)\r\n        # attn_out = F.log_softmax(attn_out, dim=-1)  # ignore if cross_entropy_loss\r\n\r\n        return attn_out  # (N, L)\r\n\r\n\r\nclass SelfAttn(nn.Module):\r\n    def __init__(self, emb_size, output_size, dropout=0.1):\r\n        super(SelfAttn, self).__init__()\r\n        self.query = nn.Linear(emb_size, output_size, bias=False)\r\n        self.key = nn.Linear(emb_size, output_size, bias=False)\r\n        self.value = nn.Linear(emb_size, output_size, bias=False)\r\n\r\n    def forward(self, joint, delta, traj_len):\r\n        delta = torch.sum(delta, -1)  # squeeze the embed dimension\r\n        # joint (N, M, emb), delta (N, M, M, emb), len [N]\r\n        # construct attention mask\r\n        mask = torch.zeros_like(delta, dtype=torch.float32)\r\n        for i in range(mask.shape[0]):\r\n            mask[i, 0:traj_len[i], 0:traj_len[i]] = 1\r\n\r\n        attn = torch.add(torch.bmm(self.query(joint), self.key(joint).transpose(-1, -2)), delta)  # (N, M, M)\r\n        attn = F.softmax(attn, dim=-1) * mask  # (N, M, M)\r\n\r\n        attn_out = torch.bmm(attn, self.value(joint))  # (N, M, emb)\r\n\r\n        return attn_out  # (N, M, emb)\r\n\r\n\r\nclass Embed(nn.Module):\r\n    def __init__(self, ex, emb_size, loc_max, embed_layers):\r\n        super(Embed, self).__init__()\r\n        _, _, _, self.emb_su, self.emb_sl, self.emb_tu, self.emb_tl = embed_layers\r\n        self.su, self.sl, self.tu, self.tl = ex\r\n        self.emb_size = emb_size\r\n        self.loc_max = loc_max\r\n\r\n    def forward(self, traj_loc, mat2, vec, traj_len):\r\n        # traj_loc (N, M), mat2 (L, L), vec (N, M), delta_t (N, M, L)\r\n        delta_t = vec.unsqueeze(-1).expand(-1, -1, self.loc_max)\r\n        delta_s = torch.zeros_like(delta_t, dtype=torch.float32)\r\n        mask = torch.zeros_like(delta_t, dtype=torch.long)\r\n        for i in range(mask.shape[0]):  # N\r\n            mask[i, 0:traj_len[i]] = 1\r\n            delta_s[i, :traj_len[i]] = torch.index_select(mat2, 0, (traj_loc[i]-1)[:traj_len[i]])\r\n\r\n        # pdb.set_trace()\r\n\r\n        esl, esu, etl, etu = self.emb_sl(mask), self.emb_su(mask), self.emb_tl(mask), self.emb_tu(mask)\r\n        vsl, vsu, vtl, vtu = (delta_s - self.sl).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (self.su - delta_s).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (delta_t - self.tl).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (self.tu - delta_t).unsqueeze(-1).expand(-1, -1, -1, self.emb_size)\r\n\r\n        space_interval = (esl * vsu + esu * vsl) / (self.su - self.sl)\r\n        time_interval = (etl * vtu + etu * vtl) / (self.tu - self.tl)\r\n        delta = space_interval + time_interval  # (N, M, L, emb)\r\n\r\n        return delta\r\n\r\n\r\nclass MultiEmbed(nn.Module):\r\n    def __init__(self, ex, emb_size, embed_layers):\r\n        super(MultiEmbed, self).__init__()\r\n        self.emb_t, self.emb_l, self.emb_u, \\\r\n        self.emb_su, self.emb_sl, self.emb_tu, self.emb_tl = embed_layers\r\n        self.su, self.sl, self.tu, self.tl = ex\r\n        self.emb_size = emb_size\r\n\r\n    def forward(self, traj, mat, traj_len):\r\n        # traj (N, M, 3), mat (N, M, M, 2), len [N]\r\n        traj[:, :, 2] = (traj[:, :, 2]-1) % hours + 1  # segment time by 24 hours * 7 days\r\n        time = self.emb_t(traj[:, :, 2])  # (N, M) --> (N, M, embed)\r\n        loc = self.emb_l(traj[:, :, 1])  # (N, M) --> (N, M, embed)\r\n        user = self.emb_u(traj[:, :, 0])  # (N, M) --> (N, M, embed)\r\n        joint = time + loc + user  # (N, M, embed)\r\n\r\n        delta_s, delta_t = mat[:, :, :, 0], mat[:, :, :, 1]  # (N, M, M)\r\n        mask = torch.zeros_like(delta_s, dtype=torch.long)\r\n        for i in range(mask.shape[0]):\r\n            mask[i, 0:traj_len[i], 0:traj_len[i]] = 1\r\n\r\n        esl, esu, etl, etu = self.emb_sl(mask), self.emb_su(mask), self.emb_tl(mask), self.emb_tu(mask)\r\n        vsl, vsu, vtl, vtu = (delta_s - self.sl).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (self.su - delta_s).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (delta_t - self.tl).unsqueeze(-1).expand(-1, -1, -1, self.emb_size), \\\r\n                             (self.tu - delta_t).unsqueeze(-1).expand(-1, -1, -1, self.emb_size)\r\n\r\n        space_interval = (esl*vsu+esu*vsl) / (self.su-self.sl)\r\n        time_interval = (etl*vtu+etu*vtl) / (self.tu-self.tl)\r\n        delta = space_interval + time_interval  # (N, M, M, emb)\r\n\r\n        return joint, delta\r\n"
  },
  {
    "path": "load.py",
    "content": "import numpy as np\r\nimport torch\r\nfrom math import radians, cos, sin, asin, sqrt\r\nimport joblib\r\nfrom torch.nn.utils.rnn import pad_sequence\r\n\r\nmax_len = 100  # max traj len; i.e., M\r\n\r\n\r\ndef haversine(lon1, lat1, lon2, lat2):\r\n    \"\"\"\r\n    Calculate the great circle distance between two points\r\n    on the earth (specified in decimal degrees)\r\n    \"\"\"\r\n    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])\r\n\r\n    dlon = lon2 - lon1\r\n    dlat = lat2 - lat1\r\n    a = sin(dlat / 2) ** 2 + cos(lat1) * cos(lat2) * sin(dlon / 2) ** 2\r\n    c = 2 * asin(sqrt(a))\r\n    r = 6371\r\n    return c * r\r\n\r\n\r\ndef euclidean(point, each):\r\n    lon1, lat1, lon2, lat2 = point[2], point[1], each[2], each[1]\r\n    return np.sqrt((lon1 - lon2)**2 + (lat1 - lat2)**2)\r\n\r\n\r\ndef rst_mat1(traj, poi):\r\n    # traj (*M, [u, l, t]), poi(L, [l, lat, lon])\r\n    mat = np.zeros((len(traj), len(traj), 2))\r\n    for i, item in enumerate(traj):\r\n        for j, term in enumerate(traj):\r\n            poi_item, poi_term = poi[item[1] - 1], poi[term[1] - 1]  # retrieve poi by loc_id\r\n            mat[i, j, 0] = haversine(lon1=poi_item[2], lat1=poi_item[1], lon2=poi_term[2], lat2=poi_term[1])\r\n            mat[i, j, 1] = abs(item[2] - term[2])\r\n    return mat  # (*M, *M, [dis, tim])\r\n\r\n\r\ndef rs_mat2s(poi, l_max):\r\n    # poi(L, [l, lat, lon])\r\n    candidate_loc = np.linspace(1, l_max, l_max)  # (L)\r\n    mat = np.zeros((l_max, l_max))  # mat (L, L)\r\n    for i, loc1 in enumerate(candidate_loc):\r\n        print(i) if i % 100 == 0 else None\r\n        for j, loc2 in enumerate(candidate_loc):\r\n            poi1, poi2 = poi[int(loc1) - 1], poi[int(loc2) - 1]  # retrieve poi by loc_id\r\n            mat[i, j] = haversine(lon1=poi1[2], lat1=poi1[1], lon2=poi2[2], lat2=poi2[1])\r\n    return mat  # (L, L)\r\n\r\n\r\ndef rt_mat2t(traj_time):  # traj_time (*M+1) triangle matrix\r\n    # construct a list of relative times w.r.t. causality\r\n    mat = np.zeros((len(traj_time)-1, len(traj_time)-1))\r\n    for i, item in enumerate(traj_time):  # label\r\n        if i == 0:\r\n            continue\r\n        for j, term in enumerate(traj_time[:i]):  # data\r\n            mat[i-1, j] = np.abs(item - term)\r\n    return mat  # (*M, *M)\r\n\r\n\r\ndef process_traj(dname):  # start from 1\r\n    # data (?, [u, l, t]), poi (L, [l, lat, lon])\r\n    data = np.load('./data/' + dname + '.npy')\r\n    # add the code below if you are using dividing time into minutes instead of hours\r\n    data[:, -1] = np.array(data[:, -1]/60, dtype=np.int)\r\n    poi = np.load('./data/' + dname + '_POI.npy')\r\n    num_user = data[-1, 0]  # max id of users, i.e. NUM\r\n    data_user = data[:, 0]  # user_id sequence in data\r\n    trajs, labels, mat1, mat2t, lens = [], [], [], [], []\r\n    u_max, l_max = np.max(data[:, 0]), np.max(data[:, 1])\r\n\r\n    for u_id in range(num_user+1):\r\n        if u_id == 0:  # skip u_id == 0\r\n            continue\r\n        init_mat1 = np.zeros((max_len, max_len, 2))  # first mat (M, M, 2)\r\n        init_mat2t = np.zeros((max_len, max_len))  # second mat of time (M, M)\r\n        user_traj = data[np.where(data_user == u_id)]  # find all check-ins of u_id\r\n        user_traj = user_traj[np.argsort(user_traj[:, 2])].copy()  # sort traj by time\r\n\r\n        print(u_id, len(user_traj)) if u_id % 100 == 0 else None\r\n\r\n        if len(user_traj) > max_len + 1:  # consider only the M+1 recent check-ins\r\n            # 0:-3 are training data, 1:-2 is training label;\r\n            # 1:-2 are validation data, 2:-1 is validation label;\r\n            # 2:-1 are test data, 3: is the label for test.\r\n            # *M would be the real length if <= max_len + 1\r\n            user_traj = user_traj[-max_len-1:]  # (*M+1, [u, l, t])\r\n\r\n        # spatial and temporal intervals\r\n        user_len = len(user_traj[:-1])  # the len of data, i.e. *M\r\n        user_mat1 = rst_mat1(user_traj[:-1], poi)  # (*M, *M, [dis, tim])\r\n        user_mat2t = rt_mat2t(user_traj[:, 2])  # (*M, *M)\r\n        init_mat1[0:user_len, 0:user_len] = user_mat1\r\n        init_mat2t[0:user_len, 0:user_len] = user_mat2t\r\n\r\n        trajs.append(torch.LongTensor(user_traj)[:-1])  # (NUM, *M, [u, l, t])\r\n        mat1.append(init_mat1)  # (NUM, M, M, 2)\r\n        mat2t.append(init_mat2t)  # (NUM, M, M)\r\n        labels.append(torch.LongTensor(user_traj[1:, 1]))  # (NUM, *M)\r\n        lens.append(user_len-2)  # (NUM), the real *M for every user\r\n\r\n    # padding zero to the vacancies in the right\r\n    mat2s = rs_mat2s(poi, l_max)  # contains dis of all locations, (L, L)\r\n    zipped = zip(*sorted(zip(trajs, mat1, mat2t, labels, lens), key=lambda x: len(x[0]), reverse=True))\r\n    trajs, mat1, mat2t, labels, lens = zipped\r\n    trajs, mat1, mat2t, labels, lens = list(trajs), list(mat1), list(mat2t), list(labels), list(lens)\r\n    trajs = pad_sequence(trajs, batch_first=True, padding_value=0)  # (NUM, M, 3)\r\n    labels = pad_sequence(labels, batch_first=True, padding_value=0)  # (NUM, M)\r\n\r\n    data = [trajs, np.array(mat1), mat2s, np.array(mat2t), labels, np.array(lens), u_max, l_max]\r\n    data_pkl = './data/' + dname + '_data.pkl'\r\n    open(data_pkl, 'a')\r\n    with open(data_pkl, 'wb') as pkl:\r\n        joblib.dump(data, pkl)\r\n\r\n\r\nif __name__ == '__main__':\r\n    name = 'NYC'\r\n    process_traj(name)\r\n"
  },
  {
    "path": "models.py",
    "content": "from layers import *\r\n\r\n\r\nclass Model(nn.Module):\r\n    def __init__(self, t_dim, l_dim, u_dim, embed_dim, ex, dropout=0.1):\r\n        super(Model, self).__init__()\r\n        emb_t = nn.Embedding(t_dim, embed_dim, padding_idx=0)\r\n        emb_l = nn.Embedding(l_dim, embed_dim, padding_idx=0)\r\n        emb_u = nn.Embedding(u_dim, embed_dim, padding_idx=0)\r\n        emb_su = nn.Embedding(2, embed_dim, padding_idx=0)\r\n        emb_sl = nn.Embedding(2, embed_dim, padding_idx=0)\r\n        emb_tu = nn.Embedding(2, embed_dim, padding_idx=0)\r\n        emb_tl = nn.Embedding(2, embed_dim, padding_idx=0)\r\n        embed_layers = emb_t, emb_l, emb_u, emb_su, emb_sl, emb_tu, emb_tl\r\n\r\n        self.MultiEmbed = MultiEmbed(ex, embed_dim, embed_layers)\r\n        self.SelfAttn = SelfAttn(embed_dim, embed_dim)\r\n        self.Embed = Embed(ex, embed_dim, l_dim-1, embed_layers)\r\n        self.Attn = Attn(emb_l, l_dim-1)\r\n\r\n    def forward(self, traj, mat1, mat2, vec, traj_len):\r\n        # long(N, M, [u, l, t]), float(N, M, M, 2), float(L, L), float(N, M), long(N)\r\n        joint, delta = self.MultiEmbed(traj, mat1, traj_len)  # (N, M, emb), (N, M, M, emb)\r\n        self_attn = self.SelfAttn(joint, delta, traj_len)  # (N, M, emb)\r\n        self_delta = self.Embed(traj[:, :, 1], mat2, vec, traj_len)  # (N, M, L, emb)\r\n        output = self.Attn(self_attn, self_delta, traj_len)  # (N, L)\r\n        return output\r\n"
  },
  {
    "path": "train.py",
    "content": "from load import *\r\nimport time\r\nimport random\r\nfrom torch import optim\r\nimport torch.utils.data as data\r\nfrom tqdm import tqdm\r\nfrom models import *\r\n\r\n\r\ndef calculate_acc(prob, label):\r\n    # log_prob (N, L), label (N), batch_size [*M]\r\n    acc_train = [0, 0, 0, 0]\r\n    for i, k in enumerate([1, 5, 10, 20]):\r\n        # topk_batch (N, k)\r\n        _, topk_predict_batch = torch.topk(prob, k=k)\r\n        for j, topk_predict in enumerate(to_npy(topk_predict_batch)):\r\n            # topk_predict (k)\r\n            if to_npy(label)[j] in topk_predict:\r\n                acc_train[i] += 1\r\n\r\n    return np.array(acc_train)\r\n\r\n\r\ndef sampling_prob(prob, label, num_neg):\r\n    num_label, l_m = prob.shape[0], prob.shape[1]-1  # prob (N, L)\r\n    label = label.view(-1)  # label (N)\r\n    init_label = np.linspace(0, num_label-1, num_label)  # (N), [0 -- num_label-1]\r\n    init_prob = torch.zeros(size=(num_label, num_neg+len(label)))  # (N, num_neg+num_label)\r\n\r\n    random_ig = random.sample(range(1, l_m+1), num_neg)  # (num_neg) from (1 -- l_max)\r\n    while len([lab for lab in label if lab in random_ig]) != 0:  # no intersection\r\n        random_ig = random.sample(range(1, l_m+1), num_neg)\r\n\r\n    global global_seed\r\n    random.seed(global_seed)\r\n    global_seed += 1\r\n\r\n    # place the pos labels ahead and neg samples in the end\r\n    for k in range(num_label):\r\n        for i in range(num_neg + len(label)):\r\n            if i < len(label):\r\n                init_prob[k, i] = prob[k, label[i]]\r\n            else:\r\n                init_prob[k, i] = prob[k, random_ig[i-len(label)]]\r\n\r\n    return torch.FloatTensor(init_prob), torch.LongTensor(init_label)  # (N, num_neg+num_label), (N)\r\n\r\n\r\nclass DataSet(data.Dataset):\r\n    def __init__(self, traj, m1, v, label, length):\r\n        # (NUM, M, 3), (NUM, M, M, 2), (L, L), (NUM, M), (NUM), (NUM)\r\n        self.traj, self.mat1, self.vec, self.label, self.length = traj, m1, v, label, length\r\n\r\n    def __getitem__(self, index):\r\n        traj = self.traj[index].to(device)\r\n        mats1 = self.mat1[index].to(device)\r\n        vector = self.vec[index].to(device)\r\n        label = self.label[index].to(device)\r\n        length = self.length[index].to(device)\r\n        return traj, mats1, vector, label, length\r\n\r\n    def __len__(self):  # no use\r\n        return len(self.traj)\r\n\r\n\r\nclass Trainer:\r\n    def __init__(self, model, record):\r\n        # load other parameters\r\n        self.model = model.to(device)\r\n        self.records = record\r\n        self.start_epoch = record['epoch'][-1] if load else 1\r\n        self.num_neg = 10\r\n        self.interval = 1000\r\n        self.batch_size = 1 # N = 1\r\n        self.learning_rate = 3e-3\r\n        self.num_epoch = 100\r\n        self.threshold = np.mean(record['acc_valid'][-1]) if load else 0  # 0 if not update\r\n\r\n        # (NUM, M, 3), (NUM, M, M, 2), (L, L), (NUM, M, M), (NUM, M), (NUM) i.e. [*M]\r\n        self.traj, self.mat1, self.mat2s, self.mat2t, self.label, self.len = \\\r\n            trajs, mat1, mat2s, mat2t, labels, lens\r\n        # nn.cross_entropy_loss counts target from 0 to C - 1, so we minus 1 here.\r\n        self.dataset = DataSet(self.traj, self.mat1, self.mat2t, self.label-1, self.len)\r\n        self.data_loader = data.DataLoader(dataset=self.dataset, batch_size=self.batch_size, shuffle=False)\r\n\r\n    def train(self):\r\n        # set optimizer\r\n        optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate, weight_decay=0)\r\n        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1000, gamma=1)\r\n\r\n        for t in range(self.num_epoch):\r\n            # settings or validation and test\r\n            valid_size, test_size = 0, 0\r\n            acc_valid, acc_test = [0, 0, 0, 0], [0, 0, 0, 0]\r\n\r\n            bar = tqdm(total=part)\r\n            for step, item in enumerate(self.data_loader):\r\n                # get batch data, (N, M, 3), (N, M, M, 2), (N, M, M), (N, M), (N)\r\n                person_input, person_m1, person_m2t, person_label, person_traj_len = item\r\n\r\n                # first, try batch_size = 1 and mini_batch = 1\r\n\r\n                input_mask = torch.zeros((self.batch_size, max_len, 3), dtype=torch.long).to(device)\r\n                m1_mask = torch.zeros((self.batch_size, max_len, max_len, 2), dtype=torch.float32).to(device)\r\n                for mask_len in range(1, person_traj_len[0]+1):  # from 1 -> len\r\n                    # if mask_len != person_traj_len[0]:\r\n                    #     continue\r\n                    input_mask[:, :mask_len] = 1.\r\n                    m1_mask[:, :mask_len, :mask_len] = 1.\r\n\r\n                    train_input = person_input * input_mask\r\n                    train_m1 = person_m1 * m1_mask\r\n                    train_m2t = person_m2t[:, mask_len - 1]\r\n                    train_label = person_label[:, mask_len - 1]  # (N)\r\n                    train_len = torch.zeros(size=(self.batch_size,), dtype=torch.long).to(device) + mask_len\r\n\r\n                    prob = self.model(train_input, train_m1, self.mat2s, train_m2t, train_len)  # (N, L)\r\n\r\n                    if mask_len <= person_traj_len[0] - 2:  # only training\r\n                        # nn.utils.clip_grad_norm_(self.model.parameters(), 10)\r\n                        prob_sample, label_sample = sampling_prob(prob, train_label, self.num_neg)\r\n                        loss_train = F.cross_entropy(prob_sample, label_sample)\r\n                        loss_train.backward()\r\n                        optimizer.step()\r\n                        optimizer.zero_grad()\r\n                        scheduler.step()\r\n\r\n                    elif mask_len == person_traj_len[0] - 1:  # only validation\r\n                        valid_size += person_input.shape[0]\r\n                        # v_prob_sample, v_label_sample = sampling_prob(prob_valid, valid_label, self.num_neg)\r\n                        # loss_valid += F.cross_entropy(v_prob_sample, v_label_sample, reduction='sum')\r\n                        acc_valid += calculate_acc(prob, train_label)\r\n\r\n                    elif mask_len == person_traj_len[0]:  # only test\r\n                        test_size += person_input.shape[0]\r\n                        # v_prob_sample, v_label_sample = sampling_prob(prob_valid, valid_label, self.num_neg)\r\n                        # loss_valid += F.cross_entropy(v_prob_sample, v_label_sample, reduction='sum')\r\n                        acc_test += calculate_acc(prob, train_label)\r\n\r\n                bar.update(self.batch_size)\r\n            bar.close()\r\n\r\n            acc_valid = np.array(acc_valid) / valid_size\r\n            print('epoch:{}, time:{}, valid_acc:{}'.format(self.start_epoch + t, time.time() - start, acc_valid))\r\n\r\n            acc_test = np.array(acc_test) / test_size\r\n            print('epoch:{}, time:{}, test_acc:{}'.format(self.start_epoch + t, time.time() - start, acc_test))\r\n\r\n            self.records['acc_valid'].append(acc_valid)\r\n            self.records['acc_test'].append(acc_test)\r\n            self.records['epoch'].append(self.start_epoch + t)\r\n\r\n            if self.threshold < np.mean(acc_valid):\r\n                self.threshold = np.mean(acc_valid)\r\n                # save the model\r\n                torch.save({'state_dict': self.model.state_dict(),\r\n                            'records': self.records,\r\n                            'time': time.time() - start},\r\n                           'best_stan_win_1000_' + dname + '.pth')\r\n\r\n    def inference(self):\r\n        user_ids = []\r\n        for t in range(self.num_epoch):\r\n            # settings or validation and test\r\n            valid_size, test_size = 0, 0\r\n            acc_valid, acc_test = [0, 0, 0, 0], [0, 0, 0, 0]\r\n            cum_valid, cum_test = [0, 0, 0, 0], [0, 0, 0, 0]\r\n\r\n            for step, item in enumerate(self.data_loader):\r\n                # get batch data, (N, M, 3), (N, M, M, 2), (N, M, M), (N, M), (N)\r\n                person_input, person_m1, person_m2t, person_label, person_traj_len = item\r\n\r\n                # first, try batch_size = 1 and mini_batch = 1\r\n\r\n                input_mask = torch.zeros((self.batch_size, max_len, 3), dtype=torch.long).to(device)\r\n                m1_mask = torch.zeros((self.batch_size, max_len, max_len, 2), dtype=torch.float32).to(device)\r\n                for mask_len in range(1, person_traj_len[0] + 1):  # from 1 -> len\r\n                    # if mask_len != person_traj_len[0]:\r\n                    #     continue\r\n                    input_mask[:, :mask_len] = 1.\r\n                    m1_mask[:, :mask_len, :mask_len] = 1.\r\n\r\n                    train_input = person_input * input_mask\r\n                    train_m1 = person_m1 * m1_mask\r\n                    train_m2t = person_m2t[:, mask_len - 1]\r\n                    train_label = person_label[:, mask_len - 1]  # (N)\r\n                    train_len = torch.zeros(size=(self.batch_size,), dtype=torch.long).to(device) + mask_len\r\n\r\n                    prob = self.model(train_input, train_m1, self.mat2s, train_m2t, train_len)  # (N, L)\r\n\r\n                    if mask_len <= person_traj_len[0] - 2:  # only training\r\n                        continue\r\n\r\n                    elif mask_len == person_traj_len[0] - 1:  # only validation\r\n                        acc_valid = calculate_acc(prob, train_label)\r\n                        cum_valid += calculate_acc(prob, train_label)\r\n\r\n                    elif mask_len == person_traj_len[0]:  # only test\r\n                        acc_test = calculate_acc(prob, train_label)\r\n                        cum_test += calculate_acc(prob, train_label)\r\n\r\n                print(step, acc_valid, acc_test)\r\n\r\n                if acc_valid.sum() == 0 and acc_test.sum() == 0:\r\n                    user_ids.append(step)\r\n\r\n\r\nif __name__ == '__main__':\r\n    # load data\r\n    dname = 'NYC'\r\n    file = open('./data/' + dname + '_data.pkl', 'rb')\r\n    file_data = joblib.load(file)\r\n    # tensor(NUM, M, 3), np(NUM, M, M, 2), np(L, L), np(NUM, M, M), tensor(NUM, M), np(NUM)\r\n    [trajs, mat1, mat2s, mat2t, labels, lens, u_max, l_max] = file_data\r\n    mat1, mat2s, mat2t, lens = torch.FloatTensor(mat1), torch.FloatTensor(mat2s).to(device), \\\r\n                               torch.FloatTensor(mat2t), torch.LongTensor(lens)\r\n\r\n    # the run speed is very flow due to the use of location matrix (also huge memory cost)\r\n    # please use a partition of the data (recommended)\r\n    part = 100\r\n    trajs, mat1, mat2t, labels, lens = \\\r\n        trajs[:part], mat1[:part], mat2t[:part], labels[:part], lens[:part]\r\n\r\n    ex = mat1[:, :, :, 0].max(), mat1[:, :, :, 0].min(), mat1[:, :, :, 1].max(), mat1[:, :, :, 1].min()\r\n\r\n    stan = Model(t_dim=hours+1, l_dim=l_max+1, u_dim=u_max+1, embed_dim=50, ex=ex, dropout=0)\r\n    num_params = 0\r\n\r\n    for name in stan.state_dict():\r\n        print(name)\r\n\r\n    for param in stan.parameters():\r\n        num_params += param.numel()\r\n    print('num of params', num_params)\r\n\r\n    load = False\r\n\r\n    if load:\r\n        checkpoint = torch.load('best_stan_win_' + dname + '.pth')\r\n        stan.load_state_dict(checkpoint['state_dict'])\r\n        start = time.time() - checkpoint['time']\r\n        records = checkpoint['records']\r\n    else:\r\n        records = {'epoch': [], 'acc_valid': [], 'acc_test': []}\r\n        start = time.time()\r\n\r\n    trainer = Trainer(stan, records)\r\n    trainer.train()\r\n    # trainer.inference()\r\n\r\n"
  }
]