Full Code of kevin031060/RL_TSP_4static for AI

master 4cb6d1d23c5e cached
436 files
75.4 KB
30.8k tokens
59 symbols
1 requests
Download .txt
Repository: kevin031060/RL_TSP_4static
Branch: master
Commit: 4cb6d1d23c5e
Files: 436
Total size: 75.4 KB

Directory structure:
gitextract_7k7g2yn_/

├── Post_process/
│   ├── convet_kro_dataloader.py
│   ├── data/
│   │   ├── obj1_4_100.mat
│   │   ├── obj1_4_150.mat
│   │   ├── obj1_4_200.mat
│   │   ├── obj1_4_40.mat
│   │   ├── obj1_4_500.mat
│   │   ├── obj1_4_70.mat
│   │   ├── obj2_4_100.mat
│   │   ├── obj2_4_150.mat
│   │   ├── obj2_4_200.mat
│   │   ├── obj2_4_40.mat
│   │   ├── obj2_4_500.mat
│   │   ├── obj2_4_70.mat
│   │   ├── rl4_100.mat
│   │   ├── rl4_150.mat
│   │   ├── rl4_200.mat
│   │   ├── rl4_40.mat
│   │   ├── rl4_500.mat
│   │   ├── rl4_70.mat
│   │   ├── tour4_100.mat
│   │   └── tour4_200.mat
│   ├── dis_matrix.py
│   ├── krodata/
│   │   ├── kroA100.tsp
│   │   ├── kroA150.tsp
│   │   ├── kroA200.tsp
│   │   ├── kroB100.tsp
│   │   ├── kroB150.tsp
│   │   └── kroB200.tsp
│   ├── load_all_reward.py
│   ├── obj1.mat
│   ├── obj2.mat
│   └── rl.mat
├── README.md
├── model.py
├── parameter_transfer.py
├── tasks/
│   ├── motsp.py
│   ├── tsp.py
│   └── vrp.py
├── trainer_motsp_no_transfer.py
├── trainer_motsp_transfer.py
├── tsp_transfer_100run_500000_5epoch_20city/
│   └── 20/
│       ├── w_0.04_0.96/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.05_0.95/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.06_0.94/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.07_0.93/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.08_0.92/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.09_0.91/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.10_0.90/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.11_0.89/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.12_0.88/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.13_0.87/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.14_0.86/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.15_0.85/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.16_0.84/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.17_0.83/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.18_0.82/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.19_0.81/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.20_0.80/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.21_0.79/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.22_0.78/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.23_0.77/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.24_0.76/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.25_0.75/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.26_0.74/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.27_0.73/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.28_0.72/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.29_0.71/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.30_0.70/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.31_0.69/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.32_0.68/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.33_0.67/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.34_0.66/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.35_0.65/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.36_0.64/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.37_0.63/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.38_0.62/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.39_0.61/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.40_0.60/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.41_0.59/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.42_0.58/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.43_0.57/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.44_0.56/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.45_0.55/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.46_0.54/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.47_0.53/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.48_0.52/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.49_0.51/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.50_0.50/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.51_0.49/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.52_0.48/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.53_0.47/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.54_0.46/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.55_0.45/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.56_0.44/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.57_0.43/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.58_0.42/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.59_0.41/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.60_0.40/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.61_0.39/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.62_0.38/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.63_0.37/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.64_0.36/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.65_0.35/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.66_0.34/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.67_0.33/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.68_0.32/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.69_0.31/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.70_0.30/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.71_0.29/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.72_0.28/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.73_0.27/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.74_0.26/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.75_0.25/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.76_0.24/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.77_0.23/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.78_0.22/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.79_0.21/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.80_0.20/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.81_0.19/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.82_0.18/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.83_0.17/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.84_0.16/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.85_0.15/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.86_0.14/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.87_0.13/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.88_0.12/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.89_0.11/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.90_0.10/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.91_0.09/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.92_0.08/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.93_0.07/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.94_0.06/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.95_0.05/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.96_0.04/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.97_0.03/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.98_0.02/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.99_0.01/
│       │   ├── actor.pt
│       │   └── critic.pt
│       └── w_1.00_0.00/
│           ├── actor.pt
│           └── critic.pt
└── tsp_transfer_100run_500000_5epoch_40city/
    └── 40/
        ├── w_0.00_1.00/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.01_0.99/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.02_0.98/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.03_0.97/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.04_0.96/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.05_0.95/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.06_0.94/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.07_0.93/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.08_0.92/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.09_0.91/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.10_0.90/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.11_0.89/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.12_0.88/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.13_0.87/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.14_0.86/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.15_0.85/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.16_0.84/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.17_0.83/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.18_0.82/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.19_0.81/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.20_0.80/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.21_0.79/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.22_0.78/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.23_0.77/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.24_0.76/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.25_0.75/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.26_0.74/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.27_0.73/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.28_0.72/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.29_0.71/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.30_0.70/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.31_0.69/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.32_0.68/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.33_0.67/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.34_0.66/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.35_0.65/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.36_0.64/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.37_0.63/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.38_0.62/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.39_0.61/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.40_0.60/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.41_0.59/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.42_0.58/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.43_0.57/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.44_0.56/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.45_0.55/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.46_0.54/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.47_0.53/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.48_0.52/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.49_0.51/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.50_0.50/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.51_0.49/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.52_0.48/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.53_0.47/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.54_0.46/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.55_0.45/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.56_0.44/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.57_0.43/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.58_0.42/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.59_0.41/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.60_0.40/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.61_0.39/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.62_0.38/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.63_0.37/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.64_0.36/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.65_0.35/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.66_0.34/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.67_0.33/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.68_0.32/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.69_0.31/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.70_0.30/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.71_0.29/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.72_0.28/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.73_0.27/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.74_0.26/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.75_0.25/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.76_0.24/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.77_0.23/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.78_0.22/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.79_0.21/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.80_0.20/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.81_0.19/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.82_0.18/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.83_0.17/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.84_0.16/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.85_0.15/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.86_0.14/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.87_0.13/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.88_0.12/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.89_0.11/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.90_0.10/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.91_0.09/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.92_0.08/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.93_0.07/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.94_0.06/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.95_0.05/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.96_0.04/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.97_0.03/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.98_0.02/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.99_0.01/
        │   ├── actor.pt
        │   └── critic.pt
        └── w_1.00_0.00/
            ├── actor.pt
            └── critic.pt

================================================
FILE CONTENTS
================================================

================================================
FILE: Post_process/convet_kro_dataloader.py
================================================
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt


class Kro_dataset(Dataset):

    def __init__(self, num_nodes):
        super(Kro_dataset, self).__init__()

        x1 = np.loadtxt('krodata/kroA%d.tsp'%num_nodes, skiprows=6, usecols=(1, 2), delimiter=' ', dtype=float)
        x1 = x1 / (np.max(x1,0))
        x2 = np.loadtxt('krodata/kroB%d.tsp'%num_nodes, skiprows=6, usecols=(1, 2), delimiter=' ', dtype=float)
        x2 = x2 / (np.max(x2,0))
        x = np.concatenate((x1, x2),axis=1)
        x = x.T
        x = x.reshape(1, 4, num_nodes)

        self.dataset = torch.from_numpy(x).float()
        self.dynamic = torch.zeros(1, 1, num_nodes)
        self.num_nodes = num_nodes
        self.size = 1


    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # (static, dynamic, start_loc)
        return (self.dataset[idx], self.dynamic[idx], [])

================================================
FILE: Post_process/dis_matrix.py
================================================
import numpy as np
import torch

def dis_matrix(static, s_size):
    static = static.squeeze(0)

    # [2,20]
    obj1 = static[:2, :]
    # [20]
    obj2 = static[2:, :]

    l = obj1.size()[1]
    obj1_matrix = np.zeros((l, l))
    obj2_matrix = np.zeros((l, l))
    for i in range(l):
        for j in range(l):
            if i != j:
                obj1_matrix[i,j] = torch.sqrt(torch.sum(torch.pow(obj1[:, i] - obj1[:, j], 2))).detach()
                if s_size == 3:
                    obj2_matrix[i, j] = torch.abs(obj2[i] - obj2[j]).detach()
                else:
                    obj2_matrix[i, j] = torch.sqrt(torch.sum(torch.pow(obj2[:, i] - obj2[:, j], 2))).detach()

    return obj1_matrix, obj2_matrix

================================================
FILE: Post_process/krodata/kroA100.tsp
================================================
NAME: kroA100
TYPE: TSP
COMMENT: 100-city problem A (Krolak/Felts/Nelson)
DIMENSION: 100
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 1380 939
2 2848 96
3 3510 1671
4 457 334
5 3888 666
6 984 965
7 2721 1482
8 1286 525
9 2716 1432
10 738 1325
11 1251 1832
12 2728 1698
13 3815 169
14 3683 1533
15 1247 1945
16 123 862
17 1234 1946
18 252 1240
19 611 673
20 2576 1676
21 928 1700
22 53 857
23 1807 1711
24 274 1420
25 2574 946
26 178 24
27 2678 1825
28 1795 962
29 3384 1498
30 3520 1079
31 1256 61
32 1424 1728
33 3913 192
34 3085 1528
35 2573 1969
36 463 1670
37 3875 598
38 298 1513
39 3479 821
40 2542 236
41 3955 1743
42 1323 280
43 3447 1830
44 2936 337
45 1621 1830
46 3373 1646
47 1393 1368
48 3874 1318
49 938 955
50 3022 474
51 2482 1183
52 3854 923
53 376 825
54 2519 135
55 2945 1622
56 953 268
57 2628 1479
58 2097 981
59 890 1846
60 2139 1806
61 2421 1007
62 2290 1810
63 1115 1052
64 2588 302
65 327 265
66 241 341
67 1917 687
68 2991 792
69 2573 599
70 19 674
71 3911 1673
72 872 1559
73 2863 558
74 929 1766
75 839 620
76 3893 102
77 2178 1619
78 3822 899
79 378 1048
80 1178 100
81 2599 901
82 3416 143
83 2961 1605
84 611 1384
85 3113 885
86 2597 1830
87 2586 1286
88 161 906
89 1429 134
90 742 1025
91 1625 1651
92 1187 706
93 1787 1009
94 22 987
95 3640 43
96 3756 882
97 776 392
98 1724 1642
99 198 1810
100 3950 1558


================================================
FILE: Post_process/krodata/kroA150.tsp
================================================
NAME: kroA150
TYPE: TSP
COMMENT: 150-city problem A (Krolak/Felts/Nelson)
DIMENSION: 150
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 1380 939
2 2848 96
3 3510 1671
4 457 334
5 3888 666
6 984 965
7 2721 1482
8 1286 525
9 2716 1432
10 738 1325
11 1251 1832
12 2728 1698
13 3815 169
14 3683 1533
15 1247 1945
16 123 862
17 1234 1946
18 252 1240
19 611 673
20 2576 1676
21 928 1700
22 53 857
23 1807 1711
24 274 1420
25 2574 946
26 178 24
27 2678 1825
28 1795 962
29 3384 1498
30 3520 1079
31 1256 61
32 1424 1728
33 3913 192
34 3085 1528
35 2573 1969
36 463 1670
37 3875 598
38 298 1513
39 3479 821
40 2542 236
41 3955 1743
42 1323 280
43 3447 1830
44 2936 337
45 1621 1830
46 3373 1646
47 1393 1368
48 3874 1318
49 938 955
50 3022 474
51 2482 1183
52 3854 923
53 376 825
54 2519 135
55 2945 1622
56 953 268
57 2628 1479
58 2097 981
59 890 1846
60 2139 1806
61 2421 1007
62 2290 1810
63 1115 1052
64 2588 302
65 327 265
66 241 341
67 1917 687
68 2991 792
69 2573 599
70 19 674
71 3911 1673
72 872 1559
73 2863 558
74 929 1766
75 839 620
76 3893 102
77 2178 1619
78 3822 899
79 378 1048
80 1178 100
81 2599 901
82 3416 143
83 2961 1605
84 611 1384
85 3113 885
86 2597 1830
87 2586 1286
88 161 906
89 1429 134
90 742 1025
91 1625 1651
92 1187 706
93 1787 1009
94 22 987
95 3640 43
96 3756 882
97 776 392
98 1724 1642
99 198 1810
100 3950 1558
101 3477 949
102 91 1732
103 3972 329
104 198 1632
105 1806 733
106 538 1023
107 3430 1088
108 2186 766
109 1513 1646
110 2143 1611
111 53 1657
112 3404 1307
113 1034 1344
114 2823 376
115 3104 1931
116 3232 324
117 2790 1457
118 374 9
119 741 146
120 3083 1938
121 3502 1067
122 1280 237
123 3326 1846
124 217 38
125 2503 1172
126 3527 41
127 739 1850
128 3548 1999
129 48 154
130 1419 872
131 1689 1223
132 3468 1404
133 1628 253
134 382 872
135 3029 1242
136 3646 1758
137 285 1029
138 1782 93
139 1067 371
140 2849 1214
141 920 1835
142 1741 712
143 876 220
144 2753 283
145 2609 1286
146 3941 258
147 3613 523
148 1754 559
149 2916 1724
150 2445 1820


