[
  {
    "path": ".gitignore",
    "content": "*.*~\n__pycache__/\n*.pkl\ndata/\n**/*.egg-info\n.python-version\n.idea/\n.vscode/\n.DS_Store\n_build/\nresults/"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2020 Tianshou contributors\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning\nDescribed in the paper \"[Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning](https://arxiv.org/pdf/2109.11251.pdf)\", this repository develops *Heterogeneous Agent Trust Region Policy Optimisation (HATRPO)* and *Heterogeneous-Agent Proximal Policy Optimisation (HAPPO)* algorithms on the bechmarks of SMAC and Multi-agent MUJOCO. *HATRPO* and *HAPPO* are the first trust region methods for multi-agent reinforcement learning **with theoretically-justified monotonic improvement guarantee**. Performance wise, it is the new state-of-the-art algorithm against its rivals such as [IPPO](https://arxiv.org/abs/2011.09533), [MAPPO](https://arxiv.org/abs/2103.01955) and [MADDPG](https://arxiv.org/abs/1706.02275). HAPPO and HATRPO have been integrated into HARL framework, please check the latest changes at [here](https://github.com/PKU-MARL/HARL).\n\n## Installation\n### Create environment\n``` Bash\nconda create -n env_name python=3.9\nconda activate env_name\npip install -r requirements.txt\nconda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia\n```\n\n### Multi-agent MuJoCo\nFollowing the instructios in https://github.com/openai/mujoco-py and https://github.com/schroederdewitt/multiagent_mujoco to setup a mujoco environment. In the end, remember to set the following environment variables:\n``` Bash\nLD_LIBRARY_PATH=${HOME}/.mujoco/mujoco200/bin;\nLD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so\n```\n### StarCraft II & SMAC\nRun the script\n``` Bash\nbash install_sc2.sh\n```\nOr you could install them manually to other path you like, just follow here: https://github.com/oxwhirl/smac.\n\n## How to run\nWhen your environment is ready, you could run shell scripts provided. For example:\n``` Bash\ncd scripts\n./train_mujoco.sh  # run with HAPPO/HATRPO on Multi-agent MuJoCo\n./train_smac.sh  # run with HAPPO/HATRPO on StarCraft II\n```\n\nIf you would like to change the configs of experiments, you could modify sh files or look for config files for more details. And you can change algorithm by modify **algo=happo** as **algo=hatrpo**.\n\n\n\n## Some experiment results\n\n### SMAC \n\n<img src=\"plots/smac.png\" width=\"500\" >\n\n\n### Multi-agent MuJoCo on MAPPO\n\n<img src=\"plots/ma-mujoco_1.png\" width=\"500\" > \n\n### \n<img src=\"plots/ma-mujoco_2.png\" width=\"500\" >\n\n## Additional Experiment Setting\n### For SMAC\n#### 2022/4/24 update important ERROR for SMAC\n##### Fix the parameter of **gamma**, the right configuration of **gamma** show as following:\n##### gamma for **3s5z** and **2c_vs_64zg**  is 0.95\n##### gamma for **corridor** is 0.99\n\n"
  },
  {
    "path": "algorithms/__init__.py",
    "content": ""
  },
  {
    "path": "algorithms/actor_critic.py",
    "content": "import torch\nimport torch.nn as nn\nfrom algorithms.utils.util import init, check\nfrom algorithms.utils.cnn import CNNBase\nfrom algorithms.utils.mlp import MLPBase\nfrom algorithms.utils.rnn import RNNLayer\nfrom algorithms.utils.act import ACTLayer\nfrom utils.util import get_shape_from_obs_space\n\n\nclass Actor(nn.Module):\n    \"\"\"\n    Actor network class for HAPPO. Outputs actions given observations.\n    :param args: (argparse.Namespace) arguments containing relevant model information.\n    :param obs_space: (gym.Space) observation space.\n    :param action_space: (gym.Space) action space.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n    def __init__(self, args, obs_space, action_space, device=torch.device(\"cpu\")):\n        super(Actor, self).__init__()\n        self.hidden_size = args.hidden_size\n        self.args=args\n        self._gain = args.gain\n        self._use_orthogonal = args.use_orthogonal\n        self._use_policy_active_masks = args.use_policy_active_masks\n        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy\n        self._use_recurrent_policy = args.use_recurrent_policy\n        self._recurrent_N = args.recurrent_N\n        self.tpdv = dict(dtype=torch.float32, device=device)\n\n        obs_shape = get_shape_from_obs_space(obs_space)\n        base = CNNBase if len(obs_shape) == 3 else MLPBase\n        self.base = base(args, obs_shape)\n\n        if self._use_naive_recurrent_policy or self._use_recurrent_policy:\n            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)\n\n        self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)\n\n        self.to(device)\n\n    def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):\n        \"\"\"\n        Compute actions from the given inputs.\n        :param obs: (np.ndarray / torch.Tensor) observation inputs into network.\n        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.\n        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.\n        :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent\n                                                              (if None, all actions available)\n        :param deterministic: (bool) whether to sample from action distribution or return the mode.\n\n        :return actions: (torch.Tensor) actions to take.\n        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.\n        :return rnn_states: (torch.Tensor) updated RNN hidden states.\n        \"\"\"\n        obs = check(obs).to(**self.tpdv)\n        rnn_states = check(rnn_states).to(**self.tpdv)\n        masks = check(masks).to(**self.tpdv)\n        if available_actions is not None:\n            available_actions = check(available_actions).to(**self.tpdv)\n\n        actor_features = self.base(obs)\n\n        if self._use_naive_recurrent_policy or self._use_recurrent_policy:\n            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)\n\n        actions, action_log_probs = self.act(actor_features, available_actions, deterministic)\n\n        return actions, action_log_probs, rnn_states\n\n    def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):\n        \"\"\"\n        Compute log probability and entropy of given actions.\n        :param obs: (torch.Tensor) observation inputs into network.\n        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.\n        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.\n        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.\n        :param available_actions: (torch.Tensor) denotes which actions are available to agent\n                                                              (if None, all actions available)\n        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.\n\n        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.\n        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.\n        \"\"\"\n        obs = check(obs).to(**self.tpdv)\n        rnn_states = check(rnn_states).to(**self.tpdv)\n        action = check(action).to(**self.tpdv)\n        masks = check(masks).to(**self.tpdv)\n        if available_actions is not None:\n            available_actions = check(available_actions).to(**self.tpdv)\n\n        if active_masks is not None:\n            active_masks = check(active_masks).to(**self.tpdv)\n\n        actor_features = self.base(obs)\n\n        if self._use_naive_recurrent_policy or self._use_recurrent_policy:\n            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)\n\n        if self.args.algorithm_name==\"hatrpo\":\n            action_log_probs, dist_entropy ,action_mu, action_std, all_probs= self.act.evaluate_actions_trpo(actor_features,\n                                                                    action, available_actions,\n                                                                    active_masks=\n                                                                    active_masks if self._use_policy_active_masks\n                                                                    else None)\n\n            return action_log_probs, dist_entropy, action_mu, action_std, all_probs\n        else:\n            action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features,\n                                                                    action, available_actions,\n                                                                    active_masks=\n                                                                    active_masks if self._use_policy_active_masks\n                                                                    else None)\n\n            return action_log_probs, dist_entropy\n\n\nclass Critic(nn.Module):\n    \"\"\"\n    Critic network class for HAPPO. Outputs value function predictions given centralized input (HAPPO) or local observations (IPPO).\n    :param args: (argparse.Namespace) arguments containing relevant model information.\n    :param cent_obs_space: (gym.Space) (centralized) observation space.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n    def __init__(self, args, cent_obs_space, device=torch.device(\"cpu\")):\n        super(Critic, self).__init__()\n        self.hidden_size = args.hidden_size\n        self._use_orthogonal = args.use_orthogonal\n        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy\n        self._use_recurrent_policy = args.use_recurrent_policy\n        self._recurrent_N = args.recurrent_N\n        self.tpdv = dict(dtype=torch.float32, device=device)\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]\n\n        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)\n        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase\n        self.base = base(args, cent_obs_shape)\n\n        if self._use_naive_recurrent_policy or self._use_recurrent_policy:\n            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)\n\n        def init_(m):\n            return init(m, init_method, lambda x: nn.init.constant_(x, 0))\n\n        self.v_out = init_(nn.Linear(self.hidden_size, 1))\n\n        self.to(device)\n\n    def forward(self, cent_obs, rnn_states, masks):\n        \"\"\"\n        Compute actions from the given inputs.\n        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.\n        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.\n        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.\n\n        :return values: (torch.Tensor) value function predictions.\n        :return rnn_states: (torch.Tensor) updated RNN hidden states.\n        \"\"\"\n        cent_obs = check(cent_obs).to(**self.tpdv)\n        rnn_states = check(rnn_states).to(**self.tpdv)\n        masks = check(masks).to(**self.tpdv)\n\n        critic_features = self.base(cent_obs)\n        if self._use_naive_recurrent_policy or self._use_recurrent_policy:\n            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)\n        values = self.v_out(critic_features)\n\n        return values, rnn_states\n"
  },
  {
    "path": "algorithms/happo_policy.py",
    "content": "import torch\nfrom algorithms.actor_critic import Actor, Critic\nfrom utils.util import update_linear_schedule\n\n\nclass HAPPO_Policy:\n    \"\"\"\n    HAPPO Policy  class. Wraps actor and critic networks to compute actions and value function predictions.\n\n    :param args: (argparse.Namespace) arguments containing relevant model and policy information.\n    :param obs_space: (gym.Space) observation space.\n    :param cent_obs_space: (gym.Space) value function input space (centralized input for HAPPO, decentralized for IPPO).\n    :param action_space: (gym.Space) action space.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n\n    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device(\"cpu\")):\n        self.args=args\n        self.device = device\n        self.lr = args.lr\n        self.critic_lr = args.critic_lr\n        self.opti_eps = args.opti_eps\n        self.weight_decay = args.weight_decay\n\n        self.obs_space = obs_space\n        self.share_obs_space = cent_obs_space\n        self.act_space = act_space\n\n        self.actor = Actor(args, self.obs_space, self.act_space, self.device)\n\n        ######################################Please Note#########################################\n        #####   We create one critic for each agent, but they are trained with same data     #####\n        #####   and using same update setting. Therefore they have the same parameter,       #####\n        #####   you can regard them as the same critic.                                      #####\n        ##########################################################################################\n        self.critic = Critic(args, self.share_obs_space, self.device)\n\n        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n                                                lr=self.lr, eps=self.opti_eps,\n                                                weight_decay=self.weight_decay)\n        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),\n                                                 lr=self.critic_lr,\n                                                 eps=self.opti_eps,\n                                                 weight_decay=self.weight_decay)\n\n    def lr_decay(self, episode, episodes):\n        \"\"\"\n        Decay the actor and critic learning rates.\n        :param episode: (int) current training episode.\n        :param episodes: (int) total number of training episodes.\n        \"\"\"\n        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)\n        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)\n\n    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,\n                    deterministic=False):\n        \"\"\"\n        Compute actions and value function predictions for the given inputs.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.\n\n        :return values: (torch.Tensor) value function predictions.\n        :return actions: (torch.Tensor) actions to take.\n        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.\n        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.\n        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.\n        \"\"\"\n        actions, action_log_probs, rnn_states_actor = self.actor(obs,\n                                                                 rnn_states_actor,\n                                                                 masks,\n                                                                 available_actions,\n                                                                 deterministic)\n\n        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)\n        return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic\n\n    def get_values(self, cent_obs, rnn_states_critic, masks):\n        \"\"\"\n        Get value function predictions.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n\n        :return values: (torch.Tensor) value function predictions.\n        \"\"\"\n        values, _ = self.critic(cent_obs, rnn_states_critic, masks)\n        return values\n\n    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,\n                         available_actions=None, active_masks=None):\n        \"\"\"\n        Get action logprobs / entropy and value function predictions for actor update.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.\n\n        :return values: (torch.Tensor) value function predictions.\n        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.\n        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.\n        \"\"\"\n\n        action_log_probs, dist_entropy = self.actor.evaluate_actions(obs,\n                                                                rnn_states_actor,\n                                                                action,\n                                                                masks,\n                                                                available_actions,\n                                                                active_masks)\n\n        values, _ = self.critic(cent_obs, rnn_states_critic, masks)\n        return values, action_log_probs, dist_entropy\n\n\n    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):\n        \"\"\"\n        Compute actions using the given inputs.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.\n        \"\"\"\n        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)\n        return actions, rnn_states_actor\n"
  },
  {
    "path": "algorithms/happo_trainer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom utils.util import get_gard_norm, huber_loss, mse_loss\nfrom utils.popart import PopArt\nfrom algorithms.utils.util import check\n\nclass HAPPO():\n    \"\"\"\n    Trainer class for HAPPO to update policies.\n    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.\n    :param policy: (HAPPO_Policy) policy to update.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n    def __init__(self,\n                 args,\n                 policy,\n                 device=torch.device(\"cpu\")):\n\n        self.device = device\n        self.tpdv = dict(dtype=torch.float32, device=device)\n        self.policy = policy\n\n        self.clip_param = args.clip_param\n        self.ppo_epoch = args.ppo_epoch\n        self.num_mini_batch = args.num_mini_batch\n        self.data_chunk_length = args.data_chunk_length\n        self.value_loss_coef = args.value_loss_coef\n        self.entropy_coef = args.entropy_coef\n        self.max_grad_norm = args.max_grad_norm       \n        self.huber_delta = args.huber_delta\n\n        self._use_recurrent_policy = args.use_recurrent_policy\n        self._use_naive_recurrent = args.use_naive_recurrent_policy\n        self._use_max_grad_norm = args.use_max_grad_norm\n        self._use_clipped_value_loss = args.use_clipped_value_loss\n        self._use_huber_loss = args.use_huber_loss\n        self._use_popart = args.use_popart\n        self._use_value_active_masks = args.use_value_active_masks\n        self._use_policy_active_masks = args.use_policy_active_masks\n\n        \n        if self._use_popart:\n            self.value_normalizer = PopArt(1, device=self.device)\n        else:\n            self.value_normalizer = None\n\n    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):\n        \"\"\"\n        Calculate value function loss.\n        :param values: (torch.Tensor) value function predictions.\n        :param value_preds_batch: (torch.Tensor) \"old\" value  predictions from data batch (used for value clip loss)\n        :param return_batch: (torch.Tensor) reward to go returns.\n        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.\n\n        :return value_loss: (torch.Tensor) value function loss.\n        \"\"\"\n        if self._use_popart:\n            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,\n                                                                                        self.clip_param)\n            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped\n            error_original = self.value_normalizer(return_batch) - values\n        else:\n            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,\n                                                                                        self.clip_param)\n            error_clipped = return_batch - value_pred_clipped\n            error_original = return_batch - values\n\n        if self._use_huber_loss:\n            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)\n            value_loss_original = huber_loss(error_original, self.huber_delta)\n        else:\n            value_loss_clipped = mse_loss(error_clipped)\n            value_loss_original = mse_loss(error_original)\n\n        if self._use_clipped_value_loss:\n            value_loss = torch.max(value_loss_original, value_loss_clipped)\n        else:\n            value_loss = value_loss_original\n\n        if self._use_value_active_masks:\n            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()\n        else:\n            value_loss = value_loss.mean()\n\n        return value_loss\n\n    def ppo_update(self, sample, update_actor=True):\n        \"\"\"\n        Update actor and critic networks.\n        :param sample: (Tuple) contains data batch with which to update networks.\n        :update_actor: (bool) whether to update actor network.\n\n        :return value_loss: (torch.Tensor) value function loss.\n        :return critic_grad_norm: (torch.Tensor) gradient norm from critic update.\n        ;return policy_loss: (torch.Tensor) actor(policy) loss value.\n        :return dist_entropy: (torch.Tensor) action entropies.\n        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.\n        :return imp_weights: (torch.Tensor) importance sampling weights.\n        \"\"\"\n        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \\\n        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \\\n        adv_targ, available_actions_batch, factor_batch = sample\n\n\n\n        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)\n        adv_targ = check(adv_targ).to(**self.tpdv)\n\n\n        value_preds_batch = check(value_preds_batch).to(**self.tpdv)\n        return_batch = check(return_batch).to(**self.tpdv)\n\n\n        active_masks_batch = check(active_masks_batch).to(**self.tpdv)\n\n        factor_batch = check(factor_batch).to(**self.tpdv)\n        # Reshape to do in a single forward pass for all steps\n        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,\n                                                                              obs_batch, \n                                                                              rnn_states_batch, \n                                                                              rnn_states_critic_batch, \n                                                                              actions_batch, \n                                                                              masks_batch, \n                                                                              available_actions_batch,\n                                                                              active_masks_batch)\n        # actor update\n        imp_weights = torch.prod(torch.exp(action_log_probs - old_action_log_probs_batch),dim=-1,keepdim=True)\n\n        surr1 = imp_weights * adv_targ\n        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ\n\n        if self._use_policy_active_masks:\n            policy_action_loss = (-torch.sum(factor_batch * torch.min(surr1, surr2),\n                                             dim=-1,\n                                             keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()\n        else:\n            policy_action_loss = -torch.sum(factor_batch * torch.min(surr1, surr2), dim=-1, keepdim=True).mean()\n\n        policy_loss = policy_action_loss\n\n        self.policy.actor_optimizer.zero_grad()\n\n        if update_actor:\n            (policy_loss - dist_entropy * self.entropy_coef).backward()\n\n        if self._use_max_grad_norm:\n            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)\n        else:\n            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())\n\n        self.policy.actor_optimizer.step()\n\n        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)\n\n        self.policy.critic_optimizer.zero_grad()\n\n        (value_loss * self.value_loss_coef).backward()\n\n        if self._use_max_grad_norm:\n            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)\n        else:\n            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())\n\n        self.policy.critic_optimizer.step()\n\n        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights\n\n    def train(self, buffer, update_actor=True):\n        \"\"\"\n        Perform a training update using minibatch GD.\n        :param buffer: (SharedReplayBuffer) buffer containing training data.\n        :param update_actor: (bool) whether to update actor network.\n\n        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).\n        \"\"\"\n        if self._use_popart:\n            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])\n        else:\n            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]\n\n        advantages_copy = advantages.copy()\n        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan\n        mean_advantages = np.nanmean(advantages_copy)\n        std_advantages = np.nanstd(advantages_copy)\n        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)\n\n        train_info = {}\n\n        train_info['value_loss'] = 0\n        train_info['policy_loss'] = 0\n        train_info['dist_entropy'] = 0\n        train_info['actor_grad_norm'] = 0\n        train_info['critic_grad_norm'] = 0\n        train_info['ratio'] = 0\n\n        for _ in range(self.ppo_epoch):\n            if self._use_recurrent_policy:\n                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)\n            elif self._use_naive_recurrent:\n                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)\n            else:\n                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)\n\n            for sample in data_generator:\n                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights = self.ppo_update(sample, update_actor=update_actor)\n\n                train_info['value_loss'] += value_loss.item()\n                train_info['policy_loss'] += policy_loss.item()\n                train_info['dist_entropy'] += dist_entropy.item()\n                train_info['actor_grad_norm'] += actor_grad_norm\n                train_info['critic_grad_norm'] += critic_grad_norm\n                train_info['ratio'] += imp_weights.mean()\n\n        num_updates = self.ppo_epoch * self.num_mini_batch\n\n        for k in train_info.keys():\n            train_info[k] /= num_updates\n \n        return train_info\n\n    def prep_training(self):\n        self.policy.actor.train()\n        self.policy.critic.train()\n\n    def prep_rollout(self):\n        self.policy.actor.eval()\n        self.policy.critic.eval()\n"
  },
  {
    "path": "algorithms/hatrpo_policy.py",
    "content": "import torch\nfrom algorithms.actor_critic import Actor, Critic\nfrom utils.util import update_linear_schedule\n\n\nclass HATRPO_Policy:\n    \"\"\"\n    HATRPO Policy  class. Wraps actor and critic networks to compute actions and value function predictions.\n\n    :param args: (argparse.Namespace) arguments containing relevant model and policy information.\n    :param obs_space: (gym.Space) observation space.\n    :param cent_obs_space: (gym.Space) value function input space .\n    :param action_space: (gym.Space) action space.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n\n    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device(\"cpu\")):\n        self.args=args\n        self.device = device\n        self.lr = args.lr\n        self.critic_lr = args.critic_lr\n        self.opti_eps = args.opti_eps\n        self.weight_decay = args.weight_decay\n\n        self.obs_space = obs_space\n        self.share_obs_space = cent_obs_space\n        self.act_space = act_space\n\n        self.actor = Actor(args, self.obs_space, self.act_space, self.device)\n\n        ######################################Please Note#########################################\n        #####   We create one critic for each agent, but they are trained with same data     #####\n        #####   and using same update setting. Therefore they have the same parameter,       #####\n        #####   you can regard them as the same critic.                                      #####\n        ##########################################################################################\n        self.critic = Critic(args, self.share_obs_space, self.device)\n\n        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),\n                                                lr=self.lr, eps=self.opti_eps,\n                                                weight_decay=self.weight_decay)\n        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),\n                                                 lr=self.critic_lr,\n                                                 eps=self.opti_eps,\n                                                 weight_decay=self.weight_decay)\n\n    def lr_decay(self, episode, episodes):\n        \"\"\"\n        Decay the actor and critic learning rates.\n        :param episode: (int) current training episode.\n        :param episodes: (int) total number of training episodes.\n        \"\"\"\n        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)\n        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)\n\n    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,\n                    deterministic=False):\n        \"\"\"\n        Compute actions and value function predictions for the given inputs.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.\n\n        :return values: (torch.Tensor) value function predictions.\n        :return actions: (torch.Tensor) actions to take.\n        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.\n        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.\n        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.\n        \"\"\"\n        actions, action_log_probs, rnn_states_actor = self.actor(obs,\n                                                                 rnn_states_actor,\n                                                                 masks,\n                                                                 available_actions,\n                                                                 deterministic)\n\n        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)\n        return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic\n\n    def get_values(self, cent_obs, rnn_states_critic, masks):\n        \"\"\"\n        Get value function predictions.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n\n        :return values: (torch.Tensor) value function predictions.\n        \"\"\"\n        values, _ = self.critic(cent_obs, rnn_states_critic, masks)\n        return values\n\n    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,\n                         available_actions=None, active_masks=None):\n        \"\"\"\n        Get action logprobs / entropy and value function predictions for actor update.\n        :param cent_obs (np.ndarray): centralized input to the critic.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.\n        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.\n\n        :return values: (torch.Tensor) value function predictions.\n        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.\n        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.\n        \"\"\"\n\n        action_log_probs, dist_entropy , action_mu, action_std, all_probs= self.actor.evaluate_actions(obs,\n                                                                    rnn_states_actor,\n                                                                    action,\n                                                                    masks,\n                                                                    available_actions,\n                                                                    active_masks)\n        values, _ = self.critic(cent_obs, rnn_states_critic, masks)\n        return values, action_log_probs, dist_entropy, action_mu, action_std, all_probs\n\n\n\n    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):\n        \"\"\"\n        Compute actions using the given inputs.\n        :param obs (np.ndarray): local agent inputs to the actor.\n        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.\n        :param masks: (np.ndarray) denotes points at which RNN states should be reset.\n        :param available_actions: (np.ndarray) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.\n        \"\"\"\n        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)\n        return actions, rnn_states_actor\n"
  },
  {
    "path": "algorithms/hatrpo_trainer.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom utils.util import get_gard_norm, huber_loss, mse_loss\nfrom utils.popart import PopArt\nfrom algorithms.utils.util import check\nfrom algorithms.actor_critic import Actor\n\nclass HATRPO():\n    \"\"\"\n    Trainer class for MATRPO to update policies.\n    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.\n    :param policy: (HATRPO_Policy) policy to update.\n    :param device: (torch.device) specifies the device to run on (cpu/gpu).\n    \"\"\"\n    def __init__(self,\n                 args,\n                 policy,\n                 device=torch.device(\"cpu\")):\n\n        self.device = device\n        self.tpdv = dict(dtype=torch.float32, device=device)\n        self.policy = policy\n\n        self.clip_param = args.clip_param\n        self.num_mini_batch = args.num_mini_batch\n        self.data_chunk_length = args.data_chunk_length\n        self.value_loss_coef = args.value_loss_coef\n        self.entropy_coef = args.entropy_coef\n        self.max_grad_norm = args.max_grad_norm       \n        self.huber_delta = args.huber_delta\n\n        self.kl_threshold = args.kl_threshold\n        self.ls_step = args.ls_step\n        self.accept_ratio = args.accept_ratio\n\n        self._use_recurrent_policy = args.use_recurrent_policy\n        self._use_naive_recurrent = args.use_naive_recurrent_policy\n        self._use_max_grad_norm = args.use_max_grad_norm\n        self._use_clipped_value_loss = args.use_clipped_value_loss\n        self._use_huber_loss = args.use_huber_loss\n        self._use_popart = args.use_popart\n        self._use_value_active_masks = args.use_value_active_masks\n        self._use_policy_active_masks = args.use_policy_active_masks\n        \n        if self._use_popart:\n            self.value_normalizer = PopArt(1, device=self.device)\n        else:\n            self.value_normalizer = None\n\n    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):\n        \"\"\"\n        Calculate value function loss.\n        :param values: (torch.Tensor) value function predictions.\n        :param value_preds_batch: (torch.Tensor) \"old\" value  predictions from data batch (used for value clip loss)\n        :param return_batch: (torch.Tensor) reward to go returns.\n        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.\n\n        :return value_loss: (torch.Tensor) value function loss.\n        \"\"\"\n        if self._use_popart:\n            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,\n                                                                                        self.clip_param)\n            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped\n            error_original = self.value_normalizer(return_batch) - values\n        else:\n            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,\n                                                                                        self.clip_param)\n            error_clipped = return_batch - value_pred_clipped\n            error_original = return_batch - values\n\n        if self._use_huber_loss:\n            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)\n            value_loss_original = huber_loss(error_original, self.huber_delta)\n        else:\n            value_loss_clipped = mse_loss(error_clipped)\n            value_loss_original = mse_loss(error_original)\n\n        if self._use_clipped_value_loss:\n            value_loss = torch.max(value_loss_original, value_loss_clipped)\n        else:\n            value_loss = value_loss_original\n\n        if self._use_value_active_masks:\n            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()\n        else:\n            value_loss = value_loss.mean()\n\n        return value_loss\n\n    def flat_grad(self, grads):\n        grad_flatten = []\n        for grad in grads:\n            if grad is None:\n                continue\n            grad_flatten.append(grad.view(-1))\n        grad_flatten = torch.cat(grad_flatten)\n        return grad_flatten\n\n    def flat_hessian(self, hessians):\n        hessians_flatten = []\n        for hessian in hessians:\n            if hessian is None:\n                continue\n            hessians_flatten.append(hessian.contiguous().view(-1))\n        hessians_flatten = torch.cat(hessians_flatten).data\n        return hessians_flatten\n\n    def flat_params(self, model):\n        params = []\n        for param in model.parameters():\n            params.append(param.data.view(-1))\n        params_flatten = torch.cat(params)\n        return params_flatten\n\n    def update_model(self, model, new_params):\n        index = 0\n        for params in model.parameters():\n            params_length = len(params.view(-1))\n            new_param = new_params[index: index + params_length]\n            new_param = new_param.view(params.size())\n            params.data.copy_(new_param)\n            index += params_length\n\n    def kl_approx(self, q, p):\n        r = torch.exp(p - q)\n        kl = r - 1 - p + q\n        return kl\n\n    def kl_divergence(self, obs, rnn_states, action, masks, available_actions, active_masks, new_actor, old_actor):\n        _, _, mu, std, probs = new_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks)\n        _, _, mu_old, std_old, probs_old = old_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks)\n        if mu.grad_fn==None:\n            probs_old=probs_old.detach()\n            kl= self.kl_approx(probs_old,probs)\n        else:\n            logstd = torch.log(std)\n            mu_old = mu_old.detach()\n            std_old = std_old.detach()\n            logstd_old = torch.log(std_old)\n            # kl divergence between old policy and new policy : D( pi_old || pi_new )\n            # pi_old -> mu0, logstd0, std0 / pi_new -> mu, logstd, std\n            # be careful of calculating KL-divergence. It is not symmetric metric\n            kl =  logstd - logstd_old  + (std_old.pow(2) + (mu_old - mu).pow(2)) / (2.0 * std.pow(2)) - 0.5\n        \n        if len(kl.shape)>1:\n            kl=kl.sum(1, keepdim=True)\n        return kl\n\n    # from openai baseline code\n    # https://github.com/openai/baselines/blob/master/baselines/common/cg.py\n    def conjugate_gradient(self, actor, obs, rnn_states, action, masks, available_actions, active_masks, b, nsteps, residual_tol=1e-10):\n        x = torch.zeros(b.size()).to(device=self.device)\n        r = b.clone()\n        p = b.clone()\n        rdotr = torch.dot(r, r)\n        for i in range(nsteps):\n            _Avp = self.fisher_vector_product(actor, obs, rnn_states, action, masks, available_actions, active_masks, p)\n            alpha = rdotr / torch.dot(p, _Avp)\n            x += alpha * p\n            r -= alpha * _Avp\n            new_rdotr = torch.dot(r, r)\n            betta = new_rdotr / rdotr\n            p = r + betta * p\n            rdotr = new_rdotr\n            if rdotr < residual_tol:\n                break\n        return x\n\n    def fisher_vector_product(self, actor, obs, rnn_states, action, masks, available_actions, active_masks, p):\n        p.detach()\n        kl = self.kl_divergence(obs, rnn_states, action, masks, available_actions, active_masks, new_actor=actor, old_actor=actor)\n        kl = kl.mean()\n        kl_grad = torch.autograd.grad(kl, actor.parameters(), create_graph=True, allow_unused=True)\n        kl_grad = self.flat_grad(kl_grad)  # check kl_grad == 0\n        kl_grad_p = (kl_grad * p).sum()\n        kl_hessian_p = torch.autograd.grad(kl_grad_p, actor.parameters(), allow_unused=True)\n        kl_hessian_p = self.flat_hessian(kl_hessian_p)\n        return kl_hessian_p + 0.1 * p\n\n    def trpo_update(self, sample, update_actor=True):\n        \"\"\"\n        Update actor and critic networks.\n        :param sample: (Tuple) contains data batch with which to update networks.\n        :update_actor: (bool) whether to update actor network.\n\n        :return value_loss: (torch.Tensor) value function loss.\n        :return critic_grad_norm: (torch.Tensor) gradient norm from critic update.\n        ;return policy_loss: (torch.Tensor) actor(policy) loss value.\n        :return dist_entropy: (torch.Tensor) action entropies.\n        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.\n        :return imp_weights: (torch.Tensor) importance sampling weights.\n        \"\"\"\n        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \\\n        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \\\n        adv_targ, available_actions_batch, factor_batch = sample\n\n        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)\n        adv_targ = check(adv_targ).to(**self.tpdv)\n        value_preds_batch = check(value_preds_batch).to(**self.tpdv)\n        return_batch = check(return_batch).to(**self.tpdv)\n        active_masks_batch = check(active_masks_batch).to(**self.tpdv)\n        factor_batch = check(factor_batch).to(**self.tpdv)\n\n        values, action_log_probs, dist_entropy, action_mu, action_std, _ = self.policy.evaluate_actions(share_obs_batch,\n                                                                              obs_batch, \n                                                                              rnn_states_batch, \n                                                                              rnn_states_critic_batch, \n                                                                              actions_batch, \n                                                                              masks_batch, \n                                                                              available_actions_batch,\n                                                                              active_masks_batch)\n\n        # critic update\n        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)\n\n        self.policy.critic_optimizer.zero_grad()\n\n        (value_loss * self.value_loss_coef).backward()\n\n        if self._use_max_grad_norm:\n            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)\n        else:\n            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())\n\n        self.policy.critic_optimizer.step()\n\n        # actor update\n        ratio = torch.prod(torch.exp(action_log_probs - old_action_log_probs_batch),dim=-1,keepdim=True)\n        if self._use_policy_active_masks:\n            loss = (torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) *\n                           active_masks_batch).sum() / active_masks_batch.sum()\n        else:\n            loss = torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True).mean()\n\n        loss_grad = torch.autograd.grad(loss, self.policy.actor.parameters(), allow_unused=True)\n        loss_grad = self.flat_grad(loss_grad)\n\n        step_dir = self.conjugate_gradient(self.policy.actor, \n                                      obs_batch, \n                                      rnn_states_batch, \n                                      actions_batch, \n                                      masks_batch, \n                                      available_actions_batch, \n                                      active_masks_batch, \n                                      loss_grad.data, \n                                      nsteps=10)\n        \n        loss = loss.data.cpu().numpy()\n\n        params = self.flat_params(self.policy.actor)\n        fvp = self.fisher_vector_product(self.policy.actor,\n                                    obs_batch, \n                                    rnn_states_batch, \n                                    actions_batch, \n                                    masks_batch, \n                                    available_actions_batch, \n                                    active_masks_batch, \n                                    step_dir)\n        shs = 0.5 * (step_dir * fvp).sum(0, keepdim=True)\n        step_size = 1 / torch.sqrt(shs / self.kl_threshold)[0]\n        full_step = step_size * step_dir\n\n        old_actor = Actor(self.policy.args, \n                            self.policy.obs_space,  \n                            self.policy.act_space, \n                            self.device)\n        self.update_model(old_actor, params)\n        expected_improve = (loss_grad * full_step).sum(0, keepdim=True)\n        expected_improve = expected_improve.data.cpu().numpy()\n        \n\n        # Backtracking line search\n        flag = False\n        fraction = 1\n        for i in range(self.ls_step):\n            new_params = params + fraction * full_step\n            self.update_model(self.policy.actor, new_params)\n            values, action_log_probs, dist_entropy, action_mu, action_std, _ = self.policy.evaluate_actions(share_obs_batch,\n                                                                                obs_batch, \n                                                                                rnn_states_batch, \n                                                                                rnn_states_critic_batch, \n                                                                                actions_batch, \n                                                                                masks_batch, \n                                                                                available_actions_batch,\n                                                                                active_masks_batch)\n\n            ratio = torch.exp(action_log_probs - old_action_log_probs_batch)\n            if self._use_policy_active_masks:\n                new_loss = (torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) *\n                            active_masks_batch).sum() / active_masks_batch.sum()\n            else:\n                new_loss = torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True).mean()\n\n            new_loss = new_loss.data.cpu().numpy()\n            loss_improve = new_loss - loss\n            \n            kl = self.kl_divergence(obs_batch, \n                               rnn_states_batch, \n                               actions_batch, \n                               masks_batch, \n                               available_actions_batch, \n                               active_masks_batch,\n                               new_actor=self.policy.actor,\n                               old_actor=old_actor)\n            kl = kl.mean()\n\n            if kl < self.kl_threshold and (loss_improve / expected_improve) > self.accept_ratio and loss_improve.item()>0:\n                flag = True\n                break\n            expected_improve *= 0.5\n            fraction *= 0.5\n\n        if not flag:\n            params = self.flat_params(old_actor)\n            self.update_model(self.policy.actor, params)\n            print('policy update does not impove the surrogate')\n\n        return value_loss, critic_grad_norm, kl, loss_improve, expected_improve, dist_entropy, ratio\n\n    def train(self, buffer, update_actor=True):\n        \"\"\"\n        Perform a training update using minibatch GD.\n        :param buffer: (SharedReplayBuffer) buffer containing training data.\n        :param update_actor: (bool) whether to update actor network.\n\n        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).\n        \"\"\"\n        if self._use_popart:\n            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])\n        else:\n            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]\n        advantages_copy = advantages.copy()\n        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan\n        mean_advantages = np.nanmean(advantages_copy)\n        std_advantages = np.nanstd(advantages_copy)\n        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)\n        \n\n        train_info = {}\n\n        train_info['value_loss'] = 0\n        train_info['kl'] = 0\n        train_info['dist_entropy'] = 0\n        train_info['loss_improve'] = 0\n        train_info['expected_improve'] = 0\n        train_info['critic_grad_norm'] = 0\n        train_info['ratio'] = 0\n\n\n        if self._use_recurrent_policy:\n            data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)\n        elif self._use_naive_recurrent:\n            data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)\n        else:\n            data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)\n\n        for sample in data_generator:\n\n            value_loss, critic_grad_norm, kl, loss_improve, expected_improve, dist_entropy, imp_weights \\\n                = self.trpo_update(sample, update_actor)\n\n            train_info['value_loss'] += value_loss.item()\n            train_info['kl'] += kl\n            train_info['loss_improve'] += loss_improve.item()\n            train_info['expected_improve'] += expected_improve\n            train_info['dist_entropy'] += dist_entropy.item()\n            train_info['critic_grad_norm'] += critic_grad_norm\n            train_info['ratio'] += imp_weights.mean()\n\n        num_updates = self.num_mini_batch\n\n        for k in train_info.keys():\n            train_info[k] /= num_updates\n \n        return train_info\n\n    def prep_training(self):\n        self.policy.actor.train()\n        self.policy.critic.train()\n\n    def prep_rollout(self):\n        self.policy.actor.eval()\n        self.policy.critic.eval()\n"
  },
  {
    "path": "algorithms/utils/act.py",
    "content": "from .distributions import Bernoulli, Categorical, DiagGaussian\nimport torch\nimport torch.nn as nn\n\nclass ACTLayer(nn.Module):\n    \"\"\"\n    MLP Module to compute actions.\n    :param action_space: (gym.Space) action space.\n    :param inputs_dim: (int) dimension of network input.\n    :param use_orthogonal: (bool) whether to use orthogonal initialization.\n    :param gain: (float) gain of the output layer of the network.\n    \"\"\"\n    def __init__(self, action_space, inputs_dim, use_orthogonal, gain, args=None):\n        super(ACTLayer, self).__init__()\n        self.mixed_action = False\n        self.multi_discrete = False\n        self.action_type = action_space.__class__.__name__\n        if action_space.__class__.__name__ == \"Discrete\":\n            action_dim = action_space.n\n            self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)\n        elif action_space.__class__.__name__ == \"Box\":\n            action_dim = action_space.shape[0]\n            self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain, args)\n        elif action_space.__class__.__name__ == \"MultiBinary\":\n            action_dim = action_space.shape[0]\n            self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain)\n        elif action_space.__class__.__name__ == \"MultiDiscrete\":\n            self.multi_discrete = True\n            action_dims = action_space.high - action_space.low + 1\n            self.action_outs = []\n            for action_dim in action_dims:\n                self.action_outs.append(Categorical(inputs_dim, action_dim, use_orthogonal, gain))\n            self.action_outs = nn.ModuleList(self.action_outs)\n        else:  # discrete + continous\n            self.mixed_action = True\n            continous_dim = action_space[0].shape[0]\n            discrete_dim = action_space[1].n\n            self.action_outs = nn.ModuleList([DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain, args),\n                                              Categorical(inputs_dim, discrete_dim, use_orthogonal, gain)])\n    \n    def forward(self, x, available_actions=None, deterministic=False):\n        \"\"\"\n        Compute actions and action logprobs from given input.\n        :param x: (torch.Tensor) input to network.\n        :param available_actions: (torch.Tensor) denotes which actions are available to agent\n                                  (if None, all actions available)\n        :param deterministic: (bool) whether to sample from action distribution or return the mode.\n\n        :return actions: (torch.Tensor) actions to take.\n        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.\n        \"\"\"\n        if self.mixed_action :\n            actions = []\n            action_log_probs = []\n            for action_out in self.action_outs:\n                action_logit = action_out(x)\n                action = action_logit.mode() if deterministic else action_logit.sample()\n                action_log_prob = action_logit.log_probs(action)\n                actions.append(action.float())\n                action_log_probs.append(action_log_prob)\n\n            actions = torch.cat(actions, -1)\n            action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)\n\n        elif self.multi_discrete:\n            actions = []\n            action_log_probs = []\n            for action_out in self.action_outs:\n                action_logit = action_out(x)\n                action = action_logit.mode() if deterministic else action_logit.sample()\n                action_log_prob = action_logit.log_probs(action)\n                actions.append(action)\n                action_log_probs.append(action_log_prob)\n\n            actions = torch.cat(actions, -1)\n            action_log_probs = torch.cat(action_log_probs, -1)\n        \n        else:\n            action_logits = self.action_out(x, available_actions)\n            actions = action_logits.mode() if deterministic else action_logits.sample() \n            action_log_probs = action_logits.log_probs(actions)\n        \n        return actions, action_log_probs\n\n    def get_probs(self, x, available_actions=None):\n        \"\"\"\n        Compute action probabilities from inputs.\n        :param x: (torch.Tensor) input to network.\n        :param available_actions: (torch.Tensor) denotes which actions are available to agent\n                                  (if None, all actions available)\n\n        :return action_probs: (torch.Tensor)\n        \"\"\"\n        if self.mixed_action or self.multi_discrete:\n            action_probs = []\n            for action_out in self.action_outs:\n                action_logit = action_out(x)\n                action_prob = action_logit.probs\n                action_probs.append(action_prob)\n            action_probs = torch.cat(action_probs, -1)\n        else:\n            action_logits = self.action_out(x, available_actions)\n            action_probs = action_logits.probs\n        \n        return action_probs\n\n    def evaluate_actions(self, x, action, available_actions=None, active_masks=None):\n        \"\"\"\n        Compute log probability and entropy of given actions.\n        :param x: (torch.Tensor) input to network.\n        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.\n        :param available_actions: (torch.Tensor) denotes which actions are available to agent\n                                                              (if None, all actions available)\n        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.\n\n        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.\n        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.\n        \"\"\"\n        if self.mixed_action:\n            a, b = action.split((2, 1), -1)\n            b = b.long()\n            action = [a, b] \n            action_log_probs = [] \n            dist_entropy = []\n            for action_out, act in zip(self.action_outs, action):\n                action_logit = action_out(x)\n                action_log_probs.append(action_logit.log_probs(act))\n                if active_masks is not None:\n                    if len(action_logit.entropy().shape) == len(active_masks.shape):\n                        dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum()) \n                    else:\n                        dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum())\n                else:\n                    dist_entropy.append(action_logit.entropy().mean())\n                \n            action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)\n            dist_entropy = dist_entropy[0] / 2.0 + dist_entropy[1] / 0.98 \n\n        elif self.multi_discrete:\n            action = torch.transpose(action, 0, 1)\n            action_log_probs = []\n            dist_entropy = []\n            for action_out, act in zip(self.action_outs, action):\n                action_logit = action_out(x)\n                action_log_probs.append(action_logit.log_probs(act))\n                if active_masks is not None:\n                    dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())\n                else:\n                    dist_entropy.append(action_logit.entropy().mean())\n\n            action_log_probs = torch.cat(action_log_probs, -1) \n            dist_entropy = torch.tensor(dist_entropy).mean()\n        \n        else:\n            action_logits = self.action_out(x, available_actions)\n            action_log_probs = action_logits.log_probs(action)\n            if active_masks is not None:\n                if self.action_type==\"Discrete\":\n                    dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()\n                else:\n                    dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum()\n            else:\n                dist_entropy = action_logits.entropy().mean()\n        \n        return action_log_probs, dist_entropy\n\n    def evaluate_actions_trpo(self, x, action, available_actions=None, active_masks=None):\n        \"\"\"\n        Compute log probability and entropy of given actions.\n        :param x: (torch.Tensor) input to network.\n        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.\n        :param available_actions: (torch.Tensor) denotes which actions are available to agent\n                                                              (if None, all actions available)\n        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.\n\n        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.\n        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.\n        \"\"\"\n\n        if self.multi_discrete:\n            action = torch.transpose(action, 0, 1)\n            action_log_probs = []\n            dist_entropy = []\n            mu_collector = []\n            std_collector = []\n            probs_collector = []\n            for action_out, act in zip(self.action_outs, action):\n                action_logit = action_out(x)\n                mu = action_logit.mean\n                std = action_logit.stddev\n                action_log_probs.append(action_logit.log_probs(act))\n                mu_collector.append(mu)\n                std_collector.append(std)\n                probs_collector.append(action_logit.logits)\n                if active_masks is not None:\n                    dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())\n                else:\n                    dist_entropy.append(action_logit.entropy().mean())\n            action_mu = torch.cat(mu_collector,-1)\n            action_std = torch.cat(std_collector,-1)\n            all_probs = torch.cat(probs_collector,-1)\n            action_log_probs = torch.cat(action_log_probs, -1)\n            dist_entropy = torch.tensor(dist_entropy).mean()\n        \n        else:\n            action_logits = self.action_out(x, available_actions)\n            action_mu = action_logits.mean\n            action_std = action_logits.stddev\n            action_log_probs = action_logits.log_probs(action)\n            if self.action_type==\"Discrete\":\n                all_probs = action_logits.logits\n            else:\n                all_probs = None\n            if active_masks is not None:\n                if self.action_type==\"Discrete\":\n                    dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()\n                else:\n                    dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum()\n            else:\n                dist_entropy = action_logits.entropy().mean()\n        \n        return action_log_probs, dist_entropy, action_mu, action_std, all_probs\n"
  },
  {
    "path": "algorithms/utils/cnn.py",
    "content": "import torch.nn as nn\nfrom .util import init\n\n\"\"\"CNN Modules and utils.\"\"\"\n\nclass Flatten(nn.Module):\n    def forward(self, x):\n        return x.view(x.size(0), -1)\n\n\nclass CNNLayer(nn.Module):\n    def __init__(self, obs_shape, hidden_size, use_orthogonal, use_ReLU, kernel_size=3, stride=1):\n        super(CNNLayer, self).__init__()\n\n        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])\n\n        def init_(m):\n            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)\n\n        input_channel = obs_shape[0]\n        input_width = obs_shape[1]\n        input_height = obs_shape[2]\n\n        self.cnn = nn.Sequential(\n            init_(nn.Conv2d(in_channels=input_channel,\n                            out_channels=hidden_size // 2,\n                            kernel_size=kernel_size,\n                            stride=stride)\n                  ),\n            active_func,\n            Flatten(),\n            init_(nn.Linear(hidden_size // 2 * (input_width - kernel_size + stride) * (input_height - kernel_size + stride),\n                            hidden_size)\n                  ),\n            active_func,\n            init_(nn.Linear(hidden_size, hidden_size)), active_func)\n\n    def forward(self, x):\n        x = x / 255.0\n        x = self.cnn(x)\n        return x\n\n\nclass CNNBase(nn.Module):\n    def __init__(self, args, obs_shape):\n        super(CNNBase, self).__init__()\n\n        self._use_orthogonal = args.use_orthogonal\n        self._use_ReLU = args.use_ReLU\n        self.hidden_size = args.hidden_size\n\n        self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._use_ReLU)\n\n    def forward(self, x):\n        x = self.cnn(x)\n        return x\n"
  },
  {
    "path": "algorithms/utils/distributions.py",
    "content": "import torch\nimport torch.nn as nn\nfrom .util import init\n\n\"\"\"\nModify standard PyTorch distributions so they to make compatible with this codebase. \n\"\"\"\n\n#\n# Standardize distribution interfaces\n#\n\n# Categorical\nclass FixedCategorical(torch.distributions.Categorical):\n    def sample(self):\n        return super().sample().unsqueeze(-1)\n\n    def log_probs(self, actions):\n        return (\n            super()\n            .log_prob(actions.squeeze(-1))\n            .view(actions.size(0), -1)\n            .sum(-1)\n            .unsqueeze(-1)\n        )\n\n    def mode(self):\n        return self.probs.argmax(dim=-1, keepdim=True)\n\n\n# Normal\nclass FixedNormal(torch.distributions.Normal):\n    def log_probs(self, actions):\n        return super().log_prob(actions)\n        # return super().log_prob(actions).sum(-1, keepdim=True)\n\n    def entrop(self):\n        return super.entropy().sum(-1)\n\n    def mode(self):\n        return self.mean\n\n\n# Bernoulli\nclass FixedBernoulli(torch.distributions.Bernoulli):\n    def log_probs(self, actions):\n        return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)\n\n    def entropy(self):\n        return super().entropy().sum(-1)\n\n    def mode(self):\n        return torch.gt(self.probs, 0.5).float()\n\n\nclass Categorical(nn.Module):\n    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):\n        super(Categorical, self).__init__()\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n        def init_(m): \n            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)\n\n        self.linear = init_(nn.Linear(num_inputs, num_outputs))\n\n    def forward(self, x, available_actions=None):\n        x = self.linear(x)\n        if available_actions is not None:\n            x[available_actions == 0] = -1e10\n        return FixedCategorical(logits=x)\n\n\n# class DiagGaussian(nn.Module):\n#     def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):\n#         super(DiagGaussian, self).__init__()\n#\n#         init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n#         def init_(m):\n#             return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)\n#\n#         self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))\n#         self.logstd = AddBias(torch.zeros(num_outputs))\n#\n#     def forward(self, x, available_actions=None):\n#         action_mean = self.fc_mean(x)\n#\n#         #  An ugly hack for my KFAC implementation.\n#         zeros = torch.zeros(action_mean.size())\n#         if x.is_cuda:\n#             zeros = zeros.cuda()\n#\n#         action_logstd = self.logstd(zeros)\n#         return FixedNormal(action_mean, action_logstd.exp())\n\nclass DiagGaussian(nn.Module):\n    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01, args=None):\n        super(DiagGaussian, self).__init__()\n\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n\n        def init_(m):\n            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)\n\n        if args is not None:\n            self.std_x_coef = args.std_x_coef\n            self.std_y_coef = args.std_y_coef\n        else:\n            self.std_x_coef = 1.\n            self.std_y_coef = 0.5\n        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))\n        log_std = torch.ones(num_outputs) * self.std_x_coef\n        self.log_std = torch.nn.Parameter(log_std)\n\n    def forward(self, x, available_actions=None):\n        action_mean = self.fc_mean(x)\n        action_std = torch.sigmoid(self.log_std / self.std_x_coef) * self.std_y_coef\n        return FixedNormal(action_mean, action_std)\n\nclass Bernoulli(nn.Module):\n    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):\n        super(Bernoulli, self).__init__()\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n        def init_(m): \n            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)\n        \n        self.linear = init_(nn.Linear(num_inputs, num_outputs))\n\n    def forward(self, x):\n        x = self.linear(x)\n        return FixedBernoulli(logits=x)\n\nclass AddBias(nn.Module):\n    def __init__(self, bias):\n        super(AddBias, self).__init__()\n        self._bias = nn.Parameter(bias.unsqueeze(1))\n\n    def forward(self, x):\n        if x.dim() == 2:\n            bias = self._bias.t().view(1, -1)\n        else:\n            bias = self._bias.t().view(1, -1, 1, 1)\n\n        return x + bias\n"
  },
  {
    "path": "algorithms/utils/mlp.py",
    "content": "import torch.nn as nn\nfrom .util import init, get_clones\n\n\"\"\"MLP modules.\"\"\"\n\nclass MLPLayer(nn.Module):\n    def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU):\n        super(MLPLayer, self).__init__()\n        self._layer_N = layer_N\n\n        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]\n        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]\n        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])\n\n        def init_(m):\n            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)\n\n        self.fc1 = nn.Sequential(\n            init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))\n        # self.fc_h = nn.Sequential(init_(\n        #     nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size))\n        # self.fc2 = get_clones(self.fc_h, self._layer_N)\n        self.fc2 = nn.ModuleList([nn.Sequential(init_(\n            nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) for i in range(self._layer_N)])\n\n    def forward(self, x):\n        x = self.fc1(x)\n        for i in range(self._layer_N):\n            x = self.fc2[i](x)\n        return x\n\n\nclass MLPBase(nn.Module):\n    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):\n        super(MLPBase, self).__init__()\n\n        self._use_feature_normalization = args.use_feature_normalization\n        self._use_orthogonal = args.use_orthogonal\n        self._use_ReLU = args.use_ReLU\n        self._stacked_frames = args.stacked_frames\n        self._layer_N = args.layer_N\n        self.hidden_size = args.hidden_size\n\n        obs_dim = obs_shape[0]\n\n        if self._use_feature_normalization:\n            self.feature_norm = nn.LayerNorm(obs_dim)\n\n        self.mlp = MLPLayer(obs_dim, self.hidden_size,\n                              self._layer_N, self._use_orthogonal, self._use_ReLU)\n\n    def forward(self, x):\n        if self._use_feature_normalization:\n            x = self.feature_norm(x)\n\n        x = self.mlp(x)\n\n        return x"
  },
  {
    "path": "algorithms/utils/rnn.py",
    "content": "import torch\nimport torch.nn as nn\n\n\"\"\"RNN modules.\"\"\"\n\n\nclass RNNLayer(nn.Module):\n    def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal):\n        super(RNNLayer, self).__init__()\n        self._recurrent_N = recurrent_N\n        self._use_orthogonal = use_orthogonal\n\n        self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)\n        for name, param in self.rnn.named_parameters():\n            if 'bias' in name:\n                nn.init.constant_(param, 0)\n            elif 'weight' in name:\n                if self._use_orthogonal:\n                    nn.init.orthogonal_(param)\n                else:\n                    nn.init.xavier_uniform_(param)\n        self.norm = nn.LayerNorm(outputs_dim)\n\n    def forward(self, x, hxs, masks):\n        if x.size(0) == hxs.size(0):\n            x, hxs = self.rnn(x.unsqueeze(0),\n                              (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())\n            x = x.squeeze(0)\n            hxs = hxs.transpose(0, 1)\n        else:\n            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)\n            N = hxs.size(0)\n            T = int(x.size(0) / N)\n\n            # unflatten\n            x = x.view(T, N, x.size(1))\n\n            # Same deal with masks\n            masks = masks.view(T, N)\n\n            # Let's figure out which steps in the sequence have a zero for any agent\n            # We will always assume t=0 has a zero in it as that makes the logic cleaner\n            has_zeros = ((masks[1:] == 0.0)\n                         .any(dim=-1)\n                         .nonzero()\n                         .squeeze()\n                         .cpu())\n\n            # +1 to correct the masks[1:]\n            if has_zeros.dim() == 0:\n                # Deal with scalar\n                has_zeros = [has_zeros.item() + 1]\n            else:\n                has_zeros = (has_zeros + 1).numpy().tolist()\n\n            # add t=0 and t=T to the list\n            has_zeros = [0] + has_zeros + [T]\n\n            hxs = hxs.transpose(0, 1)\n\n            outputs = []\n            for i in range(len(has_zeros) - 1):\n                # We can now process steps that don't have any zeros in masks together!\n                # This is much faster\n                start_idx = has_zeros[i]\n                end_idx = has_zeros[i + 1]\n                temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()\n                rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp)\n                outputs.append(rnn_scores)\n\n            # assert len(outputs) == T\n            # x is a (T, N, -1) tensor\n            x = torch.cat(outputs, dim=0)\n\n            # flatten\n            x = x.reshape(T * N, -1)\n            hxs = hxs.transpose(0, 1)\n\n        x = self.norm(x)\n        return x, hxs\n"
  },
  {
    "path": "algorithms/utils/util.py",
    "content": "import copy\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\ndef init(module, weight_init, bias_init, gain=1):\n    weight_init(module.weight.data, gain=gain)\n    bias_init(module.bias.data)\n    return module\n\ndef get_clones(module, N):\n    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])\n\ndef check(input):\n    output = torch.from_numpy(input) if type(input) == np.ndarray else input\n    return output\n"
  },
  {
    "path": "configs/config.py",
    "content": "import argparse\n\ndef get_config():\n    \"\"\"\n    The configuration parser for common hyperparameters of all environment. \n    Please reach each `scripts/train/<env>_runner.py` file to find private hyperparameters\n    only used in <env>.\n\n    Prepare parameters:\n        --algorithm_name <algorithm_name>\n            specifiy the algorithm, including `[\"happo\", \"hatrpo\"]`\n        --experiment_name <str>\n            an identifier to distinguish different experiment.\n        --seed <int>\n            set seed for numpy and torch \n        --seed_specify\n            by default True Random or specify seed for numpy/torch\n        --running_id <int>\n            the running index of experiment (default=1)\n        --cuda\n            by default True, will use GPU to train; or else will use CPU; \n        --cuda_deterministic\n            by default, make sure random seed effective. if set, bypass such function.\n        --n_training_threads <int>\n            number of training threads working in parallel. by default 1\n        --n_rollout_threads <int>\n            number of parallel envs for training rollout. by default 32\n        --n_eval_rollout_threads <int>\n            number of parallel envs for evaluating rollout. by default 1\n        --n_render_rollout_threads <int>\n            number of parallel envs for rendering, could only be set as 1 for some environments.\n        --num_env_steps <int>\n            number of env steps to train (default: 10e6)\n\n    \n    Env parameters:\n        --env_name <str>\n            specify the name of environment\n        --use_obs_instead_of_state\n            [only for some env] by default False, will use global state; or else will use concatenated local obs.\n    \n    Replay Buffer parameters:\n        --episode_length <int>\n            the max length of episode in the buffer. \n    \n    Network parameters:\n        --share_policy\n            by default True, all agents will share the same network; set to make training agents use different policies. \n        --use_centralized_V\n            by default True, use centralized training mode; or else will decentralized training mode.\n        --stacked_frames <int>\n            Number of input frames which should be stack together.\n        --hidden_size <int>\n            Dimension of hidden layers for actor/critic networks\n        --layer_N <int>\n            Number of layers for actor/critic networks\n        --use_ReLU\n            by default True, will use ReLU. or else will use Tanh.\n        --use_popart\n            by default True, use running mean and std to normalize rewards. \n        --use_feature_normalization\n            by default True, apply layernorm to normalize inputs. \n        --use_orthogonal\n            by default True, use Orthogonal initialization for weights and 0 initialization for biases. or else, will use xavier uniform inilialization.\n        --gain\n            by default 0.01, use the gain # of last action layer\n        --use_naive_recurrent_policy\n            by default False, use the whole trajectory to calculate hidden states.\n        --use_recurrent_policy\n            by default, use Recurrent Policy. If set, do not use.\n        --recurrent_N <int>\n            The number of recurrent layers ( default 1).\n        --data_chunk_length <int>\n            Time length of chunks used to train a recurrent_policy, default 10.\n    \n    Optimizer parameters:\n        --lr <float>\n            learning rate parameter,  (default: 5e-4, fixed).\n        --critic_lr <float>\n            learning rate of critic  (default: 5e-4, fixed)\n        --opti_eps <float>\n            RMSprop optimizer epsilon (default: 1e-5)\n        --weight_decay <float>\n            coefficience of weight decay (default: 0)\n    \n    TRPO parameters:\n        --kl_threshold <float>\n            the threshold of kl-divergence (default: 0.01)\n        --ls_step <int> \n            the step of line search (default: 10)\n        --accept_ratio <float>\n            accept ratio of loss improve (default: 0.5)\n    \n    PPO parameters:\n        --ppo_epoch <int>\n            number of ppo epochs (default: 15)\n        --use_clipped_value_loss \n            by default, clip loss value. If set, do not clip loss value.\n        --clip_param <float>\n            ppo clip parameter (default: 0.2)\n        --num_mini_batch <int>\n            number of batches for ppo (default: 1)\n        --entropy_coef <float>\n            entropy term coefficient (default: 0.01)\n        --use_max_grad_norm \n            by default, use max norm of gradients. If set, do not use.\n        --max_grad_norm <float>\n            max norm of gradients (default: 0.5)\n        --use_gae\n            by default, use generalized advantage estimation. If set, do not use gae.\n        --gamma <float>\n            discount factor for rewards (default: 0.99)\n        --gae_lambda <float>\n            gae lambda parameter (default: 0.95)\n        --use_proper_time_limits\n            by default, the return value does consider limits of time. If set, compute returns with considering time limits factor.\n        --use_huber_loss\n            by default, use huber loss. If set, do not use huber loss.\n        --use_value_active_masks\n            by default True, whether to mask useless data in value loss.  \n        --huber_delta <float>\n            coefficient of huber loss.  \n\n    \n    Run parameters：\n        --use_linear_lr_decay\n            by default, do not apply linear decay to learning rate. If set, use a linear schedule on the learning rate\n        --save_interval <int>\n            time duration between contiunous twice models saving.\n        --log_interval <int>\n            time duration between contiunous twice log printing.\n        --model_dir <str>\n            by default None. set the path to pretrained model.\n\n    Eval parameters:\n        --use_eval\n            by default, do not start evaluation. If set`, start evaluation alongside with training.\n        --eval_interval <int>\n            time duration between contiunous twice evaluation progress.\n        --eval_episodes <int>\n            number of episodes of a single evaluation.\n    \n    Render parameters:\n        --save_gifs\n            by default, do not save render video. If set, save video.\n        --use_render\n            by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.\n        --render_episodes <int>\n            the number of episodes to render a given env\n        --ifi <float>\n            the play interval of each rendered image in saved video.\n    \n    Pretrained parameters:\n        \n    \"\"\"\n    parser = argparse.ArgumentParser(description='onpolicy_algorithm', formatter_class=argparse.RawDescriptionHelpFormatter)\n\n    # prepare parameters\n    parser.add_argument(\"--algorithm_name\", type=str,\n                        default=' ', choices=[\"happo\",\"hatrpo\"])\n    parser.add_argument(\"--experiment_name\", type=str, \n                        default=\"check\", help=\"an identifier to distinguish different experiment.\")\n    parser.add_argument(\"--seed\", type=int, \n                        default=1, help=\"Random seed for numpy/torch\")\n    parser.add_argument(\"--seed_specify\", action=\"store_false\",\n                        default=True, help=\"Random or specify seed for numpy/torch\")\n    parser.add_argument(\"--running_id\", type=int, \n                        default=1, help=\"the running index of experiment\")\n    parser.add_argument(\"--cuda\", action='store_false', \n                        default=True, help=\"by default True, will use GPU to train; or else will use CPU;\")\n    parser.add_argument(\"--cuda_deterministic\", action='store_false', \n                        default=True, help=\"by default, make sure random seed effective. if set, bypass such function.\")\n    parser.add_argument(\"--n_training_threads\", type=int,\n                        default=1, help=\"Number of torch threads for training\")\n    parser.add_argument(\"--n_rollout_threads\", type=int, \n                        default=32, help=\"Number of parallel envs for training rollouts\")\n    parser.add_argument(\"--n_eval_rollout_threads\", type=int, \n                        default=1, help=\"Number of parallel envs for evaluating rollouts\")\n    parser.add_argument(\"--n_render_rollout_threads\", type=int, \n                        default=1, help=\"Number of parallel envs for rendering rollouts\")\n    parser.add_argument(\"--num_env_steps\", type=int, \n                        default=10e6, help='Number of environment steps to train (default: 10e6)')\n    parser.add_argument(\"--user_name\", type=str, \n                        default='marl',help=\"[for wandb usage], to specify user's name for simply collecting training data.\")\n    # env parameters\n    parser.add_argument(\"--env_name\", type=str, \n                        default='StarCraft2', help=\"specify the name of environment\")\n    parser.add_argument(\"--use_obs_instead_of_state\", action='store_true',\n                        default=False, help=\"Whether to use global state or concatenated obs\")\n\n    # replay buffer parameters\n    parser.add_argument(\"--episode_length\", type=int,\n                        default=200, help=\"Max length for any episode\")\n\n    # network parameters\n    parser.add_argument(\"--share_policy\", action='store_false',\n                        default=True, help='Whether agent share the same policy')\n    parser.add_argument(\"--use_centralized_V\", action='store_false',\n                        default=True, help=\"Whether to use centralized V function\")\n    parser.add_argument(\"--stacked_frames\", type=int, \n                        default=1, help=\"Dimension of hidden layers for actor/critic networks\")\n    parser.add_argument(\"--use_stacked_frames\", action='store_true',\n                        default=False, help=\"Whether to use stacked_frames\")\n    parser.add_argument(\"--hidden_size\", type=int, \n                        default=64, help=\"Dimension of hidden layers for actor/critic networks\") \n    parser.add_argument(\"--layer_N\", type=int, \n                        default=1, help=\"Number of layers for actor/critic networks\")\n    parser.add_argument(\"--use_ReLU\", action='store_false',\n                        default=True, help=\"Whether to use ReLU\")\n    parser.add_argument(\"--use_popart\", action='store_false', \n                        default=True, help=\"by default True, use running mean and std to normalize rewards.\")\n    parser.add_argument(\"--use_valuenorm\", action='store_false', \n                        default=True, help=\"by default True, use running mean and std to normalize rewards.\")\n    parser.add_argument(\"--use_feature_normalization\", action='store_false',\n                        default=True, help=\"Whether to apply layernorm to the inputs\")\n    parser.add_argument(\"--use_orthogonal\", action='store_false', \n                        default=True, help=\"Whether to use Orthogonal initialization for weights and 0 initialization for biases\")\n    parser.add_argument(\"--gain\", type=float, \n                        default=0.01, help=\"The gain # of last action layer\")\n\n    # recurrent parameters\n    parser.add_argument(\"--use_naive_recurrent_policy\", action='store_true',\n                        default=False, help='Whether to use a naive recurrent policy')\n    parser.add_argument(\"--use_recurrent_policy\", action='store_true',\n                        default=False, help='use a recurrent policy')\n    parser.add_argument(\"--recurrent_N\", type=int, \n                        default=1, help=\"The number of recurrent layers.\")\n    parser.add_argument(\"--data_chunk_length\", type=int, \n                        default=10, help=\"Time length of chunks used to train a recurrent_policy\")\n    \n    # optimizer parameters\n    parser.add_argument(\"--lr\", type=float, \n                        default=5e-4, help='learning rate (default: 5e-4)')\n    parser.add_argument(\"--critic_lr\", type=float, \n                        default=5e-4, help='critic learning rate (default: 5e-4)')\n    parser.add_argument(\"--opti_eps\", type=float, \n                        default=1e-5, help='RMSprop optimizer epsilon (default: 1e-5)')\n    parser.add_argument(\"--weight_decay\", type=float, default=0)\n    parser.add_argument(\"--std_x_coef\", type=float, default=1)\n    parser.add_argument(\"--std_y_coef\", type=float, default=0.5)\n\n\n    # trpo parameters\n    parser.add_argument(\"--kl_threshold\", type=float, \n                        default=0.01, help='the threshold of kl-divergence (default: 0.01)')\n    parser.add_argument(\"--ls_step\", type=int, \n                        default=10, help='number of line search (default: 10)')\n    parser.add_argument(\"--accept_ratio\", type=float, \n                        default=0.5, help='accept ratio of loss improve (default: 0.5)')\n\n    # ppo parameters\n    parser.add_argument(\"--ppo_epoch\", type=int, \n                        default=15, help='number of ppo epochs (default: 15)')\n    parser.add_argument(\"--use_clipped_value_loss\", action='store_false', \n                        default=True, help=\"by default, clip loss value. If set, do not clip loss value.\")\n    parser.add_argument(\"--clip_param\", type=float, \n                        default=0.2, help='ppo clip parameter (default: 0.2)')\n    parser.add_argument(\"--num_mini_batch\", type=int, \n                        default=1, help='number of batches for ppo (default: 1)')\n    parser.add_argument(\"--entropy_coef\", type=float, \n                        default=0.01, help='entropy term coefficient (default: 0.01)')\n    parser.add_argument(\"--value_loss_coef\", type=float,\n                        default=1, help='value loss coefficient (default: 0.5)')\n    parser.add_argument(\"--use_max_grad_norm\", action='store_false', \n                        default=True, help=\"by default, use max norm of gradients. If set, do not use.\")\n    parser.add_argument(\"--max_grad_norm\", type=float, \n                        default=10.0, help='max norm of gradients (default: 0.5)')\n    parser.add_argument(\"--use_gae\", action='store_false',\n                        default=True, help='use generalized advantage estimation')\n    parser.add_argument(\"--gamma\", type=float, default=0.99,\n                        help='discount factor for rewards (default: 0.99)')\n    parser.add_argument(\"--gae_lambda\", type=float, default=0.95,\n                        help='gae lambda parameter (default: 0.95)')\n    parser.add_argument(\"--use_proper_time_limits\", action='store_true',\n                        default=False, help='compute returns taking into account time limits')\n    parser.add_argument(\"--use_huber_loss\", action='store_false', \n                        default=True, help=\"by default, use huber loss. If set, do not use huber loss.\")\n    parser.add_argument(\"--use_value_active_masks\", action='store_false', \n                        default=True, help=\"by default True, whether to mask useless data in value loss.\")\n    parser.add_argument(\"--use_policy_active_masks\", action='store_false', \n                        default=True, help=\"by default True, whether to mask useless data in policy loss.\")\n    parser.add_argument(\"--huber_delta\", type=float, \n                        default=10.0, help=\" coefficience of huber loss.\")\n\n    # run parameters\n    parser.add_argument(\"--use_linear_lr_decay\", action='store_true',\n                        default=False, help='use a linear schedule on the learning rate')\n    parser.add_argument(\"--save_interval\", type=int, \n                        default=1, help=\"time duration between contiunous twice models saving.\")\n    parser.add_argument(\"--log_interval\", type=int, \n                        default=5, help=\"time duration between contiunous twice log printing.\")\n    parser.add_argument(\"--model_dir\", type=str, \n                        default=None, help=\"by default None. set the path to pretrained model.\")\n\n    # eval parameters\n    parser.add_argument(\"--use_eval\", action='store_true', \n                        default=False, help=\"by default, do not start evaluation. If set`, start evaluation alongside with training.\")\n    parser.add_argument(\"--eval_interval\", type=int, \n                        default=25, help=\"time duration between contiunous twice evaluation progress.\")\n    parser.add_argument(\"--eval_episodes\", type=int, \n                        default=32, help=\"number of episodes of a single evaluation.\")\n\n    # render parameters\n    parser.add_argument(\"--save_gifs\", action='store_true', \n                        default=False, help=\"by default, do not save render video. If set, save video.\")\n    parser.add_argument(\"--use_render\", action='store_true', \n                        default=False, help=\"by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.\")\n    parser.add_argument(\"--render_episodes\", type=int, \n                        default=5, help=\"the number of episodes to render a given env\")\n    parser.add_argument(\"--ifi\", type=float, \n                        default=0.1, help=\"the play interval of each rendered image in saved video.\")\n\n    return parser\n"
  },
  {
    "path": "envs/__init__.py",
    "content": "\nimport socket\nfrom absl import flags\nFLAGS = flags.FLAGS\nFLAGS(['train_sc.py'])\n\n\n"
  },
  {
    "path": "envs/env_wrappers.py",
    "content": "\"\"\"\nModified from OpenAI Baselines code to work with multi-agent envs\n\"\"\"\nimport numpy as np\nimport torch\nfrom multiprocessing import Process, Pipe\nfrom abc import ABC, abstractmethod\nfrom utils.util import tile_images\n\nclass CloudpickleWrapper(object):\n    \"\"\"\n    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)\n    \"\"\"\n\n    def __init__(self, x):\n        self.x = x\n\n    def __getstate__(self):\n        import cloudpickle\n        return cloudpickle.dumps(self.x)\n\n    def __setstate__(self, ob):\n        import pickle\n        self.x = pickle.loads(ob)\n\n\nclass ShareVecEnv(ABC):\n    \"\"\"\n    An abstract asynchronous, vectorized environment.\n    Used to batch data from multiple copies of an environment, so that\n    each observation becomes an batch of observations, and expected action is a batch of actions to\n    be applied per-environment.\n    \"\"\"\n    closed = False\n    viewer = None\n\n    metadata = {\n        'render.modes': ['human', 'rgb_array']\n    }\n\n    def __init__(self, num_envs, observation_space, share_observation_space, action_space):\n        self.num_envs = num_envs\n        self.observation_space = observation_space\n        self.share_observation_space = share_observation_space\n        self.action_space = action_space\n\n    @abstractmethod\n    def reset(self):\n        \"\"\"\n        Reset all the environments and return an array of\n        observations, or a dict of observation arrays.\n\n        If step_async is still doing work, that work will\n        be cancelled and step_wait() should not be called\n        until step_async() is invoked again.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_async(self, actions):\n        \"\"\"\n        Tell all the environments to start taking a step\n        with the given actions.\n        Call step_wait() to get the results of the step.\n\n        You should not call this if a step_async run is\n        already pending.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def step_wait(self):\n        \"\"\"\n        Wait for the step taken with step_async().\n\n        Returns (obs, rews, dones, infos):\n         - obs: an array of observations, or a dict of\n                arrays of observations.\n         - rews: an array of rewards\n         - dones: an array of \"episode done\" booleans\n         - infos: a sequence of info objects\n        \"\"\"\n        pass\n\n    def close_extras(self):\n        \"\"\"\n        Clean up the  extra resources, beyond what's in this base class.\n        Only runs when not self.closed.\n        \"\"\"\n        pass\n\n    def close(self):\n        if self.closed:\n            return\n        if self.viewer is not None:\n            self.viewer.close()\n        self.close_extras()\n        self.closed = True\n\n    def step(self, actions):\n        \"\"\"\n        Step the environments synchronously.\n\n        This is available for backwards compatibility.\n        \"\"\"\n        self.step_async(actions)\n        return self.step_wait()\n\n    def render(self, mode='human'):\n        imgs = self.get_images()\n        bigimg = tile_images(imgs)\n        if mode == 'human':\n            self.get_viewer().imshow(bigimg)\n            return self.get_viewer().isopen\n        elif mode == 'rgb_array':\n            return bigimg\n        else:\n            raise NotImplementedError\n\n    def get_images(self):\n        \"\"\"\n        Return RGB images from each environment\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def unwrapped(self):\n        if isinstance(self, VecEnvWrapper):\n            return self.venv.unwrapped\n        else:\n            return self\n\n    def get_viewer(self):\n        if self.viewer is None:\n            from gym.envs.classic_control import rendering\n            self.viewer = rendering.SimpleImageViewer()\n        return self.viewer\n\n\ndef worker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            if 'bool' in done.__class__.__name__:\n                if done:\n                    ob = env.reset()\n            else:\n                if np.all(done):\n                    ob = env.reset()\n\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset()\n            remote.send((ob))\n        elif cmd == 'render':\n            if data == \"rgb_array\":\n                fr = env.render(mode=data)\n                remote.send(fr)\n            elif data == \"human\":\n                env.render(mode=data)\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            env.close()\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send((env.observation_space, env.share_observation_space, env.action_space))\n        else:\n            raise NotImplementedError\n\n\nclass GuardSubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = False  # could cause zombie process\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv()\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        obs = [remote.recv() for remote in self.remotes]\n        return np.stack(obs)\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\nclass SubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv()\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        obs = [remote.recv() for remote in self.remotes]\n        return np.stack(obs)\n\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n    def render(self, mode=\"rgb_array\"):\n        for remote in self.remotes:\n            remote.send(('render', mode))\n        if mode == \"rgb_array\":   \n            frame = [remote.recv() for remote in self.remotes]\n            return np.stack(frame) \n\n\ndef shareworker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, s_ob, reward, done, info, available_actions = env.step(data)\n            if 'bool' in done.__class__.__name__:\n                if done:\n                    ob, s_ob, available_actions = env.reset()\n            else:\n                if np.all(done):\n                    ob, s_ob, available_actions = env.reset()\n\n            remote.send((ob, s_ob, reward, done, info, available_actions))\n        elif cmd == 'reset':\n            ob, s_ob, available_actions = env.reset()\n            remote.send((ob, s_ob, available_actions))\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'render':\n            if data == \"rgb_array\":\n                fr = env.render(mode=data)\n                remote.send(fr)\n            elif data == \"human\":\n                env.render(mode=data)\n        elif cmd == 'close':\n            env.close()\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send(\n                (env.observation_space, env.share_observation_space, env.action_space))\n        elif cmd == 'render_vulnerability':\n            fr = env.render_vulnerability(data)\n            remote.send((fr))\n        elif cmd == 'get_num_agents':\n            remote.send((env.n_agents))\n        else:\n            raise NotImplementedError\n\n\nclass ShareSubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=shareworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n        self.remotes[0].send(('get_num_agents', None))\n        self.n_agents = self.remotes[0].recv()\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv(\n        )\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, share_obs, rews, dones, infos, available_actions = zip(*results)\n        return np.stack(obs), np.stack(share_obs), np.stack(rews), np.stack(dones), infos, np.stack(available_actions)\n\n    def reset(self):\n        for remote in self.remotes:\n            remote.send(('reset', None))\n        results = [remote.recv() for remote in self.remotes]\n        obs, share_obs, available_actions = zip(*results)\n        return np.stack(obs), np.stack(share_obs), np.stack(available_actions)\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\ndef choosesimpleworker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset(data)\n            remote.send((ob))\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            env.close()\n            remote.close()\n            break\n        elif cmd == 'render':\n            if data == \"rgb_array\":\n                fr = env.render(mode=data)\n                remote.send(fr)\n            elif data == \"human\":\n                env.render(mode=data)\n        elif cmd == 'get_spaces':\n            remote.send(\n                (env.observation_space, env.share_observation_space, env.action_space))\n        else:\n            raise NotImplementedError\n\n\nclass ChooseSimpleSubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=choosesimpleworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv()\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self, reset_choose):\n        for remote, choose in zip(self.remotes, reset_choose):\n            remote.send(('reset', choose))\n        obs = [remote.recv() for remote in self.remotes]\n        return np.stack(obs)\n\n    def render(self, mode=\"rgb_array\"):\n        for remote in self.remotes:\n            remote.send(('render', mode))\n        if mode == \"rgb_array\":   \n            frame = [remote.recv() for remote in self.remotes]\n            return np.stack(frame)\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\ndef chooseworker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, s_ob, reward, done, info, available_actions = env.step(data)\n            remote.send((ob, s_ob, reward, done, info, available_actions))\n        elif cmd == 'reset':\n            ob, s_ob, available_actions = env.reset(data)\n            remote.send((ob, s_ob, available_actions))\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            env.close()\n            remote.close()\n            break\n        elif cmd == 'render':\n            remote.send(env.render(mode='rgb_array'))\n        elif cmd == 'get_spaces':\n            remote.send(\n                (env.observation_space, env.share_observation_space, env.action_space))\n        else:\n            raise NotImplementedError\n\n\nclass ChooseSubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=chooseworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = True  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv(\n        )\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, share_obs, rews, dones, infos, available_actions = zip(*results)\n        return np.stack(obs), np.stack(share_obs), np.stack(rews), np.stack(dones), infos, np.stack(available_actions)\n\n    def reset(self, reset_choose):\n        for remote, choose in zip(self.remotes, reset_choose):\n            remote.send(('reset', choose))\n        results = [remote.recv() for remote in self.remotes]\n        obs, share_obs, available_actions = zip(*results)\n        return np.stack(obs), np.stack(share_obs), np.stack(available_actions)\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\ndef chooseguardworker(remote, parent_remote, env_fn_wrapper):\n    parent_remote.close()\n    env = env_fn_wrapper.x()\n    while True:\n        cmd, data = remote.recv()\n        if cmd == 'step':\n            ob, reward, done, info = env.step(data)\n            remote.send((ob, reward, done, info))\n        elif cmd == 'reset':\n            ob = env.reset(data)\n            remote.send((ob))\n        elif cmd == 'reset_task':\n            ob = env.reset_task()\n            remote.send(ob)\n        elif cmd == 'close':\n            env.close()\n            remote.close()\n            break\n        elif cmd == 'get_spaces':\n            remote.send(\n                (env.observation_space, env.share_observation_space, env.action_space))\n        else:\n            raise NotImplementedError\n\n\nclass ChooseGuardSubprocVecEnv(ShareVecEnv):\n    def __init__(self, env_fns, spaces=None):\n        \"\"\"\n        envs: list of gym environments to run in subprocesses\n        \"\"\"\n        self.waiting = False\n        self.closed = False\n        nenvs = len(env_fns)\n        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])\n        self.ps = [Process(target=chooseguardworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))\n                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]\n        for p in self.ps:\n            p.daemon = False  # if the main process crashes, we should not cause things to hang\n            p.start()\n        for remote in self.work_remotes:\n            remote.close()\n        self.remotes[0].send(('get_spaces', None))\n        observation_space, share_observation_space, action_space = self.remotes[0].recv(\n        )\n        ShareVecEnv.__init__(self, len(env_fns), observation_space,\n                             share_observation_space, action_space)\n\n    def step_async(self, actions):\n        for remote, action in zip(self.remotes, actions):\n            remote.send(('step', action))\n        self.waiting = True\n\n    def step_wait(self):\n        results = [remote.recv() for remote in self.remotes]\n        self.waiting = False\n        obs, rews, dones, infos = zip(*results)\n        return np.stack(obs), np.stack(rews), np.stack(dones), infos\n\n    def reset(self, reset_choose):\n        for remote, choose in zip(self.remotes, reset_choose):\n            remote.send(('reset', choose))\n        obs = [remote.recv() for remote in self.remotes]\n        return np.stack(obs)\n\n    def reset_task(self):\n        for remote in self.remotes:\n            remote.send(('reset_task', None))\n        return np.stack([remote.recv() for remote in self.remotes])\n\n    def close(self):\n        if self.closed:\n            return\n        if self.waiting:\n            for remote in self.remotes:\n                remote.recv()\n        for remote in self.remotes:\n            remote.send(('close', None))\n        for p in self.ps:\n            p.join()\n        self.closed = True\n\n\n# single env\nclass DummyVecEnv(ShareVecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        ShareVecEnv.__init__(self, len(\n            env_fns), env.observation_space, env.share_observation_space, env.action_space)\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n\n        for (i, done) in enumerate(dones):\n            if 'bool' in done.__class__.__name__:\n                if done:\n                    obs[i] = self.envs[i].reset()\n            else:\n                if np.all(done):\n                    obs[i] = self.envs[i].reset()\n\n        self.actions = None\n        return obs, rews, dones, infos\n\n    def reset(self):\n        obs = [env.reset() for env in self.envs]\n        return np.array(obs)\n\n    def close(self):\n        for env in self.envs:\n            env.close()\n\n    def render(self, mode=\"human\"):\n        if mode == \"rgb_array\":\n            return np.array([env.render(mode=mode) for env in self.envs])\n        elif mode == \"human\":\n            for env in self.envs:\n                env.render(mode=mode)\n        else:\n            raise NotImplementedError\n\n\n\nclass ShareDummyVecEnv(ShareVecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        ShareVecEnv.__init__(self, len(\n            env_fns), env.observation_space, env.share_observation_space, env.action_space)\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]\n        obs, share_obs, rews, dones, infos, available_actions = map(\n            np.array, zip(*results))\n\n        for (i, done) in enumerate(dones):\n            if 'bool' in done.__class__.__name__:\n                if done:\n                    obs[i], share_obs[i], available_actions[i] = self.envs[i].reset()\n            else:\n                if np.all(done):\n                    obs[i], share_obs[i], available_actions[i] = self.envs[i].reset()\n        self.actions = None\n\n        return obs, share_obs, rews, dones, infos, available_actions\n\n    def reset(self):\n        results = [env.reset() for env in self.envs]\n        obs, share_obs, available_actions = map(np.array, zip(*results))\n        return obs, share_obs, available_actions\n\n    def close(self):\n        for env in self.envs:\n            env.close()\n    \n    def render(self, mode=\"human\"):\n        if mode == \"rgb_array\":\n            return np.array([env.render(mode=mode) for env in self.envs])\n        elif mode == \"human\":\n            for env in self.envs:\n                env.render(mode=mode)\n        else:\n            raise NotImplementedError\n\n\nclass ChooseDummyVecEnv(ShareVecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        ShareVecEnv.__init__(self, len(\n            env_fns), env.observation_space, env.share_observation_space, env.action_space)\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]\n        obs, share_obs, rews, dones, infos, available_actions = map(\n            np.array, zip(*results))\n        self.actions = None\n        return obs, share_obs, rews, dones, infos, available_actions\n\n    def reset(self, reset_choose):\n        results = [env.reset(choose)\n                   for (env, choose) in zip(self.envs, reset_choose)]\n        obs, share_obs, available_actions = map(np.array, zip(*results))\n        return obs, share_obs, available_actions\n\n    def close(self):\n        for env in self.envs:\n            env.close()\n\n    def render(self, mode=\"human\"):\n        if mode == \"rgb_array\":\n            return np.array([env.render(mode=mode) for env in self.envs])\n        elif mode == \"human\":\n            for env in self.envs:\n                env.render(mode=mode)\n        else:\n            raise NotImplementedError\n\nclass ChooseSimpleDummyVecEnv(ShareVecEnv):\n    def __init__(self, env_fns):\n        self.envs = [fn() for fn in env_fns]\n        env = self.envs[0]\n        ShareVecEnv.__init__(self, len(\n            env_fns), env.observation_space, env.share_observation_space, env.action_space)\n        self.actions = None\n\n    def step_async(self, actions):\n        self.actions = actions\n\n    def step_wait(self):\n        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]\n        obs, rews, dones, infos = map(np.array, zip(*results))\n        self.actions = None\n        return obs, rews, dones, infos\n\n    def reset(self, reset_choose):\n        obs = [env.reset(choose)\n                   for (env, choose) in zip(self.envs, reset_choose)]\n        return np.array(obs)\n\n    def close(self):\n        for env in self.envs:\n            env.close()\n\n    def render(self, mode=\"human\"):\n        if mode == \"rgb_array\":\n            return np.array([env.render(mode=mode) for env in self.envs])\n        elif mode == \"human\":\n            for env in self.envs:\n                env.render(mode=mode)\n        else:\n            raise NotImplementedError\n"
  },
  {
    "path": "envs/ma_mujoco/__init__.py",
    "content": ""
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/__init__.py",
    "content": "from .mujoco_multi import MujocoMulti\nfrom .coupled_half_cheetah import CoupledHalfCheetah\nfrom .manyagent_swimmer import ManyAgentSwimmerEnv\nfrom .manyagent_ant import ManyAgentAntEnv\n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/.gitignore",
    "content": "*.auto.xml\n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/__init__.py",
    "content": ""
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/coupled_half_cheetah.xml",
    "content": "<!-- Cheetah Model\n    The state space is populated with joints in the order that they are\n    defined in this file. The actuators also operate on joints.\n    State-Space (name/joint/parameter):\n        - rootx     slider      position (m)\n        - rootz     slider      position (m)\n        - rooty     hinge       angle (rad)\n        - bthigh    hinge       angle (rad)\n        - bshin     hinge       angle (rad)\n        - bfoot     hinge       angle (rad)\n        - fthigh    hinge       angle (rad)\n        - fshin     hinge       angle (rad)\n        - ffoot     hinge       angle (rad)\n        - rootx     slider      velocity (m/s)\n        - rootz     slider      velocity (m/s)\n        - rooty     hinge       angular velocity (rad/s)\n        - bthigh    hinge       angular velocity (rad/s)\n        - bshin     hinge       angular velocity (rad/s)\n        - bfoot     hinge       angular velocity (rad/s)\n        - fthigh    hinge       angular velocity (rad/s)\n        - fshin     hinge       angular velocity (rad/s)\n        - ffoot     hinge       angular velocity (rad/s)\n    Actuators (name/actuator/parameter):\n        - bthigh    hinge       torque (N m)\n        - bshin     hinge       torque (N m)\n        - bfoot     hinge       torque (N m)\n        - fthigh    hinge       torque (N m)\n        - fshin     hinge       torque (N m)\n        - ffoot     hinge       torque (N m)\n-->\n<mujoco model=\"cheetah\">\n  <compiler angle=\"radian\" coordinate=\"local\" inertiafromgeom=\"true\" settotalmass=\"14\"/>\n  <default>\n    <joint armature=\".1\" damping=\".01\" limited=\"true\" solimplimit=\"0 .8 .03\" solreflimit=\".02 1\" stiffness=\"8\"/>\n    <geom conaffinity=\"0\" condim=\"3\" contype=\"1\" friction=\".4 .1 .1\" rgba=\"0.8 0.6 .4 1\" solimp=\"0.0 0.8 0.01\" solref=\"0.02 1\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\"/>\n  </default>\n  <size nstack=\"300000\" nuser_geom=\"1\"/>\n  <option gravity=\"0 0 -9.81\" timestep=\"0.01\"/>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 -1 .7\">\n      <site name=\"t1\" pos=\"0.0 0 0\" size=\"0.1\"/>\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <joint armature=\"0\" axis=\"1 0 0\" damping=\"0\" limited=\"false\" name=\"rootx\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 0 1\" damping=\"0\" limited=\"false\" name=\"rootz\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 1 0\" damping=\"0\" limited=\"false\" name=\"rooty\" pos=\"0 0 0\" stiffness=\"0\" type=\"hinge\"/>\n      <geom fromto=\"-.5 0 0 .5 0 0\" name=\"torso\" size=\"0.046\" type=\"capsule\"/>\n      <geom axisangle=\"0 1 0 .87\" name=\"head\" pos=\".6 0 .1\" size=\"0.046 .15\" type=\"capsule\"/>\n      <!-- <site name='tip'  pos='.15 0 .11'/>-->\n      <body name=\"bthigh\" pos=\"-.5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"6\" name=\"bthigh\" pos=\"0 0 0\" range=\"-.52 1.05\" stiffness=\"240\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 -3.8\" name=\"bthigh\" pos=\".1 0 -.13\" size=\"0.046 .145\" type=\"capsule\"/>\n        <body name=\"bshin\" pos=\".16 0 -.25\">\n          <joint axis=\"0 1 0\" damping=\"4.5\" name=\"bshin\" pos=\"0 0 0\" range=\"-.785 .785\" stiffness=\"180\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -2.03\" name=\"bshin\" pos=\"-.14 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .15\" type=\"capsule\"/>\n          <body name=\"bfoot\" pos=\"-.28 0 -.14\">\n            <joint axis=\"0 1 0\" damping=\"3\" name=\"bfoot\" pos=\"0 0 0\" range=\"-.4 .785\" stiffness=\"120\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.27\" name=\"bfoot\" pos=\".03 0 -.097\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .094\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"fthigh\" pos=\".5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"4.5\" name=\"fthigh\" pos=\"0 0 0\" range=\"-1 .7\" stiffness=\"180\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 .52\" name=\"fthigh\" pos=\"-.07 0 -.12\" size=\"0.046 .133\" type=\"capsule\"/>\n        <body name=\"fshin\" pos=\"-.14 0 -.24\">\n          <joint axis=\"0 1 0\" damping=\"3\" name=\"fshin\" pos=\"0 0 0\" range=\"-1.2 .87\" stiffness=\"120\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -.6\" name=\"fshin\" pos=\".065 0 -.09\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .106\" type=\"capsule\"/>\n          <body name=\"ffoot\" pos=\".13 0 -.18\">\n            <joint axis=\"0 1 0\" damping=\"1.5\" name=\"ffoot\" pos=\"0 0 0\" range=\"-.5 .5\" stiffness=\"60\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.6\" name=\"ffoot\" pos=\".045 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .07\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n    <!-- second cheetah definition -->\n    <body name=\"torso2\" pos=\"0 1 .7\">\n      <site name=\"t2\" pos=\"0 0 0\" size=\"0.1\"/>\n      <camera name=\"track2\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <joint armature=\"0\" axis=\"1 0 0\" damping=\"0\" limited=\"false\" name=\"rootx2\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 0 1\" damping=\"0\" limited=\"false\" name=\"rootz2\" pos=\"0 0 0\" stiffness=\"0\" type=\"slide\"/>\n      <joint armature=\"0\" axis=\"0 1 0\" damping=\"0\" limited=\"false\" name=\"rooty2\" pos=\"0 0 0\" stiffness=\"0\" type=\"hinge\"/>\n      <geom fromto=\"-.5 0 0 .5 0 0\" name=\"torso2\" size=\"0.046\" type=\"capsule\"/>\n      <geom axisangle=\"0 1 0 .87\" name=\"head2\" pos=\".6 0 .1\" size=\"0.046 .15\" type=\"capsule\"/>\n      <!-- <site name='tip'  pos='.15 0 .11'/>-->\n      <body name=\"bthigh2\" pos=\"-.5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"6\" name=\"bthigh2\" pos=\"0 0 0\" range=\"-.52 1.05\" stiffness=\"240\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 -3.8\" name=\"bthigh2\" pos=\".1 0 -.13\" size=\"0.046 .145\" type=\"capsule\"/>\n        <body name=\"bshin2\" pos=\".16 0 -.25\">\n          <joint axis=\"0 1 0\" damping=\"4.5\" name=\"bshin2\" pos=\"0 0 0\" range=\"-.785 .785\" stiffness=\"180\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -2.03\" name=\"bshin2\" pos=\"-.14 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .15\" type=\"capsule\"/>\n          <body name=\"bfoot2\" pos=\"-.28 0 -.14\">\n            <joint axis=\"0 1 0\" damping=\"3\" name=\"bfoot2\" pos=\"0 0 0\" range=\"-.4 .785\" stiffness=\"120\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.27\" name=\"bfoot2\" pos=\".03 0 -.097\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .094\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"fthigh2\" pos=\".5 0 0\">\n        <joint axis=\"0 1 0\" damping=\"4.5\" name=\"fthigh2\" pos=\"0 0 0\" range=\"-1 .7\" stiffness=\"180\" type=\"hinge\"/>\n        <geom axisangle=\"0 1 0 .52\" name=\"fthigh2\" pos=\"-.07 0 -.12\" size=\"0.046 .133\" type=\"capsule\"/>\n        <body name=\"fshin2\" pos=\"-.14 0 -.24\">\n          <joint axis=\"0 1 0\" damping=\"3\" name=\"fshin2\" pos=\"0 0 0\" range=\"-1.2 .87\" stiffness=\"120\" type=\"hinge\"/>\n          <geom axisangle=\"0 1 0 -.6\" name=\"fshin2\" pos=\".065 0 -.09\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .106\" type=\"capsule\"/>\n          <body name=\"ffoot2\" pos=\".13 0 -.18\">\n            <joint axis=\"0 1 0\" damping=\"1.5\" name=\"ffoot2\" pos=\"0 0 0\" range=\"-.5 .5\" stiffness=\"60\" type=\"hinge\"/>\n            <geom axisangle=\"0 1 0 -.6\" name=\"ffoot2\" pos=\".045 0 -.07\" rgba=\"0.9 0.6 0.6 1\" size=\"0.046 .07\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <tendon>\n    <spatial name=\"tendon1\" width=\"0.05\" rgba=\".95 .3 .3 1\" limited=\"true\" range=\"1.5 3.5\" stiffness=\"0.1\">\n        <site site=\"t1\"/>\n        <site site=\"t2\"/>\n    </spatial>\n  </tendon>-\n  <actuator>\n    <motor gear=\"120\" joint=\"bthigh\" name=\"bthigh\"/>\n    <motor gear=\"90\" joint=\"bshin\" name=\"bshin\"/>\n    <motor gear=\"60\" joint=\"bfoot\" name=\"bfoot\"/>\n    <motor gear=\"120\" joint=\"fthigh\" name=\"fthigh\"/>\n    <motor gear=\"60\" joint=\"fshin\" name=\"fshin\"/>\n    <motor gear=\"30\" joint=\"ffoot\" name=\"ffoot\"/>\n    <motor gear=\"120\" joint=\"bthigh2\" name=\"bthigh2\"/>\n    <motor gear=\"90\" joint=\"bshin2\" name=\"bshin2\"/>\n    <motor gear=\"60\" joint=\"bfoot2\" name=\"bfoot2\"/>\n    <motor gear=\"120\" joint=\"fthigh2\" name=\"fthigh2\"/>\n    <motor gear=\"60\" joint=\"fshin2\" name=\"fshin2\"/>\n    <motor gear=\"30\" joint=\"ffoot2\" name=\"ffoot2\"/>\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml",
    "content": "<mujoco model=\"ant\">\n  <size nconmax=\"200\"/>\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.01\"/>\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n  <default>\n    <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"5.0\" friction=\"1 0.5 0.5\" margin=\"0.01\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\"0 0 0.75\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <!--<geom name=\"torso_geom\" pos=\"0 0 0\" size=\"0.25\" type=\"sphere\"/>-->\n      <joint armature=\"0\" damping=\"0\" limited=\"false\" margin=\"0.01\" name=\"root\" pos=\"0 0 0\" type=\"free\"/>\n      <body name=\"front_left_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux_1_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_1\" pos=\"0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_1\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_1\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux_4_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_4\" pos=\"0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_4\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 -0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_4\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"fourth_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"midx\" pos=\"0.0 0 0\">\n        <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n        <!--<joint axis=\"0 0 1\" limited=\"true\" name=\"rot2\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>-->\n        <body name=\"front_right_legx\" pos=\"-1 0 0\">\n          <geom fromto=\"0.0 0.0 0.0 0.0 0.2 0.0\" name=\"aux_2_geomx\" size=\"0.08\" type=\"capsule\"/>\n          <body name=\"aux_2x\" pos=\"0.0 0.2 0\">\n            <joint axis=\"0 0 1\" name=\"hip_2x\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geomx\" size=\"0.08\" type=\"capsule\"/>\n            <body pos=\"-0.2 0.2 0\">\n              <joint axis=\"1 1 0\" name=\"ankle_2x\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geomx\" size=\"0.08\" type=\"capsule\"/>\n            </body>\n          </body>\n        </body>\n        <body name=\"back_legx\" pos=\"-1 0 0\">\n          <geom fromto=\"0.0 0.0 0.0 0.0 -0.2 0.0\" name=\"aux_3_geomx\" size=\"0.08\" type=\"capsule\"/>\n          <body name=\"aux_3x\" pos=\"0.0 -0.2 0\">\n            <joint axis=\"0 0 1\" name=\"hip_3x\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geomx\" size=\"0.08\" type=\"capsule\"/>\n            <body pos=\"-0.2 -0.2 0\">\n              <joint axis=\"-1 1 0\" name=\"ankle_3x\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geomx\" size=\"0.08\" type=\"capsule\"/>\n            </body>\n          </body>\n        </body>\n        <body name=\"mid\" pos=\"-1 0 0\">\n          <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n          <!--<joint axis=\"0 0 1\" limited=\"true\" name=\"rot2\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>-->\n          <!--<body name=\"front_right_leg\" pos=\"-1 0 0\">\n            <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body name=\"aux_2\" pos=\"-0.2 0.2 0\">\n              <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n              <body pos=\"-0.2 0.2 0\">\n                <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n              </body>\n            </body>\n          </body>\n          <body name=\"back_leg\" pos=\"-1 0 0\">\n            <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body name=\"aux_3\" pos=\"-0.2 -0.2 0\">\n              <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n              <body pos=\"-0.2 -0.2 0\">\n                <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n              </body>\n            </body>\n          </body>-->\n          <body name=\"front_right_leg\" pos=\"-1 0 0\">\n            <geom fromto=\"0.0 0.0 0.0 0.0 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body name=\"aux_2\" pos=\"0.0 0.2 0\">\n              <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n              <body pos=\"-0.2 0.2 0\">\n                <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n              </body>\n            </body>\n          </body>\n          <body name=\"back_leg\" pos=\"-1 0 0\">\n            <geom fromto=\"0.0 0.0 0.0 0.0 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body name=\"aux_3\" pos=\"0.0 -0.2 0\">\n              <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n              <body pos=\"-0.2 -0.2 0\">\n                <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n              </body>\n            </body>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_3\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_3\" gear=\"150\"/>\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml.template",
    "content": "<mujoco model=\"ant\">\n  <size nconmax=\"200\"/>\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.005\"/>\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n  <default>\n    <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"5.0\" friction=\"1 0.5 0.5\" margin=\"0.01\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso_0\" pos=\"0 0 0.75\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <!--<geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>-->\n      <joint armature=\"0\" damping=\"0\" limited=\"false\" margin=\"0.01\" name=\"root\" pos=\"0 0 0\" type=\"free\"/>\n      <body name=\"front_left_leg_0\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux1_geom_0\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux1_0\" pos=\"0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip1_0\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom_0\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle1_0\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom_0\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg_0\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux2_geom_0\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux2_0\" pos=\"0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip2_0\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom_0\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 -0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle2_0\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"second_ankle_geom_0\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      {{ body }}\n    </body>\n  </worldbody>\n  <actuator>\n    {{ actuators }}\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant__stage1.xml",
    "content": "<mujoco model=\"ant\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" timestep=\"0.01\"/>\n  <custom>\n    <numeric data=\"0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0\" name=\"init_qpos\"/>\n  </custom>\n  <default>\n    <joint armature=\"1\" damping=\"1\" limited=\"true\"/>\n    <geom conaffinity=\"0\" condim=\"3\" density=\"5.0\" friction=\"1 0.5 0.5\" margin=\"0.01\" rgba=\"0.8 0.6 0.4 1\"/>\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"60 60\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 0\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 40\" type=\"plane\"/>\n    <body name=\"torso\" pos=\" 0 0.75\">\n      <camera name=\"track\" mode=\"trackcom\" pos=\"0 -3 0.3\" xyaxes=\"1 0 0 0 0 1\"/>\n      <!--<geom name=\"torso_geom\" pos=\"0 0 0\" size=\"0.25\" type=\"sphere\"/>-->\n      <joint armature=\"0\" damping=\"0\" limited=\"false\" margin=\"0.01\" name=\"root\" pos=\"0 0 0\" type=\"free\"/>\n      <body name=\"front_left_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"aux_1_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_1\" pos=\"0.2 0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_1\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 0.2 0.0\" name=\"left_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 0.2 0\">\n            <joint axis=\"-1 1 0\" name=\"ankle_1\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 0.4 0.0\" name=\"left_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"right_back_leg\" pos=\"0 0 0\">\n        <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"aux_4_geom\" size=\"0.08\" type=\"capsule\"/>\n        <body name=\"aux_4\" pos=\"0.2 -0.2 0\">\n          <joint axis=\"0 0 1\" name=\"hip_4\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n          <geom fromto=\"0.0 0.0 0.0 0.2 -0.2 0.0\" name=\"rightback_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body pos=\"0.2 -0.2 0\">\n            <joint axis=\"1 1 0\" name=\"ankle_4\" pos=\"0.0 0.0 0.0\" range=\"30 70\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 0.4 -0.4 0.0\" name=\"fourth_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n          </body>\n        </body>\n      </body>\n      <body name=\"mid\" pos=\"0.0 0 0\">\n        <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n        <joint axis=\"0 0 1\" limited=\"true\" name=\"rot2\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n        <body name=\"front_right_leg\" pos=\"-1 0 0\">\n          <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"aux_2_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body name=\"aux_2\" pos=\"-0.2 0.2 0\">\n            <joint axis=\"0 0 1\" name=\"hip_2\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body pos=\"-0.2 0.2 0\">\n              <joint axis=\"1 1 0\" name=\"ankle_2\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            </body>\n          </body>\n        </body>\n        <body name=\"back_leg\" pos=\"-1 0 0\">\n          <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"aux_3_geom\" size=\"0.08\" type=\"capsule\"/>\n          <body name=\"aux_3\" pos=\"-0.2 -0.2 0\">\n            <joint axis=\"0 0 1\" name=\"hip_3\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n            <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom\" size=\"0.08\" type=\"capsule\"/>\n            <body pos=\"-0.2 -0.2 0\">\n              <joint axis=\"-1 1 0\" name=\"ankle_3\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n              <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom\" size=\"0.08\" type=\"capsule\"/>\n            </body>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_4\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_1\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_2\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip_3\" gear=\"150\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle_3\" gear=\"150\"/>\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer.xml.template",
    "content": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"predefined\" density=\"4000\" integrator=\"RK4\" timestep=\"0.005\" viscosity=\"0.1\"/>\n  <default>\n    <geom conaffinity=\"1\" condim=\"1\" contype=\"1\" material=\"geom\" rgba=\"0.8 0.6 .4 1\"/>\n    <joint armature='0.1'  />\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"30 30\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 -0.1\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 0.1\" type=\"plane\"/>\n    <!--  ================= SWIMMER ================= /-->\n    <body name=\"torso\" pos=\"0 0 0\">\n      <geom density=\"1000\" fromto=\"1.5 0 0 0.5 0 0\" size=\"0.1\" type=\"capsule\"/>\n      <joint axis=\"1 0 0\" name=\"slider1\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 1 0\" name=\"slider2\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 0 1\" name=\"rot\" pos=\"0 0 0\" type=\"hinge\"/>\n      <body name=\"mid0\" pos=\"0.5 0 0\">\n        <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n        <joint axis=\"0 0 1\" limited=\"true\" name=\"rot0\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n        {{ body }}\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n{{ actuators }}\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer__bckp2.xml",
    "content": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"predefined\" density=\"4000\" integrator=\"RK4\" timestep=\"0.01\" viscosity=\"0.1\"/>\n  <default>\n    <geom conaffinity=\"1\" condim=\"1\" contype=\"1\" material=\"geom\" rgba=\"0.8 0.6 .4 1\"/>\n    <joint armature='0.1'  />\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"30 30\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 -0.1\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 0.1\" type=\"plane\"/>\n    <!--  ================= SWIMMER ================= /-->\n    <body name=\"torso\" pos=\"0 0 0\">\n      <geom density=\"1000\" fromto=\"1.5 0 0 0.5 0 0\" size=\"0.1\" type=\"capsule\"/>\n      <joint axis=\"1 0 0\" name=\"slider1\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 1 0\" name=\"slider2\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 0 1\" name=\"rot\" pos=\"0 0 0\" type=\"hinge\"/>\n      <body name=\"mid1\" pos=\"0.5 0 0\">\n        <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n        <joint axis=\"0 0 1\" limited=\"true\" name=\"rot0\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n        <body name=\"mid2\" pos=\"-1 0 0\">\n          <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n          <joint axis=\"0 0 -1\" limited=\"true\" name=\"rot1\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n          <body name=\"mid3\" pos=\"-1 0 0\">\n            <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n            <joint axis=\"0 0 1\" limited=\"true\" name=\"rot2\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n            <body name=\"back\" pos=\"-1 0 0\">\n              <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n              <joint axis=\"0 0 1\" limited=\"true\" name=\"rot3\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n            </body>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot0\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot1\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot2\"/>\n     <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot3\"/>\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer_bckp.xml",
    "content": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"predefined\" density=\"4000\" integrator=\"RK4\" timestep=\"0.01\" viscosity=\"0.1\"/>\n  <default>\n    <geom conaffinity=\"1\" condim=\"1\" contype=\"1\" material=\"geom\" rgba=\"0.8 0.6 .4 1\"/>\n    <joint armature='0.1'  />\n  </default>\n  <asset>\n    <texture builtin=\"gradient\" height=\"100\" rgb1=\"1 1 1\" rgb2=\"0 0 0\" type=\"skybox\" width=\"100\"/>\n    <texture builtin=\"flat\" height=\"1278\" mark=\"cross\" markrgb=\"1 1 1\" name=\"texgeom\" random=\"0.01\" rgb1=\"0.8 0.6 0.4\" rgb2=\"0.8 0.6 0.4\" type=\"cube\" width=\"127\"/>\n    <texture builtin=\"checker\" height=\"100\" name=\"texplane\" rgb1=\"0 0 0\" rgb2=\"0.8 0.8 0.8\" type=\"2d\" width=\"100\"/>\n    <material name=\"MatPlane\" reflectance=\"0.5\" shininess=\"1\" specular=\"1\" texrepeat=\"30 30\" texture=\"texplane\"/>\n    <material name=\"geom\" texture=\"texgeom\" texuniform=\"true\"/>\n  </asset>\n  <worldbody>\n    <light cutoff=\"100\" diffuse=\"1 1 1\" dir=\"-0 0 -1.3\" directional=\"true\" exponent=\"1\" pos=\"0 0 1.3\" specular=\".1 .1 .1\"/>\n    <geom conaffinity=\"1\" condim=\"3\" material=\"MatPlane\" name=\"floor\" pos=\"0 0 -0.1\" rgba=\"0.8 0.9 0.8 1\" size=\"40 40 0.1\" type=\"plane\"/>\n    <!--  ================= SWIMMER ================= /-->\n    <body name=\"torso\" pos=\"0 0 0\">\n      <geom density=\"1000\" fromto=\"1.5 0 0 0.5 0 0\" size=\"0.1\" type=\"capsule\"/>\n      <joint axis=\"1 0 0\" name=\"slider1\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 1 0\" name=\"slider2\" pos=\"0 0 0\" type=\"slide\"/>\n      <joint axis=\"0 0 1\" name=\"rot\" pos=\"0 0 0\" type=\"hinge\"/>\n      <body name=\"mid1\" pos=\"0.5 0 0\">\n        <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n        <joint axis=\"0 0 1\" limited=\"true\" name=\"rot0\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n        <body name=\"mid2\" pos=\"-1 0 0\">\n          <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n          <joint axis=\"0 0 -1\" limited=\"true\" name=\"rot1\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n          <body name=\"back\" pos=\"-1 0 0\">\n            <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n            <joint axis=\"0 0 1\" limited=\"true\" name=\"rot2\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n          </body>\n        </body>\n      </body>\n    </body>\n  </worldbody>\n  <actuator>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot0\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot1\"/>\n    <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot2\"/>\n  </actuator>\n</mujoco>"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/coupled_half_cheetah.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nimport os\n\n\nclass CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):\n    def __init__(self, **kwargs):\n        mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5)\n        utils.EzPickle.__init__(self)\n\n    def step(self, action):\n        xposbefore1 = self.sim.data.qpos[0]\n        xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]\n        self.do_simulation(action, self.frame_skip)\n        xposafter1 = self.sim.data.qpos[0]\n        xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos)//2]\n        ob = self._get_obs()\n        reward_ctrl1 = - 0.1 * np.square(action[0:len(action)//2]).sum()\n        reward_ctrl2 = - 0.1 * np.square(action[len(action)//2:]).sum()\n        reward_run1 = (xposafter1 - xposbefore1)/self.dt\n        reward_run2 = (xposafter2 - xposbefore2) / self.dt\n        reward = (reward_ctrl1 + reward_ctrl2)/2.0 + (reward_run1 + reward_run2)/2.0\n        done = False\n        return ob, reward, done, dict(reward_run1=reward_run1, reward_ctrl1=reward_ctrl1,\n                                      reward_run2=reward_run2, reward_ctrl2=reward_ctrl2)\n\n    def _get_obs(self):\n        return np.concatenate([\n            self.sim.data.qpos.flat[1:],\n            self.sim.data.qvel.flat,\n        ])\n\n    def reset_model(self):\n        qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)\n        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n        self.set_state(qpos, qvel)\n        return self._get_obs()\n\n    def viewer_setup(self):\n        self.viewer.cam.distance = self.model.stat.extent * 0.5\n\n    def get_env_info(self):\n        return {\"episode_limit\": self.episode_limit}"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/manyagent_ant.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nfrom jinja2 import Template\nimport os\n\nclass ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n    def __init__(self, **kwargs):\n        agent_conf = kwargs.get(\"agent_conf\")\n        n_agents = int(agent_conf.split(\"x\")[0])\n        n_segs_per_agents = int(agent_conf.split(\"x\")[1])\n        n_segs = n_agents * n_segs_per_agents\n\n        # Check whether asset file exists already, otherwise create it\n        asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',\n                                                  'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents,\n                                                                                                                 n_segs_per_agents))\n        #if not os.path.exists(asset_path):\n        print(\"Auto-Generating Manyagent Ant asset with {} segments at {}.\".format(n_segs, asset_path))\n        self._generate_asset(n_segs=n_segs, asset_path=asset_path)\n\n        #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p\n        #                          'manyagent_swimmer.xml')\n\n        mujoco_env.MujocoEnv.__init__(self, asset_path, 4)\n        utils.EzPickle.__init__(self)\n\n    def _generate_asset(self, n_segs, asset_path):\n        template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',\n                                                  'manyagent_ant.xml.template')\n        with open(template_path, \"r\") as f:\n            t = Template(f.read())\n        body_str_template = \"\"\"\n        <body name=\"torso_{:d}\" pos=\"-1 0 0\">\n           <!--<joint axis=\"0 1 0\" name=\"nnn_{:d}\" pos=\"0.0 0.0 0.0\" range=\"-1 1\" type=\"hinge\"/>-->\n            <geom density=\"100\" fromto=\"1 0 0 0 0 0\" size=\"0.1\" type=\"capsule\"/>\n            <body name=\"front_right_leg_{:d}\" pos=\"0 0 0\">\n              <geom fromto=\"0.0 0.0 0.0 0.0 0.2 0.0\" name=\"aux1_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n              <body name=\"aux_2_{:d}\" pos=\"0.0 0.2 0\">\n                <joint axis=\"0 0 1\" name=\"hip1_{:d}\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.2 0.2 0.0\" name=\"right_leg_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n                <body pos=\"-0.2 0.2 0\">\n                  <joint axis=\"1 1 0\" name=\"ankle1_{:d}\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                  <geom fromto=\"0.0 0.0 0.0 -0.4 0.4 0.0\" name=\"right_ankle_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n                </body>\n              </body>\n            </body>\n            <body name=\"back_leg_{:d}\" pos=\"0 0 0\">\n              <geom fromto=\"0.0 0.0 0.0 0.0 -0.2 0.0\" name=\"aux2_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n              <body name=\"aux2_{:d}\" pos=\"0.0 -0.2 0\">\n                <joint axis=\"0 0 1\" name=\"hip2_{:d}\" pos=\"0.0 0.0 0.0\" range=\"-30 30\" type=\"hinge\"/>\n                <geom fromto=\"0.0 0.0 0.0 -0.2 -0.2 0.0\" name=\"back_leg_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n                <body pos=\"-0.2 -0.2 0\">\n                  <joint axis=\"-1 1 0\" name=\"ankle2_{:d}\" pos=\"0.0 0.0 0.0\" range=\"-70 -30\" type=\"hinge\"/>\n                  <geom fromto=\"0.0 0.0 0.0 -0.4 -0.4 0.0\" name=\"third_ankle_geom_{:d}\" size=\"0.08\" type=\"capsule\"/>\n                </body>\n              </body>\n            </body>\n        \"\"\"\n\n        body_close_str_template =\"</body>\\n\"\n        actuator_str_template = \"\"\"\\t     <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip1_{:d}\" gear=\"150\"/>\n                                          <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle1_{:d}\" gear=\"150\"/>\n                                          <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"hip2_{:d}\" gear=\"150\"/>\n                                          <motor ctrllimited=\"true\" ctrlrange=\"-1.0 1.0\" joint=\"ankle2_{:d}\" gear=\"150\"/>\\n\"\"\"\n\n        body_str = \"\"\n        for i in range(1,n_segs):\n            body_str += body_str_template.format(*([i]*16))\n        body_str += body_close_str_template*(n_segs-1)\n\n        actuator_str = \"\"\n        for i in range(n_segs):\n            actuator_str += actuator_str_template.format(*([i]*8))\n\n        rt = t.render(body=body_str, actuators=actuator_str)\n        with open(asset_path, \"w\") as f:\n            f.write(rt)\n        pass\n\n    def step(self, a):\n        xposbefore = self.get_body_com(\"torso_0\")[0]\n        self.do_simulation(a, self.frame_skip)\n        xposafter = self.get_body_com(\"torso_0\")[0]\n        forward_reward = (xposafter - xposbefore)/self.dt\n        ctrl_cost = .5 * np.square(a).sum()\n        contact_cost = 0.5 * 1e-3 * np.sum(\n            np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))\n        survive_reward = 1.0\n        reward = forward_reward - ctrl_cost - contact_cost + survive_reward\n        state = self.state_vector()\n        notdone = np.isfinite(state).all() \\\n            and state[2] >= 0.2 and state[2] <= 1.0\n        done = not notdone\n        ob = self._get_obs()\n        return ob, reward, done, dict(\n            reward_forward=forward_reward,\n            reward_ctrl=-ctrl_cost,\n            reward_contact=-contact_cost,\n            reward_survive=survive_reward)\n\n    def _get_obs(self):\n        return np.concatenate([\n            self.sim.data.qpos.flat[2:],\n            self.sim.data.qvel.flat,\n            np.clip(self.sim.data.cfrc_ext, -1, 1).flat,\n        ])\n\n    def reset_model(self):\n        qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)\n        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1\n        self.set_state(qpos, qvel)\n        return self._get_obs()\n\n    def viewer_setup(self):\n        self.viewer.cam.distance = self.model.stat.extent * 0.5"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/manyagent_swimmer.py",
    "content": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nimport os\nfrom jinja2 import Template\n\nclass ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):\n    def __init__(self, **kwargs):\n        agent_conf = kwargs.get(\"agent_conf\")\n        n_agents = int(agent_conf.split(\"x\")[0])\n        n_segs_per_agents = int(agent_conf.split(\"x\")[1])\n        n_segs = n_agents * n_segs_per_agents\n\n        # Check whether asset file exists already, otherwise create it\n        asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',\n                                                  'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents,\n                                                                                                                 n_segs_per_agents))\n        # if not os.path.exists(asset_path):\n        print(\"Auto-Generating Manyagent Swimmer asset with {} segments at {}.\".format(n_segs, asset_path))\n        self._generate_asset(n_segs=n_segs, asset_path=asset_path)\n\n        #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p\n        #                          'manyagent_swimmer.xml')\n\n        mujoco_env.MujocoEnv.__init__(self, asset_path, 4)\n        utils.EzPickle.__init__(self)\n\n    def _generate_asset(self, n_segs, asset_path):\n        template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',\n                                                  'manyagent_swimmer.xml.template')\n        with open(template_path, \"r\") as f:\n            t = Template(f.read())\n        body_str_template = \"\"\"\n        <body name=\"mid{:d}\" pos=\"-1 0 0\">\n          <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n          <joint axis=\"0 0 {:d}\" limited=\"true\" name=\"rot{:d}\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n        \"\"\"\n\n        body_end_str_template = \"\"\"\n        <body name=\"back\" pos=\"-1 0 0\">\n            <geom density=\"1000\" fromto=\"0 0 0 -1 0 0\" size=\"0.1\" type=\"capsule\"/>\n            <joint axis=\"0 0 1\" limited=\"true\" name=\"rot{:d}\" pos=\"0 0 0\" range=\"-100 100\" type=\"hinge\"/>\n          </body>\n        \"\"\"\n\n        body_close_str_template =\"</body>\\n\"\n        actuator_str_template = \"\"\"\\t <motor ctrllimited=\"true\" ctrlrange=\"-1 1\" gear=\"150.0\" joint=\"rot{:d}\"/>\\n\"\"\"\n\n        body_str = \"\"\n        for i in range(1,n_segs-1):\n            body_str += body_str_template.format(i, (-1)**(i+1), i)\n        body_str += body_end_str_template.format(n_segs-1)\n        body_str += body_close_str_template*(n_segs-2)\n\n        actuator_str = \"\"\n        for i in range(n_segs):\n            actuator_str += actuator_str_template.format(i)\n\n        rt = t.render(body=body_str, actuators=actuator_str)\n        with open(asset_path, \"w\") as f:\n            f.write(rt)\n        pass\n\n    def step(self, a):\n        ctrl_cost_coeff = 0.0001\n        xposbefore = self.sim.data.qpos[0]\n        self.do_simulation(a, self.frame_skip)\n        xposafter = self.sim.data.qpos[0]\n        reward_fwd = (xposafter - xposbefore) / self.dt\n        reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()\n        reward = reward_fwd + reward_ctrl\n        ob = self._get_obs()\n        return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)\n\n    def _get_obs(self):\n        qpos = self.sim.data.qpos\n        qvel = self.sim.data.qvel\n        return np.concatenate([qpos.flat[2:], qvel.flat])\n\n    def reset_model(self):\n        self.set_state(\n            self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),\n            self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)\n        )\n        return self._get_obs()\n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/mujoco_multi.py",
    "content": "from functools import partial\nimport gym\nfrom gym.spaces import Box\nfrom gym.wrappers import TimeLimit\nimport numpy as np\n\nfrom .multiagentenv import MultiAgentEnv\nfrom .manyagent_swimmer import ManyAgentSwimmerEnv\nfrom .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs\n\n\ndef env_fn(env, **kwargs) -> MultiAgentEnv: # TODO: this may be a more complex function\n    # env_args = kwargs.get(\"env_args\", {})\n    return env(**kwargs)\n\nenv_REGISTRY = {}\nenv_REGISTRY[\"manyagent_swimmer\"] = partial(env_fn, env=ManyAgentSwimmerEnv)\n\n\n# using code from https://github.com/ikostrikov/pytorch-ddpg-naf\nclass NormalizedActions(gym.ActionWrapper):\n\n    def _action(self, action):\n        action = (action + 1) / 2\n        action *= (self.action_space.high - self.action_space.low)\n        action += self.action_space.low\n        return action\n\n    def action(self, action_):\n        return self._action(action_)\n\n    def _reverse_action(self, action):\n        action -= self.action_space.low\n        action /= (self.action_space.high - self.action_space.low)\n        action = action * 2 - 1\n        return action\n\n\nclass MujocoMulti(MultiAgentEnv):\n\n    def __init__(self, batch_size=None, **kwargs):\n        super().__init__(batch_size, **kwargs)\n        self.scenario = kwargs[\"env_args\"][\"scenario\"]  # e.g. Ant-v2\n        self.agent_conf = kwargs[\"env_args\"][\"agent_conf\"]  # e.g. '2x3'\n\n        self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(self.scenario,\n                                                                                            self.agent_conf)\n\n        self.n_agents = len(self.agent_partitions)\n        self.n_actions = max([len(l) for l in self.agent_partitions])\n        self.obs_add_global_pos = kwargs[\"env_args\"].get(\"obs_add_global_pos\", False)\n\n        self.agent_obsk = kwargs[\"env_args\"].get(\"agent_obsk\",\n                                                 None)  # if None, fully observable else k>=0 implies observe nearest k agents or joints\n        self.agent_obsk_agents = kwargs[\"env_args\"].get(\"agent_obsk_agents\",\n                                                        False)  # observe full k nearest agents (True) or just single joints (False)\n\n        if self.agent_obsk is not None:\n            self.k_categories_label = kwargs[\"env_args\"].get(\"k_categories\")\n            if self.k_categories_label is None:\n                if self.scenario in [\"Ant-v2\", \"manyagent_ant\"]:\n                    self.k_categories_label = \"qpos,qvel,cfrc_ext|qpos\"\n                elif self.scenario in [\"Humanoid-v2\", \"HumanoidStandup-v2\"]:\n                    self.k_categories_label = \"qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos\"\n                elif self.scenario in [\"Reacher-v2\"]:\n                    self.k_categories_label = \"qpos,qvel,fingertip_dist|qpos\"\n                elif self.scenario in [\"coupled_half_cheetah\"]:\n                    self.k_categories_label = \"qpos,qvel,ten_J,ten_length,ten_velocity|\"\n                else:\n                    self.k_categories_label = \"qpos,qvel|qpos\"\n\n            k_split = self.k_categories_label.split(\"|\")\n            self.k_categories = [k_split[k if k < len(k_split) else -1].split(\",\") for k in range(self.agent_obsk + 1)]\n\n            self.global_categories_label = kwargs[\"env_args\"].get(\"global_categories\")\n            self.global_categories = self.global_categories_label.split(\n                \",\") if self.global_categories_label is not None else []\n\n        if self.agent_obsk is not None:\n            self.k_dicts = [get_joints_at_kdist(agent_id,\n                                                self.agent_partitions,\n                                                self.mujoco_edges,\n                                                k=self.agent_obsk,\n                                                kagents=False, ) for agent_id in range(self.n_agents)]\n\n        # load scenario from script\n        self.episode_limit = self.args.episode_limit\n\n        self.env_version = kwargs[\"env_args\"].get(\"env_version\", 2)\n        if self.env_version == 2:\n            try:\n                self.wrapped_env = NormalizedActions(gym.make(self.scenario))\n            except gym.error.Error:\n                self.wrapped_env = NormalizedActions(\n                    TimeLimit(partial(env_REGISTRY[self.scenario], **kwargs[\"env_args\"])(),\n                              max_episode_steps=self.episode_limit))\n        else:\n            assert False, \"not implemented!\"\n        self.timelimit_env = self.wrapped_env.env\n        self.timelimit_env._max_episode_steps = self.episode_limit\n        self.env = self.timelimit_env.env\n        self.timelimit_env.reset()\n        self.obs_size = self.get_obs_size()\n        self.share_obs_size = self.get_state_size()\n\n        # COMPATIBILITY\n        self.n = self.n_agents\n        # self.observation_space = [Box(low=np.array([-10]*self.n_agents), high=np.array([10]*self.n_agents)) for _ in range(self.n_agents)]\n        self.observation_space = [Box(low=-10, high=10, shape=(self.obs_size,)) for _ in range(self.n_agents)]\n        self.share_observation_space = [Box(low=-10, high=10, shape=(self.share_obs_size,)) for _ in\n                                        range(self.n_agents)]\n\n        acdims = [len(ap) for ap in self.agent_partitions]\n        self.action_space = tuple([Box(self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],\n                                       self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]) for a in\n                                   range(self.n_agents)])\n\n        pass\n\n    def step(self, actions):\n\n        # need to remove dummy actions that arise due to unequal action vector sizes across agents\n        flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])\n        obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)\n        self.steps += 1\n\n        info = {}\n        info.update(info_n)\n\n        # if done_n:\n        #     if self.steps < self.episode_limit:\n        #         info[\"episode_limit\"] = False   # the next state will be masked out\n        #     else:\n        #         info[\"episode_limit\"] = True    # the next state will not be masked out\n        if done_n:\n            if self.steps < self.episode_limit:\n                info[\"bad_transition\"] = False  # the next state will be masked out\n            else:\n                info[\"bad_transition\"] = True  # the next state will not be masked out\n\n        # return reward_n, done_n, info\n        rewards = [[reward_n]] * self.n_agents\n        dones = [done_n] * self.n_agents\n        infos = [info for _ in range(self.n_agents)]\n        return self.get_obs(), self.get_state(), rewards, dones, infos, self.get_avail_actions()\n\n    def get_obs(self):\n        \"\"\" Returns all agent observat3ions in a list \"\"\"\n        state = self.env._get_obs()\n        obs_n = []\n        for a in range(self.n_agents):\n            agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)\n            agent_id_feats[a] = 1.0\n            # obs_n.append(self.get_obs_agent(a))\n            # obs_n.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats]))\n            # obs_n.append(np.concatenate([self.get_obs_agent(a), agent_id_feats]))\n            obs_i = np.concatenate([state, agent_id_feats])\n            obs_i = (obs_i - np.mean(obs_i)) / np.std(obs_i)\n            obs_n.append(obs_i)\n        return obs_n\n\n    def get_obs_agent(self, agent_id):\n        if self.agent_obsk is None:\n            return self.env._get_obs()\n        else:\n            # return build_obs(self.env,\n            #                       self.k_dicts[agent_id],\n            #                       self.k_categories,\n            #                       self.mujoco_globals,\n            #                       self.global_categories,\n            #                       vec_len=getattr(self, \"obs_size\", None))\n            return build_obs(self.env,\n                             self.k_dicts[agent_id],\n                             self.k_categories,\n                             self.mujoco_globals,\n                             self.global_categories)\n\n    def get_obs_size(self):\n        \"\"\" Returns the shape of the observation \"\"\"\n        if self.agent_obsk is None:\n            return self.get_obs_agent(0).size\n        else:\n            return len(self.get_obs()[0])\n            # return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])\n\n    def get_state(self, team=None):\n        # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)\n        state = self.env._get_obs()\n        share_obs = []\n        for a in range(self.n_agents):\n            agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)\n            agent_id_feats[a] = 1.0\n            # share_obs.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats]))\n            state_i = np.concatenate([state, agent_id_feats])\n            state_i = (state_i - np.mean(state_i)) / np.std(state_i)\n            share_obs.append(state_i)\n        return share_obs\n\n    def get_state_size(self):\n        \"\"\" Returns the shape of the state\"\"\"\n        return len(self.get_state()[0])\n\n    def get_avail_actions(self):  # all actions are always available\n        return np.ones(shape=(self.n_agents, self.n_actions,))\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\" Returns the available actions for agent_id \"\"\"\n        return np.ones(shape=(self.n_actions,))\n\n    def get_total_actions(self):\n        \"\"\" Returns the total number of actions an agent could ever take \"\"\"\n        return self.n_actions  # CAREFUL! - for continuous dims, this is action space dim rather\n        # return self.env.action_space.shape[0]\n\n    def get_stats(self):\n        return {}\n\n    # TODO: Temp hack\n    def get_agg_stats(self, stats):\n        return {}\n\n    def reset(self, **kwargs):\n        \"\"\" Returns initial observations and states\"\"\"\n        self.steps = 0\n        self.timelimit_env.reset()\n        return self.get_obs(), self.get_state(), self.get_avail_actions()\n\n    def render(self, **kwargs):\n        self.env.render(**kwargs)\n\n    def close(self):\n        pass\n\n    def seed(self, args):\n        pass\n\n    def get_env_info(self):\n\n        env_info = {\"state_shape\": self.get_state_size(),\n                    \"obs_shape\": self.get_obs_size(),\n                    \"n_actions\": self.get_total_actions(),\n                    \"n_agents\": self.n_agents,\n                    \"episode_limit\": self.episode_limit,\n                    \"action_spaces\": self.action_space,\n                    \"actions_dtype\": np.float32,\n                    \"normalise_actions\": False\n                    }\n        return env_info\n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/multiagentenv.py",
    "content": "from collections import namedtuple\nimport numpy as np\n\n\ndef convert(dictionary):\n    return namedtuple('GenericDict', dictionary.keys())(**dictionary)\n\nclass MultiAgentEnv(object):\n\n    def __init__(self, batch_size=None, **kwargs):\n        # Unpack arguments from sacred\n        args = kwargs[\"env_args\"]\n        if isinstance(args, dict):\n            args = convert(args)\n        self.args = args\n\n        if getattr(args, \"seed\", None) is not None:\n            self.seed = args.seed\n            self.rs = np.random.RandomState(self.seed) # initialise numpy random state\n\n    def step(self, actions):\n        \"\"\" Returns reward, terminated, info \"\"\"\n        raise NotImplementedError\n\n    def get_obs(self):\n        \"\"\" Returns all agent observations in a list \"\"\"\n        raise NotImplementedError\n\n    def get_obs_agent(self, agent_id):\n        \"\"\" Returns observation for agent_id \"\"\"\n        raise NotImplementedError\n\n    def get_obs_size(self):\n        \"\"\" Returns the shape of the observation \"\"\"\n        raise NotImplementedError\n\n    def get_state(self):\n        raise NotImplementedError\n\n    def get_state_size(self):\n        \"\"\" Returns the shape of the state\"\"\"\n        raise NotImplementedError\n\n    def get_avail_actions(self):\n        raise NotImplementedError\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\" Returns the available actions for agent_id \"\"\"\n        raise NotImplementedError\n\n    def get_total_actions(self):\n        \"\"\" Returns the total number of actions an agent could ever take \"\"\"\n        # TODO: This is only suitable for a discrete 1 dimensional action space for each agent\n        raise NotImplementedError\n\n    def get_stats(self):\n        raise NotImplementedError\n\n    # TODO: Temp hack\n    def get_agg_stats(self, stats):\n        return {}\n\n    def reset(self):\n        \"\"\" Returns initial observations and states\"\"\"\n        raise NotImplementedError\n\n    def render(self):\n        raise NotImplementedError\n\n    def close(self):\n        raise NotImplementedError\n\n    def seed(self, seed):\n        raise NotImplementedError\n\n    def get_env_info(self):\n        env_info = {\"state_shape\": self.get_state_size(),\n                    \"obs_shape\": self.get_obs_size(),\n                    \"n_actions\": self.get_total_actions(),\n                    \"n_agents\": self.n_agents,\n                    \"episode_limit\": self.episode_limit}\n        return env_info"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/obsk.py",
    "content": "import itertools\nimport numpy as np\nfrom copy import deepcopy\n\nclass Node():\n    def __init__(self, label, qpos_ids, qvel_ids, act_ids, body_fn=None, bodies=None, extra_obs=None, tendons=None):\n        self.label = label\n        self.qpos_ids = qpos_ids\n        self.qvel_ids = qvel_ids\n        self.act_ids = act_ids\n        self.bodies = bodies\n        self.extra_obs = {} if extra_obs is None else extra_obs\n        self.body_fn = body_fn\n        self.tendons = tendons\n        pass\n\n    def __str__(self):\n        return self.label\n\n    def __repr__(self):\n        return self.label\n\n\nclass HyperEdge():\n    def __init__(self, *edges):\n        self.edges = set(edges)\n\n    def __contains__(self, item):\n        return item in self.edges\n\n    def __str__(self):\n        return \"HyperEdge({})\".format(self.edges)\n\n    def __repr__(self):\n        return \"HyperEdge({})\".format(self.edges)\n\n\ndef get_joints_at_kdist(agent_id, agent_partitions, hyperedges, k=0, kagents=False,):\n    \"\"\" Identify all joints at distance <= k from agent agent_id\n\n    :param agent_id: id of agent to be considered\n    :param agent_partitions: list of joint tuples in order of agentids\n    :param edges: list of tuples (joint1, joint2)\n    :param k: kth degree\n    :param kagents: True (observe all joints of an agent if a single one is) or False (individual joint granularity)\n    :return:\n        dict with k as key, and list of joints at that distance\n    \"\"\"\n    assert not kagents, \"kagents not implemented!\"\n\n    agent_joints = agent_partitions[agent_id]\n\n    def _adjacent(lst, kagents=False):\n        # return all sets adjacent to any element in lst\n        ret = set([])\n        for l in lst:\n            ret = ret.union(set(itertools.chain(*[e.edges.difference({l}) for e in hyperedges if l in e])))\n        return ret\n\n    seen = set([])\n    new = set([])\n    k_dict = {}\n    for _k in range(k+1):\n        if not _k:\n            new = set(agent_joints)\n        else:\n            print(hyperedges)\n            new = _adjacent(new) - seen\n        seen = seen.union(new)\n        k_dict[_k] = sorted(list(new), key=lambda x:x.label)\n    return k_dict\n\n\ndef build_obs(env, k_dict, k_categories, global_dict, global_categories, vec_len=None):\n    \"\"\"Given a k_dict from get_joints_at_kdist, extract observation vector.\n\n    :param k_dict: k_dict\n    :param qpos: qpos numpy array\n    :param qvel: qvel numpy array\n    :param vec_len: if None no padding, else zero-pad to vec_len\n    :return:\n    observation vector\n    \"\"\"\n\n    # TODO: This needs to be fixed, it was designed for half-cheetah only!\n    #if add_global_pos:\n    #    obs_qpos_lst.append(global_qpos)\n    #    obs_qvel_lst.append(global_qvel)\n\n\n    body_set_dict = {}\n    obs_lst = []\n    # Add parts attributes\n    for k in sorted(list(k_dict.keys())):\n        cats = k_categories[k]\n        for _t in k_dict[k]:\n            for c in cats:\n                if c in _t.extra_obs:\n                    items = _t.extra_obs[c](env).tolist()\n                    obs_lst.extend(items if isinstance(items, list) else [items])\n                else:\n                    if c in [\"qvel\",\"qpos\"]: # this is a \"joint position/velocity\" item\n                        items = getattr(env.sim.data, c)[getattr(_t, \"{}_ids\".format(c))]\n                        obs_lst.extend(items if isinstance(items, list) else [items])\n                    elif c in [\"qfrc_actuator\"]: # this is a \"vel position\" item\n                        items = getattr(env.sim.data, c)[getattr(_t, \"{}_ids\".format(\"qvel\"))]\n                        obs_lst.extend(items if isinstance(items, list) else [items])\n                    elif c in [\"cvel\", \"cinert\", \"cfrc_ext\"]:  # this is a \"body position\" item\n                        if _t.bodies is not None:\n                            for b in _t.bodies:\n                                if c not in body_set_dict:\n                                    body_set_dict[c] = set()\n                                if b not in body_set_dict[c]:\n                                    items = getattr(env.sim.data, c)[b].tolist()\n                                    items = getattr(_t, \"body_fn\", lambda _id,x:x)(b, items)\n                                    obs_lst.extend(items if isinstance(items, list) else [items])\n                                    body_set_dict[c].add(b)\n\n    # Add global attributes\n    body_set_dict = {}\n    for c in global_categories:\n        if c in [\"qvel\", \"qpos\"]:  # this is a \"joint position\" item\n            for j in global_dict.get(\"joints\", []):\n                items = getattr(env.sim.data, c)[getattr(j, \"{}_ids\".format(c))]\n                obs_lst.extend(items if isinstance(items, list) else [items])\n        else:\n            for b in global_dict.get(\"bodies\", []):\n                if c not in body_set_dict:\n                    body_set_dict[c] = set()\n                if b not in body_set_dict[c]:\n                    obs_lst.extend(getattr(env.sim.data, c)[b].tolist())\n                    body_set_dict[c].add(b)\n\n    if vec_len is not None:\n        pad = np.array((vec_len - len(obs_lst))*[0])\n        if len(pad):\n            return np.concatenate([np.array(obs_lst), pad])\n    return np.array(obs_lst)\n\n\ndef build_actions(agent_partitions, k_dict):\n    # Composes agent actions output from networks\n    # into coherent joint action vector to be sent to the env.\n    pass\n\ndef get_parts_and_edges(label, partitioning):\n    if label in [\"half_cheetah\", \"HalfCheetah-v2\"]:\n\n        # define Mujoco graph\n        bthigh = Node(\"bthigh\", -6, -6, 0)\n        bshin = Node(\"bshin\", -5, -5, 1)\n        bfoot = Node(\"bfoot\", -4, -4, 2)\n        fthigh = Node(\"fthigh\", -3, -3, 3)\n        fshin = Node(\"fshin\", -2, -2, 4)\n        ffoot = Node(\"ffoot\", -1, -1, 5)\n\n        edges = [HyperEdge(bfoot, bshin),\n                 HyperEdge(bshin, bthigh),\n                 HyperEdge(bthigh, fthigh),\n                 HyperEdge(fthigh, fshin),\n                 HyperEdge(fshin, ffoot)]\n\n        root_x = Node(\"root_x\", 0, 0, -1,\n                      extra_obs={\"qpos\": lambda env: np.array([])})\n        root_z = Node(\"root_z\", 1, 1, -1)\n        root_y = Node(\"root_y\", 2, 2, -1)\n        globals = {\"joints\":[root_x, root_y, root_z]}\n\n        if partitioning == \"2x3\":\n            parts = [(bfoot, bshin, bthigh),\n                     (ffoot, fshin, fthigh)]\n        elif partitioning == \"6x1\":\n            parts = [(bfoot,), (bshin,), (bthigh,), (ffoot,), (fshin,), (fthigh,)]\n        elif partitioning == \"3x2\":\n            parts = [(bfoot, bshin,), (bthigh, ffoot,), (fshin, fthigh,)]\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Ant-v2\"]:\n\n        # define Mujoco graph\n        torso = 1\n        front_left_leg = 2\n        aux_1 = 3\n        ankle_1 = 4\n        front_right_leg = 5\n        aux_2 = 6\n        ankle_2 = 7\n        back_leg = 8\n        aux_3 = 9\n        ankle_3 = 10\n        right_back_leg = 11\n        aux_4 = 12\n        ankle_4 = 13\n\n        hip1 = Node(\"hip1\", -8, -8, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist()) #\n        ankle1 = Node(\"ankle1\", -7, -7, 3, bodies=[front_left_leg, aux_1, ankle_1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        hip2 = Node(\"hip2\", -6, -6, 4, bodies=[torso, front_right_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        ankle2 = Node(\"ankle2\", -5, -5, 5, bodies=[front_right_leg, aux_2, ankle_2], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        hip3 = Node(\"hip3\", -4, -4, 6, bodies=[torso, back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        ankle3 = Node(\"ankle3\", -3, -3, 7, bodies=[back_leg, aux_3, ankle_3], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        hip4 = Node(\"hip4\", -2, -2, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        ankle4 = Node(\"ankle4\", -1, -1, 1, bodies=[right_back_leg, aux_4, ankle_4], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n\n        edges = [HyperEdge(ankle4, hip4),\n                 HyperEdge(ankle1, hip1),\n                 HyperEdge(ankle2, hip2),\n                 HyperEdge(ankle3, hip3),\n                 HyperEdge(hip4, hip1, hip2, hip3),\n                 ]\n\n        free_joint = Node(\"free\", 0, 0, -1, extra_obs={\"qpos\": lambda env: env.sim.data.qpos[:7],\n                                                       \"qvel\": lambda env: env.sim.data.qvel[:6],\n                                                       \"cfrc_ext\": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)})\n        globals = {\"joints\": [free_joint]}\n\n        if partitioning == \"2x4\": # neighbouring legs together\n            parts = [(hip1, ankle1, hip2, ankle2),\n                     (hip3, ankle3, hip4, ankle4)]\n        elif partitioning == \"2x4d\": # diagonal legs together\n            parts = [(hip1, ankle1, hip3, ankle3),\n                     (hip2, ankle2, hip4, ankle4)]\n        elif partitioning == \"4x2\":\n            parts = [(hip1, ankle1),\n                     (hip2, ankle2),\n                     (hip3, ankle3),\n                     (hip4, ankle4)]\n        elif partitioning == \"8x1\":\n            parts = [(hip1,), (ankle1,),\n                     (hip2,), (ankle2,),\n                     (hip3,), (ankle3,),\n                     (hip4,), (ankle4,)]\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Hopper-v2\"]:\n\n        # define Mujoco-Graph\n        thigh_joint = Node(\"thigh_joint\", -3, -3, 0,\n                           extra_obs={\"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[-3]]), -10, 10)})\n        leg_joint = Node(\"leg_joint\", -2, -2, 1,\n                         extra_obs={\"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[-2]]), -10, 10)})\n        foot_joint = Node(\"foot_joint\", -1, -1, 2,\n                          extra_obs={\"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[-1]]), -10, 10)})\n\n        edges = [HyperEdge(foot_joint, leg_joint),\n                 HyperEdge(leg_joint, thigh_joint)]\n\n        root_x = Node(\"root_x\", 0, 0, -1, extra_obs={\"qpos\": lambda env: np.array([]),\n                                                     \"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)})\n        root_z = Node(\"root_z\", 1, 1, -1, extra_obs={\"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)})\n        root_y = Node(\"root_y\", 2, 2, -1, extra_obs={\"qvel\": lambda env: np.clip(np.array([env.sim.data.qvel[2]]), -10, 10)})\n        globals = {\"joints\":[root_x, root_y, root_z]}\n\n        if partitioning == \"3x1\":\n            parts = [(thigh_joint,),\n                     (leg_joint,),\n                     (foot_joint,)]\n\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Humanoid-v2\", \"HumanoidStandup-v2\"]:\n\n        # define Mujoco-Graph\n        abdomen_y = Node(\"abdomen_y\", -16, -16, 0) # act ordering bug in env -- double check!\n        abdomen_z = Node(\"abdomen_z\", -17, -17, 1)\n        abdomen_x = Node(\"abdomen_x\", -15, -15, 2)\n        right_hip_x = Node(\"right_hip_x\", -14, -14, 3)\n        right_hip_z = Node(\"right_hip_z\", -13, -13, 4)\n        right_hip_y = Node(\"right_hip_y\", -12, -12, 5)\n        right_knee = Node(\"right_knee\", -11, -11, 6)\n        left_hip_x = Node(\"left_hip_x\", -10, -10, 7)\n        left_hip_z = Node(\"left_hip_z\", -9, -9, 8)\n        left_hip_y = Node(\"left_hip_y\", -8, -8, 9)\n        left_knee = Node(\"left_knee\", -7, -7, 10)\n        right_shoulder1 = Node(\"right_shoulder1\", -6, -6, 11)\n        right_shoulder2 = Node(\"right_shoulder2\", -5, -5, 12)\n        right_elbow = Node(\"right_elbow\", -4, -4, 13)\n        left_shoulder1 = Node(\"left_shoulder1\", -3, -3, 14)\n        left_shoulder2 = Node(\"left_shoulder2\", -2, -2, 15)\n        left_elbow = Node(\"left_elbow\", -1, -1, 16)\n\n        edges = [HyperEdge(abdomen_x, abdomen_y, abdomen_z),\n                 HyperEdge(right_hip_x, right_hip_y, right_hip_z),\n                 HyperEdge(left_hip_x, left_hip_y, left_hip_z),\n                 HyperEdge(left_elbow, left_shoulder1, left_shoulder2),\n                 HyperEdge(right_elbow, right_shoulder1, right_shoulder2),\n                 HyperEdge(left_knee, left_hip_x, left_hip_y, left_hip_z),\n                 HyperEdge(right_knee, right_hip_x, right_hip_y, right_hip_z),\n                 HyperEdge(left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z),\n                 HyperEdge(right_shoulder1, right_shoulder2, abdomen_x, abdomen_y, abdomen_z),\n                 HyperEdge(abdomen_x, abdomen_y, abdomen_z, left_hip_x, left_hip_y, left_hip_z),\n                 HyperEdge(abdomen_x, abdomen_y, abdomen_z, right_hip_x, right_hip_y, right_hip_z),\n                 ]\n\n        globals = {}\n\n        if partitioning == \"9|8\": # 17 in total, so one action is a dummy (to be handled by pymarl)\n            # isolate upper and lower body\n            parts = [(left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z,\n                      right_shoulder1, right_shoulder2,\n                      right_elbow, left_elbow),\n                     (left_hip_x, left_hip_y, left_hip_z,\n                      right_hip_x, right_hip_y, right_hip_z,\n                      right_knee, left_knee)]\n            # TODO: There could be tons of decompositions here\n        elif partitioning == \"17x1\": # 17 in total, so one action is a dummy (to be handled by pymarl)\n            # isolate upper and lower body\n            parts = [(left_shoulder1,), (left_shoulder2,), (abdomen_x,), (abdomen_y,), (abdomen_z,),\n                     (right_shoulder1,), (right_shoulder2,), (right_elbow,), (left_elbow,),\n                     (left_hip_x,), (left_hip_y,), (left_hip_z,), (right_hip_x,), (right_hip_y,), (right_hip_z,),\n                     (right_knee,), (left_knee,)]\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Reacher-v2\"]:\n\n        # define Mujoco-Graph\n        body0 = 1\n        body1 = 2\n        fingertip = 3\n        joint0 = Node(\"joint0\", -4, -4, 0,\n                      bodies=[body0, body1],\n                      extra_obs={\"qpos\":(lambda env:np.array([np.sin(env.sim.data.qpos[-4]),\n                                                              np.cos(env.sim.data.qpos[-4])]))})\n        joint1 = Node(\"joint1\", -3, -3, 1,\n                      bodies=[body1, fingertip],\n                      extra_obs={\"fingertip_dist\":(lambda env:env.get_body_com(\"fingertip\") - env.get_body_com(\"target\")),\n                                 \"qpos\":(lambda env:np.array([np.sin(env.sim.data.qpos[-3]),\n                                                              np.cos(env.sim.data.qpos[-3])]))})\n        edges = [HyperEdge(joint0, joint1)]\n\n        worldbody = 0\n        target = 4\n        target_x = Node(\"target_x\", -2, -2, -1, extra_obs={\"qvel\":(lambda env:np.array([]))})\n        target_y = Node(\"target_y\", -1, -1, -1, extra_obs={\"qvel\":(lambda env:np.array([]))})\n        globals = {\"bodies\":[worldbody, target],\n                   \"joints\":[target_x, target_y]}\n\n        if partitioning == \"2x1\":\n            # isolate upper and lower arms\n            parts = [(joint0,), (joint1,)]\n            # TODO: There could be tons of decompositions here\n\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Swimmer-v2\"]:\n\n        # define Mujoco-Graph\n        joint0 = Node(\"rot2\", -2, -2, 0) # TODO: double-check ids\n        joint1 = Node(\"rot3\", -1, -1, 1)\n\n        edges = [HyperEdge(joint0, joint1)]\n        globals = {}\n\n        if partitioning == \"2x1\":\n            # isolate upper and lower body\n            parts = [(joint0,), (joint1,)]\n            # TODO: There could be tons of decompositions here\n\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"Walker2d-v2\"]:\n\n        # define Mujoco-Graph\n        thigh_joint = Node(\"thigh_joint\", -6, -6, 0)\n        leg_joint = Node(\"leg_joint\", -5, -5, 1)\n        foot_joint = Node(\"foot_joint\", -4, -4, 2)\n        thigh_left_joint = Node(\"thigh_left_joint\", -3, -3, 3)\n        leg_left_joint = Node(\"leg_left_joint\", -2, -2, 4)\n        foot_left_joint = Node(\"foot_left_joint\", -1, -1, 5)\n\n        edges = [HyperEdge(foot_joint, leg_joint),\n                 HyperEdge(leg_joint, thigh_joint),\n                 HyperEdge(foot_left_joint, leg_left_joint),\n                 HyperEdge(leg_left_joint, thigh_left_joint),\n                 HyperEdge(thigh_joint, thigh_left_joint)\n                 ]\n        globals = {}\n\n        if partitioning == \"2x3\":\n            # isolate upper and lower body\n            parts = [(foot_joint, leg_joint, thigh_joint),\n                     (foot_left_joint, leg_left_joint, thigh_left_joint,)]\n            # TODO: There could be tons of decompositions here\n        elif partitioning == \"6x1\":\n            # isolate upper and lower body\n            parts = [(foot_joint,), (leg_joint,), (thigh_joint,),\n                     (foot_left_joint,), (leg_left_joint,), (thigh_left_joint,)]\n        elif partitioning == \"3x2\":\n            # isolate upper and lower body\n            parts = [(foot_joint, leg_joint,), (thigh_joint, foot_left_joint,),\n                     (leg_left_joint, thigh_left_joint,)]\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"coupled_half_cheetah\"]:\n\n        # define Mujoco graph\n        tendon = 0\n\n        bthigh = Node(\"bthigh\", -6, -6, 0,\n                     tendons=[tendon],\n                     extra_obs = {\"ten_J\": lambda env: env.sim.data.ten_J[tendon],\n                                  \"ten_length\": lambda env: env.sim.data.ten_length,\n                                  \"ten_velocity\": lambda env: env.sim.data.ten_velocity})\n        bshin = Node(\"bshin\", -5, -5, 1)\n        bfoot = Node(\"bfoot\", -4, -4, 2)\n        fthigh = Node(\"fthigh\", -3, -3, 3)\n        fshin = Node(\"fshin\", -2, -2, 4)\n        ffoot = Node(\"ffoot\", -1, -1, 5)\n\n        bthigh2 = Node(\"bthigh2\", -6, -6, 0,\n                      tendons=[tendon],\n                      extra_obs={\"ten_J\": lambda env: env.sim.data.ten_J[tendon],\n                                 \"ten_length\": lambda env: env.sim.data.ten_length,\n                                 \"ten_velocity\": lambda env: env.sim.data.ten_velocity})\n        bshin2 = Node(\"bshin2\", -5, -5, 1)\n        bfoot2 = Node(\"bfoot2\", -4, -4, 2)\n        fthigh2 = Node(\"fthigh2\", -3, -3, 3)\n        fshin2 = Node(\"fshin2\", -2, -2, 4)\n        ffoot2 = Node(\"ffoot2\", -1, -1, 5)\n\n\n        edges = [HyperEdge(bfoot, bshin),\n                 HyperEdge(bshin, bthigh),\n                 HyperEdge(bthigh, fthigh),\n                 HyperEdge(fthigh, fshin),\n                 HyperEdge(fshin, ffoot),\n                 HyperEdge(bfoot2, bshin2),\n                 HyperEdge(bshin2, bthigh2),\n                 HyperEdge(bthigh2, fthigh2),\n                 HyperEdge(fthigh2, fshin2),\n                 HyperEdge(fshin2, ffoot2)\n                 ]\n        globals = {}\n\n        root_x = Node(\"root_x\", 0, 0, -1,\n                      extra_obs={\"qpos\": lambda env: np.array([])})\n        root_z = Node(\"root_z\", 1, 1, -1)\n        root_y = Node(\"root_y\", 2, 2, -1)\n        globals = {\"joints\":[root_x, root_y, root_z]}\n\n        if partitioning == \"1p1\":\n            parts = [(bfoot, bshin, bthigh, ffoot, fshin, fthigh),\n                     (bfoot2, bshin2, bthigh2, ffoot2, fshin2, fthigh2)\n                     ]\n        else:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        return parts, edges, globals\n\n    elif label in [\"manyagent_swimmer\"]:\n\n        # Generate asset file\n        try:\n            n_agents = int(partitioning.split(\"x\")[0])\n            n_segs_per_agents = int(partitioning.split(\"x\")[1])\n            n_segs = n_agents * n_segs_per_agents\n        except Exception as e:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n        # Note: Default Swimmer corresponds to n_segs = 3\n\n        # define Mujoco-Graph\n        joints = [Node(\"rot{:d}\".format(i), -n_segs + i, -n_segs + i, i) for i in range(0, n_segs)]\n        edges = [HyperEdge(joints[i], joints[i+1]) for i in range(n_segs-1)]\n        globals = {}\n\n        parts = [tuple(joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents]) for i in range(n_agents)]\n        return parts, edges, globals\n\n    elif label in [\"manyagent_ant\"]: # TODO: FIX!\n\n        # Generate asset file\n        try:\n            n_agents = int(partitioning.split(\"x\")[0])\n            n_segs_per_agents = int(partitioning.split(\"x\")[1])\n            n_segs = n_agents * n_segs_per_agents\n        except Exception as e:\n            raise Exception(\"UNKNOWN partitioning config: {}\".format(partitioning))\n\n\n        # # define Mujoco graph\n        # torso = 1\n        # front_left_leg = 2\n        # aux_1 = 3\n        # ankle_1 = 4\n        # right_back_leg = 11\n        # aux_4 = 12\n        # ankle_4 = 13\n        #\n        # off = -4*(n_segs-1)\n        # hip1 = Node(\"hip1\", -4-off, -4-off, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist()) #\n        # ankle1 = Node(\"ankle1\", -3-off, -3-off, 3, bodies=[front_left_leg, aux_1, ankle_1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        # hip4 = Node(\"hip4\", -2-off, -2-off, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        # ankle4 = Node(\"ankle4\", -1-off, -1-off, 1, bodies=[right_back_leg, aux_4, ankle_4], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,\n        #\n        # edges = [HyperEdge(ankle4, hip4),\n        #          HyperEdge(ankle1, hip1),\n        #          HyperEdge(hip4, hip1),\n        #          ]\n\n        edges = []\n        joints = []\n        for si in range(n_segs):\n\n            torso = 1 + si*7\n            front_right_leg = 2 + si*7\n            aux1 = 3 + si*7\n            ankle1 = 4 + si*7\n            back_leg = 5 + si*7\n            aux2 = 6 + si*7\n            ankle2 = 7 + si*7\n\n            off = -4 * (n_segs - 1 - si)\n            hip1n = Node(\"hip1_{:d}\".format(si), -4-off, -4-off, 2+4*si, bodies=[torso, front_right_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())\n            ankle1n = Node(\"ankle1_{:d}\".format(si), -3-off, -3-off, 3+4*si, bodies=[front_right_leg, aux1, ankle1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())\n            hip2n = Node(\"hip2_{:d}\".format(si), -2-off, -2-off, 0+4*si, bodies=[torso, back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())\n            ankle2n = Node(\"ankle2_{:d}\".format(si), -1-off, -1-off, 1+4*si, bodies=[back_leg, aux2, ankle2], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())\n\n            edges += [HyperEdge(ankle1n, hip1n),\n                      HyperEdge(ankle2n, hip2n),\n                      HyperEdge(hip1n, hip2n)]\n            if si:\n                edges += [HyperEdge(hip1m, hip2m, hip1n, hip2n)]\n\n            hip1m = deepcopy(hip1n)\n            hip2m = deepcopy(hip2n)\n            joints.append([hip1n,\n                           ankle1n,\n                           hip2n,\n                           ankle2n])\n\n        free_joint = Node(\"free\", 0, 0, -1, extra_obs={\"qpos\": lambda env: env.sim.data.qpos[:7],\n                                                       \"qvel\": lambda env: env.sim.data.qvel[:6],\n                                                       \"cfrc_ext\": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)})\n        globals = {\"joints\": [free_joint]}\n\n        parts =  [[x for sublist in joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents] for x in sublist] for i in range(n_agents)]\n\n        return parts, edges, globals"
  },
  {
    "path": "envs/starcraft2/StarCraft2_Env.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .multiagentenv import MultiAgentEnv\nfrom .smac_maps import get_map_params\n\nimport atexit\nfrom operator import attrgetter\nfrom copy import deepcopy\nimport numpy as np\nimport enum\nimport math\nfrom absl import logging\n\nfrom pysc2 import maps\nfrom pysc2 import run_configs\nfrom pysc2.lib import protocol\n\nfrom s2clientprotocol import common_pb2 as sc_common\nfrom s2clientprotocol import sc2api_pb2 as sc_pb\nfrom s2clientprotocol import raw_pb2 as r_pb\nfrom s2clientprotocol import debug_pb2 as d_pb\n\nimport random\nfrom gym.spaces import Discrete\n\nraces = {\n    \"R\": sc_common.Random,\n    \"P\": sc_common.Protoss,\n    \"T\": sc_common.Terran,\n    \"Z\": sc_common.Zerg,\n}\n\ndifficulties = {\n    \"1\": sc_pb.VeryEasy,\n    \"2\": sc_pb.Easy,\n    \"3\": sc_pb.Medium,\n    \"4\": sc_pb.MediumHard,\n    \"5\": sc_pb.Hard,\n    \"6\": sc_pb.Harder,\n    \"7\": sc_pb.VeryHard,\n    \"8\": sc_pb.CheatVision,\n    \"9\": sc_pb.CheatMoney,\n    \"A\": sc_pb.CheatInsane,\n}\n\nactions = {\n    \"move\": 16,  # target: PointOrUnit\n    \"attack\": 23,  # target: PointOrUnit\n    \"stop\": 4,  # target: None\n    \"heal\": 386,  # Unit\n}\n\n\nclass Direction(enum.IntEnum):\n    NORTH = 0\n    SOUTH = 1\n    EAST = 2\n    WEST = 3\n\n\nclass StarCraft2Env(MultiAgentEnv):\n    \"\"\"The StarCraft II environment for decentralised multi-agent\n    micromanagement scenarios.\n    \"\"\"\n\n    def __init__(\n        self,\n        args,\n        step_mul=8,\n        move_amount=2,\n        difficulty=\"7\",\n        game_version=None,\n        seed=None,\n        continuing_episode=False,\n        obs_all_health=True,\n        obs_own_health=True,\n        obs_last_action=True,\n        obs_pathing_grid=False,\n        obs_terrain_height=False,\n        obs_instead_of_state=False,\n        obs_timestep_number=False,\n        obs_agent_id=True,\n        state_pathing_grid=False,\n        state_terrain_height=False,\n        state_last_action=True,\n        state_timestep_number=False,\n        state_agent_id=True,\n        reward_sparse=False,\n        reward_only_positive=True,\n        reward_death_value=10,\n        reward_win=200,\n        reward_defeat=0,\n        reward_negative_scale=0.5,\n        reward_scale=True,\n        reward_scale_rate=20,\n        replay_dir=\"\",\n        replay_prefix=\"\",\n        window_size_x=1920,\n        window_size_y=1200,\n        heuristic_ai=False,\n        heuristic_rest=False,\n        debug=False,\n    ):\n        \"\"\"\n        Create a StarCraftC2Env environment.\n\n        Parameters\n        ----------\n        map_name : str, optional\n            The name of the SC2 map to play (default is \"8m\"). The full list\n            can be found by running bin/map_list.\n        step_mul : int, optional\n            How many game steps per agent step (default is 8). None\n            indicates to use the default map step_mul.\n        move_amount : float, optional\n            How far away units are ordered to move per step (default is 2).\n        difficulty : str, optional\n            The difficulty of built-in computer AI bot (default is \"7\").\n        game_version : str, optional\n            StarCraft II game version (default is None). None indicates the\n            latest version.\n        seed : int, optional\n            Random seed used during game initialisation. This allows to\n        continuing_episode : bool, optional\n            Whether to consider episodes continuing or finished after time\n            limit is reached (default is False).\n        obs_all_health : bool, optional\n            Agents receive the health of all units (in the sight range) as part\n            of observations (default is True).\n        obs_own_health : bool, optional\n            Agents receive their own health as a part of observations (default\n            is False). This flag is ignored when obs_all_health == True.\n        obs_last_action : bool, optional\n            Agents receive the last actions of all units (in the sight range)\n            as part of observations (default is False).\n        obs_pathing_grid : bool, optional\n            Whether observations include pathing values surrounding the agent\n            (default is False).\n        obs_terrain_height : bool, optional\n            Whether observations include terrain height values surrounding the\n            agent (default is False).\n        obs_instead_of_state : bool, optional\n            Use combination of all agents' observations as the global state\n            (default is False).\n        obs_timestep_number : bool, optional\n            Whether observations include the current timestep of the episode\n            (default is False).\n        state_last_action : bool, optional\n            Include the last actions of all agents as part of the global state\n            (default is True).\n        state_timestep_number : bool, optional\n            Whether the state include the current timestep of the episode\n            (default is False).\n        reward_sparse : bool, optional\n            Receive 1/-1 reward for winning/loosing an episode (default is\n            False). Whe rest of reward parameters are ignored if True.\n        reward_only_positive : bool, optional\n            Reward is always positive (default is True).\n        reward_death_value : float, optional\n            The amount of reward received for killing an enemy unit (default\n            is 10). This is also the negative penalty for having an allied unit\n            killed if reward_only_positive == False.\n        reward_win : float, optional\n            The reward for winning in an episode (default is 200).\n        reward_defeat : float, optional\n            The reward for loosing in an episode (default is 0). This value\n            should be nonpositive.\n        reward_negative_scale : float, optional\n            Scaling factor for negative rewards (default is 0.5). This\n            parameter is ignored when reward_only_positive == True.\n        reward_scale : bool, optional\n            Whether or not to scale the reward (default is True).\n        reward_scale_rate : float, optional\n            Reward scale rate (default is 20). When reward_scale == True, the\n            reward received by the agents is divided by (max_reward /\n            reward_scale_rate), where max_reward is the maximum possible\n            reward per episode without considering the shield regeneration\n            of Protoss units.\n        replay_dir : str, optional\n            The directory to save replays (default is None). If None, the\n            replay will be saved in Replays directory where StarCraft II is\n            installed.\n        replay_prefix : str, optional\n            The prefix of the replay to be saved (default is None). If None,\n            the name of the map will be used.\n        window_size_x : int, optional\n            The length of StarCraft II window size (default is 1920).\n        window_size_y: int, optional\n            The height of StarCraft II window size (default is 1200).\n        heuristic_ai: bool, optional\n            Whether or not to use a non-learning heuristic AI (default False).\n        heuristic_rest: bool, optional\n            At any moment, restrict the actions of the heuristic AI to be\n            chosen from actions available to RL agents (default is False).\n            Ignored if heuristic_ai == False.\n        debug: bool, optional\n            Log messages about observations, state, actions and rewards for\n            debugging purposes (default is False).\n        \"\"\"\n        # Map arguments\n        self.map_name = args.map_name\n        self.add_local_obs = args.add_local_obs\n        self.add_move_state = args.add_move_state\n        self.add_visible_state = args.add_visible_state\n        self.add_distance_state = args.add_distance_state\n        self.add_xy_state = args.add_xy_state\n        self.add_enemy_action_state = args.add_enemy_action_state\n        self.add_agent_id = args.add_agent_id\n        self.use_state_agent = args.use_state_agent\n        self.use_mustalive = args.use_mustalive\n        self.add_center_xy = args.add_center_xy\n        self.use_stacked_frames = args.use_stacked_frames\n        self.stacked_frames = args.stacked_frames\n        \n        map_params = get_map_params(self.map_name)\n        self.n_agents = map_params[\"n_agents\"]\n        self.n_enemies = map_params[\"n_enemies\"]\n        self.episode_limit = map_params[\"limit\"]\n        self._move_amount = move_amount\n        self._step_mul = step_mul\n        self.difficulty = difficulty\n\n        # Observations and state\n        self.obs_own_health = obs_own_health\n        self.obs_all_health = obs_all_health\n        self.obs_instead_of_state = args.use_obs_instead_of_state\n        self.obs_last_action = obs_last_action\n\n        self.obs_pathing_grid = obs_pathing_grid\n        self.obs_terrain_height = obs_terrain_height\n        self.obs_timestep_number = obs_timestep_number\n        self.obs_agent_id = obs_agent_id\n        self.state_pathing_grid = state_pathing_grid\n        self.state_terrain_height = state_terrain_height\n        self.state_last_action = state_last_action\n        self.state_timestep_number = state_timestep_number\n        self.state_agent_id = state_agent_id\n        if self.obs_all_health:\n            self.obs_own_health = True\n        self.n_obs_pathing = 8\n        self.n_obs_height = 9\n\n        # Rewards args\n        self.reward_sparse = reward_sparse\n        self.reward_only_positive = reward_only_positive\n        self.reward_negative_scale = reward_negative_scale\n        self.reward_death_value = reward_death_value\n        self.reward_win = reward_win\n        self.reward_defeat = reward_defeat\n\n        self.reward_scale = reward_scale\n        self.reward_scale_rate = reward_scale_rate\n\n        # Other\n        self.game_version = game_version\n        self.continuing_episode = continuing_episode\n        self._seed = seed\n        self.heuristic_ai = heuristic_ai\n        self.heuristic_rest = heuristic_rest\n        self.debug = debug\n        self.window_size = (window_size_x, window_size_y)\n        self.replay_dir = replay_dir\n        self.replay_prefix = replay_prefix\n\n        # Actions\n        self.n_actions_no_attack = 6\n        self.n_actions_move = 4\n        self.n_actions = self.n_actions_no_attack + self.n_enemies\n\n        # Map info\n        self._agent_race = map_params[\"a_race\"]\n        self._bot_race = map_params[\"b_race\"]\n        self.shield_bits_ally = 1 if self._agent_race == \"P\" else 0\n        self.shield_bits_enemy = 1 if self._bot_race == \"P\" else 0\n        self.unit_type_bits = map_params[\"unit_type_bits\"]\n        self.map_type = map_params[\"map_type\"]\n\n        self.max_reward = (\n            self.n_enemies * self.reward_death_value + self.reward_win\n        )\n\n        self.agents = {}\n        self.enemies = {}\n        self._episode_count = 0\n        self._episode_steps = 0\n        self._total_steps = 0\n        self._obs = None\n        self.battles_won = 0\n        self.battles_game = 0\n        self.timeouts = 0\n        self.force_restarts = 0\n        self.last_stats = None\n        self.death_tracker_ally = np.zeros(self.n_agents, dtype=np.float32)\n        self.death_tracker_enemy = np.zeros(self.n_enemies, dtype=np.float32)\n        self.previous_ally_units = None\n        self.previous_enemy_units = None\n        self.last_action = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)\n        self._min_unit_type = 0\n        self.marine_id = self.marauder_id = self.medivac_id = 0\n        self.hydralisk_id = self.zergling_id = self.baneling_id = 0\n        self.stalker_id = self.colossus_id = self.zealot_id = 0\n        self.max_distance_x = 0\n        self.max_distance_y = 0\n        self.map_x = 0\n        self.map_y = 0\n        self.terrain_height = None\n        self.pathing_grid = None\n        self._run_config = None\n        self._sc2_proc = None\n        self._controller = None\n\n        # Try to avoid leaking SC2 processes on shutdown\n        atexit.register(lambda: self.close())\n\n        self.action_space = []\n        self.observation_space = []\n        self.share_observation_space = []\n        for i in range(self.n_agents):\n            self.action_space.append(Discrete(self.n_actions))\n            self.observation_space.append(self.get_obs_size())\n            self.share_observation_space.append(self.get_state_size())\n\n        if self.use_stacked_frames:\n            self.stacked_local_obs = np.zeros((self.n_agents, self.stacked_frames, int(self.get_obs_size()[0]/self.stacked_frames)), dtype=np.float32)\n            self.stacked_global_state = np.zeros((self.n_agents, self.stacked_frames, int(self.get_state_size()[0]/self.stacked_frames)), dtype=np.float32)\n\n\n    def _launch(self):\n        \"\"\"Launch the StarCraft II game.\"\"\"\n        self._run_config = run_configs.get(version=self.game_version)\n        _map = maps.get(self.map_name)\n        self._seed += 1\n\n        # Setting up the interface\n        interface_options = sc_pb.InterfaceOptions(raw=True, score=False)\n        self._sc2_proc = self._run_config.start(window_size=self.window_size, want_rgb=False)\n        self._controller = self._sc2_proc.controller\n\n        # Request to create the game\n        create = sc_pb.RequestCreateGame(\n            local_map=sc_pb.LocalMap(\n                map_path=_map.path,\n                map_data=self._run_config.map_data(_map.path)),\n            realtime=False,\n            random_seed=self._seed)\n        create.player_setup.add(type=sc_pb.Participant)\n        create.player_setup.add(type=sc_pb.Computer, race=races[self._bot_race],\n                                difficulty=difficulties[self.difficulty])\n        self._controller.create_game(create)\n\n        join = sc_pb.RequestJoinGame(race=races[self._agent_race],\n                                     options=interface_options)\n        self._controller.join_game(join)\n\n        game_info = self._controller.game_info()\n        map_info = game_info.start_raw\n        map_play_area_min = map_info.playable_area.p0\n        map_play_area_max = map_info.playable_area.p1\n        self.max_distance_x = map_play_area_max.x - map_play_area_min.x\n        self.max_distance_y = map_play_area_max.y - map_play_area_min.y\n        self.map_x = map_info.map_size.x\n        self.map_y = map_info.map_size.y\n\n        if map_info.pathing_grid.bits_per_pixel == 1:\n            vals = np.array(list(map_info.pathing_grid.data)).reshape(\n                self.map_x, int(self.map_y / 8))\n            self.pathing_grid = np.transpose(np.array([\n                [(b >> i) & 1 for b in row for i in range(7, -1, -1)]\n                for row in vals], dtype=np.bool))\n        else:\n            self.pathing_grid = np.invert(np.flip(np.transpose(np.array(\n                list(map_info.pathing_grid.data), dtype=np.bool).reshape(\n                    self.map_x, self.map_y)), axis=1))\n\n        self.terrain_height = np.flip(\n            np.transpose(np.array(list(map_info.terrain_height.data))\n                         .reshape(self.map_x, self.map_y)), 1) / 255\n\n    def reset(self):\n        \"\"\"Reset the environment. Required after each full episode.\n        Returns initial observations and states.\n        \"\"\"\n        self._episode_steps = 0\n        if self._episode_count == 0:\n            # Launch StarCraft II\n            self._launch()\n        else:\n            self._restart()\n\n        # Information kept for counting the reward\n        self.death_tracker_ally = np.zeros(self.n_agents, dtype=np.float32)\n        self.death_tracker_enemy = np.zeros(self.n_enemies, dtype=np.float32)\n        self.previous_ally_units = None\n        self.previous_enemy_units = None\n        self.win_counted = False\n        self.defeat_counted = False\n\n        self.last_action = np.zeros((self.n_agents, self.n_actions), dtype=np.float32)\n\n        if self.heuristic_ai:\n            self.heuristic_targets = [None] * self.n_agents\n\n        try:\n            self._obs = self._controller.observe()\n            self.init_units()\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n\n        available_actions = []\n        for i in range(self.n_agents):\n            available_actions.append(self.get_avail_agent_actions(i))\n\n        if self.debug:\n            logging.debug(\"Started Episode {}\"\n                          .format(self._episode_count).center(60, \"*\"))\n\n        if self.use_state_agent:\n            global_state = [self.get_state_agent(agent_id) for agent_id in range(self.n_agents)]\n        else:\n            global_state = [self.get_state(agent_id) for agent_id in range(self.n_agents)]\n\n        local_obs = self.get_obs()\n\n        if self.use_stacked_frames:\n            self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)\n            self.stacked_global_state = np.roll(self.stacked_global_state, 1, axis=1)\n\n            self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()\n            self.stacked_global_state[:, -1, :] = np.array(global_state).copy()\n\n            local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)\n            global_state = self.stacked_global_state.reshape(self.n_agents, -1)\n\n        return local_obs, global_state, available_actions\n\n    def _restart(self):\n        \"\"\"Restart the environment by killing all units on the map.\n        There is a trigger in the SC2Map file, which restarts the\n        episode when there are no units left.\n        \"\"\"\n        try:\n            self._kill_all_units()\n            self._controller.step(2)\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n\n    def full_restart(self):\n        \"\"\"Full restart. Closes the SC2 process and launches a new one. \"\"\"\n        self._sc2_proc.close()\n        self._launch()\n        self.force_restarts += 1\n\n    def step(self, actions):\n        \"\"\"A single environment step. Returns reward, terminated, info.\"\"\"\n        terminated = False\n        bad_transition = False\n        infos = [{} for i in range(self.n_agents)]\n        dones = np.zeros((self.n_agents), dtype=bool)\n\n        actions_int = [int(a) for a in actions]\n\n        self.last_action = np.eye(self.n_actions)[np.array(actions_int)]\n\n        # Collect individual actions\n        sc_actions = []\n        if self.debug:\n            logging.debug(\"Actions\".center(60, \"-\"))\n\n        for a_id, action in enumerate(actions_int):\n            if not self.heuristic_ai:\n                sc_action = self.get_agent_action(a_id, action)\n            else:\n                sc_action, action_num = self.get_agent_action_heuristic(\n                    a_id, action)\n                actions[a_id] = action_num\n            if sc_action:\n                sc_actions.append(sc_action)\n\n        # Send action request\n        req_actions = sc_pb.RequestAction(actions=sc_actions)\n        try:\n            self._controller.actions(req_actions)\n            # Make step in SC2, i.e. apply actions\n            self._controller.step(self._step_mul)\n            # Observe here so that we know if the episode is over.\n            self._obs = self._controller.observe()\n        except (protocol.ProtocolError, protocol.ConnectionError):\n            self.full_restart()\n            terminated = True\n            available_actions = []\n            for i in range(self.n_agents):\n                available_actions.append(self.get_avail_agent_actions(i))\n                infos[i] = {\n                    \"battles_won\": self.battles_won,\n                    \"battles_game\": self.battles_game,\n                    \"battles_draw\": self.timeouts,\n                    \"restarts\": self.force_restarts,\n                    \"bad_transition\": bad_transition,\n                    \"won\": self.win_counted\n                }\n                if terminated:\n                    dones[i] = True\n                else:\n                    if self.death_tracker_ally[i]:\n                        dones[i] = True\n                    else:\n                        dones[i] = False\n\n            if self.use_state_agent:\n                global_state = [self.get_state_agent(agent_id) for agent_id in range(self.n_agents)]\n            else:\n                global_state = [self.get_state(agent_id) for agent_id in range(self.n_agents)]\n\n            local_obs = self.get_obs()\n\n            if self.use_stacked_frames:\n                self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)\n                self.stacked_global_state = np.roll(self.stacked_global_state, 1, axis=1)\n\n                self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()\n                self.stacked_global_state[:, -1, :] = np.array(global_state).copy()\n\n                local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)\n                global_state = self.stacked_global_state.reshape(self.n_agents, -1)\n\n            return local_obs, global_state, [[0]]*self.n_agents, dones, infos, available_actions\n\n        self._total_steps += 1\n        self._episode_steps += 1\n\n        # Update units\n        game_end_code = self.update_units()\n\n        reward = self.reward_battle()\n\n        available_actions = []\n        for i in range(self.n_agents):\n            available_actions.append(self.get_avail_agent_actions(i))\n\n        if game_end_code is not None:\n            # Battle is over\n            terminated = True\n            self.battles_game += 1\n            if game_end_code == 1 and not self.win_counted:\n                self.battles_won += 1\n                self.win_counted = True\n                if not self.reward_sparse:\n                    reward += self.reward_win\n                else:\n                    reward = 1\n            elif game_end_code == -1 and not self.defeat_counted:\n                self.defeat_counted = True\n                if not self.reward_sparse:\n                    reward += self.reward_defeat\n                else:\n                    reward = -1\n\n        elif self._episode_steps >= self.episode_limit:\n            # Episode limit reached\n            terminated = True\n            self.bad_transition = True\n            if self.continuing_episode:\n                info[\"episode_limit\"] = True\n            self.battles_game += 1\n            self.timeouts += 1\n\n        for i in range(self.n_agents):\n            infos[i] = {\n                \"battles_won\": self.battles_won,\n                \"battles_game\": self.battles_game,\n                \"battles_draw\": self.timeouts,\n                \"restarts\": self.force_restarts,\n                \"bad_transition\": bad_transition,\n                \"won\": self.win_counted\n            }\n\n            if terminated:\n                dones[i] = True\n            else:\n                if self.death_tracker_ally[i]:\n                    dones[i] = True\n                else:\n                    dones[i] = False\n\n        if self.debug:\n            logging.debug(\"Reward = {}\".format(reward).center(60, '-'))\n\n        if terminated:\n            self._episode_count += 1\n\n        if self.reward_scale:\n            reward /= self.max_reward / self.reward_scale_rate\n\n        rewards = [[reward]]*self.n_agents\n\n        if self.use_state_agent:\n            global_state = [self.get_state_agent(agent_id) for agent_id in range(self.n_agents)]\n        else:\n            global_state = [self.get_state(agent_id) for agent_id in range(self.n_agents)]\n\n        local_obs = self.get_obs()\n\n        if self.use_stacked_frames:\n            self.stacked_local_obs = np.roll(self.stacked_local_obs, 1, axis=1)\n            self.stacked_global_state = np.roll(self.stacked_global_state, 1, axis=1)\n\n            self.stacked_local_obs[:, -1, :] = np.array(local_obs).copy()\n            self.stacked_global_state[:, -1, :] = np.array(global_state).copy()\n\n            local_obs = self.stacked_local_obs.reshape(self.n_agents, -1)\n            global_state = self.stacked_global_state.reshape(self.n_agents, -1)\n\n        return local_obs, global_state, rewards, dones, infos, available_actions\n\n    def get_agent_action(self, a_id, action):\n        \"\"\"Construct the action for agent a_id.\"\"\"\n        avail_actions = self.get_avail_agent_actions(a_id)\n        assert avail_actions[action] == 1, \\\n            \"Agent {} cannot perform action {}\".format(a_id, action)\n\n        unit = self.get_unit_by_id(a_id)\n        tag = unit.tag\n        x = unit.pos.x\n        y = unit.pos.y\n\n        if action == 0:\n            # no-op (valid only when dead)\n            assert unit.health == 0, \"No-op only available for dead agents.\"\n            if self.debug:\n                logging.debug(\"Agent {}: Dead\".format(a_id))\n            return None\n        elif action == 1:\n            # stop\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"stop\"],\n                unit_tags=[tag],\n                queue_command=False)\n            if self.debug:\n                logging.debug(\"Agent {}: Stop\".format(a_id))\n\n        elif action == 2:\n            # move north\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x, y=y + self._move_amount),\n                unit_tags=[tag],\n                queue_command=False)\n            if self.debug:\n                logging.debug(\"Agent {}: Move North\".format(a_id))\n\n        elif action == 3:\n            # move south\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x, y=y - self._move_amount),\n                unit_tags=[tag],\n                queue_command=False)\n            if self.debug:\n                logging.debug(\"Agent {}: Move South\".format(a_id))\n\n        elif action == 4:\n            # move east\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x + self._move_amount, y=y),\n                unit_tags=[tag],\n                queue_command=False)\n            if self.debug:\n                logging.debug(\"Agent {}: Move East\".format(a_id))\n\n        elif action == 5:\n            # move west\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions[\"move\"],\n                target_world_space_pos=sc_common.Point2D(\n                    x=x - self._move_amount, y=y),\n                unit_tags=[tag],\n                queue_command=False)\n            if self.debug:\n                logging.debug(\"Agent {}: Move West\".format(a_id))\n        else:\n            # attack/heal units that are in range\n            target_id = action - self.n_actions_no_attack\n            if self.map_type == \"MMM\" and unit.unit_type == self.medivac_id:\n                target_unit = self.agents[target_id]\n                action_name = \"heal\"\n            else:\n                target_unit = self.enemies[target_id]\n                action_name = \"attack\"\n\n            action_id = actions[action_name]\n            target_tag = target_unit.tag\n\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=action_id,\n                target_unit_tag=target_tag,\n                unit_tags=[tag],\n                queue_command=False)\n\n            if self.debug:\n                logging.debug(\"Agent {} {}s unit # {}\".format(\n                    a_id, action_name, target_id))\n\n        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))\n        return sc_action\n\n    def get_agent_action_heuristic(self, a_id, action):\n        unit = self.get_unit_by_id(a_id)\n        tag = unit.tag\n\n        target = self.heuristic_targets[a_id]\n        if unit.unit_type == self.medivac_id:\n            if (target is None or self.agents[target].health == 0 or\n                    self.agents[target].health == self.agents[target].health_max):\n                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)\n                min_id = -1\n                for al_id, al_unit in self.agents.items():\n                    if al_unit.unit_type == self.medivac_id:\n                        continue\n                    if (al_unit.health != 0 and\n                            al_unit.health != al_unit.health_max):\n                        dist = self.distance(unit.pos.x, unit.pos.y,\n                                             al_unit.pos.x, al_unit.pos.y)\n                        if dist < min_dist:\n                            min_dist = dist\n                            min_id = al_id\n                self.heuristic_targets[a_id] = min_id\n                if min_id == -1:\n                    self.heuristic_targets[a_id] = None\n                    return None, 0\n            action_id = actions['heal']\n            target_tag = self.agents[self.heuristic_targets[a_id]].tag\n        else:\n            if target is None or self.enemies[target].health == 0:\n                min_dist = math.hypot(self.max_distance_x, self.max_distance_y)\n                min_id = -1\n                for e_id, e_unit in self.enemies.items():\n                    if (unit.unit_type == self.marauder_id and\n                            e_unit.unit_type == self.medivac_id):\n                        continue\n                    if e_unit.health > 0:\n                        dist = self.distance(unit.pos.x, unit.pos.y,\n                                             e_unit.pos.x, e_unit.pos.y)\n                        if dist < min_dist:\n                            min_dist = dist\n                            min_id = e_id\n                self.heuristic_targets[a_id] = min_id\n                if min_id == -1:\n                    self.heuristic_targets[a_id] = None\n                    return None, 0\n            action_id = actions['attack']\n            target_tag = self.enemies[self.heuristic_targets[a_id]].tag\n\n        action_num = self.heuristic_targets[a_id] + self.n_actions_no_attack\n\n        # Check if the action is available\n        if (self.heuristic_rest and\n                self.get_avail_agent_actions(a_id)[action_num] == 0):\n\n            # Move towards the target rather than attacking/healing\n            if unit.unit_type == self.medivac_id:\n                target_unit = self.agents[self.heuristic_targets[a_id]]\n            else:\n                target_unit = self.enemies[self.heuristic_targets[a_id]]\n\n            delta_x = target_unit.pos.x - unit.pos.x\n            delta_y = target_unit.pos.y - unit.pos.y\n\n            if abs(delta_x) > abs(delta_y):  # east or west\n                if delta_x > 0:  # east\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x + self._move_amount, y=unit.pos.y)\n                    action_num = 4\n                else:  # west\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x - self._move_amount, y=unit.pos.y)\n                    action_num = 5\n            else:  # north or south\n                if delta_y > 0:  # north\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x, y=unit.pos.y + self._move_amount)\n                    action_num = 2\n                else:  # south\n                    target_pos = sc_common.Point2D(\n                        x=unit.pos.x, y=unit.pos.y - self._move_amount)\n                    action_num = 3\n\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=actions['move'],\n                target_world_space_pos=target_pos,\n                unit_tags=[tag],\n                queue_command=False)\n        else:\n            # Attack/heal the target\n            cmd = r_pb.ActionRawUnitCommand(\n                ability_id=action_id,\n                target_unit_tag=target_tag,\n                unit_tags=[tag],\n                queue_command=False)\n\n        sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))\n        return sc_action, action_num\n\n    def reward_battle(self):\n        \"\"\"Reward function when self.reward_spare==False.\n        Returns accumulative hit/shield point damage dealt to the enemy\n        + reward_death_value per enemy unit killed, and, in case\n        self.reward_only_positive == False, - (damage dealt to ally units\n        + reward_death_value per ally unit killed) * self.reward_negative_scale\n        \"\"\"\n        if self.reward_sparse:\n            return 0\n\n        reward = 0\n        delta_deaths = 0\n        delta_ally = 0\n        delta_enemy = 0\n\n        neg_scale = self.reward_negative_scale\n\n        # update deaths\n        for al_id, al_unit in self.agents.items():\n            if not self.death_tracker_ally[al_id]:\n                # did not die so far\n                prev_health = (\n                    self.previous_ally_units[al_id].health\n                    + self.previous_ally_units[al_id].shield\n                )\n                if al_unit.health == 0:\n                    # just died\n                    self.death_tracker_ally[al_id] = 1\n                    if not self.reward_only_positive:\n                        delta_deaths -= self.reward_death_value * neg_scale\n                    delta_ally += prev_health * neg_scale\n                else:\n                    # still alive\n                    delta_ally += neg_scale * (\n                        prev_health - al_unit.health - al_unit.shield\n                    )\n\n        for e_id, e_unit in self.enemies.items():\n            if not self.death_tracker_enemy[e_id]:\n                prev_health = (\n                    self.previous_enemy_units[e_id].health\n                    + self.previous_enemy_units[e_id].shield\n                )\n                if e_unit.health == 0:\n                    self.death_tracker_enemy[e_id] = 1\n                    delta_deaths += self.reward_death_value\n                    delta_enemy += prev_health\n                else:\n                    delta_enemy += prev_health - e_unit.health - e_unit.shield\n\n        if self.reward_only_positive:\n            reward = abs(delta_enemy + delta_deaths)  # shield regeneration\n        else:\n            reward = delta_enemy + delta_deaths - delta_ally\n\n        return reward\n\n    def get_total_actions(self):\n        \"\"\"Returns the total number of actions an agent could ever take.\"\"\"\n        return self.n_actions\n\n    @staticmethod\n    def distance(x1, y1, x2, y2):\n        \"\"\"Distance between two points.\"\"\"\n        return math.hypot(x2 - x1, y2 - y1)\n\n    def unit_shoot_range(self, agent_id):\n        \"\"\"Returns the shooting range for an agent.\"\"\"\n        return 6\n\n    def unit_sight_range(self, agent_id):\n        \"\"\"Returns the sight range for an agent.\"\"\"\n        return 9\n\n    def unit_max_cooldown(self, unit):\n        \"\"\"Returns the maximal cooldown for a unit.\"\"\"\n        switcher = {\n            self.marine_id: 15,\n            self.marauder_id: 25,\n            self.medivac_id: 200,  # max energy\n            self.stalker_id: 35,\n            self.zealot_id: 22,\n            self.colossus_id: 24,\n            self.hydralisk_id: 10,\n            self.zergling_id: 11,\n            self.baneling_id: 1\n        }\n        return switcher.get(unit.unit_type, 15)\n\n    def save_replay(self):\n        \"\"\"Save a replay.\"\"\"\n        prefix = self.replay_prefix or self.map_name\n        replay_dir = self.replay_dir or \"\"\n        replay_path = self._run_config.save_replay(\n            self._controller.save_replay(), replay_dir=replay_dir, prefix=prefix)\n        logging.info(\"Replay saved at: %s\" % replay_path)\n\n    def unit_max_shield(self, unit):\n        \"\"\"Returns maximal shield for a given unit.\"\"\"\n        if unit.unit_type == 74 or unit.unit_type == self.stalker_id:\n            return 80  # Protoss's Stalker\n        if unit.unit_type == 73 or unit.unit_type == self.zealot_id:\n            return 50  # Protoss's Zaelot\n        if unit.unit_type == 4 or unit.unit_type == self.colossus_id:\n            return 150  # Protoss's Colossus\n\n    def can_move(self, unit, direction):\n        \"\"\"Whether a unit can move in a given direction.\"\"\"\n        m = self._move_amount / 2\n\n        if direction == Direction.NORTH:\n            x, y = int(unit.pos.x), int(unit.pos.y + m)\n        elif direction == Direction.SOUTH:\n            x, y = int(unit.pos.x), int(unit.pos.y - m)\n        elif direction == Direction.EAST:\n            x, y = int(unit.pos.x + m), int(unit.pos.y)\n        else:\n            x, y = int(unit.pos.x - m), int(unit.pos.y)\n\n        if self.check_bounds(x, y) and self.pathing_grid[x, y]:\n            return True\n\n        return False\n\n    def get_surrounding_points(self, unit, include_self=False):\n        \"\"\"Returns the surrounding points of the unit in 8 directions.\"\"\"\n        x = int(unit.pos.x)\n        y = int(unit.pos.y)\n\n        ma = self._move_amount\n\n        points = [\n            (x, y + 2 * ma),\n            (x, y - 2 * ma),\n            (x + 2 * ma, y),\n            (x - 2 * ma, y),\n            (x + ma, y + ma),\n            (x - ma, y - ma),\n            (x + ma, y - ma),\n            (x - ma, y + ma),\n        ]\n\n        if include_self:\n            points.append((x, y))\n\n        return points\n\n    def check_bounds(self, x, y):\n        \"\"\"Whether a point is within the map bounds.\"\"\"\n        return (0 <= x < self.map_x and 0 <= y < self.map_y)\n\n    def get_surrounding_pathing(self, unit):\n        \"\"\"Returns pathing values of the grid surrounding the given unit.\"\"\"\n        points = self.get_surrounding_points(unit, include_self=False)\n        vals = [\n            self.pathing_grid[x, y] if self.check_bounds(x, y) else 1\n            for x, y in points\n        ]\n        return vals\n\n    def get_surrounding_height(self, unit):\n        \"\"\"Returns height values of the grid surrounding the given unit.\"\"\"\n        points = self.get_surrounding_points(unit, include_self=True)\n        vals = [\n            self.terrain_height[x, y] if self.check_bounds(x, y) else 1\n            for x, y in points\n        ]\n        return vals\n\n    def get_obs_agent(self, agent_id):\n        \"\"\"Returns observation for agent_id. The observation is composed of:\n\n           - agent movement features (where it can move to, height information and pathing grid)\n           - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)\n           - ally features (visible, distance, relative_x, relative_y, shield, unit_type)\n           - agent unit features (health, shield, unit_type)\n\n           All of this information is flattened and concatenated into a list,\n           in the aforementioned order. To know the sizes of each of the\n           features inside the final list of features, take a look at the\n           functions ``get_obs_move_feats_size()``,\n           ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and\n           ``get_obs_own_feats_size()``.\n\n           The size of the observation vector may vary, depending on the\n           environment configuration and type of units present in the map.\n           For instance, non-Protoss units will not have shields, movement\n           features may or may not include terrain height and pathing grid,\n           unit_type is not included if there is only one type of unit in the\n           map etc.).\n\n           NOTE: Agents should have access only to their local observations\n           during decentralised execution.\n        \"\"\"\n        unit = self.get_unit_by_id(agent_id)\n\n        move_feats_dim = self.get_obs_move_feats_size()\n        enemy_feats_dim = self.get_obs_enemy_feats_size()\n        ally_feats_dim = self.get_obs_ally_feats_size()\n        own_feats_dim = self.get_obs_own_feats_size()\n\n        move_feats = np.zeros(move_feats_dim, dtype=np.float32)\n        enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)\n        ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)\n        own_feats = np.zeros(own_feats_dim, dtype=np.float32)\n        agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)\n\n        if unit.health > 0:  # otherwise dead, return all zeros\n            x = unit.pos.x\n            y = unit.pos.y\n            sight_range = self.unit_sight_range(agent_id)\n\n            # Movement features\n            avail_actions = self.get_avail_agent_actions(agent_id)\n            for m in range(self.n_actions_move):\n                move_feats[m] = avail_actions[m + 2]\n\n            ind = self.n_actions_move\n\n            if self.obs_pathing_grid:\n                move_feats[ind: ind + self.n_obs_pathing] = self.get_surrounding_pathing(unit)\n                ind += self.n_obs_pathing\n\n            if self.obs_terrain_height:\n                move_feats[ind:] = self.get_surrounding_height(unit)\n\n            # Enemy features\n            for e_id, e_unit in self.enemies.items():\n                e_x = e_unit.pos.x\n                e_y = e_unit.pos.y\n                dist = self.distance(x, y, e_x, e_y)\n\n                if (dist < sight_range and e_unit.health > 0):  # visible and alive\n                    # Sight range > shoot range\n                    enemy_feats[e_id, 0] = avail_actions[self.n_actions_no_attack + e_id]  # available\n                    enemy_feats[e_id, 1] = dist / sight_range  # distance\n                    enemy_feats[e_id, 2] = (e_x - x) / sight_range  # relative X\n                    enemy_feats[e_id, 3] = (e_y - y) / sight_range  # relative Y\n\n                    ind = 4\n                    if self.obs_all_health:\n                        enemy_feats[e_id, ind] = (e_unit.health / e_unit.health_max)  # health\n                        ind += 1\n                        if self.shield_bits_enemy > 0:\n                            max_shield = self.unit_max_shield(e_unit)\n                            enemy_feats[e_id, ind] = (e_unit.shield / max_shield)  # shield\n                            ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(e_unit, False)\n                        enemy_feats[e_id, ind + type_id] = 1  # unit type\n\n            # Ally features\n            al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id]\n            for i, al_id in enumerate(al_ids):\n\n                al_unit = self.get_unit_by_id(al_id)\n                al_x = al_unit.pos.x\n                al_y = al_unit.pos.y\n                dist = self.distance(x, y, al_x, al_y)\n\n                if (dist < sight_range and al_unit.health > 0):  # visible and alive\n                    ally_feats[i, 0] = 1  # visible\n                    ally_feats[i, 1] = dist / sight_range  # distance\n                    ally_feats[i, 2] = (al_x - x) / sight_range  # relative X\n                    ally_feats[i, 3] = (al_y - y) / sight_range  # relative Y\n\n                    ind = 4\n                    if self.obs_all_health:\n                        ally_feats[i, ind] = (al_unit.health / al_unit.health_max)  # health\n                        ind += 1\n                        if self.shield_bits_ally > 0:\n                            max_shield = self.unit_max_shield(al_unit)\n                            ally_feats[i, ind] = (al_unit.shield / max_shield)  # shield\n                            ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(al_unit, True)\n                        ally_feats[i, ind + type_id] = 1\n                        ind += self.unit_type_bits\n\n                    if self.obs_last_action:\n                        ally_feats[i, ind:] = self.last_action[al_id]\n\n            # Own features\n            ind = 0\n            own_feats[0] = 1  # visible\n            own_feats[1] = 0  # distance\n            own_feats[2] = 0  # X\n            own_feats[3] = 0  # Y\n            ind = 4\n            if self.obs_own_health:\n                own_feats[ind] = unit.health / unit.health_max\n                ind += 1\n                if self.shield_bits_ally > 0:\n                    max_shield = self.unit_max_shield(unit)\n                    own_feats[ind] = unit.shield / max_shield\n                    ind += 1\n\n            if self.unit_type_bits > 0:\n                type_id = self.get_unit_type_id(unit, True)\n                own_feats[ind + type_id] = 1\n                ind += self.unit_type_bits\n\n            if self.obs_last_action:\n                own_feats[ind:] = self.last_action[agent_id]\n\n        agent_obs = np.concatenate((ally_feats.flatten(),\n                                      enemy_feats.flatten(),\n                                      move_feats.flatten(),\n                                      own_feats.flatten()))\n\n        # Agent id features\n        if self.obs_agent_id:\n            agent_id_feats[agent_id] = 1.\n            agent_obs = np.concatenate((ally_feats.flatten(),\n                                          enemy_feats.flatten(),\n                                          move_feats.flatten(),\n                                          own_feats.flatten(),\n                                          agent_id_feats.flatten()))\n\n        if self.obs_timestep_number:\n            agent_obs = np.append(agent_obs, self._episode_steps / self.episode_limit)\n\n        if self.debug:\n            logging.debug(\"Obs Agent: {}\".format(agent_id).center(60, \"-\"))\n            logging.debug(\"Avail. actions {}\".format(\n                self.get_avail_agent_actions(agent_id)))\n            logging.debug(\"Move feats {}\".format(move_feats))\n            logging.debug(\"Enemy feats {}\".format(enemy_feats))\n            logging.debug(\"Ally feats {}\".format(ally_feats))\n            logging.debug(\"Own feats {}\".format(own_feats))\n\n        return agent_obs\n\n    def get_obs(self):\n        \"\"\"Returns all agent observations in a list.\n        NOTE: Agents should have access only to their local observations\n        during decentralised execution.\n        \"\"\"\n        agents_obs = [self.get_obs_agent(i) for i in range(self.n_agents)]\n        return agents_obs\n\n    def get_state(self, agent_id=-1):\n        \"\"\"Returns the global state.\n        NOTE: This functon should not be used during decentralised execution.\n        \"\"\"\n        if self.obs_instead_of_state:\n            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)\n            return obs_concat\n\n        nf_al = 2 + self.shield_bits_ally + self.unit_type_bits\n        nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits\n\n        if self.add_center_xy:\n            nf_al += 2\n            nf_en += 2\n\n        if self.add_distance_state:\n            nf_al += 1\n            nf_en += 1\n\n        if self.add_xy_state:\n            nf_al += 2\n            nf_en += 2\n\n        if self.add_visible_state:\n            nf_al += 1\n            nf_en += 1\n\n        if self.state_last_action:\n            nf_al += self.n_actions\n            nf_en += self.n_actions\n\n        if self.add_enemy_action_state:\n            nf_en += 1\n\n        nf_mv = self.get_state_move_feats_size()\n\n        ally_state = np.zeros((self.n_agents, nf_al), dtype=np.float32)\n        enemy_state = np.zeros((self.n_enemies, nf_en), dtype=np.float32)\n        move_state = np.zeros((1, nf_mv), dtype=np.float32)\n        agent_id_feats = np.zeros((self.n_agents, 1), dtype=np.float32)\n\n        center_x = self.map_x / 2\n        center_y = self.map_y / 2\n\n        unit = self.get_unit_by_id(agent_id)# get the unit of some agent \n        x = unit.pos.x\n        y = unit.pos.y\n        sight_range = self.unit_sight_range(agent_id)\n        avail_actions = self.get_avail_agent_actions(agent_id) \n\n        if (self.use_mustalive and unit.health > 0) or (not self.use_mustalive): # or else all zeros\n            # Movement features\n            for m in range(self.n_actions_move):\n                move_state[0, m] = avail_actions[m + 2]\n\n            ind = self.n_actions_move\n\n            if self.state_pathing_grid:\n                move_state[0, ind: ind + self.n_obs_pathing] = self.get_surrounding_pathing(unit)\n                ind += self.n_obs_pathing\n\n            if self.state_terrain_height:\n                move_state[0, ind:] = self.get_surrounding_height(unit)\n                                      \n            for al_id, al_unit in self.agents.items():\n                if al_unit.health > 0:\n                    al_x = al_unit.pos.x\n                    al_y = al_unit.pos.y\n                    max_cd = self.unit_max_cooldown(al_unit)\n                    dist = self.distance(x, y, al_x, al_y)\n\n                    ally_state[al_id, 0] = (al_unit.health / al_unit.health_max)  # health\n                    if (self.map_type == \"MMM\" and al_unit.unit_type == self.medivac_id):\n                        ally_state[al_id, 1] = al_unit.energy / max_cd  # energy\n                    else:\n                        ally_state[al_id, 1] = (al_unit.weapon_cooldown / max_cd)  # cooldown\n                    \n                    ind = 2\n                    \n                    if self.add_center_xy:\n                        ally_state[al_id, ind] = (al_x - center_x) / self.max_distance_x  # center X\n                        ally_state[al_id, ind+1] = (al_y - center_y) / self.max_distance_y  # center Y\n                        ind += 2\n\n                    if self.shield_bits_ally > 0:\n                        max_shield = self.unit_max_shield(al_unit)\n                        ally_state[al_id, ind] = (al_unit.shield / max_shield)  # shield\n                        ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(al_unit, True)\n                        ally_state[al_id, ind + type_id] = 1\n\n                    if unit.health > 0:\n                        ind += self.unit_type_bits\n                        if self.add_distance_state:\n                            ally_state[al_id, ind] = dist / sight_range  # distance\n                            ind += 1\n                        if self.add_xy_state:\n                            ally_state[al_id, ind] = (al_x - x) / sight_range  # relative X\n                            ally_state[al_id, ind + 1] = (al_y - y) / sight_range  # relative Y\n                            ind += 2\n                        if self.add_visible_state:\n                            if dist < sight_range:\n                                ally_state[al_id, ind] = 1 # visible\n                            ind += 1\n                        if self.state_last_action:\n                            ally_state[al_id, ind:] = self.last_action[al_id]\n\n            for e_id, e_unit in self.enemies.items():\n                if e_unit.health > 0:\n                    e_x = e_unit.pos.x\n                    e_y = e_unit.pos.y\n                    dist = self.distance(x, y, e_x, e_y)\n\n                    enemy_state[e_id, 0] = (e_unit.health / e_unit.health_max)  # health               \n                    \n                    ind = 1\n                    if self.add_center_xy:\n                        enemy_state[e_id, ind] = (e_x - center_x) / self.max_distance_x  # center X\n                        enemy_state[e_id, ind+1] = (e_y - center_y) / self.max_distance_y  # center Y\n                        ind += 2\n                        \n                    if self.shield_bits_enemy > 0:\n                        max_shield = self.unit_max_shield(e_unit)\n                        enemy_state[e_id, ind] = (e_unit.shield / max_shield)  # shield\n                        ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(e_unit, False)\n                        enemy_state[e_id, ind + type_id] = 1\n\n                    if unit.health > 0:\n                        ind += self.unit_type_bits\n                        if self.add_distance_state:\n                            enemy_state[e_id, ind] = dist / sight_range  # distance\n                            ind += 1\n                        if self.add_xy_state:\n                            enemy_state[e_id, ind] = (e_x - x) / sight_range  # relative X\n                            enemy_state[e_id, ind + 1] = (e_y - y) / sight_range  # relative Y\n                            ind += 2\n                        if self.add_visible_state:\n                            if dist < sight_range:\n                                enemy_state[e_id, ind] = 1 # visible\n                            ind += 1\n                        if self.add_enemy_action_state:\n                            enemy_state[e_id, ind] = avail_actions[self.n_actions_no_attack + e_id]  # available\n\n        state = np.append(ally_state.flatten(), enemy_state.flatten())\n               \n        if self.add_move_state:\n            state = np.append(state, move_state.flatten())\n\n        if self.add_local_obs:\n            state = np.append(state, self.get_obs_agent(agent_id).flatten())\n\n        if self.state_timestep_number:\n            state = np.append(state, self._episode_steps / self.episode_limit)\n\n        if self.add_agent_id:\n            agent_id_feats[agent_id] = 1.0\n            state = np.append(state, agent_id_feats.flatten())\n\n        state = state.astype(dtype=np.float32)\n\n        if self.debug:\n            logging.debug(\"STATE\".center(60, \"-\"))\n            logging.debug(\"Ally state {}\".format(ally_state))\n            logging.debug(\"Enemy state {}\".format(enemy_state))\n            logging.debug(\"Move state {}\".format(move_state))\n            if self.state_last_action:\n                logging.debug(\"Last actions {}\".format(self.last_action))\n\n        return state\n    \n    def get_state_agent(self, agent_id):\n        \"\"\"Returns observation for agent_id. The observation is composed of:\n\n           - agent movement features (where it can move to, height information and pathing grid)\n           - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)\n           - ally features (visible, distance, relative_x, relative_y, shield, unit_type)\n           - agent unit features (health, shield, unit_type)\n\n           All of this information is flattened and concatenated into a list,\n           in the aforementioned order. To know the sizes of each of the\n           features inside the final list of features, take a look at the\n           functions ``get_obs_move_feats_size()``,\n           ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and\n           ``get_obs_own_feats_size()``.\n\n           The size of the observation vector may vary, depending on the\n           environment configuration and type of units present in the map.\n           For instance, non-Protoss units will not have shields, movement\n           features may or may not include terrain height and pathing grid,\n           unit_type is not included if there is only one type of unit in the\n           map etc.).\n\n           NOTE: Agents should have access only to their local observations\n           during decentralised execution.\n        \"\"\"\n        if self.obs_instead_of_state:\n            obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)\n            return obs_concat\n            \n        unit = self.get_unit_by_id(agent_id)\n\n        move_feats_dim = self.get_obs_move_feats_size()\n        enemy_feats_dim = self.get_state_enemy_feats_size()\n        ally_feats_dim = self.get_state_ally_feats_size()\n        own_feats_dim = self.get_state_own_feats_size()\n\n        move_feats = np.zeros(move_feats_dim, dtype=np.float32)\n        enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)\n        ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)\n        own_feats = np.zeros(own_feats_dim, dtype=np.float32)\n        agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)\n\n        center_x = self.map_x / 2\n        center_y = self.map_y / 2\n\n        if (self.use_mustalive and unit.health > 0) or (not self.use_mustalive):  # otherwise dead, return all zeros\n            x = unit.pos.x\n            y = unit.pos.y\n            sight_range = self.unit_sight_range(agent_id)\n\n            # Movement features\n            avail_actions = self.get_avail_agent_actions(agent_id)\n            for m in range(self.n_actions_move):\n                move_feats[m] = avail_actions[m + 2]\n\n            ind = self.n_actions_move\n\n            if self.state_pathing_grid:\n                move_feats[ind: ind + self.n_obs_pathing] = self.get_surrounding_pathing(unit)\n                ind += self.n_obs_pathing\n\n            if self.state_terrain_height:\n                move_feats[ind:] = self.get_surrounding_height(unit)\n\n            # Enemy features\n            for e_id, e_unit in self.enemies.items():\n                e_x = e_unit.pos.x\n                e_y = e_unit.pos.y\n                dist = self.distance(x, y, e_x, e_y)\n\n                if e_unit.health > 0:  # visible and alive\n                    # Sight range > shoot range\n                    if unit.health > 0:\n                        enemy_feats[e_id, 0] = avail_actions[self.n_actions_no_attack + e_id]  # available\n                        enemy_feats[e_id, 1] = dist / sight_range  # distance\n                        enemy_feats[e_id, 2] = (e_x - x) / sight_range  # relative X\n                        enemy_feats[e_id, 3] = (e_y - y) / sight_range  # relative Y\n                        if dist < sight_range:\n                            enemy_feats[e_id, 4] = 1  # visible\n\n                    ind = 5\n                    if self.obs_all_health:\n                        enemy_feats[e_id, ind] = (e_unit.health / e_unit.health_max)  # health\n                        ind += 1\n                        if self.shield_bits_enemy > 0:\n                            max_shield = self.unit_max_shield(e_unit)\n                            enemy_feats[e_id, ind] = (e_unit.shield / max_shield)  # shield\n                            ind += 1\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(e_unit, False)\n                        enemy_feats[e_id, ind + type_id] = 1  # unit type\n                        ind += self.unit_type_bits\n\n                    if self.add_center_xy:\n                        enemy_feats[e_id, ind] = (e_x - center_x) / self.max_distance_x  # center X\n                        enemy_feats[e_id, ind+1] = (e_y - center_y) / self.max_distance_y  # center Y\n\n            # Ally features\n            al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id]\n            for i, al_id in enumerate(al_ids):\n\n                al_unit = self.get_unit_by_id(al_id)\n                al_x = al_unit.pos.x\n                al_y = al_unit.pos.y\n                dist = self.distance(x, y, al_x, al_y)\n                max_cd = self.unit_max_cooldown(al_unit)\n\n                if al_unit.health > 0:  # visible and alive\n                    if unit.health > 0:\n                        if dist < sight_range:\n                            ally_feats[i, 0] = 1  # visible\n                        ally_feats[i, 1] = dist / sight_range  # distance\n                        ally_feats[i, 2] = (al_x - x) / sight_range  # relative X\n                        ally_feats[i, 3] = (al_y - y) / sight_range  # relative Y\n\n                    if (self.map_type == \"MMM\" and al_unit.unit_type == self.medivac_id):\n                        ally_feats[i, 4] = al_unit.energy / max_cd  # energy\n                    else:\n                        ally_feats[i, 4] = (al_unit.weapon_cooldown / max_cd)  # cooldown\n\n                    ind = 5\n                    if self.obs_all_health:\n                        ally_feats[i, ind] = (al_unit.health / al_unit.health_max)  # health\n                        ind += 1\n                        if self.shield_bits_ally > 0:\n                            max_shield = self.unit_max_shield(al_unit)\n                            ally_feats[i, ind] = (al_unit.shield / max_shield)  # shield\n                            ind += 1\n\n                    if self.add_center_xy:\n                        ally_feats[i, ind] = (al_x - center_x) / self.max_distance_x  # center X\n                        ally_feats[i, ind+1] = (al_y - center_y) / self.max_distance_y  # center Y\n                        ind += 2\n\n                    if self.unit_type_bits > 0:\n                        type_id = self.get_unit_type_id(al_unit, True)\n                        ally_feats[i, ind + type_id] = 1\n                        ind += self.unit_type_bits\n\n                    if self.state_last_action:\n                        ally_feats[i, ind:] = self.last_action[al_id]\n\n            # Own features\n            ind = 0\n            own_feats[0] = 1  # visible\n            own_feats[1] = 0  # distance\n            own_feats[2] = 0  # X\n            own_feats[3] = 0  # Y\n            ind = 4\n            if self.obs_own_health:\n                own_feats[ind] = unit.health / unit.health_max\n                ind += 1\n                if self.shield_bits_ally > 0:\n                    max_shield = self.unit_max_shield(unit)\n                    own_feats[ind] = unit.shield / max_shield\n                    ind += 1\n\n            if self.add_center_xy:\n                own_feats[ind] = (x - center_x) / self.max_distance_x  # center X\n                own_feats[ind+1] = (y - center_y) / self.max_distance_y  # center Y\n                ind += 2\n\n            if self.unit_type_bits > 0:\n                type_id = self.get_unit_type_id(unit, True)\n                own_feats[ind + type_id] = 1\n                ind += self.unit_type_bits\n\n            if self.state_last_action:\n                own_feats[ind:] = self.last_action[agent_id]\n\n        state = np.concatenate((ally_feats.flatten(), \n                                enemy_feats.flatten(),\n                                move_feats.flatten(),\n                                own_feats.flatten()))\n\n        # Agent id features\n        if self.state_agent_id:\n            agent_id_feats[agent_id] = 1.\n            state = np.append(state, agent_id_feats.flatten())\n\n        if self.state_timestep_number:\n            state = np.append(state, self._episode_steps / self.episode_limit)\n\n        if self.debug:\n            logging.debug(\"Obs Agent: {}\".format(agent_id).center(60, \"-\"))\n            logging.debug(\"Avail. actions {}\".format(\n                self.get_avail_agent_actions(agent_id)))\n            logging.debug(\"Move feats {}\".format(move_feats))\n            logging.debug(\"Enemy feats {}\".format(enemy_feats))\n            logging.debug(\"Ally feats {}\".format(ally_feats))\n            logging.debug(\"Own feats {}\".format(own_feats))\n\n        return state\n\n    def get_obs_enemy_feats_size(self):\n        \"\"\" Returns the dimensions of the matrix containing enemy features.\n        Size is n_enemies x n_features.\n        \"\"\"\n        nf_en = 4 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_en += 1 + self.shield_bits_enemy\n\n        return self.n_enemies, nf_en\n\n    def get_state_enemy_feats_size(self):\n        \"\"\" Returns the dimensions of the matrix containing enemy features.\n        Size is n_enemies x n_features.\n        \"\"\"\n        nf_en = 5 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_en += 1 + self.shield_bits_enemy\n\n        if self.add_center_xy:\n            nf_en += 2\n\n        return self.n_enemies, nf_en\n\n    def get_obs_ally_feats_size(self):\n        \"\"\"Returns the dimensions of the matrix containing ally features.\n        Size is n_allies x n_features.\n        \"\"\"\n        nf_al = 4 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_al += 1 + self.shield_bits_ally\n\n        if self.obs_last_action:\n            nf_al += self.n_actions\n\n        return self.n_agents - 1, nf_al\n\n    def get_state_ally_feats_size(self):\n        \"\"\"Returns the dimensions of the matrix containing ally features.\n        Size is n_allies x n_features.\n        \"\"\"\n        nf_al = 5 + self.unit_type_bits\n\n        if self.obs_all_health:\n            nf_al += 1 + self.shield_bits_ally\n\n        if self.obs_last_action:\n            nf_al += self.n_actions\n        \n        if self.add_center_xy:\n            nf_al += 2\n\n        return self.n_agents - 1, nf_al\n\n    def get_obs_own_feats_size(self):\n        \"\"\"Returns the size of the vector containing the agents' own features.\n        \"\"\"\n        own_feats = 4 + self.unit_type_bits\n        if self.obs_own_health:\n            own_feats += 1 + self.shield_bits_ally\n\n        if self.obs_last_action:\n            own_feats += self.n_actions\n\n        return own_feats\n\n    def get_state_own_feats_size(self):\n        \"\"\"Returns the size of the vector containing the agents' own features.\n        \"\"\"\n        own_feats = 4 + self.unit_type_bits\n        if self.obs_own_health:\n            own_feats += 1 + self.shield_bits_ally\n\n        if self.obs_last_action:\n            own_feats += self.n_actions\n\n        if self.add_center_xy:\n            own_feats += 2\n\n        return own_feats\n\n    def get_obs_move_feats_size(self):\n        \"\"\"Returns the size of the vector containing the agents's movement-related features.\"\"\"\n        move_feats = self.n_actions_move\n        if self.obs_pathing_grid:\n            move_feats += self.n_obs_pathing\n        if self.obs_terrain_height:\n            move_feats += self.n_obs_height\n\n        return move_feats\n\n    def get_state_move_feats_size(self):\n        \"\"\"Returns the size of the vector containing the agents's movement-related features.\"\"\"\n        move_feats = self.n_actions_move\n        if self.state_pathing_grid:\n            move_feats += self.n_obs_pathing\n        if self.state_terrain_height:\n            move_feats += self.n_obs_height\n\n        return move_feats\n\n    def get_obs_size(self):\n        \"\"\"Returns the size of the observation.\"\"\"\n        own_feats = self.get_obs_own_feats_size()\n        move_feats = self.get_obs_move_feats_size()\n\n        n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()\n        n_allies, n_ally_feats = self.get_obs_ally_feats_size()\n\n        enemy_feats = n_enemies * n_enemy_feats\n        ally_feats = n_allies * n_ally_feats\n\n        all_feats = move_feats + enemy_feats + ally_feats + own_feats\n\n        agent_id_feats = 0\n        timestep_feats = 0\n\n        if self.obs_agent_id:\n            agent_id_feats = self.n_agents\n            all_feats += agent_id_feats\n\n        if self.obs_timestep_number:\n            timestep_feats = 1\n            all_feats += timestep_feats\n\n        return [all_feats * self.stacked_frames if self.use_stacked_frames else all_feats, [n_allies, n_ally_feats], [n_enemies, n_enemy_feats], [1, move_feats], [1, own_feats+agent_id_feats+timestep_feats]]\n\n    def get_state_size(self):\n        \"\"\"Returns the size of the global state.\"\"\"\n        if self.obs_instead_of_state:\n            return [self.get_obs_size()[0] * self.n_agents, [self.n_agents, self.get_obs_size()[0]]]\n\n        if self.use_state_agent:\n            own_feats = self.get_state_own_feats_size()\n            move_feats = self.get_obs_move_feats_size()\n\n            n_enemies, n_enemy_feats = self.get_state_enemy_feats_size()\n            n_allies, n_ally_feats = self.get_state_ally_feats_size()\n\n            enemy_feats = n_enemies * n_enemy_feats\n            ally_feats = n_allies * n_ally_feats\n\n            all_feats = move_feats + enemy_feats + ally_feats + own_feats\n\n            agent_id_feats = 0\n            timestep_feats = 0\n\n            if self.state_agent_id:\n                agent_id_feats = self.n_agents\n                all_feats += agent_id_feats\n\n            if self.state_timestep_number:\n                timestep_feats = 1\n                all_feats += timestep_feats\n\n            return [all_feats * self.stacked_frames if self.use_stacked_frames else all_feats, [n_allies, n_ally_feats], [n_enemies, n_enemy_feats], [1, move_feats], [1, own_feats+agent_id_feats+timestep_feats]]\n\n        \n        nf_al = 2 + self.shield_bits_ally + self.unit_type_bits\n        nf_en = 1 + self.shield_bits_enemy + self.unit_type_bits\n        nf_mv = self.get_state_move_feats_size()\n\n        if self.add_center_xy:\n            nf_al += 2\n            nf_en += 2\n\n        if self.state_last_action:\n            nf_al += self.n_actions\n            nf_en += self.n_actions\n\n        if self.add_visible_state:\n            nf_al += 1\n            nf_en += 1\n\n        if self.add_distance_state:\n            nf_al += 1\n            nf_en += 1\n\n        if self.add_xy_state:\n            nf_al += 2\n            nf_en += 2\n\n        if self.add_enemy_action_state:\n            nf_en += 1\n\n        enemy_state = self.n_enemies * nf_en\n        ally_state = self.n_agents * nf_al\n\n        size = enemy_state + ally_state \n\n        move_state = 0\n        obs_agent_size = 0\n        timestep_state = 0\n        agent_id_feats = 0\n\n        if self.add_move_state:\n            move_state = nf_mv\n            size += move_state\n        \n        if self.add_local_obs:\n            obs_agent_size = self.get_obs_size()[0]\n            size += obs_agent_size\n\n        if self.state_timestep_number:\n            timestep_state = 1\n            size += timestep_state\n\n        if self.add_agent_id:\n            agent_id_feats = self.n_agents\n            size += agent_id_feats\n\n        return [size * self.stacked_frames if self.use_stacked_frames else size, [self.n_agents, nf_al], [self.n_enemies, nf_en], [1, move_state + obs_agent_size + timestep_state + agent_id_feats]]\n    \n    def get_visibility_matrix(self):\n        \"\"\"Returns a boolean numpy array of dimensions \n        (n_agents, n_agents + n_enemies) indicating which units\n        are visible to each agent.\n        \"\"\"\n        arr = np.zeros((self.n_agents, self.n_agents + self.n_enemies), dtype=np.bool)\n\n        for agent_id in range(self.n_agents):\n            current_agent = self.get_unit_by_id(agent_id)\n            if current_agent.health > 0:  # it agent not dead\n                x = current_agent.pos.x\n                y = current_agent.pos.y\n                sight_range = self.unit_sight_range(agent_id)\n\n                # Enemies\n                for e_id, e_unit in self.enemies.items():\n                    e_x = e_unit.pos.x\n                    e_y = e_unit.pos.y\n                    dist = self.distance(x, y, e_x, e_y)\n\n                    if (dist < sight_range and e_unit.health > 0):\n                        # visible and alive\n                        arr[agent_id, self.n_agents + e_id] = 1\n\n                # The matrix for allies is filled symmetrically\n                al_ids = [\n                    al_id for al_id in range(self.n_agents)\n                    if al_id > agent_id\n                ]\n                for i, al_id in enumerate(al_ids):\n                    al_unit = self.get_unit_by_id(al_id)\n                    al_x = al_unit.pos.x\n                    al_y = al_unit.pos.y\n                    dist = self.distance(x, y, al_x, al_y)\n\n                    if (dist < sight_range and al_unit.health > 0):\n                        # visible and alive\n                        arr[agent_id, al_id] = arr[al_id, agent_id] = 1\n\n        return arr\n\n    def get_unit_type_id(self, unit, ally):\n        \"\"\"Returns the ID of unit type in the given scenario.\"\"\"\n        if ally:  # use new SC2 unit types\n            type_id = unit.unit_type - self._min_unit_type\n        else:  # use default SC2 unit types\n            if self.map_type == \"stalkers_and_zealots\":\n                # id(Stalker) = 74, id(Zealot) = 73\n                type_id = unit.unit_type - 73\n            elif self.map_type == \"colossi_stalkers_zealots\":\n                # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4\n                if unit.unit_type == 4:\n                    type_id = 0\n                elif unit.unit_type == 74:\n                    type_id = 1\n                else:\n                    type_id = 2\n            elif self.map_type == \"bane\":\n                if unit.unit_type == 9:\n                    type_id = 0\n                else:\n                    type_id = 1\n            elif self.map_type == \"MMM\":\n                if unit.unit_type == 51:\n                    type_id = 0\n                elif unit.unit_type == 48:\n                    type_id = 1\n                else:\n                    type_id = 2\n\n        return type_id\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\"Returns the available actions for agent_id.\"\"\"\n        unit = self.get_unit_by_id(agent_id)\n        if unit.health > 0:\n            # cannot choose no-op when alive\n            avail_actions = [0] * self.n_actions\n\n            # stop should be allowed\n            avail_actions[1] = 1\n\n            # see if we can move\n            if self.can_move(unit, Direction.NORTH):\n                avail_actions[2] = 1\n            if self.can_move(unit, Direction.SOUTH):\n                avail_actions[3] = 1\n            if self.can_move(unit, Direction.EAST):\n                avail_actions[4] = 1\n            if self.can_move(unit, Direction.WEST):\n                avail_actions[5] = 1\n\n            # Can attack only alive units that are alive in the shooting range\n            shoot_range = self.unit_shoot_range(agent_id)\n\n            target_items = self.enemies.items()\n            if self.map_type == \"MMM\" and unit.unit_type == self.medivac_id:\n                # Medivacs cannot heal themselves or other flying units\n                target_items = [\n                    (t_id, t_unit)\n                    for (t_id, t_unit) in self.agents.items()\n                    if t_unit.unit_type != self.medivac_id\n                ]\n\n            for t_id, t_unit in target_items:\n                if t_unit.health > 0:\n                    dist = self.distance(\n                        unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y\n                    )\n                    if dist <= shoot_range:\n                        avail_actions[t_id + self.n_actions_no_attack] = 1\n\n            return avail_actions\n\n        else:\n            # only no-op allowed\n            return [1] + [0] * (self.n_actions - 1)\n\n    def get_avail_actions(self):\n        \"\"\"Returns the available actions of all agents in a list.\"\"\"\n        avail_actions = []\n        for agent_id in range(self.n_agents):\n            avail_agent = self.get_avail_agent_actions(agent_id)\n            avail_actions.append(avail_agent)\n        return avail_actions\n\n    def close(self):\n        \"\"\"Close StarCraft II.\"\"\"\n        if self._sc2_proc:\n            self._sc2_proc.close()\n\n    def seed(self, seed):\n        \"\"\"Returns the random seed used by the environment.\"\"\"\n        self._seed = seed\n\n    def render(self):\n        \"\"\"Not implemented.\"\"\"\n        pass\n\n    def _kill_all_units(self):\n        \"\"\"Kill all units on the map.\"\"\"\n        units_alive = [\n            unit.tag for unit in self.agents.values() if unit.health > 0\n        ] + [unit.tag for unit in self.enemies.values() if unit.health > 0]\n        debug_command = [\n            d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=units_alive))\n        ]\n        self._controller.debug(debug_command)\n\n    def init_units(self):\n        \"\"\"Initialise the units.\"\"\"\n        while True:\n            # Sometimes not all units have yet been created by SC2\n            self.agents = {}\n            self.enemies = {}\n\n            ally_units = [\n                unit\n                for unit in self._obs.observation.raw_data.units\n                if unit.owner == 1\n            ]\n            ally_units_sorted = sorted(\n                ally_units,\n                key=attrgetter(\"unit_type\", \"pos.x\", \"pos.y\"),\n                reverse=False,\n            )\n\n            for i in range(len(ally_units_sorted)):\n                self.agents[i] = ally_units_sorted[i]\n                if self.debug:\n                    logging.debug(\n                        \"Unit {} is {}, x = {}, y = {}\".format(\n                            len(self.agents),\n                            self.agents[i].unit_type,\n                            self.agents[i].pos.x,\n                            self.agents[i].pos.y,\n                        )\n                    )\n\n            for unit in self._obs.observation.raw_data.units:\n                if unit.owner == 2:\n                    self.enemies[len(self.enemies)] = unit\n                    if self._episode_count == 0:\n                        self.max_reward += unit.health_max + unit.shield_max\n\n            if self._episode_count == 0:\n                min_unit_type = min(\n                    unit.unit_type for unit in self.agents.values()\n                )\n                self._init_ally_unit_types(min_unit_type)\n\n            all_agents_created = (len(self.agents) == self.n_agents)\n            all_enemies_created = (len(self.enemies) == self.n_enemies)\n\n            if all_agents_created and all_enemies_created:  # all good\n                return\n\n            try:\n                self._controller.step(1)\n                self._obs = self._controller.observe()\n            except (protocol.ProtocolError, protocol.ConnectionError):\n                self.full_restart()\n                self.reset()\n\n    def update_units(self):\n        \"\"\"Update units after an environment step.\n        This function assumes that self._obs is up-to-date.\n        \"\"\"\n        n_ally_alive = 0\n        n_enemy_alive = 0\n\n        # Store previous state\n        self.previous_ally_units = deepcopy(self.agents)\n        self.previous_enemy_units = deepcopy(self.enemies)\n\n        for al_id, al_unit in self.agents.items():\n            updated = False\n            for unit in self._obs.observation.raw_data.units:\n                if al_unit.tag == unit.tag:\n                    self.agents[al_id] = unit\n                    updated = True\n                    n_ally_alive += 1\n                    break\n\n            if not updated:  # dead\n                al_unit.health = 0\n\n        for e_id, e_unit in self.enemies.items():\n            updated = False\n            for unit in self._obs.observation.raw_data.units:\n                if e_unit.tag == unit.tag:\n                    self.enemies[e_id] = unit\n                    updated = True\n                    n_enemy_alive += 1\n                    break\n\n            if not updated:  # dead\n                e_unit.health = 0\n\n        if (n_ally_alive == 0 and n_enemy_alive > 0\n                or self.only_medivac_left(ally=True)):\n            return -1  # lost\n        if (n_ally_alive > 0 and n_enemy_alive == 0\n                or self.only_medivac_left(ally=False)):\n            return 1  # won\n        if n_ally_alive == 0 and n_enemy_alive == 0:\n            return 0\n\n        return None\n\n    def _init_ally_unit_types(self, min_unit_type):\n        \"\"\"Initialise ally unit types. Should be called once from the\n        init_units function.\n        \"\"\"\n        self._min_unit_type = min_unit_type\n        if self.map_type == \"marines\":\n            self.marine_id = min_unit_type\n        elif self.map_type == \"stalkers_and_zealots\":\n            self.stalker_id = min_unit_type\n            self.zealot_id = min_unit_type + 1\n        elif self.map_type == \"colossi_stalkers_zealots\":\n            self.colossus_id = min_unit_type\n            self.stalker_id = min_unit_type + 1\n            self.zealot_id = min_unit_type + 2\n        elif self.map_type == \"MMM\":\n            self.marauder_id = min_unit_type\n            self.marine_id = min_unit_type + 1\n            self.medivac_id = min_unit_type + 2\n        elif self.map_type == \"zealots\":\n            self.zealot_id = min_unit_type\n        elif self.map_type == \"hydralisks\":\n            self.hydralisk_id = min_unit_type\n        elif self.map_type == \"stalkers\":\n            self.stalker_id = min_unit_type\n        elif self.map_type == \"colossus\":\n            self.colossus_id = min_unit_type\n        elif self.map_type == \"bane\":\n            self.baneling_id = min_unit_type\n            self.zergling_id = min_unit_type + 1\n\n    def only_medivac_left(self, ally):\n        \"\"\"Check if only Medivac units are left.\"\"\"\n        if self.map_type != \"MMM\":\n            return False\n\n        if ally:\n            units_alive = [\n                a\n                for a in self.agents.values()\n                if (a.health > 0 and a.unit_type != self.medivac_id)\n            ]\n            if len(units_alive) == 0:\n                return True\n            return False\n        else:\n            units_alive = [\n                a\n                for a in self.enemies.values()\n                if (a.health > 0 and a.unit_type != self.medivac_id)\n            ]\n            if len(units_alive) == 1 and units_alive[0].unit_type == 54:\n                return True\n            return False\n\n    def get_unit_by_id(self, a_id):\n        \"\"\"Get unit by ID.\"\"\"\n        return self.agents[a_id]\n\n    def get_stats(self):\n        stats = {\n            \"battles_won\": self.battles_won,\n            \"battles_game\": self.battles_game,\n            \"battles_draw\": self.timeouts,\n            \"win_rate\": self.battles_won / self.battles_game,\n            \"timeouts\": self.timeouts,\n            \"restarts\": self.force_restarts,\n        }\n        return stats\n"
  },
  {
    "path": "envs/starcraft2/multiagentenv.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nclass MultiAgentEnv(object):\n\n    def step(self, actions):\n        \"\"\"Returns reward, terminated, info.\"\"\"\n        raise NotImplementedError\n\n    def get_obs(self):\n        \"\"\"Returns all agent observations in a list.\"\"\"\n        raise NotImplementedError\n\n    def get_obs_agent(self, agent_id):\n        \"\"\"Returns observation for agent_id.\"\"\"\n        raise NotImplementedError\n\n    def get_obs_size(self):\n        \"\"\"Returns the size of the observation.\"\"\"\n        raise NotImplementedError\n\n    def get_state(self):\n        \"\"\"Returns the global state.\"\"\"\n        raise NotImplementedError\n\n    def get_state_size(self):\n        \"\"\"Returns the size of the global state.\"\"\"\n        raise NotImplementedError\n\n    def get_avail_actions(self):\n        \"\"\"Returns the available actions of all agents in a list.\"\"\"\n        raise NotImplementedError\n\n    def get_avail_agent_actions(self, agent_id):\n        \"\"\"Returns the available actions for agent_id.\"\"\"\n        raise NotImplementedError\n\n    def get_total_actions(self):\n        \"\"\"Returns the total number of actions an agent could ever take.\"\"\"\n        raise NotImplementedError\n\n    def reset(self):\n        \"\"\"Returns initial observations and states.\"\"\"\n        raise NotImplementedError\n\n    def render(self):\n        raise NotImplementedError\n\n    def close(self):\n        raise NotImplementedError\n\n    def seed(self):\n        raise NotImplementedError\n\n    def save_replay(self):\n        \"\"\"Save a replay.\"\"\"\n        raise NotImplementedError\n\n    def get_env_info(self):\n        env_info = {\"state_shape\": self.get_state_size(),\n                    \"obs_shape\": self.get_obs_size(),\n                    \"obs_alone_shape\": self.get_obs_alone_size(),\n                    \"n_actions\": self.get_total_actions(),\n                    \"n_agents\": self.n_agents,\n                    \"episode_limit\": self.episode_limit}\n        return env_info\n"
  },
  {
    "path": "envs/starcraft2/smac_maps.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom pysc2.maps import lib\n\n\nclass SMACMap(lib.Map):\n    directory = \"SMAC_Maps\"\n    download = \"https://github.com/oxwhirl/smac#smac-maps\"\n    players = 2\n    step_mul = 8\n    game_steps_per_episode = 0\n\n\nmap_param_registry = {\n    \"3m\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 60,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"8m\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 120,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"25m\": {\n        \"n_agents\": 25,\n        \"n_enemies\": 25,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"5m_vs_6m\": {\n        \"n_agents\": 5,\n        \"n_enemies\": 6,\n        \"limit\": 70,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"8m_vs_9m\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 9,\n        \"limit\": 120,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"10m_vs_11m\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 11,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"27m_vs_30m\": {\n        \"n_agents\": 27,\n        \"n_enemies\": 30,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"MMM\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 10,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"MMM\",\n    },\n    \"MMM2\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 12,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"MMM\",\n    },\n    \"2s3z\": {\n        \"n_agents\": 5,\n        \"n_enemies\": 5,\n        \"limit\": 120,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s5z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s5z_vs_3s6z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 9,\n        \"limit\": 170,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"3s_vs_3z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"3s_vs_4z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 4,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"3s_vs_5z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 5,\n        \"limit\": 250,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"1c3s5z\": {\n        \"n_agents\": 9,\n        \"n_enemies\": 9,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"2m_vs_1z\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 1,\n        \"limit\": 150,\n        \"a_race\": \"T\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"corridor\": {\n        \"n_agents\": 6,\n        \"n_enemies\": 24,\n        \"limit\": 400,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"zealots\",\n    },\n    \"6h_vs_8z\": {\n        \"n_agents\": 6,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"Z\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"hydralisks\",\n    },\n    \"2s_vs_1sc\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 1,\n        \"limit\": 300,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"stalkers\",\n    },\n    \"so_many_baneling\": {\n        \"n_agents\": 7,\n        \"n_enemies\": 32,\n        \"limit\": 100,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"zealots\",\n    },\n    \"bane_vs_bane\": {\n        \"n_agents\": 24,\n        \"n_enemies\": 24,\n        \"limit\": 200,\n        \"a_race\": \"Z\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"bane\",\n    },\n    \"2c_vs_64zg\": {\n        \"n_agents\": 2,\n        \"n_enemies\": 64,\n        \"limit\": 400,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"colossus\",\n    },\n\n    # This is adhoc environment\n    \"1c2z_vs_1c1s1z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"1c2s_vs_1c1s1z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"2c1z_vs_1c1s1z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"2c1s_vs_1c1s1z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n    \"1c1s1z_vs_1c1s1z\": {\n        \"n_agents\": 3,\n        \"n_enemies\": 3,\n        \"limit\": 180,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"colossi_stalkers_zealots\",\n    },\n\n    \"3s5z_vs_4s4z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"4s4z_vs_4s4z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"5s3z_vs_4s4z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"6s2z_vs_4s4z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"2s6z_vs_4s4z\": {\n        \"n_agents\": 8,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n\n    \"6m_vs_6m_tz\": {\n        \"n_agents\": 6,\n        \"n_enemies\": 6,\n        \"limit\": 70,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"5m_vs_6m_tz\": {\n        \"n_agents\": 5,\n        \"n_enemies\": 6,\n        \"limit\": 70,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"3s6z_vs_3s6z\": {\n        \"n_agents\": 9,\n        \"n_enemies\": 9,\n        \"limit\": 170,\n        \"a_race\": \"P\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots\",\n    },\n    \"7h_vs_8z\": {\n        \"n_agents\": 7,\n        \"n_enemies\": 8,\n        \"limit\": 150,\n        \"a_race\": \"Z\",\n        \"b_race\": \"P\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"hydralisks\",\n    },\n    \"2s2z_vs_zg\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 20,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n    \"1s3z_vs_zg\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 20,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n    \"3s1z_vs_zg\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 20,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n\n    \"2s2z_vs_zg_easy\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 18,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n    \"1s3z_vs_zg_easy\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 18,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n    \"3s1z_vs_zg_easy\": {\n        \"n_agents\": 4,\n        \"n_enemies\": 18,\n        \"limit\": 200,\n        \"a_race\": \"P\",\n        \"b_race\": \"Z\",\n        \"unit_type_bits\": 2,\n        \"map_type\": \"stalkers_and_zealots_vs_zergling\",\n    },\n    \"28m_vs_30m\": {\n        \"n_agents\": 28,\n        \"n_enemies\": 30,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"29m_vs_30m\": {\n        \"n_agents\": 29,\n        \"n_enemies\": 30,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"30m_vs_30m\": {\n        \"n_agents\": 30,\n        \"n_enemies\": 30,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 0,\n        \"map_type\": \"marines\",\n    },\n    \"MMM2_test\": {\n        \"n_agents\": 10,\n        \"n_enemies\": 12,\n        \"limit\": 180,\n        \"a_race\": \"T\",\n        \"b_race\": \"T\",\n        \"unit_type_bits\": 3,\n        \"map_type\": \"MMM\",\n    },\n}\n\n\ndef get_smac_map_registry():\n    return map_param_registry\n\n\nfor name in map_param_registry.keys():\n    globals()[name] = type(name, (SMACMap,), dict(filename=name))\n\n\ndef get_map_params(map_name):\n    map_param_registry = get_smac_map_registry()\n    return map_param_registry[map_name]\n"
  },
  {
    "path": "install_sc2.sh",
    "content": "#!/bin/bash\n# Install SC2 and add the custom maps\n\nif [ -z \"$EXP_DIR\" ]\nthen\n    EXP_DIR=~\nfi\n\necho \"EXP_DIR: $EXP_DIR\"\ncd $EXP_DIR/pymarl\n\nmkdir 3rdparty\ncd 3rdparty\n\nexport SC2PATH=`pwd`'/StarCraftII'\necho 'SC2PATH is set to '$SC2PATH\n\nif [ ! -d $SC2PATH ]; then\n        echo 'StarCraftII is not installed. Installing now ...';\n        wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip\n        unzip -P iagreetotheeula SC2.4.10.zip\n        rm -rf SC2.4.10.zip\nelse\n        echo 'StarCraftII is already installed.'\nfi\n\necho 'Adding SMAC maps.'\nMAP_DIR=\"$SC2PATH/Maps/\"\necho 'MAP_DIR is set to '$MAP_DIR\n\nif [ ! -d $MAP_DIR ]; then\n        mkdir -p $MAP_DIR\nfi\n\ncd ..\nwget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip\nunzip SMAC_Maps.zip\nmv SMAC_Maps $MAP_DIR\nrm -rf SMAC_Maps.zip\n\necho 'StarCraft II and SMAC are installed.'\n"
  },
  {
    "path": "requirements.txt",
    "content": "absl-py==0.9.0\naiohttp==3.6.2\naioredis==1.3.1\nastor==0.8.0\nastunparse==1.6.3\nasync-timeout==3.0.1\natari-py==0.2.6\natomicwrites==1.2.1\nattrs==18.2.0\nbeautifulsoup4==4.9.1\nblessings==1.7\ncachetools==4.1.1\ncertifi==2020.4.5.2\ncffi==1.14.1\nchardet==3.0.4\nclick==7.1.2\ncloudpickle==1.3.0\ncolorama==0.4.3\ncolorful==0.5.4\nconfigparser==5.0.1\ncontextvars==2.4\ncycler==0.10.0\nCython==0.29.21\ndeepdiff==4.3.2\ndill==0.3.2\ndocker-pycreds==0.4.0\ndocopt==0.6.2\nfasteners==0.15\nfilelock==3.0.12\nfuncsigs==1.0.2\nfuture==0.16.0\ngast==0.2.2\ngin==0.1.6\ngin-config==0.3.0\ngitdb==4.0.5\nGitPython==3.1.9\nglfw==1.12.0\ngoogle==3.0.0\ngoogle-api-core==1.22.1\ngoogle-auth==1.21.0\ngoogle-auth-oauthlib==0.4.1\ngoogle-pasta==0.2.0\ngoogleapis-common-protos==1.52.0\ngpustat==0.6.0\ngql==0.2.0\ngraphql-core==1.1\ngrpcio==1.31.0\ngym==0.17.2\nh5py==2.10.0\nhiredis==1.1.0\nidna==2.7\nidna-ssl==1.1.0\nimageio==2.4.1\nimmutables==0.14\nimportlib-metadata==1.7.0\njoblib==0.16.0\njsonnet==0.16.0\njsonpickle==0.9.6\njsonschema==3.2.0\nKeras-Applications==1.0.8\nKeras-Preprocessing==1.1.2\nkiwisolver==1.0.1\nlockfile==0.12.2\nMarkdown==3.1.1\nmatplotlib==3.0.0\nmkl-fft==1.2.0\nmkl-random==1.2.0\nmkl-service==2.3.0\nmock==2.0.0\nmonotonic==1.5\nmore-itertools==4.3.0\nmpi4py==3.0.3\nmpyq==0.2.5\nmsgpack==1.0.0\nmujoco-py==2.0.2.8\nmultidict==4.7.6\nmunch==2.3.2\nnumpy\nnvidia-ml-py3==7.352.0\noauthlib==3.1.0\nopencensus==0.7.10\nopencensus-context==0.1.1\nopencv-python==4.2.0.34\nopt-einsum==3.1.0\nordered-set==4.0.2\npackaging==20.4\npandas==1.1.1\npathlib2==2.3.2\npathtools==0.1.2\npbr==4.3.0\nPillow==5.3.0\npluggy==0.7.1\nportpicker==1.2.0\nprobscale==0.2.3\nprogressbar2==3.53.1\nprometheus-client==0.8.0\npromise==2.3\nprotobuf==3.12.4\npsutil==5.7.2\npy==1.6.0\npy-spy==0.3.3\npyasn1==0.4.8\npyasn1-modules==0.2.8\npycparser==2.20\npygame==1.9.4\npyglet==1.5.0\nPyOpenGL==3.1.5\nPyOpenGL-accelerate==3.1.5\npyparsing==2.2.2\npyrsistent==0.16.0\nPySC2==3.0.0\npytest==3.8.2\npython-dateutil==2.7.3\npython-utils==2.4.0\npytz==2020.1\nPyYAML==3.13\npyzmq==19.0.2\nredis==3.4.1\nrequests==2.24.0\nrequests-oauthlib==1.3.0\nrsa==4.6\ns2clientprotocol==4.10.1.75800.0\ns2protocol==4.11.4.78285.0\nsacred==0.7.2\nscipy==1.4.1\nseaborn==0.10.1\nsentry-sdk==0.18.0\nsetproctitle==1.1.10\nshortuuid==1.0.1\nsix==1.15.0\nsk-video==1.1.10\nsmmap==3.0.4\nsnakeviz==1.0.0\nsoupsieve==2.0.1\nsubprocess32==3.5.4\ntabulate==0.8.7\ntensorboard==2.0.2\ntensorboard-logger==0.1.0\ntensorboard-plugin-wit==1.7.0\ntensorboardX==2.0\ntensorflow==2.0.0\ntensorflow-estimator==2.0.0\ntermcolor==1.1.0\ntorch\ntorchvision\ntornado\ntqdm==4.48.2\ntyping-extensions==3.7.4.3\nurllib3==1.23\nwatchdog==0.10.3\nwebsocket-client==0.53.0\nWerkzeug==0.16.1\nwhichcraft==0.5.2\nwrapt==1.12.1\nxmltodict==0.12.0\nyarl==1.5.1\nzipp==3.1.0\nzmq==0.0.0\n"
  },
  {
    "path": "runners/__init__.py",
    "content": "from runners import separated\n\n__all__=[\n\n    \"separated\"\n]"
  },
  {
    "path": "runners/separated/__init__.py",
    "content": "from runners.separated import base_runner,smac_runner\n\n__all__=[\n    \"base_runner\",\n    \"smac_runner\"\n]"
  },
  {
    "path": "runners/separated/base_runner.py",
    "content": "    \nimport time\nimport os\nimport numpy as np\nfrom itertools import chain\nimport torch\nfrom tensorboardX import SummaryWriter\nfrom utils.separated_buffer import SeparatedReplayBuffer\nfrom utils.util import update_linear_schedule\n\ndef _t2n(x):\n    return x.detach().cpu().numpy()\n\nclass Runner(object):\n    def __init__(self, config):\n\n        self.all_args = config['all_args']\n        self.envs = config['envs']\n        self.eval_envs = config['eval_envs']\n        self.device = config['device']\n        self.num_agents = config['num_agents']\n\n        # parameters\n        self.env_name = self.all_args.env_name\n        self.algorithm_name = self.all_args.algorithm_name\n        self.experiment_name = self.all_args.experiment_name\n        self.use_centralized_V = self.all_args.use_centralized_V\n        self.use_obs_instead_of_state = self.all_args.use_obs_instead_of_state\n        self.num_env_steps = self.all_args.num_env_steps\n        self.episode_length = self.all_args.episode_length\n        self.n_rollout_threads = self.all_args.n_rollout_threads\n        self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads\n        self.use_linear_lr_decay = self.all_args.use_linear_lr_decay\n        self.hidden_size = self.all_args.hidden_size\n        self.use_render = self.all_args.use_render\n        self.recurrent_N = self.all_args.recurrent_N\n        self.use_single_network = self.all_args.use_single_network\n        # interval\n        self.save_interval = self.all_args.save_interval\n        self.use_eval = self.all_args.use_eval\n        self.eval_interval = self.all_args.eval_interval\n        self.log_interval = self.all_args.log_interval\n\n        # dir\n        self.model_dir = self.all_args.model_dir\n\n        if self.use_render:\n            import imageio\n            self.run_dir = config[\"run_dir\"]\n            self.gif_dir = str(self.run_dir / 'gifs')\n            if not os.path.exists(self.gif_dir):\n                os.makedirs(self.gif_dir)\n        else:\n            self.run_dir = config[\"run_dir\"]\n            self.log_dir = str(self.run_dir / 'logs')\n            if not os.path.exists(self.log_dir):\n                os.makedirs(self.log_dir)\n            self.writter = SummaryWriter(self.log_dir)\n            self.save_dir = str(self.run_dir / 'models')\n            if not os.path.exists(self.save_dir):\n                os.makedirs(self.save_dir)\n\n\n        if self.all_args.algorithm_name == \"happo\":\n            from algorithms.happo_trainer import HAPPO as TrainAlgo\n            from algorithms.happo_policy import HAPPO_Policy as Policy\n        elif self.all_args.algorithm_name == \"hatrpo\":\n            from algorithms.hatrpo_trainer import HATRPO as TrainAlgo\n            from algorithms.hatrpo_policy import HATRPO_Policy as Policy\n        else:\n            raise NotImplementedError\n\n        print(\"share_observation_space: \", self.envs.share_observation_space)\n        print(\"observation_space: \", self.envs.observation_space)\n        print(\"action_space: \", self.envs.action_space)\n\n        self.policy = []\n        for agent_id in range(self.num_agents):\n            share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]\n            # policy network\n            po = Policy(self.all_args,\n                        self.envs.observation_space[agent_id],\n                        share_observation_space,\n                        self.envs.action_space[agent_id],\n                        device = self.device)\n            self.policy.append(po)\n\n        if self.model_dir is not None:\n            self.restore()\n\n        self.trainer = []\n        self.buffer = []\n        for agent_id in range(self.num_agents):\n            # algorithm\n            tr = TrainAlgo(self.all_args, self.policy[agent_id], device = self.device)\n            # buffer\n            share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id]\n            bu = SeparatedReplayBuffer(self.all_args,\n                                       self.envs.observation_space[agent_id],\n                                       share_observation_space,\n                                       self.envs.action_space[agent_id])\n            self.buffer.append(bu)\n            self.trainer.append(tr)\n            \n    def run(self):\n        raise NotImplementedError\n\n    def warmup(self):\n        raise NotImplementedError\n\n    def collect(self, step):\n        raise NotImplementedError\n\n    def insert(self, data):\n        raise NotImplementedError\n    \n    @torch.no_grad()\n    def compute(self):\n        for agent_id in range(self.num_agents):\n            self.trainer[agent_id].prep_rollout()\n            next_value = self.trainer[agent_id].policy.get_values(self.buffer[agent_id].share_obs[-1], \n                                                                self.buffer[agent_id].rnn_states_critic[-1],\n                                                                self.buffer[agent_id].masks[-1])\n            next_value = _t2n(next_value)\n            self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer)\n\n    def train(self):\n        train_infos = []\n        # random update order\n\n        action_dim=self.buffer[0].actions.shape[-1]\n        factor = np.ones((self.episode_length, self.n_rollout_threads, 1), dtype=np.float32)\n\n        for agent_id in torch.randperm(self.num_agents):\n            self.trainer[agent_id].prep_training()\n            self.buffer[agent_id].update_factor(factor)\n            available_actions = None if self.buffer[agent_id].available_actions is None \\\n                else self.buffer[agent_id].available_actions[:-1].reshape(-1, *self.buffer[agent_id].available_actions.shape[2:])\n            \n            if self.all_args.algorithm_name == \"hatrpo\":\n                old_actions_logprob, _, _, _, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]),\n                                                            self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]),\n                                                            self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]),\n                                                            self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]),\n                                                            available_actions,\n                                                            self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:]))\n            else:\n                old_actions_logprob, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]),\n                                                            self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]),\n                                                            self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]),\n                                                            self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]),\n                                                            available_actions,\n                                                            self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:]))\n            train_info = self.trainer[agent_id].train(self.buffer[agent_id])\n\n            if self.all_args.algorithm_name == \"hatrpo\":\n                new_actions_logprob, _, _, _, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]),\n                                                            self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]),\n                                                            self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]),\n                                                            self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]),\n                                                            available_actions,\n                                                            self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:]))\n            else:\n                new_actions_logprob, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]),\n                                                            self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]),\n                                                            self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]),\n                                                            self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]),\n                                                            available_actions,\n                                                            self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:]))\n\n            factor = factor*_t2n(torch.prod(torch.exp(new_actions_logprob-old_actions_logprob),dim=-1).reshape(self.episode_length,self.n_rollout_threads,1))\n            train_infos.append(train_info)      \n            self.buffer[agent_id].after_update()\n\n        return train_infos\n\n    def save(self):\n        for agent_id in range(self.num_agents):\n            if self.use_single_network:\n                policy_model = self.trainer[agent_id].policy.model\n                torch.save(policy_model.state_dict(), str(self.save_dir) + \"/model_agent\" + str(agent_id) + \".pt\")\n            else:\n                policy_actor = self.trainer[agent_id].policy.actor\n                torch.save(policy_actor.state_dict(), str(self.save_dir) + \"/actor_agent\" + str(agent_id) + \".pt\")\n                policy_critic = self.trainer[agent_id].policy.critic\n                torch.save(policy_critic.state_dict(), str(self.save_dir) + \"/critic_agent\" + str(agent_id) + \".pt\")\n\n    def restore(self):\n        for agent_id in range(self.num_agents):\n            if self.use_single_network:\n                policy_model_state_dict = torch.load(str(self.model_dir) + '/model_agent' + str(agent_id) + '.pt')\n                self.policy[agent_id].model.load_state_dict(policy_model_state_dict)\n            else:\n                policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor_agent' + str(agent_id) + '.pt')\n                self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict)\n                policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic_agent' + str(agent_id) + '.pt')\n                self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict)\n\n    def log_train(self, train_infos, total_num_steps): \n        for agent_id in range(self.num_agents):\n            for k, v in train_infos[agent_id].items():\n                agent_k = \"agent%i/\" % agent_id + k\n                self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps)\n\n    def log_env(self, env_infos, total_num_steps):\n        for k, v in env_infos.items():\n            if len(v) > 0:\n                self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)\n"
  },
  {
    "path": "runners/separated/mujoco_runner.py",
    "content": "import time\nimport numpy as np\nfrom functools import reduce\nimport torch\nfrom runners.separated.base_runner import Runner\n\n\ndef _t2n(x):\n    return x.detach().cpu().numpy()\n\n\nclass MujocoRunner(Runner):\n    \"\"\"Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details.\"\"\"\n\n    def __init__(self, config):\n        super(MujocoRunner, self).__init__(config)\n\n    def run(self):\n        self.warmup()\n\n        start = time.time()\n        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads\n\n        train_episode_rewards = [0 for _ in range(self.n_rollout_threads)]\n\n        for episode in range(episodes):\n            if self.use_linear_lr_decay:\n                self.trainer.policy.lr_decay(episode, episodes)\n\n            done_episodes_rewards = []\n\n            for step in range(self.episode_length):\n                # Sample actions\n                values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step)\n\n                # Obser reward and next obs\n                obs, share_obs, rewards, dones, infos, _ = self.envs.step(actions)\n\n                dones_env = np.all(dones, axis=1)\n                reward_env = np.mean(rewards, axis=1).flatten()\n                train_episode_rewards += reward_env\n                for t in range(self.n_rollout_threads):\n                    if dones_env[t]:\n                        done_episodes_rewards.append(train_episode_rewards[t])\n                        train_episode_rewards[t] = 0\n\n                data = obs, share_obs, rewards, dones, infos, \\\n                       values, actions, action_log_probs, \\\n                       rnn_states, rnn_states_critic\n\n                # insert data into buffer\n                self.insert(data)\n\n            # compute return and update network\n            self.compute()\n            train_infos = self.train()\n\n            # post process\n            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads\n            # save model\n            if (episode % self.save_interval == 0 or episode == episodes - 1):\n                self.save()\n\n            # log information\n            if episode % self.log_interval == 0:\n                end = time.time()\n                print(\"\\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\\n\"\n                      .format(self.all_args.scenario,\n                              self.algorithm_name,\n                              self.experiment_name,\n                              episode,\n                              episodes,\n                              total_num_steps,\n                              self.num_env_steps,\n                              int(total_num_steps / (end - start))))\n\n                self.log_train(train_infos, total_num_steps)\n\n                if len(done_episodes_rewards) > 0:\n                    aver_episode_rewards = np.mean(done_episodes_rewards)\n                    print(\"some episodes done, average rewards: \", aver_episode_rewards)\n                    self.writter.add_scalars(\"train_episode_rewards\", {\"aver_rewards\": aver_episode_rewards},\n                                             total_num_steps)\n\n            # eval\n            if episode % self.eval_interval == 0 and self.use_eval:\n                self.eval(total_num_steps)\n\n    def warmup(self):\n        # reset env\n        obs, share_obs, _ = self.envs.reset()\n        # replay buffer\n        if not self.use_centralized_V:\n            share_obs = obs\n\n        for agent_id in range(self.num_agents):\n            self.buffer[agent_id].share_obs[0] = share_obs[:, agent_id].copy()\n            self.buffer[agent_id].obs[0] = obs[:, agent_id].copy()\n\n    @torch.no_grad()\n    def collect(self, step):\n        value_collector = []\n        action_collector = []\n        action_log_prob_collector = []\n        rnn_state_collector = []\n        rnn_state_critic_collector = []\n        for agent_id in range(self.num_agents):\n            self.trainer[agent_id].prep_rollout()\n            value, action, action_log_prob, rnn_state, rnn_state_critic \\\n                = self.trainer[agent_id].policy.get_actions(self.buffer[agent_id].share_obs[step],\n                                                            self.buffer[agent_id].obs[step],\n                                                            self.buffer[agent_id].rnn_states[step],\n                                                            self.buffer[agent_id].rnn_states_critic[step],\n                                                            self.buffer[agent_id].masks[step])\n            value_collector.append(_t2n(value))\n            action_collector.append(_t2n(action))\n            action_log_prob_collector.append(_t2n(action_log_prob))\n            rnn_state_collector.append(_t2n(rnn_state))\n            rnn_state_critic_collector.append(_t2n(rnn_state_critic))\n        # [self.envs, agents, dim]\n        values = np.array(value_collector).transpose(1, 0, 2)\n        actions = np.array(action_collector).transpose(1, 0, 2)\n        action_log_probs = np.array(action_log_prob_collector).transpose(1, 0, 2)\n        rnn_states = np.array(rnn_state_collector).transpose(1, 0, 2, 3)\n        rnn_states_critic = np.array(rnn_state_critic_collector).transpose(1, 0, 2, 3)\n\n        return values, actions, action_log_probs, rnn_states, rnn_states_critic\n\n    def insert(self, data):\n        obs, share_obs, rewards, dones, infos, \\\n        values, actions, action_log_probs, rnn_states, rnn_states_critic = data\n\n        dones_env = np.all(dones, axis=1)\n\n        rnn_states[dones_env == True] = np.zeros(\n            ((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)\n        rnn_states_critic[dones_env == True] = np.zeros(\n            ((dones_env == True).sum(), self.num_agents, *self.buffer[0].rnn_states_critic.shape[2:]), dtype=np.float32)\n\n        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)\n        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)\n\n        active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)\n        active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)\n        active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)\n\n        if not self.use_centralized_V:\n            share_obs = obs\n\n        for agent_id in range(self.num_agents):\n            self.buffer[agent_id].insert(share_obs[:, agent_id], obs[:, agent_id], rnn_states[:, agent_id],\n                                         rnn_states_critic[:, agent_id], actions[:, agent_id],\n                                         action_log_probs[:, agent_id],\n                                         values[:, agent_id], rewards[:, agent_id], masks[:, agent_id], None,\n                                         active_masks[:, agent_id], None)\n\n    def log_train(self, train_infos, total_num_steps):\n        print(\"average_step_rewards is {}.\".format(np.mean(self.buffer[0].rewards)))\n        for agent_id in range(self.num_agents):\n            train_infos[agent_id][\"average_step_rewards\"] = np.mean(self.buffer[agent_id].rewards)\n            for k, v in train_infos[agent_id].items():\n                agent_k = \"agent%i/\" % agent_id + k\n                self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps)\n\n    @torch.no_grad()\n    def eval(self, total_num_steps):\n        eval_episode = 0\n        eval_episode_rewards = []\n        one_episode_rewards = []\n        for eval_i in range(self.n_eval_rollout_threads):\n            one_episode_rewards.append([])\n            eval_episode_rewards.append([])\n\n        eval_obs, eval_share_obs, _ = self.eval_envs.reset()\n\n        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size),\n                                   dtype=np.float32)\n        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)\n\n        while True:\n            eval_actions_collector = []\n            eval_rnn_states_collector = []\n            for agent_id in range(self.num_agents):\n                self.trainer[agent_id].prep_rollout()\n                eval_actions, temp_rnn_state = \\\n                    self.trainer[agent_id].policy.act(eval_obs[:, agent_id],\n                                                      eval_rnn_states[:, agent_id],\n                                                      eval_masks[:, agent_id],\n                                                      deterministic=True)\n                eval_rnn_states[:, agent_id] = _t2n(temp_rnn_state)\n                eval_actions_collector.append(_t2n(eval_actions))\n\n            eval_actions = np.array(eval_actions_collector).transpose(1, 0, 2)\n\n            # Obser reward and next obs\n            eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, _ = self.eval_envs.step(\n                eval_actions)\n            for eval_i in range(self.n_eval_rollout_threads):\n                one_episode_rewards[eval_i].append(eval_rewards[eval_i])\n\n            eval_dones_env = np.all(eval_dones, axis=1)\n\n            eval_rnn_states[eval_dones_env == True] = np.zeros(\n                ((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)\n\n            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)\n            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1),\n                                                          dtype=np.float32)\n\n            for eval_i in range(self.n_eval_rollout_threads):\n                if eval_dones_env[eval_i]:\n                    eval_episode += 1\n                    eval_episode_rewards[eval_i].append(np.sum(one_episode_rewards[eval_i], axis=0))\n                    one_episode_rewards[eval_i] = []\n\n            if eval_episode >= self.all_args.eval_episodes:\n                eval_episode_rewards = np.concatenate(eval_episode_rewards)\n                eval_env_infos = {'eval_average_episode_rewards': eval_episode_rewards,\n                                  'eval_max_episode_rewards': [np.max(eval_episode_rewards)]}\n                self.log_env(eval_env_infos, total_num_steps)\n                print(\"eval_average_episode_rewards is {}.\".format(np.mean(eval_episode_rewards)))\n                break\n"
  },
  {
    "path": "runners/separated/smac_runner.py",
    "content": "import time\nimport numpy as np\nfrom functools import reduce\nimport torch\nfrom runners.separated.base_runner import Runner\n\ndef _t2n(x):\n    return x.detach().cpu().numpy()\n\nclass SMACRunner(Runner):\n    \"\"\"Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details.\"\"\"\n    def __init__(self, config):\n        super(SMACRunner, self).__init__(config)\n\n    def run(self):\n        self.warmup()   \n\n        start = time.time()\n        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads\n\n        last_battles_game = np.zeros(self.n_rollout_threads, dtype=np.float32)\n        last_battles_won = np.zeros(self.n_rollout_threads, dtype=np.float32)\n\n        for episode in range(episodes):\n            if self.use_linear_lr_decay:\n                self.trainer.policy.lr_decay(episode, episodes)\n\n            for step in range(self.episode_length):\n                # Sample actions\n                values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step)\n                # Obser reward and next obs\n                obs, share_obs, rewards, dones, infos, available_actions = self.envs.step(actions)\n\n                data = obs, share_obs, rewards, dones, infos, available_actions, \\\n                       values, actions, action_log_probs, \\\n                       rnn_states, rnn_states_critic \n                \n                # insert data into buffer\n                self.insert(data)\n\n            # compute return and update network\n            self.compute()\n            train_infos = self.train()\n            \n            # post process\n            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads           \n            # save model\n            if (episode % self.save_interval == 0 or episode == episodes - 1):\n                self.save()\n\n            # log information\n            if episode % self.log_interval == 0:\n                end = time.time()\n                print(\"\\n Map {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\\n\"\n                        .format(self.all_args.map_name,\n                                self.algorithm_name,\n                                self.experiment_name,\n                                episode,\n                                episodes,\n                                total_num_steps,\n                                self.num_env_steps,\n                                int(total_num_steps / (end - start))))\n\n                if self.env_name == \"StarCraft2\":\n                    battles_won = []\n                    battles_game = []\n                    incre_battles_won = []\n                    incre_battles_game = []                    \n\n                    for i, info in enumerate(infos):\n                        if 'battles_won' in info[0].keys():\n                            battles_won.append(info[0]['battles_won'])\n                            incre_battles_won.append(info[0]['battles_won']-last_battles_won[i])\n                        if 'battles_game' in info[0].keys():\n                            battles_game.append(info[0]['battles_game'])\n                            incre_battles_game.append(info[0]['battles_game']-last_battles_game[i])\n\n                    incre_win_rate = np.sum(incre_battles_won)/np.sum(incre_battles_game) if np.sum(incre_battles_game)>0 else 0.0\n                    print(\"incre win rate is {}.\".format(incre_win_rate))\n                    self.writter.add_scalars(\"incre_win_rate\", {\"incre_win_rate\": incre_win_rate}, total_num_steps)\n                    \n                    last_battles_game = battles_game\n                    last_battles_won = battles_won\n                # modified\n\n                for agent_id in range(self.num_agents):\n                    train_infos[agent_id]['dead_ratio'] = 1 - self.buffer[agent_id].active_masks.sum() /(self.num_agents* reduce(lambda x, y: x*y, list(self.buffer[agent_id].active_masks.shape)))\n                \n                self.log_train(train_infos, total_num_steps)\n\n            # eval\n            if episode % self.eval_interval == 0 and self.use_eval:\n                self.eval(total_num_steps)\n\n    def warmup(self):\n        # reset env\n        obs, share_obs, available_actions = self.envs.reset()\n        # replay buffer\n        if not self.use_centralized_V:\n            share_obs = obs\n        for agent_id in range(self.num_agents):\n            self.buffer[agent_id].share_obs[0] = share_obs[:,agent_id].copy()\n            self.buffer[agent_id].obs[0] = obs[:,agent_id].copy()\n            self.buffer[agent_id].available_actions[0] = available_actions[:,agent_id].copy()\n\n    @torch.no_grad()\n    def collect(self, step):\n        value_collector=[]\n        action_collector=[]\n        action_log_prob_collector=[]\n        rnn_state_collector=[]\n        rnn_state_critic_collector=[]\n        for agent_id in range(self.num_agents):\n            self.trainer[agent_id].prep_rollout()\n            value, action, action_log_prob, rnn_state, rnn_state_critic \\\n                = self.trainer[agent_id].policy.get_actions(self.buffer[agent_id].share_obs[step],\n                                                self.buffer[agent_id].obs[step],\n                                                self.buffer[agent_id].rnn_states[step],\n                                                self.buffer[agent_id].rnn_states_critic[step],\n                                                self.buffer[agent_id].masks[step],\n                                                self.buffer[agent_id].available_actions[step])\n            value_collector.append(_t2n(value))\n            action_collector.append(_t2n(action))\n            action_log_prob_collector.append(_t2n(action_log_prob))\n            rnn_state_collector.append(_t2n(rnn_state))\n            rnn_state_critic_collector.append(_t2n(rnn_state_critic))\n        # [self.envs, agents, dim]\n        values = np.array(value_collector).transpose(1, 0, 2)\n        actions = np.array(action_collector).transpose(1, 0, 2)\n        action_log_probs = np.array(action_log_prob_collector).transpose(1, 0, 2)\n        rnn_states = np.array(rnn_state_collector).transpose(1, 0, 2, 3)\n        rnn_states_critic = np.array(rnn_state_critic_collector).transpose(1, 0, 2, 3)\n\n        return values, actions, action_log_probs, rnn_states, rnn_states_critic\n\n    def insert(self, data):\n        obs, share_obs, rewards, dones, infos, available_actions, \\\n        values, actions, action_log_probs, rnn_states, rnn_states_critic = data\n\n        dones_env = np.all(dones, axis=1)\n\n        rnn_states[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)\n        rnn_states_critic[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, *self.buffer[0].rnn_states_critic.shape[2:]), dtype=np.float32)\n\n        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)\n        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)\n\n        active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)\n        active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)\n        active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)\n\n        bad_masks = np.array([[[0.0] if info[agent_id]['bad_transition'] else [1.0] for agent_id in range(self.num_agents)] for info in infos])\n        \n        if not self.use_centralized_V:\n            share_obs = obs\n        for agent_id in range(self.num_agents):\n            self.buffer[agent_id].insert(share_obs[:,agent_id], obs[:,agent_id], rnn_states[:,agent_id],\n                    rnn_states_critic[:,agent_id],actions[:,agent_id], action_log_probs[:,agent_id],\n                    values[:,agent_id], rewards[:,agent_id], masks[:,agent_id], bad_masks[:,agent_id], \n                    active_masks[:,agent_id], available_actions[:,agent_id])\n\n    def log_train(self, train_infos, total_num_steps):\n        for agent_id in range(self.num_agents):\n            train_infos[agent_id][\"average_step_rewards\"] = np.mean(self.buffer[agent_id].rewards)\n            for k, v in train_infos[agent_id].items():\n                agent_k = \"agent%i/\" % agent_id + k\n                self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps)\n    \n    @torch.no_grad()\n    def eval(self, total_num_steps):\n        eval_battles_won = 0\n        eval_episode = 0\n\n        eval_episode_rewards = []\n        one_episode_rewards = []\n        for eval_i in range(self.n_eval_rollout_threads):\n            one_episode_rewards.append([])\n            eval_episode_rewards.append([])\n\n        eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset()\n\n        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)\n        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)\n\n        while True:\n            eval_actions_collector=[]\n            eval_rnn_states_collector=[]\n            for agent_id in range(self.num_agents):\n                self.trainer[agent_id].prep_rollout()\n                eval_actions, temp_rnn_state = \\\n                    self.trainer[agent_id].policy.act(eval_obs[:,agent_id],\n                                            eval_rnn_states[:,agent_id],\n                                            eval_masks[:,agent_id],\n                                            eval_available_actions[:,agent_id],\n                                            deterministic=True)\n                eval_rnn_states[:,agent_id]=_t2n(temp_rnn_state)\n                eval_actions_collector.append(_t2n(eval_actions))\n\n            eval_actions = np.array(eval_actions_collector).transpose(1,0,2)\n\n            \n            # Obser reward and next obs\n            eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = self.eval_envs.step(eval_actions)\n            for eval_i in range(self.n_eval_rollout_threads):\n                one_episode_rewards[eval_i].append(eval_rewards[eval_i])\n\n            eval_dones_env = np.all(eval_dones, axis=1)\n\n            eval_rnn_states[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)\n\n            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)\n            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)\n\n            for eval_i in range(self.n_eval_rollout_threads):\n                if eval_dones_env[eval_i]:\n                    eval_episode += 1\n                    eval_episode_rewards[eval_i].append(np.sum(one_episode_rewards[eval_i], axis=0))\n                    one_episode_rewards[eval_i] = []\n                    if eval_infos[eval_i][0]['won']:\n                        eval_battles_won += 1\n\n            if eval_episode >= self.all_args.eval_episodes:\n                eval_episode_rewards = np.concatenate(eval_episode_rewards)\n                eval_env_infos = {'eval_average_episode_rewards': eval_episode_rewards}                \n                self.log_env(eval_env_infos, total_num_steps)\n                eval_win_rate = eval_battles_won/eval_episode\n                print(\"eval win rate is {}.\".format(eval_win_rate))\n                self.writter.add_scalars(\"eval_win_rate\", {\"eval_win_rate\": eval_win_rate}, total_num_steps)\n                break\n"
  },
  {
    "path": "scripts/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/train/__init__.py",
    "content": ""
  },
  {
    "path": "scripts/train/train_mujoco.py",
    "content": "#!/usr/bin/env python\nimport sys\nimport os\nsys.path.append(\"../\")\nimport socket\nimport setproctitle\nimport numpy as np\nfrom pathlib import Path\nimport torch\nfrom configs.config import get_config\nfrom envs.ma_mujoco.multiagent_mujoco.mujoco_multi import MujocoMulti\nfrom envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv\nfrom runners.separated.mujoco_runner import MujocoRunner as Runner\n\"\"\"Train script for Mujoco.\"\"\"\n\n\ndef make_train_env(all_args):\n    def get_env_fn(rank):\n        def init_env():\n            if all_args.env_name == \"mujoco\":\n                env_args = {\"scenario\": all_args.scenario,\n                            \"agent_conf\": all_args.agent_conf,\n                            \"agent_obsk\": all_args.agent_obsk,\n                            \"episode_limit\": 1000}\n                env = MujocoMulti(env_args=env_args)\n            else:\n                print(\"Can not support the \" + all_args.env_name + \"environment.\")\n                raise NotImplementedError\n            env.seed(all_args.seed + rank * 1000)\n            return env\n\n        return init_env\n\n    if all_args.n_rollout_threads == 1:\n        return ShareDummyVecEnv([get_env_fn(0)])\n    else:\n        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])\n\n\ndef make_eval_env(all_args):\n    def get_env_fn(rank):\n        def init_env():\n            if all_args.env_name == \"mujoco\":\n                env_args = {\"scenario\": all_args.scenario,\n                            \"agent_conf\": all_args.agent_conf,\n                            \"agent_obsk\": all_args.agent_obsk,\n                            \"episode_limit\": 1000}\n                env = MujocoMulti(env_args=env_args)\n            else:\n                print(\"Can not support the \" + all_args.env_name + \"environment.\")\n                raise NotImplementedError\n            env.seed(all_args.seed * 50000 + rank * 10000)\n            return env\n\n        return init_env\n\n    if all_args.n_eval_rollout_threads == 1:\n        return ShareDummyVecEnv([get_env_fn(0)])\n    else:\n        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])\n\n\ndef parse_args(args, parser):\n    parser.add_argument('--scenario', type=str, default='Hopper-v2', help=\"Which mujoco task to run on\")\n    parser.add_argument('--agent_conf', type=str, default='3x1')\n    parser.add_argument('--agent_obsk', type=int, default=0)\n    parser.add_argument(\"--add_move_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_local_obs\", action='store_true', default=False)\n    parser.add_argument(\"--add_distance_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_enemy_action_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_agent_id\", action='store_true', default=False)\n    parser.add_argument(\"--add_visible_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_xy_state\", action='store_true', default=False)\n\n    # agent-specific state should be designed carefully\n    parser.add_argument(\"--use_state_agent\", action='store_true', default=False)\n    parser.add_argument(\"--use_mustalive\", action='store_false', default=True)\n    parser.add_argument(\"--add_center_xy\", action='store_true', default=False)\n    parser.add_argument(\"--use_single_network\", action='store_true', default=False)\n\n    all_args = parser.parse_known_args(args)[0]\n\n    return all_args\n\n\ndef main(args):\n    parser = get_config()\n    all_args = parse_args(args, parser)\n    print(\"all config: \", all_args)\n    if all_args.seed_specify:\n        all_args.seed=all_args.runing_id\n    else:\n        all_args.seed=np.random.randint(1000,10000)\n    print(\"seed is :\",all_args.seed)\n    # cuda\n    if all_args.cuda and torch.cuda.is_available():\n        print(\"choose to use gpu...\")\n        device = torch.device(\"cuda:0\")\n        torch.set_num_threads(all_args.n_training_threads)\n        if all_args.cuda_deterministic:\n            torch.backends.cudnn.benchmark = False\n            torch.backends.cudnn.deterministic = True\n    else:\n        print(\"choose to use cpu...\")\n        device = torch.device(\"cpu\")\n        torch.set_num_threads(all_args.n_training_threads)\n\n    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[\n                       0] + \"/results\") / all_args.env_name / all_args.scenario / all_args.algorithm_name / all_args.experiment_name / str(all_args.seed)\n    if not run_dir.exists():\n        os.makedirs(str(run_dir))\n\n    if not run_dir.exists():\n        curr_run = 'run1'\n    else:\n        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if\n                            str(folder.name).startswith('run')]\n        if len(exst_run_nums) == 0:\n            curr_run = 'run1'\n        else:\n            curr_run = 'run%i' % (max(exst_run_nums) + 1)\n    run_dir = run_dir / curr_run\n    if not run_dir.exists():\n        os.makedirs(str(run_dir))\n\n    setproctitle.setproctitle(\n        str(all_args.algorithm_name) + \"-\" + str(all_args.env_name) + \"-\" + str(all_args.experiment_name) + \"@\" + str(\n            all_args.user_name))\n\n    # seed\n    torch.manual_seed(all_args.seed)\n    torch.cuda.manual_seed_all(all_args.seed)\n    np.random.seed(all_args.seed)\n\n    # env\n    envs = make_train_env(all_args)\n    eval_envs = make_eval_env(all_args) if all_args.use_eval else None\n    num_agents = envs.n_agents\n\n    config = {\n        \"all_args\": all_args,\n        \"envs\": envs,\n        \"eval_envs\": eval_envs,\n        \"num_agents\": num_agents,\n        \"device\": device,\n        \"run_dir\": run_dir\n    }\n\n    # run experiments\n    runner = Runner(config)\n    runner.run()\n\n    # post process\n    envs.close()\n    if all_args.use_eval and eval_envs is not envs:\n        eval_envs.close()\n\n    runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))\n    runner.writter.close()\n\n\nif __name__ == \"__main__\":\n    main(sys.argv[1:])\n"
  },
  {
    "path": "scripts/train/train_smac.py",
    "content": "#!/usr/bin/env python\nimport sys\nimport os\nsys.path.append(\"../\")\nimport socket\nimport setproctitle\nimport numpy as np\nfrom pathlib import Path\nimport torch\nfrom configs.config import get_config\nfrom envs.starcraft2.StarCraft2_Env import StarCraft2Env\nfrom envs.starcraft2.smac_maps import get_map_params\nfrom envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv\nfrom runners.separated.smac_runner import SMACRunner as Runner\n\"\"\"Train script for SMAC.\"\"\"\n\ndef make_train_env(all_args):\n    def get_env_fn(rank):\n        def init_env():\n            if all_args.env_name == \"StarCraft2\":\n                env = StarCraft2Env(all_args)\n            else:\n                print(\"Can not support the \" + all_args.env_name + \"environment.\")\n                raise NotImplementedError\n            env.seed(all_args.seed + rank * 1000)\n            return env\n\n        return init_env\n\n    if all_args.n_rollout_threads == 1:\n        return ShareDummyVecEnv([get_env_fn(0)])\n    else:\n        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)])\n\n\ndef make_eval_env(all_args):\n    def get_env_fn(rank):\n        def init_env():\n            if all_args.env_name == \"StarCraft2\":\n                env = StarCraft2Env(all_args)\n            else:\n                print(\"Can not support the \" + all_args.env_name + \"environment.\")\n                raise NotImplementedError\n            env.seed(all_args.seed * 50000 + rank * 10000)\n            return env\n\n        return init_env\n\n    if all_args.n_eval_rollout_threads == 1:\n        return ShareDummyVecEnv([get_env_fn(0)])\n    else:\n        return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)])\n\n\ndef parse_args(args, parser):\n    parser.add_argument('--map_name', type=str, default='3m',help=\"Which smac map to run on\")\n    parser.add_argument(\"--add_move_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_local_obs\", action='store_true', default=False)\n    parser.add_argument(\"--add_distance_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_enemy_action_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_agent_id\", action='store_true', default=False)\n    parser.add_argument(\"--add_visible_state\", action='store_true', default=False)\n    parser.add_argument(\"--add_xy_state\", action='store_true', default=False)\n    parser.add_argument(\"--use_state_agent\", action='store_true', default=False)\n    parser.add_argument(\"--use_mustalive\", action='store_false', default=True)\n    parser.add_argument(\"--add_center_xy\", action='store_true', default=False)\n    parser.add_argument(\"--use_single_network\", action='store_true', default=False)\n    all_args = parser.parse_known_args(args)[0]\n\n    return all_args\n\n\ndef main(args):\n    parser = get_config()\n    all_args = parse_args(args, parser)\n    print(\"all config: \", all_args)\n    if all_args.seed_specify:\n        all_args.seed=all_args.runing_id\n    else:\n        all_args.seed=np.random.randint(1000,10000)\n    print(\"seed is :\",all_args.seed)\n    # cuda\n    if all_args.cuda and torch.cuda.is_available():\n        print(\"choose to use gpu...\")\n        device = torch.device(\"cuda:0\")\n        torch.set_num_threads(all_args.n_training_threads)\n        if all_args.cuda_deterministic:\n            torch.backends.cudnn.benchmark = False\n            torch.backends.cudnn.deterministic = True\n    else:\n        print(\"choose to use cpu...\")\n        device = torch.device(\"cpu\")\n        torch.set_num_threads(all_args.n_training_threads)\n\n    run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[\n                       0] + \"/results\") / all_args.env_name / all_args.map_name / all_args.algorithm_name / all_args.experiment_name / str(all_args.seed)\n    if not run_dir.exists():\n        os.makedirs(str(run_dir))\n\n    if not run_dir.exists():\n        curr_run = 'run1'\n    else:\n        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if\n                            str(folder.name).startswith('run')]\n        if len(exst_run_nums) == 0:\n            curr_run = 'run1'\n        else:\n            curr_run = 'run%i' % (max(exst_run_nums) + 1)\n    run_dir = run_dir / curr_run\n    if not run_dir.exists():\n        os.makedirs(str(run_dir))\n\n    setproctitle.setproctitle(\n        str(all_args.algorithm_name) + \"-\" + str(all_args.env_name) + \"-\" + str(all_args.experiment_name) + \"@\" + str(\n            all_args.user_name))\n\n    # seed\n    torch.manual_seed(all_args.seed)\n    torch.cuda.manual_seed_all(all_args.seed)\n    np.random.seed(all_args.seed)\n\n    # env\n    envs = make_train_env(all_args)\n    eval_envs = make_eval_env(all_args) if all_args.use_eval else None\n    num_agents = get_map_params(all_args.map_name)[\"n_agents\"]\n\n    config = {\n        \"all_args\": all_args,\n        \"envs\": envs,\n        \"eval_envs\": eval_envs,\n        \"num_agents\": num_agents,\n        \"device\": device,\n        \"run_dir\": run_dir\n    }\n    # run experiments\n    runner = Runner(config)\n    runner.run()\n\n    # post process\n    envs.close()\n    if all_args.use_eval and eval_envs is not envs:\n        eval_envs.close()\n    runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))\n    runner.writter.close()\n\n\nif __name__ == \"__main__\":\n    \n    main(sys.argv[1:])\n\n    "
  },
  {
    "path": "scripts/train_mujoco.sh",
    "content": "#!/bin/sh\nenv=\"mujoco\"\nscenario=\"Ant-v2\"\nagent_conf=\"2x4\"\nagent_obsk=2\nalgo=\"happo\"\nexp=\"mlp\"\nrunning_max=20\nkl_threshold=1e-4\necho \"env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}\"\nfor number in `seq ${running_max}`;\ndo\n    echo \"the ${number}-th running:\"\n    CUDA_VISIBLE_DEVICES=1 python train/train_mujoco.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario ${scenario} --agent_conf ${agent_conf} --agent_obsk ${agent_obsk} --lr 5e-6 --critic_lr 5e-3 --std_x_coef 1 --std_y_coef 5e-1 --running_id ${number} --n_training_threads 8 --n_rollout_threads 4 --num_mini_batch 40 --episode_length 1000 --num_env_steps 10000000 --ppo_epoch 5 --kl_threshold ${kl_threshold} --use_value_active_masks --use_eval --add_center_xy --use_state_agent --share_policy\ndone\n"
  },
  {
    "path": "scripts/train_smac.sh",
    "content": "#!/bin/sh\nenv=\"StarCraft2\"\nmap=\"3s5z\"\nalgo=\"happo\"\nexp=\"mlp\"\nrunning_max=20\nkl_threshold=0.06\necho \"env is ${env}, map is ${map}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}\"\nfor number in `seq ${running_max}`;\ndo\n    echo \"the ${number}-th running:\"\n    CUDA_VISIBLE_DEVICES=1 python train/train_smac.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --map_name ${map} --running_id ${number} --gamma 0.95 --n_training_threads 32 --n_rollout_threads 20 --num_mini_batch 1 --episode_length 160 --num_env_steps 20000000 --ppo_epoch 5 --stacked_frames 1 --kl_threshold ${kl_threshold} --use_value_active_masks --use_eval --add_center_xy --use_state_agent --share_policy\ndone\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/multi_discrete.py",
    "content": "import gym\nimport numpy as np\n\n# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)\n# (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py)\nclass MultiDiscrete(gym.Space):\n    \"\"\"\n    - The multi-discrete action space consists of a series of discrete action spaces with different parameters\n    - It can be adapted to both a Discrete action space or a continuous (Box) action space\n    - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space\n    - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space where the discrete action space can take any integers from `min` to `max` (both inclusive)\n    Note: A value of 0 always need to represent the NOOP action.\n    e.g. Nintendo Game Controller\n    - Can be conceptualized as 3 discrete action spaces:\n        1) Arrow Keys: Discrete 5  - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4]  - params: min: 0, max: 4\n        2) Button A:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n        3) Button B:   Discrete 2  - NOOP[0], Pressed[1] - params: min: 0, max: 1\n    - Can be initialized as\n        MultiDiscrete([ [0,4], [0,1], [0,1] ])\n    \"\"\"\n\n    def __init__(self, array_of_param_array):\n        self.low = np.array([x[0] for x in array_of_param_array])\n        self.high = np.array([x[1] for x in array_of_param_array])\n        self.num_discrete_space = self.low.shape[0]\n        self.n = np.sum(self.high) + 2\n\n    def sample(self):\n        \"\"\" Returns a array with one sample from each discrete action space \"\"\"\n        # For each row: round(random .* (max - min) + min, 0)\n        random_array = np.random.rand(self.num_discrete_space)\n        return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)]\n\n    def contains(self, x):\n        return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all()\n\n    @property\n    def shape(self):\n        return self.num_discrete_space\n\n    def __repr__(self):\n        return \"MultiDiscrete\" + str(self.num_discrete_space)\n\n    def __eq__(self, other):\n        return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high)\n"
  },
  {
    "path": "utils/popart.py",
    "content": "\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\n\nclass PopArt(nn.Module):\n    \"\"\" Normalize a vector of observations - across the first norm_axes dimensions\"\"\"\n\n    def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device(\"cpu\")):\n        super(PopArt, self).__init__()\n\n        self.input_shape = input_shape\n        self.norm_axes = norm_axes\n        self.epsilon = epsilon\n        self.beta = beta\n        self.per_element_update = per_element_update\n        self.tpdv = dict(dtype=torch.float32, device=device)\n\n        self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)\n        self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)\n        self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv)\n\n    def reset_parameters(self):\n        self.running_mean.zero_()\n        self.running_mean_sq.zero_()\n        self.debiasing_term.zero_()\n\n    def running_mean_var(self):\n        debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon)\n        debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2)\n        return debiased_mean, debiased_var\n\n    def forward(self, input_vector, train=True):\n        # Make sure input is float32\n        if type(input_vector) == np.ndarray:\n            input_vector = torch.from_numpy(input_vector)\n        input_vector = input_vector.to(**self.tpdv)\n\n        if train:\n            # Detach input before adding it to running means to avoid backpropping through it on\n            # subsequent batches.\n            detached_input = input_vector.detach()\n            batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes)))\n            batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes)))\n\n            if self.per_element_update:\n                batch_size = np.prod(detached_input.size()[:self.norm_axes])\n                weight = self.beta ** batch_size\n            else:\n                weight = self.beta\n\n            self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))\n            self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))\n            self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))\n\n        mean, var = self.running_mean_var()\n        out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes]\n        \n        return out\n\n    def denormalize(self, input_vector):\n        \"\"\" Transform normalized data back into original distribution \"\"\"\n        if type(input_vector) == np.ndarray:\n            input_vector = torch.from_numpy(input_vector)\n        input_vector = input_vector.to(**self.tpdv)\n\n        mean, var = self.running_mean_var()\n        out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes]\n        \n        out = out.cpu().numpy()\n        \n        return out\n"
  },
  {
    "path": "utils/separated_buffer.py",
    "content": "import torch\nimport numpy as np\nfrom collections import defaultdict\nfrom utils.util import check, get_shape_from_obs_space, get_shape_from_act_space\n\ndef _flatten(T, N, x):\n    return x.reshape(T * N, *x.shape[2:])\n\ndef _cast(x):\n    return x.transpose(1,0,2).reshape(-1, *x.shape[2:])\n\nclass SeparatedReplayBuffer(object):\n    def __init__(self, args, obs_space, share_obs_space, act_space):\n        self.episode_length = args.episode_length\n        self.n_rollout_threads = args.n_rollout_threads\n        self.rnn_hidden_size = args.hidden_size\n        self.recurrent_N = args.recurrent_N\n        self.gamma = args.gamma\n        self.gae_lambda = args.gae_lambda\n        self._use_gae = args.use_gae\n        self._use_popart = args.use_popart\n        self._use_valuenorm = args.use_valuenorm\n        self._use_proper_time_limits = args.use_proper_time_limits\n\n\n\n        obs_shape = get_shape_from_obs_space(obs_space)\n        share_obs_shape = get_shape_from_obs_space(share_obs_space)\n\n        if type(obs_shape[-1]) == list:\n            obs_shape = obs_shape[:1]\n\n        if type(share_obs_shape[-1]) == list:\n            share_obs_shape = share_obs_shape[:1]\n\n        self.share_obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *share_obs_shape), dtype=np.float32)\n        self.obs = np.zeros((self.episode_length + 1, self.n_rollout_threads, *obs_shape), dtype=np.float32)\n\n        self.rnn_states = np.zeros((self.episode_length + 1, self.n_rollout_threads, self.recurrent_N, self.rnn_hidden_size), dtype=np.float32)\n        self.rnn_states_critic = np.zeros_like(self.rnn_states)\n\n        self.value_preds = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)\n        self.returns = np.zeros((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)\n        \n        if act_space.__class__.__name__ == 'Discrete':\n            self.available_actions = np.ones((self.episode_length + 1, self.n_rollout_threads, act_space.n), dtype=np.float32)\n        else:\n            self.available_actions = None\n\n        act_shape = get_shape_from_act_space(act_space)\n\n        self.actions = np.zeros((self.episode_length, self.n_rollout_threads, act_shape), dtype=np.float32)\n        self.action_log_probs = np.zeros((self.episode_length, self.n_rollout_threads, act_shape), dtype=np.float32)\n        self.rewards = np.zeros((self.episode_length, self.n_rollout_threads, 1), dtype=np.float32)\n        \n        self.masks = np.ones((self.episode_length + 1, self.n_rollout_threads, 1), dtype=np.float32)\n        self.bad_masks = np.ones_like(self.masks)\n        self.active_masks = np.ones_like(self.masks)\n\n        self.factor = None\n\n        self.step = 0\n\n    def update_factor(self, factor):\n        self.factor = factor.copy()\n\n    def insert(self, share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs,\n               value_preds, rewards, masks, bad_masks=None, active_masks=None, available_actions=None):\n        self.share_obs[self.step + 1] = share_obs.copy()\n        self.obs[self.step + 1] = obs.copy()\n        self.rnn_states[self.step + 1] = rnn_states.copy()\n        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()\n        self.actions[self.step] = actions.copy()\n        self.action_log_probs[self.step] = action_log_probs.copy()\n        self.value_preds[self.step] = value_preds.copy()\n        self.rewards[self.step] = rewards.copy()\n        self.masks[self.step + 1] = masks.copy()\n        if bad_masks is not None:\n            self.bad_masks[self.step + 1] = bad_masks.copy()\n        if active_masks is not None:\n            self.active_masks[self.step + 1] = active_masks.copy()\n        if available_actions is not None:\n            self.available_actions[self.step + 1] = available_actions.copy()\n\n        self.step = (self.step + 1) % self.episode_length\n\n    def chooseinsert(self, share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs,\n                     value_preds, rewards, masks, bad_masks=None, active_masks=None, available_actions=None):\n        self.share_obs[self.step] = share_obs.copy()\n        self.obs[self.step] = obs.copy()\n        self.rnn_states[self.step + 1] = rnn_states.copy()\n        self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy()\n        self.actions[self.step] = actions.copy()\n        self.action_log_probs[self.step] = action_log_probs.copy()\n        self.value_preds[self.step] = value_preds.copy()\n        self.rewards[self.step] = rewards.copy()\n        self.masks[self.step + 1] = masks.copy()\n        if bad_masks is not None:\n            self.bad_masks[self.step + 1] = bad_masks.copy()\n        if active_masks is not None:\n            self.active_masks[self.step] = active_masks.copy()\n        if available_actions is not None:\n            self.available_actions[self.step] = available_actions.copy()\n\n        self.step = (self.step + 1) % self.episode_length\n    \n    def after_update(self):\n        self.share_obs[0] = self.share_obs[-1].copy()\n        self.obs[0] = self.obs[-1].copy()\n        self.rnn_states[0] = self.rnn_states[-1].copy()\n        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()\n        self.masks[0] = self.masks[-1].copy()\n        self.bad_masks[0] = self.bad_masks[-1].copy()\n        self.active_masks[0] = self.active_masks[-1].copy()\n        if self.available_actions is not None:\n            self.available_actions[0] = self.available_actions[-1].copy()\n\n    def chooseafter_update(self):\n        self.rnn_states[0] = self.rnn_states[-1].copy()\n        self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy()\n        self.masks[0] = self.masks[-1].copy()\n        self.bad_masks[0] = self.bad_masks[-1].copy()\n\n    def compute_returns(self, next_value, value_normalizer=None):\n        \"\"\"\n        use proper time limits, the difference of use or not is whether use bad_mask\n        \"\"\"\n        if self._use_proper_time_limits:\n            if self._use_gae:\n                self.value_preds[-1] = next_value\n                gae = 0\n                for step in reversed(range(self.rewards.shape[0])):\n                    if self._use_popart or self._use_valuenorm:\n                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[\n                            step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])\n                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae\n                        gae = gae * self.bad_masks[step + 1]\n                        self.returns[step] = gae + value_normalizer.denormalize(self.value_preds[step])\n                    else:\n                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]\n                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae\n                        gae = gae * self.bad_masks[step + 1]\n                        self.returns[step] = gae + self.value_preds[step]\n            else:\n                self.returns[-1] = next_value\n                for step in reversed(range(self.rewards.shape[0])):\n                    if self._use_popart:\n                        self.returns[step] = (self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \\\n                            + (1 - self.bad_masks[step + 1]) * value_normalizer.denormalize(self.value_preds[step])\n                    else:\n                        self.returns[step] = (self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]) * self.bad_masks[step + 1] \\\n                            + (1 - self.bad_masks[step + 1]) * self.value_preds[step]\n        else:\n            if self._use_gae:\n                self.value_preds[-1] = next_value\n                gae = 0\n                for step in reversed(range(self.rewards.shape[0])):\n                    if self._use_popart or self._use_valuenorm:\n                        delta = self.rewards[step] + self.gamma * value_normalizer.denormalize(self.value_preds[step + 1]) * self.masks[step + 1] - value_normalizer.denormalize(self.value_preds[step])\n                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae\n                        self.returns[step] = gae + value_normalizer.denormalize(self.value_preds[step])\n                    else:\n                        delta = self.rewards[step] + self.gamma * self.value_preds[step + 1] * self.masks[step + 1] - self.value_preds[step]\n                        gae = delta + self.gamma * self.gae_lambda * self.masks[step + 1] * gae\n                        self.returns[step] = gae + self.value_preds[step]\n            else:\n                self.returns[-1] = next_value\n                for step in reversed(range(self.rewards.shape[0])):\n                    self.returns[step] = self.returns[step + 1] * self.gamma * self.masks[step + 1] + self.rewards[step]\n\n    def feed_forward_generator(self, advantages, num_mini_batch=None, mini_batch_size=None):\n        episode_length, n_rollout_threads = self.rewards.shape[0:2]\n        batch_size = n_rollout_threads * episode_length\n\n        if mini_batch_size is None:\n            assert batch_size >= num_mini_batch, (\n                \"PPO requires the number of processes ({}) \"\n                \"* number of steps ({}) = {} \"\n                \"to be greater than or equal to the number of PPO mini batches ({}).\"\n                \"\".format(n_rollout_threads, episode_length, n_rollout_threads * episode_length,\n                          num_mini_batch))\n            mini_batch_size = batch_size // num_mini_batch\n\n        rand = torch.randperm(batch_size).numpy()\n        sampler = [rand[i*mini_batch_size:(i+1)*mini_batch_size] for i in range(num_mini_batch)]\n\n        share_obs = self.share_obs[:-1].reshape(-1, *self.share_obs.shape[2:])\n        obs = self.obs[:-1].reshape(-1, *self.obs.shape[2:])\n        rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[2:])\n        rnn_states_critic = self.rnn_states_critic[:-1].reshape(-1, *self.rnn_states_critic.shape[2:])\n        actions = self.actions.reshape(-1, self.actions.shape[-1])\n        if self.available_actions is not None:\n            available_actions = self.available_actions[:-1].reshape(-1, self.available_actions.shape[-1])\n        value_preds = self.value_preds[:-1].reshape(-1, 1)\n        returns = self.returns[:-1].reshape(-1, 1)\n        masks = self.masks[:-1].reshape(-1, 1)\n        active_masks = self.active_masks[:-1].reshape(-1, 1)\n        action_log_probs = self.action_log_probs.reshape(-1, self.action_log_probs.shape[-1])\n        if self.factor is not None:\n            # factor = self.factor.reshape(-1,1)\n            factor = self.factor.reshape(-1, self.factor.shape[-1])\n        advantages = advantages.reshape(-1, 1)\n\n        for indices in sampler:\n            # obs size [T+1 N Dim]-->[T N Dim]-->[T*N,Dim]-->[index,Dim]\n            share_obs_batch = share_obs[indices]\n            obs_batch = obs[indices]\n            rnn_states_batch = rnn_states[indices]\n            rnn_states_critic_batch = rnn_states_critic[indices]\n            actions_batch = actions[indices]\n            if self.available_actions is not None:\n                available_actions_batch = available_actions[indices]\n            else:\n                available_actions_batch = None\n            value_preds_batch = value_preds[indices]\n            return_batch = returns[indices]\n            masks_batch = masks[indices]\n            active_masks_batch = active_masks[indices]\n            old_action_log_probs_batch = action_log_probs[indices]\n            if advantages is None:\n                adv_targ = None\n            else:\n                adv_targ = advantages[indices]\n\n            if self.factor is None:\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch\n            else:\n                factor_batch = factor[indices]\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch\n\n    def naive_recurrent_generator(self, advantages, num_mini_batch):\n        n_rollout_threads = self.rewards.shape[1]\n        assert n_rollout_threads >= num_mini_batch, (\n            \"PPO requires the number of processes ({}) \"\n            \"to be greater than or equal to the number of \"\n            \"PPO mini batches ({}).\".format(n_rollout_threads, num_mini_batch))\n        num_envs_per_batch = n_rollout_threads // num_mini_batch\n        perm = torch.randperm(n_rollout_threads).numpy()\n        for start_ind in range(0, n_rollout_threads, num_envs_per_batch):\n            share_obs_batch = []\n            obs_batch = []\n            rnn_states_batch = []\n            rnn_states_critic_batch = []\n            actions_batch = []\n            available_actions_batch = []\n            value_preds_batch = []\n            return_batch = []\n            masks_batch = []\n            active_masks_batch = []\n            old_action_log_probs_batch = []\n            adv_targ = []\n            factor_batch = []\n            for offset in range(num_envs_per_batch):\n                ind = perm[start_ind + offset]\n                share_obs_batch.append(self.share_obs[:-1, ind])\n                obs_batch.append(self.obs[:-1, ind])\n                rnn_states_batch.append(self.rnn_states[0:1, ind])\n                rnn_states_critic_batch.append(self.rnn_states_critic[0:1, ind])\n                actions_batch.append(self.actions[:, ind])\n                if self.available_actions is not None:\n                    available_actions_batch.append(self.available_actions[:-1, ind])\n                value_preds_batch.append(self.value_preds[:-1, ind])\n                return_batch.append(self.returns[:-1, ind])\n                masks_batch.append(self.masks[:-1, ind])\n                active_masks_batch.append(self.active_masks[:-1, ind])\n                old_action_log_probs_batch.append(self.action_log_probs[:, ind])\n                adv_targ.append(advantages[:, ind])\n                if self.factor is not None:\n                    factor_batch.append(self.factor[:,ind])\n\n            # [N[T, dim]]\n            T, N = self.episode_length, num_envs_per_batch\n            # These are all from_numpys of size (T, N, -1)\n            share_obs_batch = np.stack(share_obs_batch, 1)\n            obs_batch = np.stack(obs_batch, 1)\n            actions_batch = np.stack(actions_batch, 1)\n            if self.available_actions is not None:\n                available_actions_batch = np.stack(available_actions_batch, 1)\n            if self.factor is not None:\n                factor_batch=np.stack(factor_batch,1)\n            value_preds_batch = np.stack(value_preds_batch, 1)\n            return_batch = np.stack(return_batch, 1)\n            masks_batch = np.stack(masks_batch, 1)\n            active_masks_batch = np.stack(active_masks_batch, 1)\n            old_action_log_probs_batch = np.stack(old_action_log_probs_batch, 1)\n            adv_targ = np.stack(adv_targ, 1)\n\n            # States is just a (N, -1) from_numpy [N[1,dim]]\n            rnn_states_batch = np.stack(rnn_states_batch, 1).reshape(N, *self.rnn_states.shape[2:])\n            rnn_states_critic_batch = np.stack(rnn_states_critic_batch, 1).reshape(N, *self.rnn_states_critic.shape[2:])\n\n            # Flatten the (T, N, ...) from_numpys to (T * N, ...)\n            share_obs_batch = _flatten(T, N, share_obs_batch)\n            obs_batch = _flatten(T, N, obs_batch)\n            actions_batch = _flatten(T, N, actions_batch)\n            if self.available_actions is not None:\n                available_actions_batch = _flatten(T, N, available_actions_batch)\n            else:\n                available_actions_batch = None\n            if self.factor is not None:\n                factor_batch=_flatten(T,N,factor_batch)\n            value_preds_batch = _flatten(T, N, value_preds_batch)\n            return_batch = _flatten(T, N, return_batch)\n            masks_batch = _flatten(T, N, masks_batch)\n            active_masks_batch = _flatten(T, N, active_masks_batch)\n            old_action_log_probs_batch = _flatten(T, N, old_action_log_probs_batch)\n            adv_targ = _flatten(T, N, adv_targ)\n            if self.factor is not None:\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch\n            else:\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch\n\n    def recurrent_generator(self, advantages, num_mini_batch, data_chunk_length):\n        episode_length, n_rollout_threads = self.rewards.shape[0:2]\n        batch_size = n_rollout_threads * episode_length\n        data_chunks = batch_size // data_chunk_length  # [C=r*T/L]\n        mini_batch_size = data_chunks // num_mini_batch\n\n        assert episode_length * n_rollout_threads >= data_chunk_length, (\n            \"PPO requires the number of processes ({}) * episode length ({}) \"\n            \"to be greater than or equal to the number of \"\n            \"data chunk length ({}).\".format(n_rollout_threads, episode_length, data_chunk_length))\n        assert data_chunks >= 2, (\"need larger batch size\")\n\n        rand = torch.randperm(data_chunks).numpy()\n        sampler = [rand[i*mini_batch_size:(i+1)*mini_batch_size] for i in range(num_mini_batch)]\n\n        if len(self.share_obs.shape) > 3:\n            share_obs = self.share_obs[:-1].transpose(1, 0, 2, 3, 4).reshape(-1, *self.share_obs.shape[2:])\n            obs = self.obs[:-1].transpose(1, 0, 2, 3, 4).reshape(-1, *self.obs.shape[2:])\n        else:\n            share_obs = _cast(self.share_obs[:-1])\n            obs = _cast(self.obs[:-1])\n\n        actions = _cast(self.actions)\n        action_log_probs = _cast(self.action_log_probs)\n        advantages = _cast(advantages)\n        value_preds = _cast(self.value_preds[:-1])\n        returns = _cast(self.returns[:-1])\n        masks = _cast(self.masks[:-1])\n        active_masks = _cast(self.active_masks[:-1])\n        if self.factor is not None:\n            factor = _cast(self.factor)\n        # rnn_states = _cast(self.rnn_states[:-1])\n        # rnn_states_critic = _cast(self.rnn_states_critic[:-1])\n        rnn_states = self.rnn_states[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states.shape[2:])\n        rnn_states_critic = self.rnn_states_critic[:-1].transpose(1, 0, 2, 3).reshape(-1, *self.rnn_states_critic.shape[2:])\n\n        if self.available_actions is not None:\n            available_actions = _cast(self.available_actions[:-1])\n\n        for indices in sampler:\n            share_obs_batch = []\n            obs_batch = []\n            rnn_states_batch = []\n            rnn_states_critic_batch = []\n            actions_batch = []\n            available_actions_batch = []\n            value_preds_batch = []\n            return_batch = []\n            masks_batch = []\n            active_masks_batch = []\n            old_action_log_probs_batch = []\n            adv_targ = []\n            factor_batch = []\n            for index in indices:\n                ind = index * data_chunk_length\n                # size [T+1 N M Dim]-->[T N Dim]-->[N T Dim]-->[T*N,Dim]-->[L,Dim]\n                share_obs_batch.append(share_obs[ind:ind+data_chunk_length])\n                obs_batch.append(obs[ind:ind+data_chunk_length])\n                actions_batch.append(actions[ind:ind+data_chunk_length])\n                if self.available_actions is not None:\n                    available_actions_batch.append(available_actions[ind:ind+data_chunk_length])\n                value_preds_batch.append(value_preds[ind:ind+data_chunk_length])\n                return_batch.append(returns[ind:ind+data_chunk_length])\n                masks_batch.append(masks[ind:ind+data_chunk_length])\n                active_masks_batch.append(active_masks[ind:ind+data_chunk_length])\n                old_action_log_probs_batch.append(action_log_probs[ind:ind+data_chunk_length])\n                adv_targ.append(advantages[ind:ind+data_chunk_length])\n                # size [T+1 N Dim]-->[T N Dim]-->[T*N,Dim]-->[1,Dim]\n                rnn_states_batch.append(rnn_states[ind])\n                rnn_states_critic_batch.append(rnn_states_critic[ind])\n                if self.factor is not None:\n                    factor_batch.append(factor[ind:ind+data_chunk_length])\n            L, N = data_chunk_length, mini_batch_size\n\n            # These are all from_numpys of size (N, L, Dim)\n            share_obs_batch = np.stack(share_obs_batch)\n            obs_batch = np.stack(obs_batch)\n\n            actions_batch = np.stack(actions_batch)\n            if self.available_actions is not None:\n                available_actions_batch = np.stack(available_actions_batch)\n            if self.factor is not None:\n                factor_batch = np.stack(factor_batch)\n            value_preds_batch = np.stack(value_preds_batch)\n            return_batch = np.stack(return_batch)\n            masks_batch = np.stack(masks_batch)\n            active_masks_batch = np.stack(active_masks_batch)\n            old_action_log_probs_batch = np.stack(old_action_log_probs_batch)\n            adv_targ = np.stack(adv_targ)\n\n            # States is just a (N, -1) from_numpy\n            rnn_states_batch = np.stack(rnn_states_batch).reshape(N, *self.rnn_states.shape[2:])\n            rnn_states_critic_batch = np.stack(rnn_states_critic_batch).reshape(N, *self.rnn_states_critic.shape[2:])\n\n            # Flatten the (L, N, ...) from_numpys to (L * N, ...)\n            share_obs_batch = _flatten(L, N, share_obs_batch)\n            obs_batch = _flatten(L, N, obs_batch)\n            actions_batch = _flatten(L, N, actions_batch)\n            if self.available_actions is not None:\n                available_actions_batch = _flatten(L, N, available_actions_batch)\n            else:\n                available_actions_batch = None\n            if self.factor is not None:\n                factor_batch = _flatten(L, N, factor_batch)\n            value_preds_batch = _flatten(L, N, value_preds_batch)\n            return_batch = _flatten(L, N, return_batch)\n            masks_batch = _flatten(L, N, masks_batch)\n            active_masks_batch = _flatten(L, N, active_masks_batch)\n            old_action_log_probs_batch = _flatten(L, N, old_action_log_probs_batch)\n            adv_targ = _flatten(L, N, adv_targ)\n            if self.factor is not None:\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch, factor_batch\n            else:\n                yield share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch\n"
  },
  {
    "path": "utils/util.py",
    "content": "import numpy as np\nimport math\nimport torch\n\ndef check(input):\n    if type(input) == np.ndarray:\n        return torch.from_numpy(input)\n        \ndef get_gard_norm(it):\n    sum_grad = 0\n    for x in it:\n        if x.grad is None:\n            continue\n        sum_grad += x.grad.norm() ** 2\n    return math.sqrt(sum_grad)\n\ndef update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):\n    \"\"\"Decreases the learning rate linearly\"\"\"\n    lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs)))\n    for param_group in optimizer.param_groups:\n        param_group['lr'] = lr\n\ndef huber_loss(e, d):\n    a = (abs(e) <= d).float()\n    b = (e > d).float()\n    return a*e**2/2 + b*d*(abs(e)-d/2)\n\ndef mse_loss(e):\n    return e**2/2\n\ndef get_shape_from_obs_space(obs_space):\n    if obs_space.__class__.__name__ == 'Box':\n        obs_shape = obs_space.shape\n    elif obs_space.__class__.__name__ == 'list':\n        obs_shape = obs_space\n    else:\n        raise NotImplementedError\n    return obs_shape\n\ndef get_shape_from_act_space(act_space):\n    if act_space.__class__.__name__ == 'Discrete':\n        act_shape = 1\n    elif act_space.__class__.__name__ == \"MultiDiscrete\":\n        act_shape = act_space.shape\n    elif act_space.__class__.__name__ == \"Box\":\n        act_shape = act_space.shape[0]\n    elif act_space.__class__.__name__ == \"MultiBinary\":\n        act_shape = act_space.shape[0]\n    else:  # agar\n        act_shape = act_space[0].shape[0] + 1  \n    return act_shape\n\n\ndef tile_images(img_nhwc):\n    \"\"\"\n    Tile N images into one big PxQ image\n    (P,Q) are chosen to be as close as possible, and if N\n    is square, then P=Q.\n    input: img_nhwc, list or array of images, ndim=4 once turned into array\n        n = batch index, h = height, w = width, c = channel\n    returns:\n        bigim_HWc, ndarray with ndim=3\n    \"\"\"\n    img_nhwc = np.asarray(img_nhwc)\n    N, h, w, c = img_nhwc.shape\n    H = int(np.ceil(np.sqrt(N)))\n    W = int(np.ceil(float(N)/H))\n    img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)])\n    img_HWhwc = img_nhwc.reshape(H, W, h, w, c)\n    img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4)\n    img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c)\n    return img_Hh_Ww_c"
  }
]