[
  {
    "path": "Post_process/convet_kro_dataloader.py",
    "content": "import numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nimport matplotlib\n# matplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n\nclass Kro_dataset(Dataset):\n\n    def __init__(self, num_nodes):\n        super(Kro_dataset, self).__init__()\n\n        x1 = np.loadtxt('krodata/kroA%d.tsp'%num_nodes, skiprows=6, usecols=(1, 2), delimiter=' ', dtype=float)\n        x1 = x1 / (np.max(x1,0))\n        x2 = np.loadtxt('krodata/kroB%d.tsp'%num_nodes, skiprows=6, usecols=(1, 2), delimiter=' ', dtype=float)\n        x2 = x2 / (np.max(x2,0))\n        x = np.concatenate((x1, x2),axis=1)\n        x = x.T\n        x = x.reshape(1, 4, num_nodes)\n\n        self.dataset = torch.from_numpy(x).float()\n        self.dynamic = torch.zeros(1, 1, num_nodes)\n        self.num_nodes = num_nodes\n        self.size = 1\n\n\n    def __len__(self):\n        return self.size\n\n    def __getitem__(self, idx):\n        # (static, dynamic, start_loc)\n        return (self.dataset[idx], self.dynamic[idx], [])"
  },
  {
    "path": "Post_process/dis_matrix.py",
    "content": "import numpy as np\nimport torch\n\ndef dis_matrix(static, s_size):\n    static = static.squeeze(0)\n\n    # [2,20]\n    obj1 = static[:2, :]\n    # [20]\n    obj2 = static[2:, :]\n\n    l = obj1.size()[1]\n    obj1_matrix = np.zeros((l, l))\n    obj2_matrix = np.zeros((l, l))\n    for i in range(l):\n        for j in range(l):\n            if i != j:\n                obj1_matrix[i,j] = torch.sqrt(torch.sum(torch.pow(obj1[:, i] - obj1[:, j], 2))).detach()\n                if s_size == 3:\n                    obj2_matrix[i, j] = torch.abs(obj2[i] - obj2[j]).detach()\n                else:\n                    obj2_matrix[i, j] = torch.sqrt(torch.sum(torch.pow(obj2[:, i] - obj2[:, j], 2))).detach()\n\n    return obj1_matrix, obj2_matrix"
  },
  {
    "path": "Post_process/krodata/kroA100.tsp",
    "content": "NAME: kroA100\nTYPE: TSP\nCOMMENT: 100-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 100\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 1380 939\n2 2848 96\n3 3510 1671\n4 457 334\n5 3888 666\n6 984 965\n7 2721 1482\n8 1286 525\n9 2716 1432\n10 738 1325\n11 1251 1832\n12 2728 1698\n13 3815 169\n14 3683 1533\n15 1247 1945\n16 123 862\n17 1234 1946\n18 252 1240\n19 611 673\n20 2576 1676\n21 928 1700\n22 53 857\n23 1807 1711\n24 274 1420\n25 2574 946\n26 178 24\n27 2678 1825\n28 1795 962\n29 3384 1498\n30 3520 1079\n31 1256 61\n32 1424 1728\n33 3913 192\n34 3085 1528\n35 2573 1969\n36 463 1670\n37 3875 598\n38 298 1513\n39 3479 821\n40 2542 236\n41 3955 1743\n42 1323 280\n43 3447 1830\n44 2936 337\n45 1621 1830\n46 3373 1646\n47 1393 1368\n48 3874 1318\n49 938 955\n50 3022 474\n51 2482 1183\n52 3854 923\n53 376 825\n54 2519 135\n55 2945 1622\n56 953 268\n57 2628 1479\n58 2097 981\n59 890 1846\n60 2139 1806\n61 2421 1007\n62 2290 1810\n63 1115 1052\n64 2588 302\n65 327 265\n66 241 341\n67 1917 687\n68 2991 792\n69 2573 599\n70 19 674\n71 3911 1673\n72 872 1559\n73 2863 558\n74 929 1766\n75 839 620\n76 3893 102\n77 2178 1619\n78 3822 899\n79 378 1048\n80 1178 100\n81 2599 901\n82 3416 143\n83 2961 1605\n84 611 1384\n85 3113 885\n86 2597 1830\n87 2586 1286\n88 161 906\n89 1429 134\n90 742 1025\n91 1625 1651\n92 1187 706\n93 1787 1009\n94 22 987\n95 3640 43\n96 3756 882\n97 776 392\n98 1724 1642\n99 198 1810\n100 3950 1558\n"
  },
  {
    "path": "Post_process/krodata/kroA150.tsp",
    "content": "NAME: kroA150\nTYPE: TSP\nCOMMENT: 150-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 150\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 1380 939\n2 2848 96\n3 3510 1671\n4 457 334\n5 3888 666\n6 984 965\n7 2721 1482\n8 1286 525\n9 2716 1432\n10 738 1325\n11 1251 1832\n12 2728 1698\n13 3815 169\n14 3683 1533\n15 1247 1945\n16 123 862\n17 1234 1946\n18 252 1240\n19 611 673\n20 2576 1676\n21 928 1700\n22 53 857\n23 1807 1711\n24 274 1420\n25 2574 946\n26 178 24\n27 2678 1825\n28 1795 962\n29 3384 1498\n30 3520 1079\n31 1256 61\n32 1424 1728\n33 3913 192\n34 3085 1528\n35 2573 1969\n36 463 1670\n37 3875 598\n38 298 1513\n39 3479 821\n40 2542 236\n41 3955 1743\n42 1323 280\n43 3447 1830\n44 2936 337\n45 1621 1830\n46 3373 1646\n47 1393 1368\n48 3874 1318\n49 938 955\n50 3022 474\n51 2482 1183\n52 3854 923\n53 376 825\n54 2519 135\n55 2945 1622\n56 953 268\n57 2628 1479\n58 2097 981\n59 890 1846\n60 2139 1806\n61 2421 1007\n62 2290 1810\n63 1115 1052\n64 2588 302\n65 327 265\n66 241 341\n67 1917 687\n68 2991 792\n69 2573 599\n70 19 674\n71 3911 1673\n72 872 1559\n73 2863 558\n74 929 1766\n75 839 620\n76 3893 102\n77 2178 1619\n78 3822 899\n79 378 1048\n80 1178 100\n81 2599 901\n82 3416 143\n83 2961 1605\n84 611 1384\n85 3113 885\n86 2597 1830\n87 2586 1286\n88 161 906\n89 1429 134\n90 742 1025\n91 1625 1651\n92 1187 706\n93 1787 1009\n94 22 987\n95 3640 43\n96 3756 882\n97 776 392\n98 1724 1642\n99 198 1810\n100 3950 1558\n101 3477 949\n102 91 1732\n103 3972 329\n104 198 1632\n105 1806 733\n106 538 1023\n107 3430 1088\n108 2186 766\n109 1513 1646\n110 2143 1611\n111 53 1657\n112 3404 1307\n113 1034 1344\n114 2823 376\n115 3104 1931\n116 3232 324\n117 2790 1457\n118 374 9\n119 741 146\n120 3083 1938\n121 3502 1067\n122 1280 237\n123 3326 1846\n124 217 38\n125 2503 1172\n126 3527 41\n127 739 1850\n128 3548 1999\n129 48 154\n130 1419 872\n131 1689 1223\n132 3468 1404\n133 1628 253\n134 382 872\n135 3029 1242\n136 3646 1758\n137 285 1029\n138 1782 93\n139 1067 371\n140 2849 1214\n141 920 1835\n142 1741 712\n143 876 220\n144 2753 283\n145 2609 1286\n146 3941 258\n147 3613 523\n148 1754 559\n149 2916 1724\n150 2445 1820\n"
  },
  {
    "path": "Post_process/krodata/kroA200.tsp",
    "content": "NAME: kroA200\nTYPE: TSP\nCOMMENT: 200-city problem A (Krolak/Felts/Nelson)\nDIMENSION: 200\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 1357 1905\n2 2650 802\n3 1774 107\n4 1307 964\n5 3806 746\n6 2687 1353\n7 43 1957\n8 3092 1668\n9 185 1542\n10 834 629\n11 40 462\n12 1183 1391\n13 2048 1628\n14 1097 643\n15 1838 1732\n16 234 1118\n17 3314 1881\n18 737 1285\n19 779 777\n20 2312 1949\n21 2576 189\n22 3078 1541\n23 2781 478\n24 705 1812\n25 3409 1917\n26 323 1714\n27 1660 1556\n28 3729 1188\n29 693 1383\n30 2361 640\n31 2433 1538\n32 554 1825\n33 913 317\n34 3586 1909\n35 2636 727\n36 1000 457\n37 482 1337\n38 3704 1082\n39 3635 1174\n40 1362 1526\n41 2049 417\n42 2552 1909\n43 3939 640\n44 219 898\n45 812 351\n46 901 1552\n47 2513 1572\n48 242 584\n49 826 1226\n50 3278 799\n51 86 1065\n52 14 454\n53 1327 1893\n54 2773 1286\n55 2469 1838\n56 3835 963\n57 1031 428\n58 3853 1712\n59 1868 197\n60 1544 863\n61 457 1607\n62 3174 1064\n63 192 1004\n64 2318 1925\n65 2232 1374\n66 396 828\n67 2365 1649\n68 2499 658\n69 1410 307\n70 2990 214\n71 3646 1018\n72 3394 1028\n73 1779 90\n74 1058 372\n75 2933 1459\n76 3099 173\n77 2178 978\n78 138 1610\n79 2082 1753\n80 2302 1127\n81 805 272\n82 22 1617\n83 3213 1085\n84 99 536\n85 1533 1780\n86 3564 676\n87 29 6\n88 3808 1375\n89 2221 291\n90 3499 1885\n91 3124 408\n92 781 671\n93 1027 1041\n94 3249 378\n95 3297 491\n96 213 220\n97 721 186\n98 3736 1542\n99 868 731\n100 960 303\n101 1380 939\n102 2848 96\n103 3510 1671\n104 457 334\n105 3888 666\n106 984 965\n107 2721 1482\n108 1286 525\n109 2716 1432\n110 738 1325\n111 1251 1832\n112 2728 1698\n113 3815 169\n114 3683 1533\n115 1247 1945\n116 123 862\n117 1234 1946\n118 252 1240\n119 611 673\n120 2576 1676\n121 928 1700\n122 53 857\n123 1807 1711\n124 274 1420\n125 2574 946\n126 178 24\n127 2678 1825\n128 1795 962\n129 3384 1498\n130 3520 1079\n131 1256 61\n132 1424 1728\n133 3913 192\n134 3085 1528\n135 2573 1969\n136 463 1670\n137 3875 598\n138 298 1513\n139 3479 821\n140 2542 236\n141 3955 1743\n142 1323 280\n143 3447 1830\n144 2936 337\n145 1621 1830\n146 3373 1646\n147 1393 1368\n148 3874 1318\n149 938 955\n150 3022 474\n151 2482 1183\n152 3854 923\n153 376 825\n154 2519 135\n155 2945 1622\n156 953 268\n157 2628 1479\n158 2097 981\n159 890 1846\n160 2139 1806\n161 2421 1007\n162 2290 1810\n163 1115 1052\n164 2588 302\n165 327 265\n166 241 341\n167 1917 687\n168 2991 792\n169 2573 599\n170 19 674\n171 3911 1673\n172 872 1559\n173 2863 558\n174 929 1766\n175 839 620\n176 3893 102\n177 2178 1619\n178 3822 899\n179 378 1048\n180 1178 100\n181 2599 901\n182 3416 143\n183 2961 1605\n184 611 1384\n185 3113 885\n186 2597 1830\n187 2586 1286\n188 161 906\n189 1429 134\n190 742 1025\n191 1625 1651\n192 1187 706\n193 1787 1009\n194 22 987\n195 3640 43\n196 3756 882\n197 776 392\n198 1724 1642\n199 198 1810\n200 3950 1558\n"
  },
  {
    "path": "Post_process/krodata/kroB100.tsp",
    "content": "NAME: kroB100\nTYPE: TSP\nCOMMENT: 100-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 100\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 3140 1401\n2 556 1056\n3 3675 1522\n4 1182 1853\n5 3595 111\n6 962 1895\n7 2030 1186\n8 3507 1851\n9 2642 1269\n10 3438 901\n11 3858 1472\n12 2937 1568\n13 376 1018\n14 839 1355\n15 706 1925\n16 749 920\n17 298 615\n18 694 552\n19 387 190\n20 2801 695\n21 3133 1143\n22 1517 266\n23 1538 224\n24 844 520\n25 2639 1239\n26 3123 217\n27 2489 1520\n28 3834 1827\n29 3417 1808\n30 2938 543\n31 71 1323\n32 3245 1828\n33 731 1741\n34 2312 1270\n35 2426 1851\n36 380 478\n37 2310 635\n38 2830 775\n39 3829 513\n40 3684 445\n41 171 514\n42 627 1261\n43 1490 1123\n44 61 81\n45 422 542\n46 2698 1221\n47 2372 127\n48 177 1390\n49 3084 748\n50 1213 910\n51 3 1817\n52 1782 995\n53 3896 742\n54 1829 812\n55 1286 550\n56 3017 108\n57 2132 1432\n58 2000 1110\n59 3317 1966\n60 1729 1498\n61 2408 1747\n62 3292 152\n63 193 1210\n64 782 1462\n65 2503 352\n66 1697 1924\n67 3821 147\n68 3370 791\n69 3162 367\n70 3938 516\n71 2741 1583\n72 2330 741\n73 3918 1088\n74 1794 1589\n75 2929 485\n76 3453 1998\n77 896 705\n78 399 850\n79 2614 195\n80 2800 653\n81 2630 20\n82 563 1513\n83 1090 1652\n84 2009 1163\n85 3876 1165\n86 3084 774\n87 1526 1612\n88 1612 328\n89 1423 1322\n90 3058 1276\n91 3782 1865\n92 347 252\n93 3904 1444\n94 2191 1579\n95 3220 1454\n96 468 319\n97 3611 1968\n98 3114 1629\n99 3515 1892\n100 3060 155\n"
  },
  {
    "path": "Post_process/krodata/kroB150.tsp",
    "content": "NAME: kroB150\nTYPE: TSP\nCOMMENT: 150-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 150\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 1357 1905\n2 2650 802\n3 1774 107\n4 1307 964\n5 3806 746\n6 2687 1353\n7 43 1957\n8 3092 1668\n9 185 1542\n10 834 629\n11 40 462\n12 1183 1391\n13 2048 1628\n14 1097 643\n15 1838 1732\n16 234 1118\n17 3314 1881\n18 737 1285\n19 779 777\n20 2312 1949\n21 2576 189\n22 3078 1541\n23 2781 478\n24 705 1812\n25 3409 1917\n26 323 1714\n27 1660 1556\n28 3729 1188\n29 693 1383\n30 2361 640\n31 2433 1538\n32 554 1825\n33 913 317\n34 3586 1909\n35 2636 727\n36 1000 457\n37 482 1337\n38 3704 1082\n39 3635 1174\n40 1362 1526\n41 2049 417\n42 2552 1909\n43 3939 640\n44 219 898\n45 812 351\n46 901 1552\n47 2513 1572\n48 242 584\n49 826 1226\n50 3278 799\n51 86 1065\n52 14 454\n53 1327 1893\n54 2773 1286\n55 2469 1838\n56 3835 963\n57 1031 428\n58 3853 1712\n59 1868 197\n60 1544 863\n61 457 1607\n62 3174 1064\n63 192 1004\n64 2318 1925\n65 2232 1374\n66 396 828\n67 2365 1649\n68 2499 658\n69 1410 307\n70 2990 214\n71 3646 1018\n72 3394 1028\n73 1779 90\n74 1058 372\n75 2933 1459\n76 3099 173\n77 2178 978\n78 138 1610\n79 2082 1753\n80 2302 1127\n81 805 272\n82 22 1617\n83 3213 1085\n84 99 536\n85 1533 1780\n86 3564 676\n87 29 6\n88 3808 1375\n89 2221 291\n90 3499 1885\n91 3124 408\n92 781 671\n93 1027 1041\n94 3249 378\n95 3297 491\n96 213 220\n97 721 186\n98 3736 1542\n99 868 731\n100 960 303\n101 3825 1101\n102 2779 435\n103 201 693\n104 2502 1274\n105 765 833\n106 3105 1823\n107 1937 1400\n108 3364 1498\n109 3702 1624\n110 2164 1874\n111 3019 189\n112 3098 1594\n113 3239 1376\n114 3359 1693\n115 2081 1011\n116 1398 1100\n117 618 1953\n118 1878 59\n119 3803 886\n120 397 1217\n121 3035 152\n122 2502 146\n123 3230 380\n124 3479 1023\n125 958 1670\n126 3423 1241\n127 78 1066\n128 96 691\n129 3431 78\n130 2053 1461\n131 3048 1\n132 571 1711\n133 3393 782\n134 2835 1472\n135 144 1185\n136 923 108\n137 989 1997\n138 3061 1211\n139 2977 39\n140 1668 658\n141 878 715\n142 678 1599\n143 1086 868\n144 640 110\n145 3551 1673\n146 106 1267\n147 2243 1332\n148 3796 1401\n149 2643 1320\n150 48 267\n"
  },
  {
    "path": "Post_process/krodata/kroB200.tsp",
    "content": "NAME: kroB200\nTYPE: TSP\nCOMMENT: 200-city problem B (Krolak/Felts/Nelson)\nDIMENSION: 200\nEDGE_WEIGHT_TYPE : EUC_2D\nNODE_COORD_SECTION\n1 3140 1401\n2 556 1056\n3 3675 1522\n4 1182 1853\n5 3595 111\n6 962 1895\n7 2030 1186\n8 3507 1851\n9 2642 1269\n10 3438 901\n11 3858 1472\n12 2937 1568\n13 376 1018\n14 839 1355\n15 706 1925\n16 749 920\n17 298 615\n18 694 552\n19 387 190\n20 2801 695\n21 3133 1143\n22 1517 266\n23 1538 224\n24 844 520\n25 2639 1239\n26 3123 217\n27 2489 1520\n28 3834 1827\n29 3417 1808\n30 2938 543\n31 71 1323\n32 3245 1828\n33 731 1741\n34 2312 1270\n35 2426 1851\n36 380 478\n37 2310 635\n38 2830 775\n39 3829 513\n40 3684 445\n41 171 514\n42 627 1261\n43 1490 1123\n44 61 81\n45 422 542\n46 2698 1221\n47 2372 127\n48 177 1390\n49 3084 748\n50 1213 910\n51 3 1817\n52 1782 995\n53 3896 742\n54 1829 812\n55 1286 550\n56 3017 108\n57 2132 1432\n58 2000 1110\n59 3317 1966\n60 1729 1498\n61 2408 1747\n62 3292 152\n63 193 1210\n64 782 1462\n65 2503 352\n66 1697 1924\n67 3821 147\n68 3370 791\n69 3162 367\n70 3938 516\n71 2741 1583\n72 2330 741\n73 3918 1088\n74 1794 1589\n75 2929 485\n76 3453 1998\n77 896 705\n78 399 850\n79 2614 195\n80 2800 653\n81 2630 20\n82 563 1513\n83 1090 1652\n84 2009 1163\n85 3876 1165\n86 3084 774\n87 1526 1612\n88 1612 328\n89 1423 1322\n90 3058 1276\n91 3782 1865\n92 347 252\n93 3904 1444\n94 2191 1579\n95 3220 1454\n96 468 319\n97 3611 1968\n98 3114 1629\n99 3515 1892\n100 3060 155\n101 2995 264\n102 202 233\n103 981 848\n104 1346 408\n105 781 670\n106 1009 1001\n107 2927 1777\n108 2982 949\n109 555 1121\n110 464 1302\n111 3452 637\n112 571 1982\n113 2656 128\n114 1623 1723\n115 2067 694\n116 1725 927\n117 3600 459\n118 1109 1196\n119 366 339\n120 778 1282\n121 386 1616\n122 3918 1217\n123 3332 1049\n124 2597 349\n125 811 1295\n126 241 1069\n127 2658 360\n128 394 1944\n129 3786 1862\n130 264 36\n131 2050 1833\n132 3538 125\n133 1646 1817\n134 2993 624\n135 547 25\n136 3373 1902\n137 460 267\n138 3060 781\n139 1828 456\n140 1021 962\n141 2347 388\n142 3535 1112\n143 1529 581\n144 1203 385\n145 1787 1902\n146 2740 1101\n147 555 1753\n148 47 363\n149 3935 540\n150 3062 329\n151 387 199\n152 2901 920\n153 931 512\n154 1766 692\n155 401 980\n156 149 1629\n157 2214 1977\n158 3805 1619\n159 1179 969\n160 1017 333\n161 2834 1512\n162 634 294\n163 1819 814\n164 1393 859\n165 1768 1578\n166 3023 871\n167 3248 1906\n168 1632 1742\n169 2223 990\n170 3868 697\n171 1541 354\n172 2374 1944\n173 1962 389\n174 3007 1524\n175 3220 1945\n176 2356 1568\n177 1604 706\n178 2028 1736\n179 2581 121\n180 2221 1578\n181 2944 632\n182 1082 1561\n183 997 942\n184 2334 523\n185 1264 1090\n186 1699 1294\n187 235 1059\n188 2592 248\n189 3642 699\n190 3599 514\n191 1766 678\n192 240 619\n193 1272 246\n194 3503 301\n195 80 1533\n196 1677 1238\n197 3766 154\n198 3946 459\n199 1994 1852\n200 278 165\n"
  },
  {
    "path": "Post_process/load_all_reward.py",
    "content": "import torch\nfrom tasks import motsp\nfrom tasks.motsp import TSPDataset, reward\nfrom torch.utils.data import DataLoader\nfrom model import DRL4TSP\nfrom trainer_motsp_transfer import StateCritic\nimport numpy as np\nimport os\nimport matplotlib.pyplot as plt\nimport scipy.io as scio\nfrom Post_process.dis_matrix import dis_matrix\nimport time\n\n# Load the trained model and convert the obtained Pareto Front to the .mat file.\n# It is convenient to visualize it in matlab\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n# \"../tsp_transfer_100run_500000_5epoch_20city/20\"效果一般。应该再训练一遍\nsave_dir = \"../tsp_transfer_100run_500000_5epoch_40city/40\"\n# save_dir = \"../tsp_transfer/100\"\n# param\nupdate_fn = None\nSTATIC_SIZE = 4  # (x, y)\nDYNAMIC_SIZE = 1  # dummy for compatibility\n\n# claim model\nactor = DRL4TSP(STATIC_SIZE,\n                DYNAMIC_SIZE,\n                128,\n                update_fn,\n                motsp.update_mask,\n                1,\n                0.1).to(device)\ncritic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, 128).to(device)\n\n# data 143\nfrom Post_process.convet_kro_dataloader import Kro_dataset\nkro = 1\nD = 200\nif kro:\n    D = 200\n    Test_data = Kro_dataset(D)\n    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)\nelse:\n    # 40city_train: city20 13 city40 143 city70 2523\n    #\n    Test_data = TSPDataset(D, 1, 2523)\n    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)\n\niter_data = iter(Test_loader)\nstatic, dynamic, x0 = iter_data.next()\nstatic = static.to(device)\ndynamic = dynamic.to(device)\nx0 = x0.to(device) if len(x0) > 0 else None\n\n# load 50 models\nN=100\nw = np.arange(N+1)/N\nobjs = np.zeros((N+1,2))\nstart  = time.time()\nt1_all = 0\nt2_all = 0\ntours=[]\nfor i in range(0, N+1):\n    t1 = time.time()\n    ac = os.path.join(save_dir, \"w_%2.2f_%2.2f\" % (1-w[i], w[i]),\"actor.pt\")\n    cri = os.path.join(save_dir, \"w_%2.2f_%2.2f\" % (1-w[i], w[i]),\"critic.pt\")\n    actor.load_state_dict(torch.load(ac, device))\n    critic.load_state_dict(torch.load(cri, device))\n    t1_all = t1_all + time.time()-t1\n    # calculate\n\n    with torch.no_grad():\n        # t2 = time.time()\n        tour_indices, _ = actor.forward(static, dynamic, x0)\n        # t2_all = t2_all + time.time() - t2\n    _, obj1, obj2 = reward(static, tour_indices, 1-w[i], w[i])\n    tours.append(tour_indices.cpu().numpy())\n    objs[i,:] = [obj1, obj2]\n\nprint(\"time_load_model:%2.4f\"%t1_all)\nprint(\"time_predict_model:%2.4f\"%t2_all)\nprint(time.time()-start)\n\nprint(tours)\nplt.figure()\nplt.plot(objs[:,0],objs[:,1],\"ro\")\nplt.show()\n\n# Convert to .mat\nobj1_matrix, obj2_matrix = dis_matrix(static, STATIC_SIZE)\nscio.savemat(\"data/obj1_%d_%d.mat\"%(STATIC_SIZE, D), {'obj1':obj1_matrix})\nscio.savemat(\"data/obj2_%d_%d.mat\"%(STATIC_SIZE, D), {'obj2':obj2_matrix})\nscio.savemat(\"data/rl%d_%d.mat\"%(STATIC_SIZE, D),{'rl':objs})\nscio.savemat(\"data/tour%d_%d.mat\"%(STATIC_SIZE, D),{'tour':np.array(tours)})\n\n\n# from load_test_plot import show\n# show_if = 1\n# if show_if:\n#     i = 0\n#     ac = os.path.join(save_dir, \"w_%2.2f_%2.2f\" % (1-w[i], w[i]),\"actor.pt\")\n#     cri = os.path.join(save_dir, \"w_%2.2f_%2.2f\" % (1-w[i], w[i]),\"critic.pt\")\n#     actor.load_state_dict(torch.load(ac, device))\n#     critic.load_state_dict(torch.load(cri, device))\n#\n#     show(Test_loader, actor)\n\n"
  },
  {
    "path": "README.md",
    "content": "# Using Deep Reinforcement Learning method and Attention model to solve the Multiobjectve TSP. \n## This code is the model with four-dimension input (Euclidean-type).\n### The model with three-dimension input (Mixed-type) is in the RL_3static_MOTSP.zip.\n### Matlab code for visualzing and comparisons in the paper is in the MOTSP_compare_EMO.zip.\n\n+ Trained model is available in the tsp_transfer_... dirs.\n+ To test the model, use the load_all_rewards in Post_process dir.\n+ To train the model, run train_motsp_transfer.py\n+ To visualize the obtained Pareto Front, the result should be visulaized using Matlab.\n+ matlab code is in the .zip file. It is in the \" MOTSP_compare_EMO/Problems/Combinatorial MOPs/compare.m \". It is used to produce the figures in batch. \n    \n    > First you need to run the train_motsp_transfer.py to train the model. \n    \n    > Run the load_all_rewards.py to load and test the model. It also converts the obtained Pareto Front to the .mat file\n    \n    > Run the Matlab code to visualize the Pareto Front and compare with NSGA-II and MOEA/D\n    \n    \n\n### A lot codes are inherited from https://github.com/mveres01/pytorch-drl4vrp\n"
  },
  {
    "path": "model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n#device = torch.device('cpu')\n\n\nclass Encoder(nn.Module):\n    \"\"\"Encodes the static & dynamic states using 1d Convolution.\"\"\"\n\n    def __init__(self, input_size, hidden_size):\n        super(Encoder, self).__init__()\n        self.conv = nn.Conv1d(input_size, hidden_size, kernel_size=1)\n\n    def forward(self, input):\n        output = self.conv(input)\n        return output  # (batch, hidden_size, seq_len)\n\n\nclass Attention(nn.Module):\n    \"\"\"Calculates attention over the input nodes given the current state.\"\"\"\n\n    def __init__(self, hidden_size):\n        super(Attention, self).__init__()\n\n        # W processes features from static decoder elements\n        self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),\n                                          device=device, requires_grad=True))\n\n        self.W = nn.Parameter(torch.zeros((1, hidden_size, 3 * hidden_size),\n                                          device=device, requires_grad=True))\n\n    def forward(self, static_hidden, dynamic_hidden, decoder_hidden):\n\n        batch_size, hidden_size, _ = static_hidden.size()\n\n        hidden = decoder_hidden.unsqueeze(2).expand_as(static_hidden)\n        hidden = torch.cat((static_hidden, dynamic_hidden, hidden), 1)\n\n        # Broadcast some dimensions so we can do batch-matrix-multiply\n        v = self.v.expand(batch_size, 1, hidden_size)\n        W = self.W.expand(batch_size, hidden_size, -1)\n\n        attns = torch.bmm(v, torch.tanh(torch.bmm(W, hidden)))\n        attns = F.softmax(attns, dim=2)  # (batch, seq_len)\n        return attns\n\n\nclass Pointer(nn.Module):\n    \"\"\"Calculates the next state given the previous state and input embeddings.\"\"\"\n\n    def __init__(self, hidden_size, num_layers=1, dropout=0.2):\n        super(Pointer, self).__init__()\n\n        self.hidden_size = hidden_size\n        self.num_layers = num_layers\n\n        # Used to calculate probability of selecting next state\n        self.v = nn.Parameter(torch.zeros((1, 1, hidden_size),\n                                          device=device, requires_grad=True))\n\n        self.W = nn.Parameter(torch.zeros((1, hidden_size, 2 * hidden_size),\n                                          device=device, requires_grad=True))\n\n        # Used to compute a representation of the current decoder output\n        # GRU（输入dim，隐含层dim，层数）\n        self.gru = nn.GRU(hidden_size, hidden_size, num_layers,\n                          batch_first=True,\n                          dropout=dropout if num_layers > 1 else 0)\n        self.encoder_attn = Attention(hidden_size)\n\n        self.drop_rnn = nn.Dropout(p=dropout)\n        self.drop_hh = nn.Dropout(p=dropout)\n\n    def forward(self, static_hidden, dynamic_hidden, decoder_hidden, last_hh):\n\n        rnn_out, last_hh = self.gru(decoder_hidden.transpose(2, 1), last_hh)\n        rnn_out = rnn_out.squeeze(1)\n\n        # Always apply dropout on the RNN output\n        rnn_out = self.drop_rnn(rnn_out)\n        if self.num_layers == 1:\n            # If > 1 layer dropout is already applied\n            last_hh = self.drop_hh(last_hh) \n\n        # Given a summary of the output, find an  input context\n        enc_attn = self.encoder_attn(static_hidden, dynamic_hidden, rnn_out)\n        context = enc_attn.bmm(static_hidden.permute(0, 2, 1))  # (B, 1, num_feats)\n\n        # Calculate the next output using Batch-matrix-multiply ops\n        context = context.transpose(1, 2).expand_as(static_hidden)\n        energy = torch.cat((static_hidden, context), dim=1)  # (B, num_feats, seq_len)\n\n        v = self.v.expand(static_hidden.size(0), -1, -1)\n        W = self.W.expand(static_hidden.size(0), -1, -1)\n\n        probs = torch.bmm(v, torch.tanh(torch.bmm(W, energy))).squeeze(1)\n\n        return probs, last_hh\n\n\nclass DRL4TSP(nn.Module):\n    \"\"\"Defines the main Encoder, Decoder, and Pointer combinatorial models.\n\n    Parameters\n    ----------\n    static_size: int\n        Defines how many features are in the static elements of the model\n        (e.g. 2 for (x, y) coordinates)\n    dynamic_size: int > 1\n        Defines how many features are in the dynamic elements of the model\n        (e.g. 2 for the VRP which has (load, demand) attributes. The TSP doesn't\n        have dynamic elements, but to ensure compatility with other optimization\n        problems, assume we just pass in a vector of zeros.\n    hidden_size: int\n        Defines the number of units in the hidden layer for all static, dynamic,\n        and decoder output units.\n    update_fn: function or None\n        If provided, this method is used to calculate how the input dynamic\n        elements are updated, and is called after each 'point' to the input element.\n    mask_fn: function or None\n        Allows us to specify which elements of the input sequence are allowed to\n        be selected. This is useful for speeding up training of the networks,\n        by providing a sort of 'rules' guidlines to the algorithm. If no mask\n        is provided, we terminate the search after a fixed number of iterations\n        to avoid tours that stretch forever\n    num_layers: int\n        Specifies the number of hidden layers to use in the decoder RNN\n    dropout: float\n        Defines the dropout rate for the decoder\n    \"\"\"\n\n    def __init__(self, static_size, dynamic_size, hidden_size,\n                 update_fn=None, mask_fn=None, num_layers=1, dropout=0.):\n        super(DRL4TSP, self).__init__()\n\n        if dynamic_size < 1:\n            raise ValueError(':param dynamic_size: must be > 0, even if the '\n                             'problem has no dynamic elements')\n\n        self.update_fn = update_fn\n        self.mask_fn = mask_fn\n\n        # Define the encoder & decoder models\n        self.static_encoder = Encoder(static_size, hidden_size)\n        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)\n        self.decoder = Encoder(static_size, hidden_size)\n        self.pointer = Pointer(hidden_size, num_layers, dropout)\n\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                nn.init.xavier_uniform_(p)\n\n        # Used as a proxy initial state in the decoder when not specified\n        self.x0 = torch.zeros((1, static_size, 1), requires_grad=True, device=device)\n\n    def forward(self, static, dynamic, decoder_input=None, last_hh=None):\n        \"\"\"\n        Parameters\n        ----------\n        static: Array of size (batch_size, feats, num_cities)\n            Defines the elements to consider as static. For the TSP, this could be\n            things like the (x, y) coordinates, which won't change\n        dynamic: Array of size (batch_size, feats, num_cities)\n            Defines the elements to consider as static. For the VRP, this can be\n            things like the (load, demand) of each city. If there are no dynamic\n            elements, this can be set to None\n        decoder_input: Array of size (batch_size, num_feats)\n            Defines the outputs for the decoder. Currently, we just use the\n            static elements (e.g. (x, y) coordinates), but this can technically\n            be other things as well\n        last_hh: Array of size (batch_size, num_hidden)\n            Defines the last hidden state for the RNN\n        \"\"\"\n\n        batch_size, input_size, sequence_size = static.size()\n\n        if decoder_input is None:\n            decoder_input = self.x0.expand(batch_size, -1, -1)\n\n        # Always use a mask - if no function is provided, we don't update it\n        mask = torch.ones(batch_size, sequence_size, device=device)\n\n        # Structures for holding the output sequences\n        tour_idx, tour_logp = [], []\n        max_steps = sequence_size if self.mask_fn is None else 1000\n\n        # Static elements only need to be processed once, and can be used across\n        # all 'pointing' iterations. When / if the dynamic elements change,\n        # their representations will need to get calculated again.\n        static_hidden = self.static_encoder(static)\n        dynamic_hidden = self.dynamic_encoder(dynamic)\n\n        for _ in range(max_steps):\n\n            if not mask.byte().any():\n                break\n\n            # ... but compute a hidden rep for each element added to sequence\n            decoder_hidden = self.decoder(decoder_input)\n\n            probs, last_hh = self.pointer(static_hidden,\n                                          dynamic_hidden,\n                                          decoder_hidden, last_hh)\n            probs = F.softmax(probs + mask.log(), dim=1)\n\n            # When training, sample the next step according to its probability.\n            # During testing, we can take the greedy approach and choose highest\n            if self.training:\n                m = torch.distributions.Categorical(probs)\n\n                # Sometimes an issue with Categorical & sampling on GPU; See:\n                # https://github.com/pemami4911/neural-combinatorial-rl-pytorch/issues/5\n                ptr = m.sample()\n                while not torch.gather(mask, 1, ptr.data.unsqueeze(1)).byte().all():\n                    ptr = m.sample()\n                logp = m.log_prob(ptr)\n            else:\n                prob, ptr = torch.max(probs, 1)  # Greedy\n                logp = prob.log()\n\n            # After visiting a node update the dynamic representation\n            if self.update_fn is not None:\n                dynamic = self.update_fn(dynamic, ptr.data)\n                dynamic_hidden = self.dynamic_encoder(dynamic)\n\n                # Since we compute the VRP in minibatches, some tours may have\n                # number of stops. We force the vehicles to remain at the depot \n                # in these cases, and logp := 0\n                is_done = dynamic[:, 1].sum(1).eq(0).float()\n                logp = logp * (1. - is_done)\n\n            # And update the mask so we don't re-visit if we don't need to\n            if self.mask_fn is not None:\n                mask = self.mask_fn(mask, dynamic, ptr.data).detach()\n\n            tour_logp.append(logp.unsqueeze(1))\n            tour_idx.append(ptr.data.unsqueeze(1))\n\n            decoder_input = torch.gather(static, 2,\n                                         ptr.view(-1, 1, 1)\n                                         .expand(-1, input_size, 1)).detach()\n\n        tour_idx = torch.cat(tour_idx, dim=1)  # (batch_size, seq_len)\n        tour_logp = torch.cat(tour_logp, dim=1)  # (batch_size, seq_len)\n\n        return tour_idx, tour_logp\n\n\nif __name__ == '__main__':\n    raise Exception('Cannot be called from main')\n"
  },
  {
    "path": "parameter_transfer.py",
    "content": "import torch\nimport os\nfrom model import DRL4TSP, Encoder\nimport argparse\nfrom tasks import motsp\nfrom trainer_motsp_transfer import StateCritic\n\n'''\nThis file is used to test. It has been obsoleted\nThis file is used to convert the trained single-TSP PN model to the parameters from which we can transfer.\nThe trained single-TSP PN model can be found here: https://github.com/mveres01/pytorch-drl4vrp. Save it as \"tsp20\".\nThen the start-up parameters for the first subproblem of the MOTSP to transfer can be obtained.\n'''\n\n\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\nSTATIC_SIZE_original = 2  # (x, y)\nSTATIC_SIZE = 3  # (x, y)\nDYNAMIC_SIZE = 1  # dummy for compatibility\nupdate_fn = None\nhidden_size = 128\nnum_layers = 1\ndropout = 0.1\ncheckpoint = \"tsp20\"\nactor = DRL4TSP(STATIC_SIZE_original,\n                DYNAMIC_SIZE,\n                hidden_size,\n                update_fn,\n                motsp.update_mask,\n                num_layers,\n                dropout).to(device)\n\ncritic = StateCritic(STATIC_SIZE_original, DYNAMIC_SIZE, hidden_size).to(device)\n# 加载原128*2*1的原模型\npath = os.path.join(checkpoint, 'actor.pt')\nactor.load_state_dict(torch.load(path, device))\n\npath = os.path.join(checkpoint, 'critic.pt')\ncritic.load_state_dict(torch.load(path, device))\n# 其中actor的static_encoder，decoder需要更改维度，critic需要更改维度\n# static_encoder\nstatic_parameter = actor.static_encoder.state_dict()\ntemp = static_parameter['conv.weight']\ntemp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列\nstatic_parameter['conv.weight'] = temp\nactor.static_encoder = Encoder(STATIC_SIZE, hidden_size)\nactor.static_encoder.load_state_dict(static_parameter)\n# decoder\nstatic_parameter = actor.decoder.state_dict()\ntemp = static_parameter['conv.weight']\ntemp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列\nstatic_parameter['conv.weight'] = temp\nactor.decoder = Encoder(STATIC_SIZE, hidden_size)\nactor.decoder.load_state_dict(static_parameter)\n\n# CRITIC\nstatic_parameter = critic.static_encoder.state_dict()\ntemp = static_parameter['conv.weight']\ntemp = torch.cat([temp, temp[:,1,:].unsqueeze(1)], dim=1)   # 在第二维拓展一列\nstatic_parameter['conv.weight'] = temp\ncritic.static_encoder = Encoder(STATIC_SIZE, hidden_size)\ncritic.static_encoder.load_state_dict(static_parameter)\n\nsave_path = os.path.join(\"modified_checkpoint_3obj\", 'actor.pt')\ntorch.save(actor.state_dict(), save_path)\nsave_path = os.path.join(\"modified_checkpoint_3obj\", 'critic.pt')\ntorch.save(critic.state_dict(), save_path)\n\nprint(actor,critic)\n"
  },
  {
    "path": "tasks/motsp.py",
    "content": "\"\"\"Defines the main task for the TSP\n\nThe TSP is defined by the following traits:\n    1. Each city in the list must be visited once and only once\n    2. The salesman must return to the original node at the end of the tour\n\nSince the TSP doesn't have dynamic elements, we return an empty list on\n__getitem__, which gets processed in trainer.py to be None\n\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nimport matplotlib\n# matplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n\nclass TSPDataset(Dataset):\n\n    def __init__(self, size=50, num_samples=1e6, seed=None):\n        super(TSPDataset, self).__init__()\n\n        if seed is None:\n            seed = np.random.randint(123456789)\n\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        self.dataset = torch.rand((num_samples, 4, size))\n        self.dynamic = torch.zeros(num_samples, 1, size)\n        self.num_nodes = size\n        self.size = num_samples\n\n\n    def __len__(self):\n        return self.size\n\n    def __getitem__(self, idx):\n        # (static, dynamic, start_loc)\n        return (self.dataset[idx], self.dynamic[idx], [])\n\n\ndef update_mask(mask, dynamic, chosen_idx):\n    \"\"\"Marks the visited city, so it can't be selected a second time.\"\"\"\n    mask.scatter_(1, chosen_idx.unsqueeze(1), 0)\n    return mask\n\n\ndef reward(static, tour_indices, w1=1, w2=0):\n    \"\"\"\n    Parameters\n    ----------\n    static: torch.FloatTensor containing static (e.g. x, y) data\n    tour_indices: torch.IntTensor of size (batch_size, num_cities)\n\n    Returns\n    -------\n    Euclidean distance between consecutive nodes on the route. of size\n    (batch_size, num_cities)\n    \"\"\"\n\n    # Convert the indices back into a tour\n    idx = tour_indices.unsqueeze(1).expand_as(static)\n    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)\n\n    # Make a full tour by returning to the start\n    y = torch.cat((tour, tour[:, :1]), dim=1)\n    # first 2 is xy coordinate, third column is another obj\n    y_dis = y[:, :, :2]\n    y_dis2 = y[:, :, 2:]\n\n    # Euclidean distance between each consecutive point\n    tour_len = torch.sqrt(torch.sum(torch.pow(y_dis[:, :-1] - y_dis[:, 1:], 2), dim=2))\n    obj1 = tour_len.sum(1).detach()\n\n    tour_len2 = torch.sqrt(torch.sum(torch.pow(y_dis2[:, :-1] - y_dis2[:, 1:], 2), dim=2))\n    obj2 = tour_len2.sum(1).detach()\n\n    obj = w1*obj1 + w2*obj2\n    return obj, obj1, obj2\n\n\n\ndef render(static, tour_indices, save_path):\n    \"\"\"Plots the found tours.\"\"\"\n\n    plt.close('all')\n\n    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1\n\n    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,\n                           sharex='col', sharey='row')\n\n    if num_plots == 1:\n        axes = [[axes]]\n    axes = [a for ax in axes for a in ax]\n\n    for i, ax in enumerate(axes):\n\n        # Convert the indices back into a tour\n        idx = tour_indices[i]\n        if len(idx.size()) == 1:\n            idx = idx.unsqueeze(0)\n\n        # End tour at the starting index\n        idx = idx.expand(static.size(1), -1)\n        idx = torch.cat((idx, idx[:, 0:1]), dim=1)\n\n        data = torch.gather(static[i].data, 1, idx).cpu().numpy()\n\n        #plt.subplot(num_plots, num_plots, i + 1)\n        ax.plot(data[0], data[1], zorder=1)\n        ax.scatter(data[0], data[1], s=4, c='r', zorder=2)\n        ax.scatter(data[0, 0], data[1, 0], s=20, c='k', marker='*', zorder=3)\n\n        ax.set_xlim(0, 1)\n        ax.set_ylim(0, 1)\n\n    plt.tight_layout()\n    plt.savefig(save_path, bbox_inches='tight', dpi=400)\n"
  },
  {
    "path": "tasks/tsp.py",
    "content": "\"\"\"Defines the main task for the TSP\n\nThe TSP is defined by the following traits:\n    1. Each city in the list must be visited once and only once\n    2. The salesman must return to the original node at the end of the tour\n\nSince the TSP doesn't have dynamic elements, we return an empty list on\n__getitem__, which gets processed in trainer.py to be None\n\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nimport matplotlib\n# matplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n\nclass TSPDataset(Dataset):\n\n    def __init__(self, size=50, num_samples=1e6, seed=None):\n        super(TSPDataset, self).__init__()\n\n        if seed is None:\n            seed = np.random.randint(123456789)\n\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        self.dataset = torch.rand((num_samples, 2, size))\n        self.dynamic = torch.zeros(num_samples, 1, size)\n        self.num_nodes = size\n        self.size = num_samples\n\n    def __len__(self):\n        return self.size\n\n    def __getitem__(self, idx):\n        # (static, dynamic, start_loc)\n        return (self.dataset[idx], self.dynamic[idx], [])\n\n\ndef update_mask(mask, dynamic, chosen_idx):\n    \"\"\"Marks the visited city, so it can't be selected a second time.\"\"\"\n    mask.scatter_(1, chosen_idx.unsqueeze(1), 0)\n    return mask\n\n\ndef reward(static, tour_indices):\n    \"\"\"\n    Parameters\n    ----------\n    static: torch.FloatTensor containing static (e.g. x, y) data\n    tour_indices: torch.IntTensor of size (batch_size, num_cities)\n\n    Returns\n    -------\n    Euclidean distance between consecutive nodes on the route. of size\n    (batch_size, num_cities)\n    \"\"\"\n\n    # Convert the indices back into a tour\n    idx = tour_indices.unsqueeze(1).expand_as(static)\n    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)\n\n    # Make a full tour by returning to the start\n    y = torch.cat((tour, tour[:, :1]), dim=1)\n\n    # Euclidean distance between each consecutive point\n    tour_len = torch.sqrt(torch.sum(torch.pow(y[:, :-1] - y[:, 1:], 2), dim=2))\n\n    return tour_len.sum(1).detach()\n\n\ndef render(static, tour_indices, save_path):\n    \"\"\"Plots the found tours.\"\"\"\n\n    plt.close('all')\n\n    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1\n\n    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,\n                           sharex='col', sharey='row')\n\n    if num_plots == 1:\n        axes = [[axes]]\n    axes = [a for ax in axes for a in ax]\n\n    for i, ax in enumerate(axes):\n\n        # Convert the indices back into a tour\n        idx = tour_indices[i]\n        if len(idx.size()) == 1:\n            idx = idx.unsqueeze(0)\n\n        # End tour at the starting index\n        idx = idx.expand(static.size(1), -1)\n        idx = torch.cat((idx, idx[:, 0:1]), dim=1)\n\n        data = torch.gather(static[i].data, 1, idx).cpu().numpy()\n\n        #plt.subplot(num_plots, num_plots, i + 1)\n        ax.plot(data[0], data[1], zorder=1)\n        ax.scatter(data[0], data[1], s=4, c='r', zorder=2)\n        ax.scatter(data[0, 0], data[1, 0], s=20, c='k', marker='*', zorder=3)\n\n        ax.set_xlim(0, 1)\n        ax.set_ylim(0, 1)\n\n    plt.tight_layout()\n    plt.savefig(save_path, bbox_inches='tight', dpi=400)\n"
  },
  {
    "path": "tasks/vrp.py",
    "content": "\"\"\"Defines the main task for the VRP.\n\nThe VRP is defined by the following traits:\n    1. Each city has a demand in [1, 9], which must be serviced by the vehicle\n    2. Each vehicle has a capacity (depends on problem), the must visit all cities\n    3. When the vehicle load is 0, it __must__ return to the depot to refill\n\"\"\"\n\nimport os\nimport numpy as np\nimport torch\nfrom torch.utils.data import Dataset\nfrom torch.autograd import Variable\nimport matplotlib\nmatplotlib.use('Agg')\nimport matplotlib.pyplot as plt\n\n\nclass VehicleRoutingDataset(Dataset):\n    def __init__(self, num_samples, input_size, max_load=20, max_demand=9,\n                 seed=None):\n        super(VehicleRoutingDataset, self).__init__()\n\n        if max_load < max_demand:\n            raise ValueError(':param max_load: must be > max_demand')\n\n        if seed is None:\n            seed = np.random.randint(1234567890)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n\n        self.num_samples = num_samples\n        self.max_load = max_load\n        self.max_demand = max_demand\n\n        # Depot location will be the first node in each\n        locations = torch.rand((num_samples, 2, input_size + 1))\n        self.static = locations\n\n        # All states will broadcast the drivers current load\n        # Note that we only use a load between [0, 1] to prevent large\n        # numbers entering the neural network\n        dynamic_shape = (num_samples, 1, input_size + 1)\n        loads = torch.full(dynamic_shape, 1.)\n\n        # All states will have their own intrinsic demand in [1, max_demand), \n        # then scaled by the maximum load. E.g. if load=10 and max_demand=30, \n        # demands will be scaled to the range (0, 3)\n        #######################\n        # demands = torch.randint(1, max_demand + 1, dynamic_shape)\n        demands = torch.randint(1, max_demand + 1, dynamic_shape).float()\n        demands = demands / float(max_load)\n\n        demands[:, 0, 0] = 0  # depot starts with a demand of 0\n        self.dynamic = torch.tensor(np.concatenate((loads, demands), axis=1))\n\n    def __len__(self):\n        return self.num_samples\n\n    def __getitem__(self, idx):\n        # (static, dynamic, start_loc)\n        return (self.static[idx], self.dynamic[idx], self.static[idx, :, 0:1])\n\n    def update_mask(self, mask, dynamic, chosen_idx=None):\n        \"\"\"Updates the mask used to hide non-valid states.\n\n        Parameters\n        ----------\n        dynamic: torch.autograd.Variable of size (1, num_feats, seq_len)\n        \"\"\"\n\n        # Convert floating point to integers for calculations\n        loads = dynamic.data[:, 0]  # (batch_size, seq_len)\n        demands = dynamic.data[:, 1]  # (batch_size, seq_len)\n\n        # If there is no positive demand left, we can end the tour.\n        # Note that the first node is the depot, which always has a negative demand\n        if demands.eq(0).all():\n            return demands * 0.\n\n        # Otherwise, we can choose to go anywhere where demand is > 0\n        new_mask = demands.ne(0) * demands.lt(loads)\n\n        # We should avoid traveling to the depot back-to-back\n        repeat_home = chosen_idx.ne(0)\n\n        if repeat_home.any():\n            new_mask[repeat_home.nonzero(), 0] = 1.\n        if (1 - repeat_home).any():\n            new_mask[(1 - repeat_home).nonzero(), 0] = 0.\n\n        # ... unless we're waiting for all other samples in a minibatch to finish\n        has_no_load = loads[:, 0].eq(0).float()\n        has_no_demand = demands[:, 1:].sum(1).eq(0).float()\n\n        combined = (has_no_load + has_no_demand).gt(0)\n        if combined.any():\n            new_mask[combined.nonzero(), 0] = 1.\n            new_mask[combined.nonzero(), 1:] = 0.\n\n        return new_mask.float()\n\n    def update_dynamic(self, dynamic, chosen_idx):\n        \"\"\"Updates the (load, demand) dataset values.\"\"\"\n\n        # Update the dynamic elements differently for if we visit depot vs. a city\n        visit = chosen_idx.ne(0)\n        depot = chosen_idx.eq(0)\n\n        # Clone the dynamic variable so we don't mess up graph\n        all_loads = dynamic[:, 0].clone()\n        all_demands = dynamic[:, 1].clone()\n\n        load = torch.gather(all_loads, 1, chosen_idx.unsqueeze(1))\n        demand = torch.gather(all_demands, 1, chosen_idx.unsqueeze(1))\n\n        # Across the minibatch - if we've chosen to visit a city, try to satisfy\n        # as much demand as possible\n        if visit.any():\n\n            new_load = torch.clamp(load - demand, min=0)\n            new_demand = torch.clamp(demand - load, min=0)\n\n            # Broadcast the load to all nodes, but update demand seperately\n            visit_idx = visit.nonzero().squeeze()\n\n            all_loads[visit_idx] = new_load[visit_idx]\n            all_demands[visit_idx, chosen_idx[visit_idx]] = new_demand[visit_idx].view(-1)\n            all_demands[visit_idx, 0] = -1. + new_load[visit_idx].view(-1)\n\n        # Return to depot to fill vehicle load\n        if depot.any():\n            all_loads[depot.nonzero().squeeze()] = 1.\n            all_demands[depot.nonzero().squeeze(), 0] = 0.\n\n        tensor = torch.cat((all_loads.unsqueeze(1), all_demands.unsqueeze(1)), 1)\n        return torch.tensor(tensor.data, device=dynamic.device)\n\n\ndef reward(static, tour_indices):\n    \"\"\"\n    Euclidean distance between all cities / nodes given by tour_indices\n    \"\"\"\n\n    # Convert the indices back into a tour\n    idx = tour_indices.unsqueeze(1).expand(-1, static.size(1), -1)\n    tour = torch.gather(static.data, 2, idx).permute(0, 2, 1)\n\n    # Ensure we're always returning to the depot - note the extra concat\n    # won't add any extra loss, as the euclidean distance between consecutive\n    # points is 0\n    start = static.data[:, :, 0].unsqueeze(1)\n    y = torch.cat((start, tour, start), dim=1)\n\n    # Euclidean distance between each consecutive point\n    tour_len = torch.sqrt(torch.sum(torch.pow(y[:, :-1] - y[:, 1:], 2), dim=2))\n\n    return tour_len.sum(1)\n\n\ndef render(static, tour_indices, save_path):\n    \"\"\"Plots the found solution.\"\"\"\n\n    plt.close('all')\n\n    num_plots = 3 if int(np.sqrt(len(tour_indices))) >= 3 else 1\n\n    _, axes = plt.subplots(nrows=num_plots, ncols=num_plots,\n                           sharex='col', sharey='row')\n\n    if num_plots == 1:\n        axes = [[axes]]\n    axes = [a for ax in axes for a in ax]\n\n    for i, ax in enumerate(axes):\n\n        # Convert the indices back into a tour\n        idx = tour_indices[i]\n        if len(idx.size()) == 1:\n            idx = idx.unsqueeze(0)\n\n        idx = idx.expand(static.size(1), -1)\n        data = torch.gather(static[i].data, 1, idx).cpu().numpy()\n\n        start = static[i, :, 0].cpu().data.numpy()\n        x = np.hstack((start[0], data[0], start[0]))\n        y = np.hstack((start[1], data[1], start[1]))\n\n        # Assign each subtour a different colour & label in order traveled\n        idx = np.hstack((0, tour_indices[i].cpu().numpy().flatten(), 0))\n        where = np.where(idx == 0)[0]\n\n        for j in range(len(where) - 1):\n\n            low = where[j]\n            high = where[j + 1]\n\n            if low + 1 == high:\n                continue\n\n            ax.plot(x[low: high + 1], y[low: high + 1], zorder=1, label=j)\n\n        ax.legend(loc=\"upper right\", fontsize=3, framealpha=0.5)\n        ax.scatter(x, y, s=4, c='r', zorder=2)\n        ax.scatter(x[0], y[0], s=20, c='k', marker='*', zorder=3)\n\n        ax.set_xlim(0, 1)\n        ax.set_ylim(0, 1)\n\n    plt.tight_layout()\n    plt.savefig(save_path, bbox_inches='tight', dpi=200)\n\n\n'''\ndef render(static, tour_indices, save_path):\n    \"\"\"Plots the found solution.\"\"\"\n\n    path = 'C:/Users/Matt/Documents/ffmpeg-3.4.2-win64-static/bin/ffmpeg.exe'\n    plt.rcParams['animation.ffmpeg_path'] = path\n\n    plt.close('all')\n\n    num_plots = min(int(np.sqrt(len(tour_indices))), 3)\n    fig, axes = plt.subplots(nrows=num_plots, ncols=num_plots,\n                             sharex='col', sharey='row')\n    axes = [a for ax in axes for a in ax]\n\n    all_lines = []\n    all_tours = []\n    for i, ax in enumerate(axes):\n\n        # Convert the indices back into a tour\n        idx = tour_indices[i]\n        if len(idx.size()) == 1:\n            idx = idx.unsqueeze(0)\n\n        idx = idx.expand(static.size(1), -1)\n        data = torch.gather(static[i].data, 1, idx).cpu().numpy()\n\n        start = static[i, :, 0].cpu().data.numpy()\n        x = np.hstack((start[0], data[0], start[0]))\n        y = np.hstack((start[1], data[1], start[1]))\n\n        cur_tour = np.vstack((x, y))\n\n        all_tours.append(cur_tour)\n        all_lines.append(ax.plot([], [])[0])\n\n        ax.scatter(x, y, s=4, c='r', zorder=2)\n        ax.scatter(x[0], y[0], s=20, c='k', marker='*', zorder=3)\n\n    from matplotlib.animation import FuncAnimation\n\n    tours = all_tours\n\n    def update(idx):\n\n        for i, line in enumerate(all_lines):\n\n            if idx >= tours[i].shape[1]:\n                continue\n\n            data = tours[i][:, idx]\n\n            xy_data = line.get_xydata()\n            xy_data = np.vstack((xy_data, np.atleast_2d(data)))\n\n            line.set_data(xy_data[:, 0], xy_data[:, 1])\n            line.set_linewidth(0.75)\n\n        return all_lines\n\n    anim = FuncAnimation(fig, update, init_func=None,\n                         frames=100, interval=200, blit=False,\n                         repeat=False)\n\n    anim.save('line.mp4', dpi=160)\n    plt.show()\n\n    import sys\n    sys.exit(1)\n'''\n"
  },
  {
    "path": "trainer_motsp_no_transfer.py",
    "content": "\"\"\"Defines the main trainer model for combinatorial problems\n\nEach task must define the following functions:\n* mask_fn: can be None\n* update_fn: can be None\n* reward_fn: specifies the quality of found solutions\n* render_fn: Specifies how to plot found solutions. Can be None\n\"\"\"\n\nimport os\nimport time\nimport argparse\nimport datetime\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom model import DRL4TSP, Encoder\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n#device = torch.device('cpu')\n\n\nclass StateCritic(nn.Module):\n    \"\"\"Estimates the problem complexity.\n\n    This is a basic module that just looks at the log-probabilities predicted by\n    the encoder + decoder, and returns an estimate of complexity\n    \"\"\"\n\n    def __init__(self, static_size, dynamic_size, hidden_size):\n        super(StateCritic, self).__init__()\n\n        self.static_encoder = Encoder(static_size, hidden_size)\n        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)\n\n        # Define the encoder & decoder models\n        self.fc1 = nn.Conv1d(hidden_size * 2, 20, kernel_size=1)\n        self.fc2 = nn.Conv1d(20, 20, kernel_size=1)\n        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)\n\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, static, dynamic):\n\n        # Use the probabilities of visiting each\n        static_hidden = self.static_encoder(static)\n        dynamic_hidden = self.dynamic_encoder(dynamic)\n\n        hidden = torch.cat((static_hidden, dynamic_hidden), 1)\n\n        output = F.relu(self.fc1(hidden))\n        output = F.relu(self.fc2(output))\n        output = self.fc3(output).sum(dim=2)\n        return output\n\n\nclass Critic(nn.Module):\n    \"\"\"Estimates the problem complexity.\n\n    This is a basic module that just looks at the log-probabilities predicted by\n    the encoder + decoder, and returns an estimate of complexity\n    \"\"\"\n\n    def __init__(self, hidden_size):\n        super(Critic, self).__init__()\n\n        # Define the encoder & decoder models\n        self.fc1 = nn.Conv1d(1, hidden_size, kernel_size=1)\n        self.fc2 = nn.Conv1d(hidden_size, 20, kernel_size=1)\n        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)\n\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, input):\n\n        output = F.relu(self.fc1(input.unsqueeze(1)))\n        output = F.relu(self.fc2(output)).squeeze(2)\n        output = self.fc3(output).sum(dim=2)\n        return output\n\n\ndef validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save_dir='.',\n             num_plot=5):\n    \"\"\"Used to monitor progress on a validation set & optionally plot solution.\"\"\"\n\n    actor.eval()\n\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n\n    rewards = []\n    obj1s = []\n    obj2s = []\n    for batch_idx, batch in enumerate(data_loader):\n\n        static, dynamic, x0 = batch\n\n        static = static.to(device)\n        dynamic = dynamic.to(device)\n        x0 = x0.to(device) if len(x0) > 0 else None\n\n        with torch.no_grad():\n            tour_indices, _ = actor.forward(static, dynamic, x0)\n\n        reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)\n\n        rewards.append(torch.mean(reward.detach()).item())\n        obj1s.append(torch.mean(obj1.detach()).item())\n        obj2s.append(torch.mean(obj2.detach()).item())\n        if render_fn is not None and batch_idx < num_plot:\n            name = 'batch%d_%2.4f.png'%(batch_idx, torch.mean(reward.detach()).item())\n            path = os.path.join(save_dir, name)\n            render_fn(static, tour_indices, path)\n\n    actor.train()\n    return np.mean(rewards), np.mean(obj1s), np.mean(obj2s)\n\n\ndef train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data, reward_fn,\n          render_fn, batch_size, actor_lr, critic_lr, max_grad_norm,\n          **kwargs):\n    \"\"\"Constructs the main actor & critic networks, and performs all training.\"\"\"\n\n    now = '%s' % datetime.datetime.now().time()\n    now = now.replace(':', '_')\n    bname = \"_4static\"\n    save_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2), now)\n\n    checkpoint_dir = os.path.join(save_dir, 'checkpoints')\n    if not os.path.exists(checkpoint_dir):\n        os.makedirs(checkpoint_dir)\n\n    actor_optim = optim.Adam(actor.parameters(), lr=actor_lr)\n    critic_optim = optim.Adam(critic.parameters(), lr=critic_lr)\n\n    train_loader = DataLoader(train_data, batch_size, True, num_workers=0)\n    valid_loader = DataLoader(valid_data, batch_size, False, num_workers=0)\n\n    best_params = None\n    best_reward = np.inf\n\n    for epoch in range(5):\n        print(\"epoch %d start:\"% epoch)\n        actor.train()\n        critic.train()\n\n        times, losses, rewards, critic_rewards = [], [], [], []\n        obj1s, obj2s = [], []\n\n        epoch_start = time.time()\n        start = epoch_start\n\n        for batch_idx, batch in enumerate(train_loader):\n\n            static, dynamic, x0 = batch\n\n            static = static.to(device)\n            dynamic = dynamic.to(device)\n            x0 = x0.to(device) if len(x0) > 0 else None\n\n            # Full forward pass through the dataset\n            tour_indices, tour_logp = actor(static, dynamic, x0)\n\n            # Sum the log probabilities for each city in the tour\n            reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)\n\n            # Query the critic for an estimate of the reward\n            critic_est = critic(static, dynamic).view(-1)\n\n            advantage = (reward - critic_est)\n            actor_loss = torch.mean(advantage.detach() * tour_logp.sum(dim=1))\n            critic_loss = torch.mean(advantage ** 2)\n\n            actor_optim.zero_grad()\n            actor_loss.backward()\n            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_grad_norm)\n            actor_optim.step()\n\n            critic_optim.zero_grad()\n            critic_loss.backward()\n            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_grad_norm)\n            critic_optim.step()\n\n            critic_rewards.append(torch.mean(critic_est.detach()).item())\n            rewards.append(torch.mean(reward.detach()).item())\n            losses.append(torch.mean(actor_loss.detach()).item())\n            obj1s.append(torch.mean(obj1.detach()).item())\n            obj2s.append(torch.mean(obj2.detach()).item())\n            if (batch_idx + 1) % 200 == 0:\n                print(\"\\n\")\n                end = time.time()\n                times.append(end - start)\n                start = end\n\n                mean_loss = np.mean(losses[-100:])\n                mean_reward = np.mean(rewards[-100:])\n                mean_obj1 = np.mean(obj1s[-100:])\n                mean_obj2 = np.mean(obj2s[-100:])\n                print('  Batch %d/%d, reward: %2.3f, obj1: %2.3f, obj2: %2.3f, loss: %2.4f, took: %2.4fs' %\n                      (batch_idx, len(train_loader), mean_reward, mean_obj1, mean_obj2, mean_loss,\n                       times[-1]))\n\n        mean_loss = np.mean(losses)\n        mean_reward = np.mean(rewards)\n\n        # Save the weights\n        epoch_dir = os.path.join(checkpoint_dir, '%s' % epoch)\n        if not os.path.exists(epoch_dir):\n            os.makedirs(epoch_dir)\n\n        save_path = os.path.join(epoch_dir, 'actor.pt')\n        torch.save(actor.state_dict(), save_path)\n\n        save_path = os.path.join(epoch_dir, 'critic.pt')\n        torch.save(critic.state_dict(), save_path)\n\n        # Save rendering of validation set tours\n        valid_dir = os.path.join(save_dir, '%s' % epoch)\n\n        print(\"begin valid\")\n        s = time.time()\n        mean_valid, mean_obj1_valid, mean_obj2_valid = validate(valid_loader, actor, reward_fn, w1, w2, render_fn,\n                              valid_dir, num_plot=5)\n        print(\"valid end time: %2.4f\" % (time.time()-s) )\n        # Save best model parameters\n        if mean_valid < best_reward:\n\n            best_reward = mean_valid\n\n            # save_path = os.path.join(save_dir, 'actor.pt')\n            # torch.save(actor.state_dict(), save_path)\n            #\n            # save_path = os.path.join(save_dir, 'critic.pt')\n            # torch.save(critic.state_dict(), save_path)\n            # 存在w_1_0主文件夹下，多存一份，用来transfer to next w\n            main_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2))\n            save_path = os.path.join(main_dir, 'actor.pt')\n            torch.save(actor.state_dict(), save_path)\n            save_path = os.path.join(main_dir, 'critic.pt')\n            torch.save(critic.state_dict(), save_path)\n\n        print('Mean epoch loss/reward: %2.4f, %2.4f, %2.4f, obj1_valid: %2.3f, obj2_valid: %2.3f. took: %2.4fs '\\\n              '(%2.4fs / 100 batches)\\n' % \\\n              (mean_loss, mean_reward, mean_valid, mean_obj1_valid, mean_obj2_valid, time.time() - epoch_start,\n              np.mean(times)))\n\n\n\ndef train_tsp(args, w1=1, w2=0, checkpoint = None):\n\n    # Goals from paper:\n    # TSP20, 3.97\n    # TSP50, 6.08\n    # TSP100, 8.44\n\n    from tasks import motsp\n    from tasks.motsp import TSPDataset\n\n    STATIC_SIZE = 4 # (x, y)\n    DYNAMIC_SIZE = 1 # dummy for compatibility\n\n    train_data = TSPDataset(args.num_nodes, args.train_size, args.seed)\n    valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1)\n\n    update_fn = None\n\n    actor = DRL4TSP(STATIC_SIZE,\n                    DYNAMIC_SIZE,\n                    args.hidden_size,\n                    update_fn,\n                    motsp.update_mask,\n                    args.num_layers,\n                    args.dropout).to(device)\n\n    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)\n\n    kwargs = vars(args)\n    kwargs['train_data'] = train_data\n    kwargs['valid_data'] = valid_data\n    kwargs['reward_fn'] = motsp.reward\n    kwargs['render_fn'] = motsp.render\n\n    if checkpoint:\n        path = os.path.join(checkpoint, 'actor.pt')\n        actor.load_state_dict(torch.load(path, device))\n        # actor.static_encoder.state_dict().get(\"conv.weight\").size()\n        path = os.path.join(checkpoint, 'critic.pt')\n        critic.load_state_dict(torch.load(path, device))\n\n    if not args.test:\n        train(actor, critic, w1, w2, **kwargs)\n\n    test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2)\n\n    test_dir = 'test'\n    test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0)\n    out = validate(test_loader, actor, motsp.reward, w1, w2, motsp.render, test_dir, num_plot=5)\n\n    print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)\n\n\ndef train_vrp(args):\n\n    # Goals from paper:\n    # VRP10, Capacity 20:  4.84  (Greedy)\n    # VRP20, Capacity 30:  6.59  (Greedy)\n    # VRP50, Capacity 40:  11.39 (Greedy)\n    # VRP100, Capacity 50: 17.23  (Greedy)\n\n    from tasks import vrp\n    from tasks.vrp import VehicleRoutingDataset\n\n    # Determines the maximum amount of load for a vehicle based on num nodes\n    LOAD_DICT = {10: 20, 20: 30, 50: 40, 100: 50}\n    MAX_DEMAND = 9\n    STATIC_SIZE = 2 # (x, y)\n    DYNAMIC_SIZE = 2 # (load, demand)\n\n    max_load = LOAD_DICT[args.num_nodes]\n\n    train_data = VehicleRoutingDataset(args.train_size,\n                                       args.num_nodes,\n                                       max_load,\n                                       MAX_DEMAND,\n                                       args.seed)\n\n    valid_data = VehicleRoutingDataset(args.valid_size,\n                                       args.num_nodes,\n                                       max_load,\n                                       MAX_DEMAND,\n                                       args.seed + 1)\n\n    actor = DRL4TSP(STATIC_SIZE,\n                    DYNAMIC_SIZE,\n                    args.hidden_size,\n                    train_data.update_dynamic,\n                    train_data.update_mask,\n                    args.num_layers,\n                    args.dropout).to(device)\n\n    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)\n\n    kwargs = vars(args)\n    kwargs['train_data'] = train_data\n    kwargs['valid_data'] = valid_data\n    kwargs['reward_fn'] = vrp.reward\n    kwargs['render_fn'] = vrp.render\n\n    if args.checkpoint:\n        path = os.path.join(args.checkpoint, 'actor.pt')\n        actor.load_state_dict(torch.load(path, device))\n\n        path = os.path.join(args.checkpoint, 'critic.pt')\n        critic.load_state_dict(torch.load(path, device))\n\n    if not args.test:\n        train(actor, critic, **kwargs)\n\n    test_data = VehicleRoutingDataset(args.valid_size,\n                                      args.num_nodes,\n                                      max_load,\n                                      MAX_DEMAND,\n                                      args.seed + 2)\n\n    test_dir = 'test'\n    test_loader = DataLoader(test_data, args.batch_size, False, num_workers=0)\n    out = validate(test_loader, actor, vrp.reward, vrp.render, test_dir, num_plot=5)\n\n    print('Average tour length: ', out)\n\n\nif __name__ == '__main__':\n\n    parser = argparse.ArgumentParser(description='Combinatorial Optimization')\n    parser.add_argument('--seed', default=12345, type=int)\n    # parser.add_argument('--checkpoint', default=\"tsp/20/w_1_0/20_06_30.888074\")\n    parser.add_argument('--test', action='store_true', default=False)\n    parser.add_argument('--task', default='tsp')\n    parser.add_argument('--nodes', dest='num_nodes', default=40, type=int)\n    parser.add_argument('--actor_lr', default=5e-4, type=float)\n    parser.add_argument('--critic_lr', default=5e-4, type=float)\n    parser.add_argument('--max_grad_norm', default=2., type=float)\n    parser.add_argument('--batch_size', default=200, type=int)\n    parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)\n    parser.add_argument('--dropout', default=0.1, type=float)\n    parser.add_argument('--layers', dest='num_layers', default=1, type=int)\n    parser.add_argument('--train-size',default=500000, type=int)\n    parser.add_argument('--valid-size', default=1000, type=int)\n\n    args = parser.parse_args()\n\n    # Trained without transfer\n\n    if args.task == 'tsp':\n        w2_list = np.arange(101)/100\n        for i in range(0,101):\n            print(\"Current w:%2.2f/%2.2f\"% (1-w2_list[i], w2_list[i]))\n            train_tsp(args, 1-w2_list[i], w2_list[i], None)\n\n    elif args.task == 'vrp':\n        train_vrp(args)\n    else:\n        raise ValueError('Task <%s> not understood'%args.task)\n"
  },
  {
    "path": "trainer_motsp_transfer.py",
    "content": "\"\"\"Defines the main trainer model for combinatorial problems\n\nEach task must define the following functions:\n* mask_fn: can be None\n* update_fn: can be None\n* reward_fn: specifies the quality of found solutions\n* render_fn: Specifies how to plot found solutions. Can be None\n\"\"\"\n\nimport os\nimport time\nimport argparse\nimport datetime\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\n\nfrom model import DRL4TSP, Encoder\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n# device = torch.device('cpu')\n\n\nclass StateCritic(nn.Module):\n    \"\"\"Estimates the problem complexity.\n\n    This is a basic module that just looks at the log-probabilities predicted by\n    the encoder + decoder, and returns an estimate of complexity\n    \"\"\"\n\n    def __init__(self, static_size, dynamic_size, hidden_size):\n        super(StateCritic, self).__init__()\n\n        self.static_encoder = Encoder(static_size, hidden_size)\n        self.dynamic_encoder = Encoder(dynamic_size, hidden_size)\n\n        # Define the encoder & decoder models\n        self.fc1 = nn.Conv1d(hidden_size * 2, 20, kernel_size=1)\n        self.fc2 = nn.Conv1d(20, 20, kernel_size=1)\n        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)\n\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, static, dynamic):\n\n        # Use the probabilities of visiting each\n        static_hidden = self.static_encoder(static)\n        dynamic_hidden = self.dynamic_encoder(dynamic)\n\n        hidden = torch.cat((static_hidden, dynamic_hidden), 1)\n\n        output = F.relu(self.fc1(hidden))\n        output = F.relu(self.fc2(output))\n        output = self.fc3(output).sum(dim=2)\n        return output\n\n\nclass Critic(nn.Module):\n    \"\"\"Estimates the problem complexity.\n\n    This is a basic module that just looks at the log-probabilities predicted by\n    the encoder + decoder, and returns an estimate of complexity\n    \"\"\"\n\n    def __init__(self, hidden_size):\n        super(Critic, self).__init__()\n\n        # Define the encoder & decoder models\n        self.fc1 = nn.Conv1d(1, hidden_size, kernel_size=1)\n        self.fc2 = nn.Conv1d(hidden_size, 20, kernel_size=1)\n        self.fc3 = nn.Conv1d(20, 1, kernel_size=1)\n\n        for p in self.parameters():\n            if len(p.shape) > 1:\n                nn.init.xavier_uniform_(p)\n\n    def forward(self, input):\n\n        output = F.relu(self.fc1(input.unsqueeze(1)))\n        output = F.relu(self.fc2(output)).squeeze(2)\n        output = self.fc3(output).sum(dim=2)\n        return output\n\n\ndef validate(data_loader, actor, reward_fn, w1, w2, render_fn=None, save_dir='.',\n             num_plot=5):\n    \"\"\"Used to monitor progress on a validation set & optionally plot solution.\"\"\"\n\n    actor.eval()\n\n    # if not os.path.exists(save_dir):\n    #     os.makedirs(save_dir)\n\n    rewards = []\n    obj1s = []\n    obj2s = []\n    for batch_idx, batch in enumerate(data_loader):\n\n        static, dynamic, x0 = batch\n\n        static = static.to(device)\n        dynamic = dynamic.to(device)\n        x0 = x0.to(device) if len(x0) > 0 else None\n\n        with torch.no_grad():\n            tour_indices, _ = actor.forward(static, dynamic, x0)\n\n        reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)\n\n        rewards.append(torch.mean(reward.detach()).item())\n        obj1s.append(torch.mean(obj1.detach()).item())\n        obj2s.append(torch.mean(obj2.detach()).item())\n        # if render_fn is not None and batch_idx < num_plot:\n        #     name = 'batch%d_%2.4f.png'%(batch_idx, torch.mean(reward.detach()).item())\n        #     path = os.path.join(save_dir, name)\n        #     render_fn(static, tour_indices, path)\n\n    actor.train()\n    return np.mean(rewards), np.mean(obj1s), np.mean(obj2s)\n\n\ndef train(actor, critic, w1, w2, task, num_nodes, train_data, valid_data, reward_fn,\n          render_fn, batch_size, actor_lr, critic_lr, max_grad_norm,\n          **kwargs):\n    \"\"\"Constructs the main actor & critic networks, and performs all training.\"\"\"\n\n    now = '%s' % datetime.datetime.now().time()\n    now = now.replace(':', '_')\n    bname = \"_transfer\"\n    save_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2), now)\n\n    checkpoint_dir = os.path.join(save_dir, 'checkpoints')\n    if not os.path.exists(checkpoint_dir):\n         os.makedirs(checkpoint_dir)\n\n    actor_optim = optim.Adam(actor.parameters(), lr=actor_lr)\n    critic_optim = optim.Adam(critic.parameters(), lr=critic_lr)\n\n    train_loader = DataLoader(train_data, batch_size, True, num_workers=0)\n    valid_loader = DataLoader(valid_data, batch_size, False, num_workers=0)\n\n    best_params = None\n    best_reward = np.inf\n    start_total = time.time()\n    for epoch in range(3):\n        print(\"epoch %d start:\"% epoch)\n        actor.train()\n        critic.train()\n\n        times, losses, rewards, critic_rewards = [], [], [], []\n        obj1s, obj2s = [], []\n\n        epoch_start = time.time()\n        start = epoch_start\n\n        for batch_idx, batch in enumerate(train_loader):\n\n            static, dynamic, x0 = batch\n\n            static = static.to(device)\n            dynamic = dynamic.to(device)\n            x0 = x0.to(device) if len(x0) > 0 else None\n\n            # Full forward pass through the dataset\n            tour_indices, tour_logp = actor(static, dynamic, x0)\n\n            # Sum the log probabilities for each city in the tour\n            reward, obj1, obj2 = reward_fn(static, tour_indices, w1, w2)\n\n            # Query the critic for an estimate of the reward\n            critic_est = critic(static, dynamic).view(-1)\n\n            advantage = (reward - critic_est)\n            actor_loss = torch.mean(advantage.detach() * tour_logp.sum(dim=1))\n            critic_loss = torch.mean(advantage ** 2)\n\n            actor_optim.zero_grad()\n            actor_loss.backward()\n            torch.nn.utils.clip_grad_norm_(actor.parameters(), max_grad_norm)\n            actor_optim.step()\n\n            critic_optim.zero_grad()\n            critic_loss.backward()\n            torch.nn.utils.clip_grad_norm_(critic.parameters(), max_grad_norm)\n            critic_optim.step()\n\n            critic_rewards.append(torch.mean(critic_est.detach()).item())\n            rewards.append(torch.mean(reward.detach()).item())\n            losses.append(torch.mean(actor_loss.detach()).item())\n            obj1s.append(torch.mean(obj1.detach()).item())\n            obj2s.append(torch.mean(obj2.detach()).item())\n            if (batch_idx + 1) % 200 == 0:\n                print(\"\\n\")\n                end = time.time()\n                times.append(end - start)\n                start = end\n\n                mean_loss = np.mean(losses[-100:])\n                mean_reward = np.mean(rewards[-100:])\n                mean_obj1 = np.mean(obj1s[-100:])\n                mean_obj2 = np.mean(obj2s[-100:])\n                print('  Batch %d/%d, reward: %2.3f, obj1: %2.3f, obj2: %2.3f, loss: %2.4f, took: %2.4fs' %\n                      (batch_idx, len(train_loader), mean_reward, mean_obj1, mean_obj2, mean_loss,\n                       times[-1]))\n\n        mean_loss = np.mean(losses)\n        mean_reward = np.mean(rewards)\n\n        # Save the weights\n        # epoch_dir = os.path.join(checkpoint_dir, '%s' % epoch)\n        # if not os.path.exists(epoch_dir):\n        #     os.makedirs(epoch_dir)\n        #\n        # save_path = os.path.join(epoch_dir, 'actor.pt')\n        # torch.save(actor.state_dict(), save_path)\n        #\n        # save_path = os.path.join(epoch_dir, 'critic.pt')\n        # torch.save(critic.state_dict(), save_path)\n\n        # Save rendering of validation set tours\n        # valid_dir = os.path.join(save_dir, '%s' % epoch)\n        mean_valid, mean_obj1_valid, mean_obj2_valid = validate(valid_loader, actor, reward_fn, w1, w2, render_fn,\n                              '.', num_plot=5)\n\n        # Save best model parameters\n        if mean_valid < best_reward:\n\n            best_reward = mean_valid\n\n            # save_path = os.path.join(save_dir, 'actor.pt')\n            # torch.save(actor.state_dict(), save_path)\n            #\n            # save_path = os.path.join(save_dir, 'critic.pt')\n            # torch.save(critic.state_dict(), save_path)\n            # 存在w_1_0主文件夹下，多存一份，用来transfer to next w\n            main_dir = os.path.join(task+bname, '%d' % num_nodes, 'w_%2.2f_%2.2f' % (w1, w2))\n            save_path = os.path.join(main_dir, 'actor.pt')\n            torch.save(actor.state_dict(), save_path)\n            save_path = os.path.join(main_dir, 'critic.pt')\n            torch.save(critic.state_dict(), save_path)\n\n        print('Mean epoch loss/reward: %2.4f, %2.4f, %2.4f, obj1_valid: %2.3f, obj2_valid: %2.3f. took: %2.4fs '\\\n              '(%2.4fs / 100 batches)\\n' % \\\n              (mean_loss, mean_reward, mean_valid, mean_obj1_valid, mean_obj2_valid, time.time() - epoch_start,\n              np.mean(times)))\n    print(\"Total run time of epoches: %2.4f\" % (time.time() - start_total))\n\n\n\ndef train_tsp(args, w1=1, w2=0, checkpoint = None):\n\n    # Goals from paper:\n    # TSP20, 3.97\n    # TSP50, 6.08\n    # TSP100, 8.44\n\n    from tasks import motsp\n    from tasks.motsp import TSPDataset\n\n    STATIC_SIZE = 4 # (x, y)\n    DYNAMIC_SIZE = 1 # dummy for compatibility\n\n    train_data = TSPDataset(args.num_nodes, args.train_size, args.seed)\n    valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1)\n\n    update_fn = None\n\n    actor = DRL4TSP(STATIC_SIZE,\n                    DYNAMIC_SIZE,\n                    args.hidden_size,\n                    update_fn,\n                    motsp.update_mask,\n                    args.num_layers,\n                    args.dropout).to(device)\n\n    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)\n\n    kwargs = vars(args)\n    kwargs['train_data'] = train_data\n    kwargs['valid_data'] = valid_data\n    kwargs['reward_fn'] = motsp.reward\n    kwargs['render_fn'] = motsp.render\n\n    if checkpoint:\n        path = os.path.join(checkpoint, 'actor.pt')\n        actor.load_state_dict(torch.load(path, device))\n        # actor.static_encoder.state_dict().get(\"conv.weight\").size()\n        path = os.path.join(checkpoint, 'critic.pt')\n        critic.load_state_dict(torch.load(path, device))\n\n    if not args.test:\n        train(actor, critic, w1, w2, **kwargs)\n\n    test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2)\n\n    test_dir = 'test'\n    test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0)\n    out = validate(test_loader, actor, motsp.reward, w1, w2, motsp.render, test_dir, num_plot=5)\n\n    print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)\n\n\ndef train_vrp(args):\n\n    # Goals from paper:\n    # VRP10, Capacity 20:  4.84  (Greedy)\n    # VRP20, Capacity 30:  6.59  (Greedy)\n    # VRP50, Capacity 40:  11.39 (Greedy)\n    # VRP100, Capacity 50: 17.23  (Greedy)\n\n    from tasks import vrp\n    from tasks.vrp import VehicleRoutingDataset\n\n    # Determines the maximum amount of load for a vehicle based on num nodes\n    LOAD_DICT = {10: 20, 20: 30, 50: 40, 100: 50}\n    MAX_DEMAND = 9\n    STATIC_SIZE = 2 # (x, y)\n    DYNAMIC_SIZE = 2 # (load, demand)\n\n    max_load = LOAD_DICT[args.num_nodes]\n\n    train_data = VehicleRoutingDataset(args.train_size,\n                                       args.num_nodes,\n                                       max_load,\n                                       MAX_DEMAND,\n                                       args.seed)\n\n    valid_data = VehicleRoutingDataset(args.valid_size,\n                                       args.num_nodes,\n                                       max_load,\n                                       MAX_DEMAND,\n                                       args.seed + 1)\n\n    actor = DRL4TSP(STATIC_SIZE,\n                    DYNAMIC_SIZE,\n                    args.hidden_size,\n                    train_data.update_dynamic,\n                    train_data.update_mask,\n                    args.num_layers,\n                    args.dropout).to(device)\n\n    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device)\n\n    kwargs = vars(args)\n    kwargs['train_data'] = train_data\n    kwargs['valid_data'] = valid_data\n    kwargs['reward_fn'] = vrp.reward\n    kwargs['render_fn'] = vrp.render\n\n    if args.checkpoint:\n        path = os.path.join(args.checkpoint, 'actor.pt')\n        actor.load_state_dict(torch.load(path, device))\n\n        path = os.path.join(args.checkpoint, 'critic.pt')\n        critic.load_state_dict(torch.load(path, device))\n\n    if not args.test:\n        train(actor, critic, **kwargs)\n\n    test_data = VehicleRoutingDataset(args.valid_size,\n                                      args.num_nodes,\n                                      max_load,\n                                      MAX_DEMAND,\n                                      args.seed + 2)\n\n    test_dir = 'test'\n    test_loader = DataLoader(test_data, args.batch_size, False, num_workers=0)\n    out = validate(test_loader, actor, vrp.reward, vrp.render, test_dir, num_plot=5)\n\n    print('Average tour length: ', out)\n\n\nif __name__ == '__main__':\n    num_nodes = 100\n    parser = argparse.ArgumentParser(description='Combinatorial Optimization')\n    parser.add_argument('--seed', default=12345, type=int)\n    # parser.add_argument('--checkpoint', default=\"tsp/20/w_1_0/20_06_30.888074\")\n    parser.add_argument('--test', action='store_true', default=False)\n    parser.add_argument('--task', default='tsp')\n    parser.add_argument('--nodes', dest='num_nodes', default=num_nodes, type=int)\n    parser.add_argument('--actor_lr', default=5e-4, type=float)\n    parser.add_argument('--critic_lr', default=5e-4, type=float)\n    parser.add_argument('--max_grad_norm', default=2., type=float)\n    parser.add_argument('--batch_size', default=200, type=int)\n    parser.add_argument('--hidden', dest='hidden_size', default=128, type=int)\n    parser.add_argument('--dropout', default=0.1, type=float)\n    parser.add_argument('--layers', dest='num_layers', default=1, type=int)\n    parser.add_argument('--train-size',default=120000, type=int)\n    parser.add_argument('--valid-size', default=1000, type=int)\n\n    args = parser.parse_args()\n\n\n    T = 100\n    if args.task == 'tsp':\n        w2_list = np.arange(T+1)/T\n        for i in range(0,T+1):\n            print(\"Current w:%2.2f/%2.2f\"% (1-w2_list[i], w2_list[i]))\n            if i==0:\n                # The first subproblem can be trained from scratch. It also can be trained based on a\n                # single-TSP trained model, where the model can be obtained from everywhere in github\n                checkpoint = 'tsp_transfer_100run_500000_5epoch_40city/40/w_1.00_0.00'\n                train_tsp(args, 1, 0, checkpoint)\n            else:\n                # Parameter transfer. train based on the parameters of the previous subproblem\n                checkpoint = 'tsp_transfer/%d/w_%2.2f_%2.2f'%(num_nodes, 1-w2_list[i-1], w2_list[i-1])\n                train_tsp(args, 1-w2_list[i], w2_list[i], checkpoint)\n\n\n"
  }
]