================================================
FILE: Post_process/krodata/kroA200.tsp
================================================
NAME: kroA200
TYPE: TSP
COMMENT: 200-city problem A (Krolak/Felts/Nelson)
DIMENSION: 200
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 1357 1905
2 2650 802
3 1774 107
4 1307 964
5 3806 746
6 2687 1353
7 43 1957
8 3092 1668
9 185 1542
10 834 629
11 40 462
12 1183 1391
13 2048 1628
14 1097 643
15 1838 1732
16 234 1118
17 3314 1881
18 737 1285
19 779 777
20 2312 1949
21 2576 189
22 3078 1541
23 2781 478
24 705 1812
25 3409 1917
26 323 1714
27 1660 1556
28 3729 1188
29 693 1383
30 2361 640
31 2433 1538
32 554 1825
33 913 317
34 3586 1909
35 2636 727
36 1000 457
37 482 1337
38 3704 1082
39 3635 1174
40 1362 1526
41 2049 417
42 2552 1909
43 3939 640
44 219 898
45 812 351
46 901 1552
47 2513 1572
48 242 584
49 826 1226
50 3278 799
51 86 1065
52 14 454
53 1327 1893
54 2773 1286
55 2469 1838
56 3835 963
57 1031 428
58 3853 1712
59 1868 197
60 1544 863
61 457 1607
62 3174 1064
63 192 1004
64 2318 1925
65 2232 1374
66 396 828
67 2365 1649
68 2499 658
69 1410 307
70 2990 214
71 3646 1018
72 3394 1028
73 1779 90
74 1058 372
75 2933 1459
76 3099 173
77 2178 978
78 138 1610
79 2082 1753
80 2302 1127
81 805 272
82 22 1617
83 3213 1085
84 99 536
85 1533 1780
86 3564 676
87 29 6
88 3808 1375
89 2221 291
90 3499 1885
91 3124 408
92 781 671
93 1027 1041
94 3249 378
95 3297 491
96 213 220
97 721 186
98 3736 1542
99 868 731
100 960 303
101 1380 939
102 2848 96
103 3510 1671
104 457 334
105 3888 666
106 984 965
107 2721 1482
108 1286 525
109 2716 1432
110 738 1325
111 1251 1832
112 2728 1698
113 3815 169
114 3683 1533
115 1247 1945
116 123 862
117 1234 1946
118 252 1240
119 611 673
120 2576 1676
121 928 1700
122 53 857
123 1807 1711
124 274 1420
125 2574 946
126 178 24
127 2678 1825
128 1795 962
129 3384 1498
130 3520 1079
131 1256 61
132 1424 1728
133 3913 192
134 3085 1528
135 2573 1969
136 463 1670
137 3875 598
138 298 1513
139 3479 821
140 2542 236
141 3955 1743
142 1323 280
143 3447 1830
144 2936 337
145 1621 1830
146 3373 1646
147 1393 1368
148 3874 1318
149 938 955
150 3022 474
151 2482 1183
152 3854 923
153 376 825
154 2519 135
155 2945 1622
156 953 268
157 2628 1479
158 2097 981
159 890 1846
160 2139 1806
161 2421 1007
162 2290 1810
163 1115 1052
164 2588 302
165 327 265
166 241 341
167 1917 687
168 2991 792
169 2573 599
170 19 674
171 3911 1673
172 872 1559
173 2863 558
174 929 1766
175 839 620
176 3893 102
177 2178 1619
178 3822 899
179 378 1048
180 1178 100
181 2599 901
182 3416 143
183 2961 1605
184 611 1384
185 3113 885
186 2597 1830
187 2586 1286
188 161 906
189 1429 134
190 742 1025
191 1625 1651
192 1187 706
193 1787 1009
194 22 987
195 3640 43
196 3756 882
197 776 392
198 1724 1642
199 198 1810
200 3950 1558


================================================
FILE: Post_process/krodata/kroB100.tsp
================================================
NAME: kroB100
TYPE: TSP
COMMENT: 100-city problem B (Krolak/Felts/Nelson)
DIMENSION: 100
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 3140 1401
2 556 1056
3 3675 1522
4 1182 1853
5 3595 111
6 962 1895
7 2030 1186
8 3507 1851
9 2642 1269
10 3438 901
11 3858 1472
12 2937 1568
13 376 1018
14 839 1355
15 706 1925
16 749 920
17 298 615
18 694 552
19 387 190
20 2801 695
21 3133 1143
22 1517 266
23 1538 224
24 844 520
25 2639 1239
26 3123 217
27 2489 1520
28 3834 1827
29 3417 1808
30 2938 543
31 71 1323
32 3245 1828
33 731 1741
34 2312 1270
35 2426 1851
36 380 478
37 2310 635
38 2830 775
39 3829 513
40 3684 445
41 171 514
42 627 1261
43 1490 1123
44 61 81
45 422 542
46 2698 1221
47 2372 127
48 177 1390
49 3084 748
50 1213 910
51 3 1817
52 1782 995
53 3896 742
54 1829 812
55 1286 550
56 3017 108
57 2132 1432
58 2000 1110
59 3317 1966
60 1729 1498
61 2408 1747
62 3292 152
63 193 1210
64 782 1462
65 2503 352
66 1697 1924
67 3821 147
68 3370 791
69 3162 367
70 3938 516
71 2741 1583
72 2330 741
73 3918 1088
74 1794 1589
75 2929 485
76 3453 1998
77 896 705
78 399 850
79 2614 195
80 2800 653
81 2630 20
82 563 1513
83 1090 1652
84 2009 1163
85 3876 1165
86 3084 774
87 1526 1612
88 1612 328
89 1423 1322
90 3058 1276
91 3782 1865
92 347 252
93 3904 1444
94 2191 1579
95 3220 1454
96 468 319
97 3611 1968
98 3114 1629
99 3515 1892
100 3060 155


================================================
FILE: Post_process/krodata/kroB150.tsp
================================================
NAME: kroB150
TYPE: TSP
COMMENT: 150-city problem B (Krolak/Felts/Nelson)
DIMENSION: 150
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 1357 1905
2 2650 802
3 1774 107
4 1307 964
5 3806 746
6 2687 1353
7 43 1957
8 3092 1668
9 185 1542
10 834 629
11 40 462
12 1183 1391
13 2048 1628
14 1097 643
15 1838 1732
16 234 1118
17 3314 1881
18 737 1285
19 779 777
20 2312 1949
21 2576 189
22 3078 1541
23 2781 478
24 705 1812
25 3409 1917
26 323 1714
27 1660 1556
28 3729 1188
29 693 1383
30 2361 640
31 2433 1538
32 554 1825
33 913 317
34 3586 1909
35 2636 727
36 1000 457
37 482 1337
38 3704 1082
39 3635 1174
40 1362 1526
41 2049 417
42 2552 1909
43 3939 640
44 219 898
45 812 351
46 901 1552
47 2513 1572
48 242 584
49 826 1226
50 3278 799
51 86 1065
52 14 454
53 1327 1893
54 2773 1286
55 2469 1838
56 3835 963
57 1031 428
58 3853 1712
59 1868 197
60 1544 863
61 457 1607
62 3174 1064
63 192 1004
64 2318 1925
65 2232 1374
66 396 828
67 2365 1649
68 2499 658
69 1410 307
70 2990 214
71 3646 1018
72 3394 1028
73 1779 90
74 1058 372
75 2933 1459
76 3099 173
77 2178 978
78 138 1610
79 2082 1753
80 2302 1127
81 805 272
82 22 1617
83 3213 1085
84 99 536
85 1533 1780
86 3564 676
87 29 6
88 3808 1375
89 2221 291
90 3499 1885
91 3124 408
92 781 671
93 1027 1041
94 3249 378
95 3297 491
96 213 220
97 721 186
98 3736 1542
99 868 731
100 960 303
101 3825 1101
102 2779 435
103 201 693
104 2502 1274
105 765 833
106 3105 1823
107 1937 1400
108 3364 1498
109 3702 1624
110 2164 1874
111 3019 189
112 3098 1594
113 3239 1376
114 3359 1693
115 2081 1011
116 1398 1100
117 618 1953
118 1878 59
119 3803 886
120 397 1217
121 3035 152
122 2502 146
123 3230 380
124 3479 1023
125 958 1670
126 3423 1241
127 78 1066
128 96 691
129 3431 78
130 2053 1461
131 3048 1
132 571 1711
133 3393 782
134 2835 1472
135 144 1185
136 923 108
137 989 1997
138 3061 1211
139 2977 39
140 1668 658
141 878 715
142 678 1599
143 1086 868
144 640 110
145 3551 1673
146 106 1267
147 2243 1332
148 3796 1401
149 2643 1320
150 48 267


================================================
FILE: Post_process/krodata/kroB200.tsp
================================================
NAME: kroB200
TYPE: TSP
COMMENT: 200-city problem B (Krolak/Felts/Nelson)
DIMENSION: 200
EDGE_WEIGHT_TYPE : EUC_2D
NODE_COORD_SECTION
1 3140 1401
2 556 1056
3 3675 1522
4 1182 1853
5 3595 111
6 962 1895
7 2030 1186
8 3507 1851
9 2642 1269
10 3438 901
11 3858 1472
12 2937 1568
13 376 1018
14 839 1355
15 706 1925
16 749 920
17 298 615
18 694 552
19 387 190
20 2801 695
21 3133 1143
22 1517 266
23 1538 224
24 844 520
25 2639 1239
26 3123 217
27 2489 1520
28 3834 1827
29 3417 1808
30 2938 543
31 71 1323
32 3245 1828
33 731 1741
34 2312 1270
35 2426 1851
36 380 478
37 2310 635
38 2830 775
39 3829 513
40 3684 445
41 171 514
42 627 1261
43 1490 1123
44 61 81
45 422 542
46 2698 1221
47 2372 127
48 177 1390
49 3084 748
50 1213 910
51 3 1817
52 1782 995
53 3896 742
54 1829 812
55 1286 550
56 3017 108
57 2132 1432
58 2000 1110
59 3317 1966
60 1729 1498
61 2408 1747
62 3292 152
63 193 1210
64 782 1462
65 2503 352
66 1697 1924
67 3821 147
68 3370 791
69 3162 367
70 3938 516
71 2741 1583
72 2330 741
73 3918 1088
74 1794 1589
75 2929 485
76 3453 1998
77 896 705
78 399 850
79 2614 195
80 2800 653
81 2630 20
82 563 1513
83 1090 1652
84 2009 1163
85 3876 1165
86 3084 774
87 1526 1612
88 1612 328
89 1423 1322
90 3058 1276
91 3782 1865
92 347 252
93 3904 1444
94 2191 1579
95 3220 1454
96 468 319
97 3611 1968
98 3114 1629
99 3515 1892
100 3060 155
101 2995 264
102 202 233
103 981 848
104 1346 408
105 781 670
106 1009 1001
107 2927 1777
108 2982 949
109 555 1121
110 464 1302
111 3452 637
112 571 1982
113 2656 128
114 1623 1723
115 2067 694
116 1725 927
117 3600 459
118 1109 1196
119 366 339
120 778 1282
121 386 1616
122 3918 1217
123 3332 1049
124 2597 349
125 811 1295
126 241 1069
127 2658 360
128 394 1944
129 3786 1862
130 264 36
131 2050 1833
132 3538 125
133 1646 1817
134 2993 624
135 547 25
136 3373 1902
137 460 267
138 3060 781
139 1828 456
140 1021 962
141 2347 388
142 3535 1112
143 1529 581
144 1203 385
145 1787 1902
146 2740 1101
147 555 1753
148 47 363
149 3935 540
150 3062 329
151 387 199
152 2901 920
153 931 512
154 1766 692
155 401 980
156 149 1629
157 2214 1977
158 3805 1619
159 1179 969
160 1017 333
161 2834 1512
162 634 294
163 1819 814
164 1393 859
165 1768 1578
166 3023 871
167 3248 1906
168 1632 1742
169 2223 990
170 3868 697
171 1541 354
172 2374 1944
173 1962 389
174 3007 1524
175 3220 1945
176 2356 1568
177 1604 706
178 2028 1736
179 2581 121
180 2221 1578
181 2944 632
182 1082 1561
183 997 942
184 2334 523
185 1264 1090
186 1699 1294
187 235 1059
188 2592 248
189 3642 699
190 3599 514
191 1766 678
192 240 619
193 1272 246
194 3503 301
195 80 1533
196 1677 1238
197 3766 154
198 3946 459
199 1994 1852
200 278 165


================================================
FILE: Post_process/load_all_reward.py
================================================
import torch
from tasks import motsp
from tasks.motsp import TSPDataset, reward
from torch.utils.data import DataLoader
from model import DRL4TSP
from trainer_motsp_transfer import StateCritic
import numpy as np
import os
import matplotlib.pyplot as plt
import scipy.io as scio
from Post_process.dis_matrix import dis_matrix
import time

# Load the trained model and convert the obtained Pareto Front to the .mat file.
# It is convenient to visualize it in matlab

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# "../tsp_transfer_100run_500000_5epoch_20city/20"效果一般。应该再训练一遍
save_dir = "../tsp_transfer_100run_500000_5epoch_40city/40"
# save_dir = "../tsp_transfer/100"
# param
update_fn = None
STATIC_SIZE = 4  # (x, y)
DYNAMIC_SIZE = 1  # dummy for compatibility

# claim model
actor = DRL4TSP(STATIC_SIZE,
                DYNAMIC_SIZE,
                128,
                update_fn,
                motsp.update_mask,
                1,
                0.1).to(device)
critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, 128).to(device)

# data 143
from Post_process.convet_kro_dataloader import Kro_dataset
kro = 1
D = 200
if kro:
    D = 200
    Test_data = Kro_dataset(D)
    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)
else:
    # 40city_train: city20 13 city40 143 city70 2523
    #
    Test_data = TSPDataset(D, 1, 2523)
    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)

iter_data = iter(Test_loader)
static, dynamic, x0 = iter_data.next()
static = static.to(device)
dynamic = dynamic.to(device)
x0 = x0.to(device) if len(x0) > 0 else None

# load 50 models
N=100
w = np.arange(N+1)/N
objs = np.zeros((N+1,2))
start  = time.time()
t1_all = 0
t2_all = 0
tours=[]
for i in range(0, N+1):
    t1 = time.time()
    ac = os.path.join(save_dir, "w_%2.2f_%2.2f" % (1-w[i], w[i]),"actor.pt")
    cri = os.path.join(save_dir, "w_%2.2f_%2.2f" % (1-w[i], w[i]),"critic.pt")
    actor.load_state_dict(torch.load(ac, device))
    critic.load_state_dict(torch.load(cri, device))
    t1_all = t1_all + time.time()-t1
    # calculate

    with torch.no_grad():
        # t2 = time.time()
        tour_indices, _ = actor.forward(static, dynamic, x0)
        # t2_all = t2_all + time.time() - t2
    _, obj1, obj2 = reward(static, tour_indices, 1-w[i], w[i])
    tours.append(tour_indices.cpu().numpy())
    objs[i,:] = [obj1, obj2]

print("time_load_model:%2.4f"%t1_all)
print("time_predict_model:%2.4f"%t2_all)
print(time.time()-start)

print(tours)
plt.figure()
plt.plot(objs[:,0],objs[:,1],"ro")
plt.show()

# Convert to .mat
obj1_matrix, obj2_matrix = dis_matrix(static, STATIC_SIZE)
scio.savemat("data/obj1_%d_%d.mat"%(STATIC_SIZE, D), {'obj1':obj1_matrix})
scio.savemat("data/obj2_%d_%d.mat"%(STATIC_SIZE, D), {'obj2':obj2_matrix})
scio.savemat("data/rl%d_%d.mat"%(STATIC_SIZE, D),{'rl':objs})
scio.savemat("data/tour%d_%d.mat"%(STATIC_SIZE, D),{'tour':np.array(tours)})


# from load_test_plot import show
# show_if = 1
# if show_if:
#     i = 0
#     ac = os.path.join(save_dir, "w_%2.2f_%2.2f" % (1-w[i], w[i]),"actor.pt")
#     cri = os.path.join(save_dir, "w_%2.2f_%2.2f" % (1-w[i], w[i]),"critic.pt")
#     actor.load_state_dict(torch.load(ac, device))
#     critic.load_state_dict(torch.load(cri, device))
#
#     show(Test_loader, actor)



================================================
FILE: README.md
================================================
# Using Deep Reinforcement Learning method and Attention model to solve the Multiobjectve TSP. 
## This code is the model with four-dimension input (Euclidean-type).
### The model with three-dimension input (Mixed-type) is in the RL_3static_MOTSP.zip.
### Matlab code for visualzing and comparisons in the paper is in the MOTSP_compare_EMO.zip.

+ Trained model is available in the tsp_transfer_... dirs.
+ To test the model, use the load_all_rewards in Post_process dir.
+ To train the model, run train_motsp_transfer.py
+ To visualize the obtained Pareto Front, the result should be visulaized using Matlab.
+ matlab code is in the .zip file. It is in the " MOTSP_compare_EMO/Problems/Combinatorial MOPs/compare.m ". It is used to produce the figures in batch. 
    
    > First you need to run the train_motsp_transfer.py to train the model. 
    
    > Run the load_all_rewards.py to load and test the model. It also converts the obtained Pareto Front to the .mat file
    
    > Run the Matlab code to visualize the Pareto Front and compare with NSGA-II and MOEA/D
    
    

### A lot codes are inherited from https://github.com/mveres01/pytorch-drl4vrp


================================================
FILE: model.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')


class Encoder(nn.Module):
    """Encodes the static & dynamic states using 1d Convolution."""

    def __init__(self, input_size, hidden_size):
        super(Encoder, self).__init__()
        self.conv = nn.Conv1d(input_size, hidden_size, kernel_size=1)

    def forward(self, input):
        output = self.conv(input)
        return output  # (batch, hidden_size, seq_len)


class Attention(nn.Module):
    """Calculates attention over the input nodes given the current state."""

    def __init__(self, hidden_size):
        super(Attention, self).__init__()

        # W processes features from static decoder elements
        self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),
                                          device=device, requires_grad=True))

        self.W = nn.Parameter(torch.zeros((1, hidden_size, 3 * hidden_size),
                                          device=device, requires_grad=True))

    def forward(self, static_hidden, dynamic_hidden, decoder_hidden):

        batch_size, hidden_size, _ = static_hidden.size()

        hidden = decoder_hidden.unsqueeze(2).expand_as(static_hidden)
        hidden = torch.cat((static_hidden, dynamic_hidden, hidden), 1)

        # Broadcast some dimensions so we can do batch-matrix-multiply
        v = self.v.expand(batch_size, 1, hidden_size)
        W = self.W.expand(batch_size, hidden_size, -1)

        attns = torch.bmm(v, torch.tanh(torch.bmm(W, hidden)))
        attns = F.softmax(attns, dim=2)  # (batch, seq_len)
        return attns


class Pointer(nn.Module):
    """Calculates the next state given the previous state and input embeddings."""

    def __init__(self, hidden_size, num_layers=1, dropout=0.2):
        super(Pointer, self).__init__()

        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Used to calculate probability of selecting next state
        self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),
                                          device=device, requires_grad=True))

        self.W = nn.Parameter(torch.zeros((1, hidden_size, 2 * hidden_size),
                                          device=device, requires_grad=True))

        # Used to compute a representation of the current decoder output
        # GRU(输入dim,隐含层dim,层数)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers,
                          batch_first=True,
                          dropout=dropout if num_layers > 1 else 0)
        self.encoder_attn = Attention(hidden_size)

        self.drop_rnn = nn.Dropout(p=dropout)
        self.drop_hh = nn.Dropout(p=dropout)

    def forward(self, static_hidden, dynamic_hidden, decoder_hidden, last_hh):

        rnn_out, last_hh = self.gru(decoder_hidden.transpose(2, 1), last_hh)
        rnn_out = rnn_out.squeeze(1)

        # Always apply dropout on the RNN output
        rnn_out = self.drop_rnn(rnn_out)
        if self.num_layers == 1:
            # If > 1 layer dropout is already applied
            last_hh = self.drop_hh(last_hh) 

        # Given a summary of the output, find an  input context
        enc_attn = self.encoder_attn(static_hidden, dynamic_hidden, rnn_out)
        context = enc_attn.bmm(static_hidden.permute(0, 2, 1))  # (B, 1, num_feats)

        # Calculate the next output using Batch-matrix-multiply ops
        context = context.transpose(1, 2).expand_as(static_hidden)
        energy = torch.cat((static_hidden, context), dim=1)  # (B, num_feats, seq_len)

        v = self.v.expand(static_hidden.size(0), -1, -1)
        W = self.W.expand(static_hidden.size(0), -1, -1)

        probs = torch.bmm(v, torch.tanh(torch.bmm(W, energy))).squeeze(1)

        return probs, last_hh


class DRL4TSP(nn.Module):
    """Defines the main Encoder, Decoder, and Pointer combinatorial models.

    Parameters
    ----------
    static_size: int
        Defines how many features are in the static elements of the model
        (e.g. 2 for (x, y) coordinates)
    dynamic_size: int > 1
        Defines how many features are in the dynamic elements of the model
        (e.g. 2 for the VRP which has (load, demand) attributes. The TSP doesn't
        have dynamic elements, but to ensure compatility with other optimization
        problems, assume we just pass in a vector of zeros.
    hidden_size: int
        Defines the number of units in the hidden layer for all static, dynamic,
        and decoder output units.
    update_fn: function or None
        If provided, this method is used to calculate how the input dynamic
        elements are updated, and is called after each 'point' to the input element.
    mask_fn: function or None
        Allows us to specify which elements of the input sequence are allowed to
        be selected. This is useful for speeding up training of the networks,
        by providing a sort of 'rules' guidlines to the algorithm. If no mask
        is provided, we terminate the search after a fixed number of iterations
        to avoid tours that stretch forever
    num_layers: int
        Specifies the number of hidden layers to use in the decoder RNN
    dropout: float
        Defines the dropout rate for the decoder
    """

    def __init__(self, static_size, dynamic_size, hidden_size,
                 update_fn=None, mask_fn=None, num_layers=1, dropout=0.):
        super(DRL4TSP, self).__init__()

        if dynamic_size < 1:
            raise ValueError(':param dynamic_size: must be > 0, even if the '
                             'problem has no dynamic elements')

        self.update_fn = update_fn
        self.mask_fn = mask_fn

        # Define the encoder & decoder models
        self.static_encoder = Encoder(static_size, hidden_size)
        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)
        self.decoder = Encoder(static_size, hidden_size)
        self.pointer = Pointer(hidden_size, num_layers, dropout)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

        # Used as a proxy initial state in the decoder when not specified
        self.x0 = torch.zeros((1, static_size, 1), requires_grad=True, device=device)

    def forward(self, static, dynamic, decoder_input=None, last_hh=None):
        """
        Parameters
        ----------
        static: Array of size (batch_size, feats, num_cities)
            Defines the elements to consider as static. For the TSP, this could be
            things like the (x, y) coordinates, which won't change
        dynamic: Array of size (batch_size, feats, num_cities)
            Defines the elements to consider as static. For the VRP, this can be
            things like the (load, demand) of each city. If there are no dynamic
            elements, this can be set to None
        decoder_input: Array of size (batch_size, num_feats)
            Defines the outputs for the decoder. Currently, we just use the
            static elements (e.g. (x, y) coordinates), but this can technically
            be other things as well
        last_hh: Array of size (batch_size, num_hidden)
            Defines the last hidden state for the RNN
        """

        batch_size, input_size, sequence_size = static.size()

        if decoder_input is None:
            decoder_input = self.x0.expand(batch_size, -1, -1)

        # Always use a mask - if no function is provided, we don't update it
        mask = torch.ones(batch_size, sequence_size, device=device)

        # Structures for holding the output sequences
        tour_idx, tour_logp = [], []
        max_steps = sequence_size if self.mask_fn is None else 1000

        # Static elements only need to be processed once, and can be used across
        # all 'pointing' iterations. When / if the dynamic elements change,
        # their representations will need to get calculated again.
        static_hidden = self.static_encoder(static)
        dynamic_hidden = self.dynamic_encoder(dynamic)

        for _ in range(max_steps):

            if not mask.byte().any():
                break

            # ... but compute a hidden rep for each element added to sequence
            decoder_hidden = self.decoder(decoder_input)

            probs, last_hh = self.pointer(static_hidden,
                                          dynamic_hidden,
                                          decoder_hidden, last_hh)
            probs = F.softmax(probs + mask.log(), dim=1)

            # When training, sample the next step according to its probability.
            # During testing, we can take the greedy approach and choose highest
            if self.training:
                m = torch.distributions.Categorical(probs)

                # Sometimes an issue with Categorical & sampling on GPU; See:
                # https://github.com/pemami4911/neural-combinatorial-rl-pytorch/issues/5
                ptr = m.sample()
                while not torch.gather(mask, 1, ptr.data.unsqueeze(1)).byte().all():
                    ptr = m.sample()
                logp = m.log_prob(ptr)
            else:
                prob, ptr = torch.max(probs, 1)  # Greedy
                logp = prob.log()

            # After visiting a node update the dynamic representation
            if self.update_fn is not None:
                dynamic = self.update_fn(dynamic, ptr.data)
                dynamic_hidden = self.dynamic_encoder(dynamic)

                # Since we compute the VRP in minibatches, some tours may have
                # number of stops. We force the vehicles to remain at the depot 
                # in these cases, and logp := 0
                is_done = dynamic[:, 1].sum(1).eq(0).float()
                logp = logp * (1. - is_done)

            # And update the mask so we don't re-visit if we don't need to
            if self.mask_fn is not None:
                mask = self.mask_fn(mask, dynamic, ptr.data).detach()

            tour_logp.append(logp.unsqueeze(1))
            tour_idx.append(ptr.data.unsqueeze(1))

            decoder_input = torch.gather(static, 2,
                                         ptr.view(-1, 1, 1)
                                         .expand(-1, input_size, 1)).detach()

        tour_idx = torch.cat(tour_idx, dim=1)  # (batch_size, seq_len)
        tour_logp = torch.cat(tour_logp, dim=1)  # (batch_size, seq_len)

        return tour_idx, tour_logp


if __name__ == '__main__':
    raise Exception('Cannot be called from main')


================================================
FILE: parameter_transfer.py
================================================
import torch
import os
from model import DRL4TSP, Encoder
import argparse
from tasks import motsp
from trainer_motsp_transfer import StateCritic

'''
This file is used to test. It has been obsoleted
This file is used to convert the trained single-TSP PN model to the parameters from which we can transfer.
The trained single-TSP PN model can be found here: https://github.com/mveres01/pytorch-drl4vrp. Save it as "tsp20".
Then the start-up parameters for the first subproblem of the MOTSP to transfer can be obtained.
'''



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
STATIC_SIZE_original = 2  # (x, y)
STATIC_SIZE = 3  # (x, y)
DYNAMIC_SIZE = 1  # dummy for compatibility
update_fn = None
hidden_size = 128
num_layers = 1
dropout = 0.1
checkpoint = "tsp20"
actor = DRL4TSP(STATIC_SIZE_original,
                DYNAMIC_SIZE,
                hidden_size,
                update_fn,
                motsp.update_mask,
                num_layers,
                dropout).to(device)

critic = StateCritic(STATIC_SIZE_original, DYNAMIC_SIZE, hidden_size).to(device)
# 加载原128*2*1的原模型
path = os.path.join(checkpoint, 'actor.pt')
actor.load_state_dict(torch.load(path, device))

path = os.path.join(checkpoint, 'critic.pt')
critic.load_state_dict(torch.load(path, device))
# 其中actor的static_encoder,decoder需要更改维度,critic需要更改维度
# static_encoder
static_parameter = actor.static_encoder.state_dict()
temp = static_parameter['conv.weight']
temp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列
static_parameter['conv.weight'] = temp
actor.static_encoder = Encoder(STATIC_SIZE, hidden_size)
actor.static_encoder.load_state_dict(static_parameter)
# decoder
static_parameter = actor.decoder.state_dict()
temp = static_parameter['conv.weight']
temp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列
static_parameter['conv.weight'] = temp
actor.decoder = Encoder(STATIC_SIZE, hidden_size)
actor.decoder.load_state_dict(static_parameter)

# CRITIC
static_parameter = critic.static_encoder.state_dict()
temp = static_parameter['conv.weight']
temp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列
static_parameter['conv.weight'] = temp
critic.static_encoder = Encoder(STATIC_SIZE, hidden_size)
critic.static_encoder.load_state_dict(static_parameter)

save_path = os.path.join("modified_checkpoint_3obj", 'actor.pt')
torch.save(actor.state_dict(), save_path)
save_path = os.path.join("modified_checkpoint_3obj", 'critic.pt')
torch.save(critic.state_dict(), save_path)

print(actor,critic)


================================================
FILE: tasks/motsp.py
================================================
"""Defines the main task for the TSP

The TSP is defined by the following traits:
    1. Each city in the list must be visited once and only once
    2. The salesman must return to the original node at the end of the tour

Since the TSP doesn't have dynamic elements, we return an empty list on
__getitem__, which gets processed in trainer.py to be None

"""

import os
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt


class TSPDataset(Dataset):

    def __init__(self, size=50, num_samples=1e6, seed=None):
        super(TSPDataset, self).__init__()

        if seed is None:
            seed = np.random.randint(123456789)

        np.random.seed(seed)
        torch.manual_seed(seed)
        self.dataset = torch.rand((num_samples, 4, size))
        self.dynamic = torch.zeros(num_samples, 1, size)
        self.num_nodes = size
        self.size = num_samples


    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # (static, dynamic, start_loc)
        return (self.dataset[idx], self.dynamic[idx], [])


def update_mask(mask, dynamic, chosen_idx):
    """Marks the visited city, so it can't be selected a second time."""
    mask.scatter_(1, chosen_idx.unsqueeze(1), 0)
    return mask


def reward(static, tour_indices, w1=1, w2=0):
    """
    Parameters
    ----------
    static: torch.FloatTensor containing static (e.g. x, y) data
    tour_indices: torch.IntTensor of size (batch_size, num_cities)

    Returns
    -------
    Euclidean distance between consecutive nodes on the route. of size
    (batch_size, num_cities)
    """

    # Convert the indices back into a tour
    idx = tour_indices.unsqueeze(1).expand_as(static)
    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)

    # Make a full tour by returning to the start
    y = torch.cat((tour, tour[:, :1]), dim=1)
    # first 2 is xy coordinate, third column is another obj
    y_dis = y[:, :, :2]
    y_dis2 = y[:, :, 2:]

    # Euclidean distance between each consecutive point
    tour_len = torch.sqrt(torch.sum(torch.pow(y_dis[:, :-1] - y_dis[:, 1:], 2), dim=2))
    obj1 = tour_len.sum(1).detach()

    tour_len2 = torch.sqrt(torch.sum(torch.pow(y_dis2[:, :-1] - y_dis2[:, 1:], 2), dim=2))
    obj2 = tour_len2.sum(1).detach()

    obj = w1*obj1 + w2*obj2
    return obj, obj1, obj2



def render(static, tour_indices, save_path):
    """Plots the found tours."""

    plt.close('all')

    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1

    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,
                           sharex='col', sharey='row')

    if num_plots == 1:
        axes = [[axes]]
    axes = [a for ax in axes for a in ax]

    for i, ax in enumerate(axes):

        # Convert the indices back into a tour
        idx = tour_indices[i]
        if len(idx.size()) == 1:
            idx = idx.unsqueeze(0)

        # End tour at the starting index
        idx = idx.expand(static.size(1), -1)
        idx = torch.cat((idx, idx[:, 0:1]), dim=1)

        data = torch.gather(static[i].data, 1, idx).cpu().numpy()

        #plt.subplot(num_plots, num_plots, i + 1)
        ax.plot(data[0], data[1], zorder=1)
        ax.scatter(data[0], data[1], s=4, c='r', zorder=2)
        ax.scatter(data[0, 0], data[1, 0], s=20, c='k', marker='*', zorder=3)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=400)


================================================
FILE: tasks/tsp.py
================================================
"""Defines the main task for the TSP

The TSP is defined by the following traits:
    1. Each city in the list must be visited once and only once
    2. The salesman must return to the original node at the end of the tour

Since the TSP doesn't have dynamic elements, we return an empty list on
__getitem__, which gets processed in trainer.py to be None

"""

import os
import numpy as np
import torch
from torch.utils.data import Dataset
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt


class TSPDataset(Dataset):

    def __init__(self, size=50, num_samples=1e6, seed=None):
        super(TSPDataset, self).__init__()

        if seed is None:
            seed = np.random.randint(123456789)

        np.random.seed(seed)
        torch.manual_seed(seed)
        self.dataset = torch.rand((num_samples, 2, size))
        self.dynamic = torch.zeros(num_samples, 1, size)
        self.num_nodes = size
        self.size = num_samples

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # (static, dynamic, start_loc)
        return (self.dataset[idx], self.dynamic[idx], [])


def update_mask(mask, dynamic, chosen_idx):
    """Marks the visited city, so it can't be selected a second time."""
    mask.scatter_(1, chosen_idx.unsqueeze(1), 0)
    return mask


def reward(static, tour_indices):
    """
    Parameters
    ----------
    static: torch.FloatTensor containing static (e.g. x, y) data
    tour_indices: torch.IntTensor of size (batch_size, num_cities)

    Returns
    -------
    Euclidean distance between consecutive nodes on the route. of size
    (batch_size, num_cities)
    """

    # Convert the indices back into a tour
    idx = tour_indices.unsqueeze(1).expand_as(static)
    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)

    # Make a full tour by returning to the start
    y = torch.cat((tour, tour[:, :1]), dim=1)

    # Euclidean distance between each consecutive point
    tour_len = torch.sqrt(torch.sum(torch.pow(y[:, :-1] - y[:, 1:], 2), dim=2))

    return tour_len.sum(1).detach()


def render(static, tour_indices, save_path):
    """Plots the found tours."""

    plt.close('all')

    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1

    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,
                           sharex='col', sharey='row')

    if num_plots == 1:
        axes = [[axes]]
    axes = [a for ax in axes for a in ax]

    for i, ax in enumerate(axes):

        # Convert the indices back into a tour
        idx = tour_indices[i]
        if len(idx.size()) == 1:
            idx = idx.unsqueeze(0)

        # End tour at the starting index
        idx = idx.expand(static.size(1), -1)
        idx = torch.cat((idx, idx[:, 0:1]), dim=1)

        data = torch.gather(static[i].data, 1, idx).cpu().numpy()

        #plt.subplot(num_plots, num_plots, i + 1)
        ax.plot(data[0], data[1], zorder=1)
        ax.scatter(data[0], data[1], s=4, c='r', zorder=2)
        ax.scatter(data[0, 0], data[1, 0], s=20, c='k', marker='*', zorder=3)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=400)


================================================
FILE: tasks/vrp.py
================================================
"""Defines the main task for the VRP.

The VRP is defined by the following traits:
    1. Each city has a demand in [1, 9], which must be serviced by the vehicle
    2. Each vehicle has a capacity (depends on problem), the must visit all cities
    3. When the vehicle load is 0, it __must__ return to the depot to refill
"""

import os
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.autograd import Variable
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


class VehicleRoutingDataset(Dataset):
    def __init__(self, num_samples, input_size, max_load=20, max_demand=9,
                 seed=None):
        super(VehicleRoutingDataset, self).__init__()

        if max_load < max_demand:
            raise ValueError(':param max_load: must be > max_demand')

        if seed is None:
            seed = np.random.randint(1234567890)
        np.random.seed(seed)
        torch.manual_seed(seed)

        self.num_samples = num_samples
        self.max_load = max_load
        self.max_demand = max_demand

        # Depot location will be the first node in each
        locations = torch.rand((num_samples, 2, input_size + 1))
        self.static = locations

        # All states will broadcast the drivers current load
        # Note that we only use a load between [0, 1] to prevent large
        # numbers entering the neural network
        dynamic_shape = (num_samples, 1, input_size + 1)
        loads = torch.full(dynamic_shape, 1.)

        # All states will have their own intrinsic demand in [1, max_demand), 
        # then scaled by the maximum load. E.g. if load=10 and max_demand=30, 
        # demands will be scaled to the range (0, 3)
        #######################
        # demands = torch.randint(1, max_demand + 1, dynamic_shape)
        demands = torch.randint(1, max_demand + 1, dynamic_shape).float()
        demands = demands / float(max_load)

        demands[:, 0, 0] = 0  # depot starts with a demand of 0
        self.dynamic = torch.tensor(np.concatenate((loads, demands), axis=1))

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # (static, dynamic, start_loc)
        return (self.static[idx], self.dynamic[idx], self.static[idx, :, 0:1])

    def update_mask(self, mask, dynamic, chosen_idx=None):
        """Updates the mask used to hide non-valid states.

        Parameters
        ----------
        dynamic: torch.autograd.Variable of size (1, num_feats, seq_len)
        """

        # Convert floating point to integers for calculations
        loads = dynamic.data[:, 0]  # (batch_size, seq_len)
        demands = dynamic.data[:, 1]  # (batch_size, seq_len)

        # If there is no positive demand left, we can end the tour.
        # Note that the first node is the depot, which always has a negative demand
        if demands.eq(0).all():
            return demands * 0.

        # Otherwise, we can choose to go anywhere where demand is > 0
        new_mask = demands.ne(0) * demands.lt(loads)

        # We should avoid traveling to the depot back-to-back
        repeat_home = chosen_idx.ne(0)

        if repeat_home.any():
            new_mask[repeat_home.nonzero(), 0] = 1.
        if (1 - repeat_home).any():
            new_mask[(1 - repeat_home).nonzero(), 0] = 0.

        # ... unless we're waiting for all other samples in a minibatch to finish
        has_no_load = loads[:, 0].eq(0).float()
        has_no_demand = demands[:, 1:].sum(1).eq(0).float()

        combined = (has_no_load + has_no_demand).gt(0)
        if combined.any():
            new_mask[combined.nonzero(), 0] = 1.
            new_mask[combined.nonzero(), 1:] = 0.

        return new_mask.float()

    def update_dynamic(self, dynamic, chosen_idx):
        """Updates the (load, demand) dataset values."""

        # Update the dynamic elements differently for if we visit depot vs. a city
        visit = chosen_idx.ne(0)
        depot = chosen_idx.eq(0)

        # Clone the dynamic variable so we don't mess up graph
        all_loads = dynamic[:, 0].clone()
        all_demands = dynamic[:, 1].clone()

        load = torch.gather(all_loads, 1, chosen_idx.unsqueeze(1))
        demand = torch.gather(all_demands, 1, chosen_idx.unsqueeze(1))

        # Across the minibatch - if we've chosen to visit a city, try to satisfy
        # as much demand as possible
        if visit.any():

            new_load = torch.clamp(load - demand, min=0)
            new_demand = torch.clamp(demand - load, min=0)

            # Broadcast the load to all nodes, but update demand seperately
            visit_idx = visit.nonzero().squeeze()

            all_loads[visit_idx] = new_load[visit_idx]
            all_demands[visit_idx, chosen_idx[visit_idx]] = new_demand[visit_idx].view(-1)
            all_demands[visit_idx, 0] = -1. + new_load[visit_idx].view(-1)

        # Return to depot to fill vehicle load
        if depot.any():
            all_loads[depot.nonzero().squeeze()] = 1.
            all_demands[depot.nonzero().squeeze(), 0] = 0.

        tensor = torch.cat((all_loads.unsqueeze(1), all_demands.unsqueeze(1)), 1)
        return torch.tensor(tensor.data, device=dynamic.device)


def reward(static, tour_indices):
    """
    Euclidean distance between all cities / nodes given by tour_indices
    """

    # Convert the indices back into a tour
    idx = tour_indices.unsqueeze(1).expand(-1, static.size(1), -1)
    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)

    # Ensure we're always returning to the depot - note the extra concat
    # won't add any extra loss, as the euclidean distance between consecutive
    # points is 0
    start = static.data[:, :, 0].unsqueeze(1)
    y = torch.cat((start, tour, start), dim=1)

    # Euclidean distance between each consecutive point
    tour_len = torch.sqrt(torch.sum(torch.pow(y[:, :-1] - y[:, 1:], 2), dim=2))

    return tour_len.sum(1)


def render(static, tour_indices, save_path):
    """Plots the found solution."""

    plt.close('all')

    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1

    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,
                           sharex='col', sharey='row')

    if num_plots == 1:
        axes = [[axes]]
    axes = [a for ax in axes for a in ax]

    for i, ax in enumerate(axes):

        # Convert the indices back into a tour
        idx = tour_indices[i]
        if len(idx.size()) == 1:
            idx = idx.unsqueeze(0)

        idx = idx.expand(static.size(1), -1)
        data = torch.gather(static[i].data, 1, idx).cpu().numpy()

        start = static[i, :, 0].cpu().data.numpy()
        x = np.hstack((start[0], data[0], start[0]))
        y = np.hstack((start[1], data[1], start[1]))

        # Assign each subtour a different colour & label in order traveled
        idx = np.hstack((0, tour_indices[i].cpu().numpy().flatten(), 0))
        where = np.where(idx == 0)[0]

        for j in range(len(where) - 1):

            low = where[j]
            high = where[j + 1]

            if low + 1 == high:
                continue

            ax.plot(x[low: high + 1], y[low: high + 1], zorder=1, label=j)

        ax.legend(loc="upper right", fontsize=3, framealpha=0.5)
        ax.scatter(x, y, s=4, c='r', zorder=2)
        ax.scatter(x[0], y[0], s=20, c='k', marker='*', zorder=3)

        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)

    plt.tight_layout()
    plt.savefig(save_path, bbox_inches='tight', dpi=200)


'''
def render(static, tour_indices, save_path):
    """Plots the found solution."""

    path = 'C:/Users/Matt/Documents/ffmpeg-3.4.2-win64-static/bin/ffmpeg.exe'
    plt.rcParams['animation.ffmpeg_path'] = path

    plt.close('all')

    num_plots = min(int(np.sqrt(len(tour_indices))), 3)
    fig, axes = plt.subplots(nrows=num_plots, ncols=num_plots,
                             sharex='col', sharey='row')
    axes = [a for ax in axes for a in ax]

    all_lines = []
    all_tours = []
    for i, ax in enumerate(axes):

        # Convert the indices back into a tour
        idx = tour_indices[i]
        if len(idx.size()) == 1:
            idx = idx.unsqueeze(0)

        idx = idx.expand(static.size(1), -1)
        data = torch.gather(static[i].data, 1, idx).cpu().numpy()

        start = static[i, :, 0].cpu().data.numpy()
        x = np.hstack((start[0], data[0], start[0]))
        y = np.hstack((start[1], data[1], start[1]))

        cur_tour = np.vstack((x, y))

        all_tours.append(cur_tour)
        all_lines.append(ax.plot([], [])[0])

        ax.scatter(x, y, s=4, c='r', zorder=2)
        ax.scatter(x[0], y[0], s=20, c='k', marker='*', zorder=3)

    from matplotlib.animation import FuncAnimation

    tours = all_tours

    def update(idx):

        for i, line in enumerate(all_lines):

            if idx >= tours[i].shape[1]:
                continue

            data = tours[i][:, idx]

            xy_data = line.get_xydata()
            xy_data = np.vstack((xy_data, np.atleast_2d(data)))

            line.set_data(xy_data[:, 0], xy_data[:, 1])
            line.set_linewidth(0.75)

        return all_lines

    anim = FuncAnimation(fig, update, init_func=None,
                         frames=100, interval=200, blit=False,
                         repeat=False)

    anim.save('line.mp4', dpi=160)
    plt.show()

    import sys
    sys.exit(1)
'''


================================================
FILE: trainer_motsp_no_transfer.py
================================================
"""Defines the main trainer model for combinatorial problems

Each task must define the following functions:
* mask_fn: can be None
* update_fn: can be None
* reward_fn: specifies the quality of found solutions
* render_fn: Specifies how to plot found solutions. Can be None
"""

import os
import time
import argparse
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from model import DRL4TSP, Encoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')


class StateCritic(nn.Module):
    """Estimates the problem complexity.

    This is a basic module that just looks at the log-probabilities predicted by
    the encoder + decoder, and returns an estimate of complexity
    """

    def __init__(self, static_size, dynamic_size, hidden_size):
        super(StateCritic, self).__init__()

        self.static_encoder = Encoder(static_size, hidden_size)
        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)

        # Define the encoder & decoder models
        self.fc1 = nn.Conv1d(hidden_size * 2, 20, kernel_size=1)
        self.fc2 = nn.Conv1d(20, 20, kernel_size=1)
        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, static, dynamic):

        # Use the probabilities of visiting each
        static_hidden = self.static_encoder(static)
        dynamic_hidden = self.dynamic_encoder(dynamic)

        hidden = torch.cat((static_hidden, dynamic_hidden), 1)

        output = F.relu(self.fc1(hidden))
        output = F.relu(self.fc2(output))
        output = self.fc3(output).sum(dim=2)
        return output


class Critic(nn.Module):
    """Estimates the problem complexity.

    This is a basic module that just looks at the log-probabilities predicted by
    the encoder + decoder, and returns an estimate of complexity
    """

    def __init__(self, hidden_size):
        super(Critic, self).__init__()

        # Define the encoder & decoder models
        self.fc1 = nn.Conv1d(1, hidden_size, kernel_size=1)
        self.fc2 = nn.Conv1d(hidden_size, 20, kernel_size=1)
        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input):

        output = F.relu(self.fc1(input.unsqueeze(1)))
        output = F.relu(self.fc2(output)).squeeze(2)
        output = self.fc3(output).sum(dim=2)
        return output


def validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save_dir='.',
             num_plot=5):
    """Used to monitor progress on a validation set & optionally plot solution."""

    actor.eval()

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    rewards = []
    obj1s = []
    obj2s = []
    for batch_idx, batch in enumerate(data_loader):

        static, dynamic, x0 = batch

        static = static.to(device)
        dynamic = dynamic.to(device)
        x0 = x0.to(device) if len(x0) > 0 else None

        with torch.no_grad():
            tour_indices, _ = actor.forward(static, dynamic, x0)

        reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)

        rewards.append(torch.mean(reward.detach()).item())
        obj1s.append(torch.mean(obj1.detach()).item())
        obj2s.append(torch.mean(obj2.detach()).item())
        if render_fn is not None and batch_idx < num_plot:
            name = 'batch%d_%2.4f.png'%(batch_idx, torch.mean(reward.detach()).item())
            path = os.path.join(save_dir, name)
            render_fn(static, tour_indices, path)

    actor.train()
    return np.mean(rewards), np.mean(obj1s), np.mean(obj2s)


def train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data, reward_fn,
          render_fn, batch_size, actor_lr, critic_lr, max_grad_norm,
          **kwargs):
    """Constructs the main actor & critic networks, and performs all training."""

    now = '%s' % datetime.datetime.now().time()
    now = now.replace(':', '_')
    bname = "_4static"
    save_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2), now)

    checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    actor_optim = optim.Adam(actor.parameters(), lr=actor_lr)
    critic_optim = optim.Adam(critic.parameters(), lr=critic_lr)

    train_loader = DataLoader(train_data, batch_size, True, num_workers=0)
    valid_loader = DataLoader(valid_data, batch_size, False, num_workers=0)

    best_params = None
    best_reward = np.inf

    for epoch in range(5):
        print("epoch %d start:"% epoch)
        actor.train()
        critic.train()

        times, losses, rewards, critic_rewards = [], [], [], []
        obj1s, obj2s = [], []

        epoch_start = time.time()
        start = epoch_start

        for batch_idx, batch in enumerate(train_loader):

            static, dynamic, x0 = batch

            static = static.to(device)
            dynamic = dynamic.to(device)
            x0 = x0.to(device) if len(x0) > 0 else None

            # Full forward pass through the dataset
            tour_indices, tour_logp = actor(static, dynamic, x0)

            # Sum the log probabilities for each city in the tour
            reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)

            # Query the critic for an estimate of the reward
            critic_est = critic(static, dynamic).view(-1)

            advantage = (reward - critic_est)
            actor_loss = torch.mean(advantage.detach() * tour_logp.sum(dim=1))
            critic_loss = torch.mean(advantage ** 2)

            actor_optim.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_grad_norm)
            actor_optim.step()

            critic_optim.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_grad_norm)
            critic_optim.step()

            critic_rewards.append(torch.mean(critic_est.detach()).item())
            rewards.append(torch.mean(reward.detach()).item())
            losses.append(torch.mean(actor_loss.detach()).item())
            obj1s.append(torch.mean(obj1.detach()).item())
            obj2s.append(torch.mean(obj2.detach()).item())
            if (batch_idx + 1) % 200 == 0:
                print("\n")
                end = time.time()
                times.append(end - start)
                start = end

                mean_loss = np.mean(losses[-100:])
                mean_reward = np.mean(rewards[-100:])
                mean_obj1 = np.mean(obj1s[-100:])
                mean_obj2 = np.mean(obj2s[-100:])
                print('  Batch %d/%d, reward: %2.3f, obj1: %2.3f, obj2: %2.3f, loss: %2.4f, took: %2.4fs' %
                      (batch_idx, len(train_loader), mean_reward, mean_obj1, mean_obj2, mean_loss,
                       times[-1]))

        mean_loss = np.mean(losses)
        mean_reward = np.mean(rewards)

        # Save the weights
        epoch_dir = os.path.join(checkpoint_dir, '%s' % epoch)
        if not os.path.exists(epoch_dir):
            os.makedirs(epoch_dir)

        save_path = os.path.join(epoch_dir, 'actor.pt')
        torch.save(actor.state_dict(), save_path)

        save_path = os.path.join(epoch_dir, 'critic.pt')
        torch.save(critic.state_dict(), save_path)

        # Save rendering of validation set tours
        valid_dir = os.path.join(save_dir, '%s' % epoch)

        print("begin valid")
        s = time.time()
        mean_valid, mean_obj1_valid, mean_obj2_valid = validate(valid_loader, actor, reward_fn, w1, w2, render_fn,
                              valid_dir, num_plot=5)
        print("valid end time: %2.4f" % (time.time()-s) )
        # Save best model parameters
        if mean_valid < best_reward:

            best_reward = mean_valid

            # save_path = os.path.join(save_dir, 'actor.pt')
            # torch.save(actor.state_dict(), save_path)
            #
            # save_path = os.path.join(save_dir, 'critic.pt')
            # torch.save(critic.state_dict(), save_path)
            # 存在w_1_0主文件夹下,多存一份,用来transfer to next w
            main_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2))
            save_path = os.path.join(main_dir, 'actor.pt')
            torch.save(actor.state_dict(), save_path)
            save_path = os.path.join(main_dir, 'critic.pt')
            torch.save(critic.state_dict(), save_path)

        print('Mean epoch loss/reward: %2.4f, %2.4f, %2.4f, obj1_valid: %2.3f, obj2_valid: %2.3f. took: %2.4fs '\
              '(%2.4fs / 100 batches)\n' % \
              (mean_loss, mean_reward, mean_valid, mean_obj1_valid, mean_obj2_valid, time.time() - epoch_start,
              np.mean(times)))



def train_tsp(args, w1=1, w2=0, checkpoint = None):

    # Goals from paper:
    # TSP20, 3.97
    # TSP50, 6.08
    # TSP100, 8.44

    from tasks import motsp
    from tasks.motsp import TSPDataset

    STATIC_SIZE = 4 # (x, y)
    DYNAMIC_SIZE = 1 # dummy for compatibility

    train_data = TSPDataset(args.num_nodes, args.train_size, args.seed)
    valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1)

    update_fn = None

    actor = DRL4TSP(STATIC_SIZE,
                    DYNAMIC_SIZE,
                    args.hidden_size,
                    update_fn,
                    motsp.update_mask,
                    args.num_layers,
                    args.dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)

    kwargs = vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = motsp.reward
    kwargs['render_fn'] = motsp.render

    if checkpoint:
        path = os.path.join(checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))
        # actor.static_encoder.state_dict().get("conv.weight").size()
        path = os.path.join(checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))

    if not args.test:
        train(actor, critic, w1, w2, **kwargs)

    test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0)
    out = validate(test_loader, actor, motsp.reward, w1, w2, motsp.render, test_dir, num_plot=5)

    print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)


def train_vrp(args):

    # Goals from paper:
    # VRP10, Capacity 20:  4.84  (Greedy)
    # VRP20, Capacity 30:  6.59  (Greedy)
    # VRP50, Capacity 40:  11.39 (Greedy)
    # VRP100, Capacity 50: 17.23  (Greedy)

    from tasks import vrp
    from tasks.vrp import VehicleRoutingDataset

    # Determines the maximum amount of load for a vehicle based on num nodes
    LOAD_DICT = {10: 20, 20: 30, 50: 40, 100: 50}
    MAX_DEMAND = 9
    STATIC_SIZE = 2 # (x, y)
    DYNAMIC_SIZE = 2 # (load, demand)

    max_load = LOAD_DICT[args.num_nodes]

    train_data = VehicleRoutingDataset(args.train_size,
                                       args.num_nodes,
                                       max_load,
                                       MAX_DEMAND,
                                       args.seed)

    valid_data = VehicleRoutingDataset(args.valid_size,
                                       args.num_nodes,
                                       max_load,
                                       MAX_DEMAND,
                                       args.seed + 1)

    actor = DRL4TSP(STATIC_SIZE,
                    DYNAMIC_SIZE,
                    args.hidden_size,
                    train_data.update_dynamic,
                    train_data.update_mask,
                    args.num_layers,
                    args.dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)

    kwargs = vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = vrp.reward
    kwargs['render_fn'] = vrp.render

    if args.checkpoint:
        path = os.path.join(args.checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))

        path = os.path.join(args.checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))

    if not args.test:
        train(actor, critic, **kwargs)

    test_data = VehicleRoutingDataset(args.valid_size,
                                      args.num_nodes,
                                      max_load,
                                      MAX_DEMAND,
                                      args.seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, args.batch_size, False, num_workers=0)
    out = validate(test_loader, actor, vrp.reward, vrp.render, test_dir, num_plot=5)

    print('Average tour length: ', out)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='Combinatorial Optimization')
    parser.add_argument('--seed', default=12345, type=int)
    # parser.add_argument('--checkpoint', default="tsp/20/w_1_0/20_06_30.888074")
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--task', default='tsp')
    parser.add_argument('--nodes', dest='num_nodes', default=40, type=int)
    parser.add_argument('--actor_lr', default=5e-4, type=float)
    parser.add_argument('--critic_lr', default=5e-4, type=float)
    parser.add_argument('--max_grad_norm', default=2., type=float)
    parser.add_argument('--batch_size', default=200, type=int)
    parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)
    parser.add_argument('--dropout', default=0.1, type=float)
    parser.add_argument('--layers', dest='num_layers', default=1, type=int)
    parser.add_argument('--train-size',default=500000, type=int)
    parser.add_argument('--valid-size', default=1000, type=int)

    args = parser.parse_args()

    # Trained without transfer

    if args.task == 'tsp':
        w2_list = np.arange(101)/100
        for i in range(0,101):
            print("Current w:%2.2f/%2.2f"% (1-w2_list[i], w2_list[i]))
            train_tsp(args, 1-w2_list[i], w2_list[i], None)

    elif args.task == 'vrp':
        train_vrp(args)
    else:
        raise ValueError('Task <%s> not understood'%args.task)


================================================
FILE: trainer_motsp_transfer.py
================================================
"""Defines the main trainer model for combinatorial problems

Each task must define the following functions:
* mask_fn: can be None
* update_fn: can be None
* reward_fn: specifies the quality of found solutions
* render_fn: Specifies how to plot found solutions. Can be None
"""

import os
import time
import argparse
import datetime
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from model import DRL4TSP, Encoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')


class StateCritic(nn.Module):
    """Estimates the problem complexity.

    This is a basic module that just looks at the log-probabilities predicted by
    the encoder + decoder, and returns an estimate of complexity
    """

    def __init__(self, static_size, dynamic_size, hidden_size):
        super(StateCritic, self).__init__()

        self.static_encoder = Encoder(static_size, hidden_size)
        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)

        # Define the encoder & decoder models
        self.fc1 = nn.Conv1d(hidden_size * 2, 20, kernel_size=1)
        self.fc2 = nn.Conv1d(20, 20, kernel_size=1)
        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, static, dynamic):

        # Use the probabilities of visiting each
        static_hidden = self.static_encoder(static)
        dynamic_hidden = self.dynamic_encoder(dynamic)

        hidden = torch.cat((static_hidden, dynamic_hidden), 1)

        output = F.relu(self.fc1(hidden))
        output = F.relu(self.fc2(output))
        output = self.fc3(output).sum(dim=2)
        return output


class Critic(nn.Module):
    """Estimates the problem complexity.

    This is a basic module that just looks at the log-probabilities predicted by
    the encoder + decoder, and returns an estimate of complexity
    """

    def __init__(self, hidden_size):
        super(Critic, self).__init__()

        # Define the encoder & decoder models
        self.fc1 = nn.Conv1d(1, hidden_size, kernel_size=1)
        self.fc2 = nn.Conv1d(hidden_size, 20, kernel_size=1)
        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)

        for p in self.parameters():
            if len(p.shape) > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input):

        output = F.relu(self.fc1(input.unsqueeze(1)))
        output = F.relu(self.fc2(output)).squeeze(2)
        output = self.fc3(output).sum(dim=2)
        return output


def validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save_dir='.',
             num_plot=5):
    """Used to monitor progress on a validation set & optionally plot solution."""

    actor.eval()

    # if not os.path.exists(save_dir):
    #     os.makedirs(save_dir)

    rewards = []
    obj1s = []
    obj2s = []
    for batch_idx, batch in enumerate(data_loader):

        static, dynamic, x0 = batch

        static = static.to(device)
        dynamic = dynamic.to(device)
        x0 = x0.to(device) if len(x0) > 0 else None

        with torch.no_grad():
            tour_indices, _ = actor.forward(static, dynamic, x0)

        reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)

        rewards.append(torch.mean(reward.detach()).item())
        obj1s.append(torch.mean(obj1.detach()).item())
        obj2s.append(torch.mean(obj2.detach()).item())
        # if render_fn is not None and batch_idx < num_plot:
        #     name = 'batch%d_%2.4f.png'%(batch_idx, torch.mean(reward.detach()).item())
        #     path = os.path.join(save_dir, name)
        #     render_fn(static, tour_indices, path)

    actor.train()
    return np.mean(rewards), np.mean(obj1s), np.mean(obj2s)


def train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data, reward_fn,
          render_fn, batch_size, actor_lr, critic_lr, max_grad_norm,
          **kwargs):
    """Constructs the main actor & critic networks, and performs all training."""

    now = '%s' % datetime.datetime.now().time()
    now = now.replace(':', '_')
    bname = "_transfer"
    save_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2), now)

    checkpoint_dir = os.path.join(save_dir, 'checkpoints')
    if not os.path.exists(checkpoint_dir):
         os.makedirs(checkpoint_dir)

    actor_optim = optim.Adam(actor.parameters(), lr=actor_lr)
    critic_optim = optim.Adam(critic.parameters(), lr=critic_lr)

    train_loader = DataLoader(train_data, batch_size, True, num_workers=0)
    valid_loader = DataLoader(valid_data, batch_size, False, num_workers=0)

    best_params = None
    best_reward = np.inf
    start_total = time.time()
    for epoch in range(3):
        print("epoch %d start:"% epoch)
        actor.train()
        critic.train()

        times, losses, rewards, critic_rewards = [], [], [], []
        obj1s, obj2s = [], []

        epoch_start = time.time()
        start = epoch_start

        for batch_idx, batch in enumerate(train_loader):

            static, dynamic, x0 = batch

            static = static.to(device)
            dynamic = dynamic.to(device)
            x0 = x0.to(device) if len(x0) > 0 else None

            # Full forward pass through the dataset
            tour_indices, tour_logp = actor(static, dynamic, x0)

            # Sum the log probabilities for each city in the tour
            reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)

            # Query the critic for an estimate of the reward
            critic_est = critic(static, dynamic).view(-1)

            advantage = (reward - critic_est)
            actor_loss = torch.mean(advantage.detach() * tour_logp.sum(dim=1))
            critic_loss = torch.mean(advantage ** 2)

            actor_optim.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_grad_norm)
            actor_optim.step()

            critic_optim.zero_grad()
            critic_loss.backward()
            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_grad_norm)
            critic_optim.step()

            critic_rewards.append(torch.mean(critic_est.detach()).item())
            rewards.append(torch.mean(reward.detach()).item())
            losses.append(torch.mean(actor_loss.detach()).item())
            obj1s.append(torch.mean(obj1.detach()).item())
            obj2s.append(torch.mean(obj2.detach()).item())
            if (batch_idx + 1) % 200 == 0:
                print("\n")
                end = time.time()
                times.append(end - start)
                start = end

                mean_loss = np.mean(losses[-100:])
                mean_reward = np.mean(rewards[-100:])
                mean_obj1 = np.mean(obj1s[-100:])
                mean_obj2 = np.mean(obj2s[-100:])
                print('  Batch %d/%d, reward: %2.3f, obj1: %2.3f, obj2: %2.3f, loss: %2.4f, took: %2.4fs' %
                      (batch_idx, len(train_loader), mean_reward, mean_obj1, mean_obj2, mean_loss,
                       times[-1]))

        mean_loss = np.mean(losses)
        mean_reward = np.mean(rewards)

        # Save the weights
        # epoch_dir = os.path.join(checkpoint_dir, '%s' % epoch)
        # if not os.path.exists(epoch_dir):
        #     os.makedirs(epoch_dir)
        #
        # save_path = os.path.join(epoch_dir, 'actor.pt')
        # torch.save(actor.state_dict(), save_path)
        #
        # save_path = os.path.join(epoch_dir, 'critic.pt')
        # torch.save(critic.state_dict(), save_path)

        # Save rendering of validation set tours
        # valid_dir = os.path.join(save_dir, '%s' % epoch)
        mean_valid, mean_obj1_valid, mean_obj2_valid = validate(valid_loader, actor, reward_fn, w1, w2, render_fn,
                              '.', num_plot=5)

        # Save best model parameters
        if mean_valid < best_reward:

            best_reward = mean_valid

            # save_path = os.path.join(save_dir, 'actor.pt')
            # torch.save(actor.state_dict(), save_path)
            #
            # save_path = os.path.join(save_dir, 'critic.pt')
            # torch.save(critic.state_dict(), save_path)
            # 存在w_1_0主文件夹下,多存一份,用来transfer to next w
            main_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2))
            save_path = os.path.join(main_dir, 'actor.pt')
            torch.save(actor.state_dict(), save_path)
            save_path = os.path.join(main_dir, 'critic.pt')
            torch.save(critic.state_dict(), save_path)

        print('Mean epoch loss/reward: %2.4f, %2.4f, %2.4f, obj1_valid: %2.3f, obj2_valid: %2.3f. took: %2.4fs '\
              '(%2.4fs / 100 batches)\n' % \
              (mean_loss, mean_reward, mean_valid, mean_obj1_valid, mean_obj2_valid, time.time() - epoch_start,
              np.mean(times)))
    print("Total run time of epoches: %2.4f" % (time.time() - start_total))



def train_tsp(args, w1=1, w2=0, checkpoint = None):

    # Goals from paper:
    # TSP20, 3.97
    # TSP50, 6.08
    # TSP100, 8.44

    from tasks import motsp
    from tasks.motsp import TSPDataset

    STATIC_SIZE = 4 # (x, y)
    DYNAMIC_SIZE = 1 # dummy for compatibility

    train_data = TSPDataset(args.num_nodes, args.train_size, args.seed)
    valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1)

    update_fn = None

    actor = DRL4TSP(STATIC_SIZE,
                    DYNAMIC_SIZE,
                    args.hidden_size,
                    update_fn,
                    motsp.update_mask,
                    args.num_layers,
                    args.dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)

    kwargs = vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = motsp.reward
    kwargs['render_fn'] = motsp.render

    if checkpoint:
        path = os.path.join(checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))
        # actor.static_encoder.state_dict().get("conv.weight").size()
        path = os.path.join(checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))

    if not args.test:
        train(actor, critic, w1, w2, **kwargs)

    test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0)
    out = validate(test_loader, actor, motsp.reward, w1, w2, motsp.render, test_dir, num_plot=5)

    print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)


def train_vrp(args):

    # Goals from paper:
    # VRP10, Capacity 20:  4.84  (Greedy)
    # VRP20, Capacity 30:  6.59  (Greedy)
    # VRP50, Capacity 40:  11.39 (Greedy)
    # VRP100, Capacity 50: 17.23  (Greedy)

    from tasks import vrp
    from tasks.vrp import VehicleRoutingDataset

    # Determines the maximum amount of load for a vehicle based on num nodes
    LOAD_DICT = {10: 20, 20: 30, 50: 40, 100: 50}
    MAX_DEMAND = 9
    STATIC_SIZE = 2 # (x, y)
    DYNAMIC_SIZE = 2 # (load, demand)

    max_load = LOAD_DICT[args.num_nodes]

    train_data = VehicleRoutingDataset(args.train_size,
                                       args.num_nodes,
                                       max_load,
                                       MAX_DEMAND,
                                       args.seed)

    valid_data = VehicleRoutingDataset(args.valid_size,
                                       args.num_nodes,
                                       max_load,
                                       MAX_DEMAND,
                                       args.seed + 1)

    actor = DRL4TSP(STATIC_SIZE,
                    DYNAMIC_SIZE,
                    args.hidden_size,
                    train_data.update_dynamic,
                    train_data.update_mask,
                    args.num_layers,
                    args.dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)

    kwargs = vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = vrp.reward
    kwargs['render_fn'] = vrp.render

    if args.checkpoint:
        path = os.path.join(args.checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))

        path = os.path.join(args.checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))

    if not args.test:
        train(actor, critic, **kwargs)

    test_data = VehicleRoutingDataset(args.valid_size,
                                      args.num_nodes,
                                      max_load,
                                      MAX_DEMAND,
                                      args.seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, args.batch_size, False, num_workers=0)
    out = validate(test_loader, actor, vrp.reward, vrp.render, test_dir, num_plot=5)

    print('Average tour length: ', out)


if __name__ == '__main__':
    num_nodes = 100
    parser = argparse.ArgumentParser(description='Combinatorial Optimization')
    parser.add_argument('--seed', default=12345, type=int)
    # parser.add_argument('--checkpoint', default="tsp/20/w_1_0/20_06_30.888074")
    parser.add_argument('--test', action='store_true', default=False)
    parser.add_argument('--task', default='tsp')
    parser.add_argument('--nodes', dest='num_nodes', default=num_nodes, type=int)
    parser.add_argument('--actor_lr', default=5e-4, type=float)
    parser.add_argument('--critic_lr', default=5e-4, type=float)
    parser.add_argument('--max_grad_norm', default=2., type=float)
    parser.add_argument('--batch_size', default=200, type=int)
    parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)
    parser.add_argument('--dropout', default=0.1, type=float)
    parser.add_argument('--layers', dest='num_layers', default=1, type=int)
    parser.add_argument('--train-size',default=120000, type=int)
    parser.add_argument('--valid-size', default=1000, type=int)

    args = parser.parse_args()


    T = 100
    if args.task == 'tsp':
        w2_list = np.arange(T+1)/T
        for i in range(0,T+1):
            print("Current w:%2.2f/%2.2f"% (1-w2_list[i], w2_list[i]))
            if i==0:
                # The first subproblem can be trained from scratch. It also can be trained based on a
                # single-TSP trained model, where the model can be obtained from everywhere in github
                checkpoint = 'tsp_transfer_100run_500000_5epoch_40city/40/w_1.00_0.00'
                train_tsp(args, 1, 0, checkpoint)
            else:
                # Parameter transfer. train based on the parameters of the previous subproblem
                checkpoint = 'tsp_transfer/%d/w_%2.2f_%2.2f'%(num_nodes, 1-w2_list[i-1], w2_list[i-1])
                train_tsp(args, 1-w2_list[i], w2_list[i], checkpoint)


Download .txt
gitextract_7k7g2yn_/

├── Post_process/
│   ├── convet_kro_dataloader.py
│   ├── data/
│   │   ├── obj1_4_100.mat
│   │   ├── obj1_4_150.mat
│   │   ├── obj1_4_200.mat
│   │   ├── obj1_4_40.mat
│   │   ├── obj1_4_500.mat
│   │   ├── obj1_4_70.mat
│   │   ├── obj2_4_100.mat
│   │   ├── obj2_4_150.mat
│   │   ├── obj2_4_200.mat
│   │   ├── obj2_4_40.mat
│   │   ├── obj2_4_500.mat
│   │   ├── obj2_4_70.mat
│   │   ├── rl4_100.mat
│   │   ├── rl4_150.mat
│   │   ├── rl4_200.mat
│   │   ├── rl4_40.mat
│   │   ├── rl4_500.mat
│   │   ├── rl4_70.mat
│   │   ├── tour4_100.mat
│   │   └── tour4_200.mat
│   ├── dis_matrix.py
│   ├── krodata/
│   │   ├── kroA100.tsp
│   │   ├── kroA150.tsp
│   │   ├── kroA200.tsp
│   │   ├── kroB100.tsp
│   │   ├── kroB150.tsp
│   │   └── kroB200.tsp
│   ├── load_all_reward.py
│   ├── obj1.mat
│   ├── obj2.mat
│   └── rl.mat
├── README.md
├── model.py
├── parameter_transfer.py
├── tasks/
│   ├── motsp.py
│   ├── tsp.py
│   └── vrp.py
├── trainer_motsp_no_transfer.py
├── trainer_motsp_transfer.py
├── tsp_transfer_100run_500000_5epoch_20city/
│   └── 20/
│       ├── w_0.04_0.96/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.05_0.95/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.06_0.94/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.07_0.93/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.08_0.92/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.09_0.91/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.10_0.90/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.11_0.89/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.12_0.88/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.13_0.87/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.14_0.86/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.15_0.85/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.16_0.84/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.17_0.83/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.18_0.82/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.19_0.81/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.20_0.80/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.21_0.79/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.22_0.78/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.23_0.77/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.24_0.76/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.25_0.75/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.26_0.74/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.27_0.73/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.28_0.72/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.29_0.71/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.30_0.70/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.31_0.69/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.32_0.68/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.33_0.67/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.34_0.66/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.35_0.65/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.36_0.64/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.37_0.63/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.38_0.62/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.39_0.61/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.40_0.60/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.41_0.59/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.42_0.58/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.43_0.57/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.44_0.56/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.45_0.55/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.46_0.54/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.47_0.53/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.48_0.52/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.49_0.51/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.50_0.50/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.51_0.49/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.52_0.48/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.53_0.47/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.54_0.46/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.55_0.45/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.56_0.44/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.57_0.43/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.58_0.42/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.59_0.41/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.60_0.40/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.61_0.39/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.62_0.38/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.63_0.37/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.64_0.36/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.65_0.35/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.66_0.34/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.67_0.33/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.68_0.32/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.69_0.31/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.70_0.30/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.71_0.29/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.72_0.28/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.73_0.27/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.74_0.26/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.75_0.25/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.76_0.24/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.77_0.23/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.78_0.22/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.79_0.21/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.80_0.20/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.81_0.19/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.82_0.18/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.83_0.17/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.84_0.16/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.85_0.15/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.86_0.14/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.87_0.13/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.88_0.12/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.89_0.11/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.90_0.10/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.91_0.09/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.92_0.08/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.93_0.07/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.94_0.06/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.95_0.05/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.96_0.04/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.97_0.03/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.98_0.02/
│       │   ├── actor.pt
│       │   └── critic.pt
│       ├── w_0.99_0.01/
│       │   ├── actor.pt
│       │   └── critic.pt
│       └── w_1.00_0.00/
│           ├── actor.pt
│           └── critic.pt
└── tsp_transfer_100run_500000_5epoch_40city/
    └── 40/
        ├── w_0.00_1.00/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.01_0.99/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.02_0.98/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.03_0.97/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.04_0.96/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.05_0.95/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.06_0.94/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.07_0.93/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.08_0.92/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.09_0.91/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.10_0.90/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.11_0.89/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.12_0.88/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.13_0.87/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.14_0.86/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.15_0.85/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.16_0.84/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.17_0.83/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.18_0.82/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.19_0.81/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.20_0.80/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.21_0.79/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.22_0.78/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.23_0.77/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.24_0.76/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.25_0.75/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.26_0.74/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.27_0.73/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.28_0.72/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.29_0.71/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.30_0.70/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.31_0.69/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.32_0.68/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.33_0.67/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.34_0.66/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.35_0.65/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.36_0.64/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.37_0.63/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.38_0.62/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.39_0.61/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.40_0.60/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.41_0.59/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.42_0.58/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.43_0.57/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.44_0.56/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.45_0.55/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.46_0.54/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.47_0.53/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.48_0.52/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.49_0.51/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.50_0.50/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.51_0.49/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.52_0.48/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.53_0.47/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.54_0.46/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.55_0.45/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.56_0.44/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.57_0.43/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.58_0.42/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.59_0.41/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.60_0.40/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.61_0.39/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.62_0.38/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.63_0.37/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.64_0.36/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.65_0.35/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.66_0.34/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.67_0.33/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.68_0.32/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.69_0.31/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.70_0.30/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.71_0.29/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.72_0.28/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.73_0.27/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.74_0.26/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.75_0.25/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.76_0.24/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.77_0.23/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.78_0.22/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.79_0.21/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.80_0.20/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.81_0.19/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.82_0.18/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.83_0.17/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.84_0.16/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.85_0.15/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.86_0.14/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.87_0.13/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.88_0.12/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.89_0.11/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.90_0.10/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.91_0.09/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.92_0.08/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.93_0.07/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.94_0.06/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.95_0.05/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.96_0.04/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.97_0.03/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.98_0.02/
        │   ├── actor.pt
        │   └── critic.pt
        ├── w_0.99_0.01/
        │   ├── actor.pt
        │   └── critic.pt
        └── w_1.00_0.00/
            ├── actor.pt
            └── critic.pt
Download .txt
SYMBOL INDEX (59 symbols across 8 files)

FILE: Post_process/convet_kro_dataloader.py
  class Kro_dataset (line 9) | class Kro_dataset(Dataset):
    method __init__ (line 11) | def __init__(self, num_nodes):
    method __len__ (line 28) | def __len__(self):
    method __getitem__ (line 31) | def __getitem__(self, idx):

FILE: Post_process/dis_matrix.py
  function dis_matrix (line 4) | def dis_matrix(static, s_size):

FILE: model.py
  class Encoder (line 9) | class Encoder(nn.Module):
    method __init__ (line 12) | def __init__(self, input_size, hidden_size):
    method forward (line 16) | def forward(self, input):
  class Attention (line 21) | class Attention(nn.Module):
    method __init__ (line 24) | def __init__(self, hidden_size):
    method forward (line 34) | def forward(self, static_hidden, dynamic_hidden, decoder_hidden):
  class Pointer (line 50) | class Pointer(nn.Module):
    method __init__ (line 53) | def __init__(self, hidden_size, num_layers=1, dropout=0.2):
    method forward (line 76) | def forward(self, static_hidden, dynamic_hidden, decoder_hidden, last_...
  class DRL4TSP (line 103) | class DRL4TSP(nn.Module):
    method __init__ (line 134) | def __init__(self, static_size, dynamic_size, hidden_size,
    method forward (line 158) | def forward(self, static, dynamic, decoder_input=None, last_hh=None):

FILE: tasks/motsp.py
  class TSPDataset (line 21) | class TSPDataset(Dataset):
    method __init__ (line 23) | def __init__(self, size=50, num_samples=1e6, seed=None):
    method __len__ (line 37) | def __len__(self):
    method __getitem__ (line 40) | def __getitem__(self, idx):
  function update_mask (line 45) | def update_mask(mask, dynamic, chosen_idx):
  function reward (line 51) | def reward(static, tour_indices, w1=1, w2=0):
  function render (line 86) | def render(static, tour_indices, save_path):

FILE: tasks/tsp.py
  class TSPDataset (line 21) | class TSPDataset(Dataset):
    method __init__ (line 23) | def __init__(self, size=50, num_samples=1e6, seed=None):
    method __len__ (line 36) | def __len__(self):
    method __getitem__ (line 39) | def __getitem__(self, idx):
  function update_mask (line 44) | def update_mask(mask, dynamic, chosen_idx):
  function reward (line 50) | def reward(static, tour_indices):
  function render (line 76) | def render(static, tour_indices, save_path):

FILE: tasks/vrp.py
  class VehicleRoutingDataset (line 19) | class VehicleRoutingDataset(Dataset):
    method __init__ (line 20) | def __init__(self, num_samples, input_size, max_load=20, max_demand=9,
    method __len__ (line 57) | def __len__(self):
    method __getitem__ (line 60) | def __getitem__(self, idx):
    method update_mask (line 64) | def update_mask(self, mask, dynamic, chosen_idx=None):
    method update_dynamic (line 103) | def update_dynamic(self, dynamic, chosen_idx):
  function reward (line 140) | def reward(static, tour_indices):
  function render (line 161) | def render(static, tour_indices, save_path):

FILE: trainer_motsp_no_transfer.py
  class StateCritic (line 27) | class StateCritic(nn.Module):
    method __init__ (line 34) | def __init__(self, static_size, dynamic_size, hidden_size):
    method forward (line 49) | def forward(self, static, dynamic):
  class Critic (line 63) | class Critic(nn.Module):
    method __init__ (line 70) | def __init__(self, hidden_size):
    method forward (line 82) | def forward(self, input):
  function validate (line 90) | def validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save...
  function train (line 127) | def train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data...
  function train_tsp (line 257) | def train_tsp(args, w1=1, w2=0, checkpoint = None):
  function train_vrp (line 310) | def train_vrp(args):

FILE: trainer_motsp_transfer.py
  class StateCritic (line 27) | class StateCritic(nn.Module):
    method __init__ (line 34) | def __init__(self, static_size, dynamic_size, hidden_size):
    method forward (line 49) | def forward(self, static, dynamic):
  class Critic (line 63) | class Critic(nn.Module):
    method __init__ (line 70) | def __init__(self, hidden_size):
    method forward (line 82) | def forward(self, input):
  function validate (line 90) | def validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save...
  function train (line 127) | def train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data...
  function train_tsp (line 255) | def train_tsp(args, w1=1, w2=0, checkpoint = None):
  function train_vrp (line 308) | def train_vrp(args):
Condensed preview — 436 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (81K chars).
[
  {
    "path": "Post_process/convet_kro_dataloader.py",
    "chars": 983,
    "preview": "import numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nimport matplotlib\n# matplotlib.use('Agg')\nimport ma"
  },
  {
    "path": "Post_process/dis_matrix.py",
    "chars": 721,
    "preview": "import numpy as np\nimport torch\n\ndef dis_matrix(static, s_size):\n    static = static.squeeze(0)\n\n    # [2,20]\n    obj1 ="
  },
  {
    "path": "Post_process/krodata/kroA100.tsp",
    "chars": 1341,
    "preview": "NAME: kroA100\nTYPE: TSP\nCOMMENT: 100-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 100\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/krodata/kroA150.tsp",
    "chars": 1997,
    "preview": "NAME: kroA150\nTYPE: TSP\nCOMMENT: 150-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 150\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/krodata/kroA200.tsp",
    "chars": 2652,
    "preview": "NAME: kroA200\nTYPE: TSP\nCOMMENT: 200-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 200\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/krodata/kroB100.tsp",
    "chars": 1348,
    "preview": "NAME: kroB100\nTYPE: TSP\nCOMMENT: 100-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 100\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/krodata/kroB150.tsp",
    "chars": 1993,
    "preview": "NAME: kroB150\nTYPE: TSP\nCOMMENT: 150-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 150\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/krodata/kroB200.tsp",
    "chars": 2659,
    "preview": "NAME: kroB200\nTYPE: TSP\nCOMMENT: 200-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 200\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_"
  },
  {
    "path": "Post_process/load_all_reward.py",
    "chars": 3298,
    "preview": "import torch\nfrom tasks import motsp\nfrom tasks.motsp import TSPDataset, reward\nfrom torch.utils.data import DataLoader\n"
  },
  {
    "path": "README.md",
    "chars": 1160,
    "preview": "# Using Deep Reinforcement Learning method and Attention model to solve the Multiobjectve TSP. \n## This code is the mode"
  },
  {
    "path": "model.py",
    "chars": 10593,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndevice = torch.device('cuda' if torch.cuda.is_availa"
  },
  {
    "path": "parameter_transfer.py",
    "chars": 2545,
    "preview": "import torch\nimport os\nfrom model import DRL4TSP, Encoder\nimport argparse\nfrom tasks import motsp\nfrom trainer_motsp_tra"
  },
  {
    "path": "tasks/motsp.py",
    "chars": 3534,
    "preview": "\"\"\"Defines the main task for the TSP\n\nThe TSP is defined by the following traits:\n    1. Each city in the list must be v"
  },
  {
    "path": "tasks/tsp.py",
    "chars": 3219,
    "preview": "\"\"\"Defines the main task for the TSP\n\nThe TSP is defined by the following traits:\n    1. Each city in the list must be v"
  },
  {
    "path": "tasks/vrp.py",
    "chars": 9405,
    "preview": "\"\"\"Defines the main task for the VRP.\n\nThe VRP is defined by the following traits:\n    1. Each city has a demand in [1, "
  },
  {
    "path": "trainer_motsp_no_transfer.py",
    "chars": 14634,
    "preview": "\"\"\"Defines the main trainer model for combinatorial problems\n\nEach task must define the following functions:\n* mask_fn: "
  },
  {
    "path": "trainer_motsp_transfer.py",
    "chars": 15139,
    "preview": "\"\"\"Defines the main trainer model for combinatorial problems\n\nEach task must define the following functions:\n* mask_fn: "
  }
]

// ... and 419 more files (download for full content)

About this extraction

This page contains the full source code of the kevin031060/RL_TSP_4static GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 436 files (75.4 KB), approximately 30.8k tokens, and a symbol index with 59 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!