Full Code of cyanrain7/TRPO-in-MARL for AI

master e412f13da689 cached
56 files
372.2 KB
96.4k tokens
405 symbols
1 requests
Download .txt
Showing preview only (392K chars total). Download the full file or copy to clipboard to get everything.
Repository: cyanrain7/TRPO-in-MARL
Branch: master
Commit: e412f13da689
Files: 56
Total size: 372.2 KB

Directory structure:
gitextract_0j5bd_hz/

├── .gitignore
├── LICENSE
├── README.md
├── algorithms/
│   ├── __init__.py
│   ├── actor_critic.py
│   ├── happo_policy.py
│   ├── happo_trainer.py
│   ├── hatrpo_policy.py
│   ├── hatrpo_trainer.py
│   └── utils/
│       ├── act.py
│       ├── cnn.py
│       ├── distributions.py
│       ├── mlp.py
│       ├── rnn.py
│       └── util.py
├── configs/
│   └── config.py
├── envs/
│   ├── __init__.py
│   ├── env_wrappers.py
│   ├── ma_mujoco/
│   │   ├── __init__.py
│   │   └── multiagent_mujoco/
│   │       ├── __init__.py
│   │       ├── assets/
│   │       │   ├── .gitignore
│   │       │   ├── __init__.py
│   │       │   ├── coupled_half_cheetah.xml
│   │       │   ├── manyagent_ant.xml
│   │       │   ├── manyagent_ant.xml.template
│   │       │   ├── manyagent_ant__stage1.xml
│   │       │   ├── manyagent_swimmer.xml.template
│   │       │   ├── manyagent_swimmer__bckp2.xml
│   │       │   └── manyagent_swimmer_bckp.xml
│   │       ├── coupled_half_cheetah.py
│   │       ├── manyagent_ant.py
│   │       ├── manyagent_swimmer.py
│   │       ├── mujoco_multi.py
│   │       ├── multiagentenv.py
│   │       └── obsk.py
│   └── starcraft2/
│       ├── StarCraft2_Env.py
│       ├── multiagentenv.py
│       └── smac_maps.py
├── install_sc2.sh
├── requirements.txt
├── runners/
│   ├── __init__.py
│   └── separated/
│       ├── __init__.py
│       ├── base_runner.py
│       ├── mujoco_runner.py
│       └── smac_runner.py
├── scripts/
│   ├── __init__.py
│   ├── train/
│   │   ├── __init__.py
│   │   ├── train_mujoco.py
│   │   └── train_smac.py
│   ├── train_mujoco.sh
│   └── train_smac.sh
└── utils/
    ├── __init__.py
    ├── multi_discrete.py
    ├── popart.py
    ├── separated_buffer.py
    └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.*~
__pycache__/
*.pkl
data/
**/*.egg-info
.python-version
.idea/
.vscode/
.DS_Store
_build/
results/

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2020 Tianshou contributors

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: README.md
================================================
# Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning
Described in the paper "[Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning](https://arxiv.org/pdf/2109.11251.pdf)", this repository develops *Heterogeneous Agent Trust Region Policy Optimisation (HATRPO)* and *Heterogeneous-Agent Proximal Policy Optimisation (HAPPO)* algorithms on the bechmarks of SMAC and Multi-agent MUJOCO. *HATRPO* and *HAPPO* are the first trust region methods for multi-agent reinforcement learning **with theoretically-justified monotonic improvement guarantee**. Performance wise, it is the new state-of-the-art algorithm against its rivals such as [IPPO](https://arxiv.org/abs/2011.09533), [MAPPO](https://arxiv.org/abs/2103.01955) and [MADDPG](https://arxiv.org/abs/1706.02275). HAPPO and HATRPO have been integrated into HARL framework, please check the latest changes at [here](https://github.com/PKU-MARL/HARL).

## Installation
### Create environment
``` Bash
conda create -n env_name python=3.9
conda activate env_name
pip install -r requirements.txt
conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia
```

### Multi-agent MuJoCo
Following the instructios in https://github.com/openai/mujoco-py and https://github.com/schroederdewitt/multiagent_mujoco to setup a mujoco environment. In the end, remember to set the following environment variables:
``` Bash
LD_LIBRARY_PATH=${HOME}/.mujoco/mujoco200/bin;
LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so
```
### StarCraft II & SMAC
Run the script
``` Bash
bash install_sc2.sh
```
Or you could install them manually to other path you like, just follow here: https://github.com/oxwhirl/smac.

## How to run
When your environment is ready, you could run shell scripts provided. For example:
``` Bash
cd scripts
./train_mujoco.sh  # run with HAPPO/HATRPO on Multi-agent MuJoCo
./train_smac.sh  # run with HAPPO/HATRPO on StarCraft II
```

If you would like to change the configs of experiments, you could modify sh files or look for config files for more details. And you can change algorithm by modify **algo=happo** as **algo=hatrpo**.



## Some experiment results

### SMAC 

<img src="plots/smac.png" width="500" >


### Multi-agent MuJoCo on MAPPO

<img src="plots/ma-mujoco_1.png" width="500" > 

### 
<img src="plots/ma-mujoco_2.png" width="500" >

## Additional Experiment Setting
### For SMAC
#### 2022/4/24 update important ERROR for SMAC
##### Fix the parameter of **gamma**, the right configuration of **gamma** show as following:
##### gamma for **3s5z** and **2c_vs_64zg**  is 0.95
##### gamma for **corridor** is 0.99



================================================
FILE: algorithms/__init__.py
================================================


================================================
FILE: algorithms/actor_critic.py
================================================
import torch
import torch.nn as nn
from algorithms.utils.util import init, check
from algorithms.utils.cnn import CNNBase
from algorithms.utils.mlp import MLPBase
from algorithms.utils.rnn import RNNLayer
from algorithms.utils.act import ACTLayer
from utils.util import get_shape_from_obs_space


class Actor(nn.Module):
    """
    Actor network class for HAPPO. Outputs actions given observations.
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param obs_space: (gym.Space) observation space.
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, obs_space, action_space, device=torch.device("cpu")):
        super(Actor, self).__init__()
        self.hidden_size = args.hidden_size
        self.args=args
        self._gain = args.gain
        self._use_orthogonal = args.use_orthogonal
        self._use_policy_active_masks = args.use_policy_active_masks
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.tpdv = dict(dtype=torch.float32, device=device)

        obs_shape = get_shape_from_obs_space(obs_space)
        base = CNNBase if len(obs_shape) == 3 else MLPBase
        self.base = base(args, obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args)

        self.to(device)

    def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False):
        """
        Compute actions from the given inputs.
        :param obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param deterministic: (bool) whether to sample from action distribution or return the mode.

        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        actor_features = self.base(obs)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)

        actions, action_log_probs = self.act(actor_features, available_actions, deterministic)

        return actions, action_log_probs, rnn_states

    def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None):
        """
        Compute log probability and entropy of given actions.
        :param obs: (torch.Tensor) observation inputs into network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        obs = check(obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        if active_masks is not None:
            active_masks = check(active_masks).to(**self.tpdv)

        actor_features = self.base(obs)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks)

        if self.args.algorithm_name=="hatrpo":
            action_log_probs, dist_entropy ,action_mu, action_std, all_probs= self.act.evaluate_actions_trpo(actor_features,
                                                                    action, available_actions,
                                                                    active_masks=
                                                                    active_masks if self._use_policy_active_masks
                                                                    else None)

            return action_log_probs, dist_entropy, action_mu, action_std, all_probs
        else:
            action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features,
                                                                    action, available_actions,
                                                                    active_masks=
                                                                    active_masks if self._use_policy_active_masks
                                                                    else None)

            return action_log_probs, dist_entropy


class Critic(nn.Module):
    """
    Critic network class for HAPPO. Outputs value function predictions given centralized input (HAPPO) or local observations (IPPO).
    :param args: (argparse.Namespace) arguments containing relevant model information.
    :param cent_obs_space: (gym.Space) (centralized) observation space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self, args, cent_obs_space, device=torch.device("cpu")):
        super(Critic, self).__init__()
        self.hidden_size = args.hidden_size
        self._use_orthogonal = args.use_orthogonal
        self._use_naive_recurrent_policy = args.use_naive_recurrent_policy
        self._use_recurrent_policy = args.use_recurrent_policy
        self._recurrent_N = args.recurrent_N
        self.tpdv = dict(dtype=torch.float32, device=device)
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal]

        cent_obs_shape = get_shape_from_obs_space(cent_obs_space)
        base = CNNBase if len(cent_obs_shape) == 3 else MLPBase
        self.base = base(args, cent_obs_shape)

        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal)

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0))

        self.v_out = init_(nn.Linear(self.hidden_size, 1))

        self.to(device)

    def forward(self, cent_obs, rnn_states, masks):
        """
        Compute actions from the given inputs.
        :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network.
        :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
        :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.

        :return values: (torch.Tensor) value function predictions.
        :return rnn_states: (torch.Tensor) updated RNN hidden states.
        """
        cent_obs = check(cent_obs).to(**self.tpdv)
        rnn_states = check(rnn_states).to(**self.tpdv)
        masks = check(masks).to(**self.tpdv)

        critic_features = self.base(cent_obs)
        if self._use_naive_recurrent_policy or self._use_recurrent_policy:
            critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks)
        values = self.v_out(critic_features)

        return values, rnn_states


================================================
FILE: algorithms/happo_policy.py
================================================
import torch
from algorithms.actor_critic import Actor, Critic
from utils.util import update_linear_schedule


class HAPPO_Policy:
    """
    HAPPO Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space (centralized input for HAPPO, decentralized for IPPO).
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
        self.args=args
        self.device = device
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay

        self.obs_space = obs_space
        self.share_obs_space = cent_obs_space
        self.act_space = act_space

        self.actor = Actor(args, self.obs_space, self.act_space, self.device)

        ######################################Please Note#########################################
        #####   We create one critic for each agent, but they are trained with same data     #####
        #####   and using same update setting. Therefore they have the same parameter,       #####
        #####   you can regard them as the same critic.                                      #####
        ##########################################################################################
        self.critic = Critic(args, self.share_obs_space, self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr, eps=self.opti_eps,
                                                weight_decay=self.weight_decay)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr,
                                                 eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)

    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
                    deterministic=False):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """
        actions, action_log_probs, rnn_states_actor = self.actor(obs,
                                                                 rnn_states_actor,
                                                                 masks,
                                                                 available_actions,
                                                                 deterministic)

        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)
        return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic

    def get_values(self, cent_obs, rnn_states_critic, masks):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """
        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values

    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,
                         available_actions=None, active_masks=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """

        action_log_probs, dist_entropy = self.actor.evaluate_actions(obs,
                                                                rnn_states_actor,
                                                                action,
                                                                masks,
                                                                available_actions,
                                                                active_masks)

        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values, action_log_probs, dist_entropy


    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """
        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
        return actions, rnn_states_actor


================================================
FILE: algorithms/happo_trainer.py
================================================
import numpy as np
import torch
import torch.nn as nn
from utils.util import get_gard_norm, huber_loss, mse_loss
from utils.popart import PopArt
from algorithms.utils.util import check

class HAPPO():
    """
    Trainer class for HAPPO to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (HAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):

        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks

        
        if self._use_popart:
            self.value_normalizer = PopArt(1, device=self.device)
        else:
            self.value_normalizer = None

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """
        if self._use_popart:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped
            error_original = self.value_normalizer(return_batch) - values
        else:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def ppo_update(self, sample, update_actor=True):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic update.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
        adv_targ, available_actions_batch, factor_batch = sample



        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)


        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)


        active_masks_batch = check(active_masks_batch).to(**self.tpdv)

        factor_batch = check(factor_batch).to(**self.tpdv)
        # Reshape to do in a single forward pass for all steps
        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
                                                                              obs_batch, 
                                                                              rnn_states_batch, 
                                                                              rnn_states_critic_batch, 
                                                                              actions_batch, 
                                                                              masks_batch, 
                                                                              available_actions_batch,
                                                                              active_masks_batch)
        # actor update
        imp_weights = torch.prod(torch.exp(action_log_probs - old_action_log_probs_batch),dim=-1,keepdim=True)

        surr1 = imp_weights * adv_targ
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ

        if self._use_policy_active_masks:
            policy_action_loss = (-torch.sum(factor_batch * torch.min(surr1, surr2),
                                             dim=-1,
                                             keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            policy_action_loss = -torch.sum(factor_batch * torch.min(surr1, surr2), dim=-1, keepdim=True).mean()

        policy_loss = policy_action_loss

        self.policy.actor_optimizer.zero_grad()

        if update_actor:
            (policy_loss - dist_entropy * self.entropy_coef).backward()

        if self._use_max_grad_norm:
            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
        else:
            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

        self.policy.actor_optimizer.step()

        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)

        self.policy.critic_optimizer.zero_grad()

        (value_loss * self.value_loss_coef).backward()

        if self._use_max_grad_norm:
            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
        else:
            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

        self.policy.critic_optimizer.step()

        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights

    def train(self, buffer, update_actor=True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]

        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0

        for _ in range(self.ppo_epoch):
            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:
                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights = self.ppo_update(sample, update_actor=update_actor)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates
 
        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()


================================================
FILE: algorithms/hatrpo_policy.py
================================================
import torch
from algorithms.actor_critic import Actor, Critic
from utils.util import update_linear_schedule


class HATRPO_Policy:
    """
    HATRPO Policy  class. Wraps actor and critic networks to compute actions and value function predictions.

    :param args: (argparse.Namespace) arguments containing relevant model and policy information.
    :param obs_space: (gym.Space) observation space.
    :param cent_obs_space: (gym.Space) value function input space .
    :param action_space: (gym.Space) action space.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """

    def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")):
        self.args=args
        self.device = device
        self.lr = args.lr
        self.critic_lr = args.critic_lr
        self.opti_eps = args.opti_eps
        self.weight_decay = args.weight_decay

        self.obs_space = obs_space
        self.share_obs_space = cent_obs_space
        self.act_space = act_space

        self.actor = Actor(args, self.obs_space, self.act_space, self.device)

        ######################################Please Note#########################################
        #####   We create one critic for each agent, but they are trained with same data     #####
        #####   and using same update setting. Therefore they have the same parameter,       #####
        #####   you can regard them as the same critic.                                      #####
        ##########################################################################################
        self.critic = Critic(args, self.share_obs_space, self.device)

        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
                                                lr=self.lr, eps=self.opti_eps,
                                                weight_decay=self.weight_decay)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
                                                 lr=self.critic_lr,
                                                 eps=self.opti_eps,
                                                 weight_decay=self.weight_decay)

    def lr_decay(self, episode, episodes):
        """
        Decay the actor and critic learning rates.
        :param episode: (int) current training episode.
        :param episodes: (int) total number of training episodes.
        """
        update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr)
        update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr)

    def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None,
                    deterministic=False):
        """
        Compute actions and value function predictions for the given inputs.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.

        :return values: (torch.Tensor) value function predictions.
        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of chosen actions.
        :return rnn_states_actor: (torch.Tensor) updated actor network RNN states.
        :return rnn_states_critic: (torch.Tensor) updated critic network RNN states.
        """
        actions, action_log_probs, rnn_states_actor = self.actor(obs,
                                                                 rnn_states_actor,
                                                                 masks,
                                                                 available_actions,
                                                                 deterministic)

        values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks)
        return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic

    def get_values(self, cent_obs, rnn_states_critic, masks):
        """
        Get value function predictions.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.

        :return values: (torch.Tensor) value function predictions.
        """
        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values

    def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks,
                         available_actions=None, active_masks=None):
        """
        Get action logprobs / entropy and value function predictions for actor update.
        :param cent_obs (np.ndarray): centralized input to the critic.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic.
        :param action: (np.ndarray) actions whose log probabilites and entropy to compute.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return values: (torch.Tensor) value function predictions.
        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """

        action_log_probs, dist_entropy , action_mu, action_std, all_probs= self.actor.evaluate_actions(obs,
                                                                    rnn_states_actor,
                                                                    action,
                                                                    masks,
                                                                    available_actions,
                                                                    active_masks)
        values, _ = self.critic(cent_obs, rnn_states_critic, masks)
        return values, action_log_probs, dist_entropy, action_mu, action_std, all_probs



    def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False):
        """
        Compute actions using the given inputs.
        :param obs (np.ndarray): local agent inputs to the actor.
        :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor.
        :param masks: (np.ndarray) denotes points at which RNN states should be reset.
        :param available_actions: (np.ndarray) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether the action should be mode of distribution or should be sampled.
        """
        actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic)
        return actions, rnn_states_actor


================================================
FILE: algorithms/hatrpo_trainer.py
================================================
import numpy as np
import torch
import torch.nn as nn
from utils.util import get_gard_norm, huber_loss, mse_loss
from utils.popart import PopArt
from algorithms.utils.util import check
from algorithms.actor_critic import Actor

class HATRPO():
    """
    Trainer class for MATRPO to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (HATRPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):

        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        self.clip_param = args.clip_param
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta

        self.kl_threshold = args.kl_threshold
        self.ls_step = args.ls_step
        self.accept_ratio = args.accept_ratio

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        
        if self._use_popart:
            self.value_normalizer = PopArt(1, device=self.device)
        else:
            self.value_normalizer = None

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """
        if self._use_popart:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = self.value_normalizer(return_batch) - value_pred_clipped
            error_original = self.value_normalizer(return_batch) - values
        else:
            value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param,
                                                                                        self.clip_param)
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def flat_grad(self, grads):
        grad_flatten = []
        for grad in grads:
            if grad is None:
                continue
            grad_flatten.append(grad.view(-1))
        grad_flatten = torch.cat(grad_flatten)
        return grad_flatten

    def flat_hessian(self, hessians):
        hessians_flatten = []
        for hessian in hessians:
            if hessian is None:
                continue
            hessians_flatten.append(hessian.contiguous().view(-1))
        hessians_flatten = torch.cat(hessians_flatten).data
        return hessians_flatten

    def flat_params(self, model):
        params = []
        for param in model.parameters():
            params.append(param.data.view(-1))
        params_flatten = torch.cat(params)
        return params_flatten

    def update_model(self, model, new_params):
        index = 0
        for params in model.parameters():
            params_length = len(params.view(-1))
            new_param = new_params[index: index + params_length]
            new_param = new_param.view(params.size())
            params.data.copy_(new_param)
            index += params_length

    def kl_approx(self, q, p):
        r = torch.exp(p - q)
        kl = r - 1 - p + q
        return kl

    def kl_divergence(self, obs, rnn_states, action, masks, available_actions, active_masks, new_actor, old_actor):
        _, _, mu, std, probs = new_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks)
        _, _, mu_old, std_old, probs_old = old_actor.evaluate_actions(obs, rnn_states, action, masks, available_actions, active_masks)
        if mu.grad_fn==None:
            probs_old=probs_old.detach()
            kl= self.kl_approx(probs_old,probs)
        else:
            logstd = torch.log(std)
            mu_old = mu_old.detach()
            std_old = std_old.detach()
            logstd_old = torch.log(std_old)
            # kl divergence between old policy and new policy : D( pi_old || pi_new )
            # pi_old -> mu0, logstd0, std0 / pi_new -> mu, logstd, std
            # be careful of calculating KL-divergence. It is not symmetric metric
            kl =  logstd - logstd_old  + (std_old.pow(2) + (mu_old - mu).pow(2)) / (2.0 * std.pow(2)) - 0.5
        
        if len(kl.shape)>1:
            kl=kl.sum(1, keepdim=True)
        return kl

    # from openai baseline code
    # https://github.com/openai/baselines/blob/master/baselines/common/cg.py
    def conjugate_gradient(self, actor, obs, rnn_states, action, masks, available_actions, active_masks, b, nsteps, residual_tol=1e-10):
        x = torch.zeros(b.size()).to(device=self.device)
        r = b.clone()
        p = b.clone()
        rdotr = torch.dot(r, r)
        for i in range(nsteps):
            _Avp = self.fisher_vector_product(actor, obs, rnn_states, action, masks, available_actions, active_masks, p)
            alpha = rdotr / torch.dot(p, _Avp)
            x += alpha * p
            r -= alpha * _Avp
            new_rdotr = torch.dot(r, r)
            betta = new_rdotr / rdotr
            p = r + betta * p
            rdotr = new_rdotr
            if rdotr < residual_tol:
                break
        return x

    def fisher_vector_product(self, actor, obs, rnn_states, action, masks, available_actions, active_masks, p):
        p.detach()
        kl = self.kl_divergence(obs, rnn_states, action, masks, available_actions, active_masks, new_actor=actor, old_actor=actor)
        kl = kl.mean()
        kl_grad = torch.autograd.grad(kl, actor.parameters(), create_graph=True, allow_unused=True)
        kl_grad = self.flat_grad(kl_grad)  # check kl_grad == 0
        kl_grad_p = (kl_grad * p).sum()
        kl_hessian_p = torch.autograd.grad(kl_grad_p, actor.parameters(), allow_unused=True)
        kl_hessian_p = self.flat_hessian(kl_hessian_p)
        return kl_hessian_p + 0.1 * p

    def trpo_update(self, sample, update_actor=True):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic update.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
        adv_targ, available_actions_batch, factor_batch = sample

        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)
        factor_batch = check(factor_batch).to(**self.tpdv)

        values, action_log_probs, dist_entropy, action_mu, action_std, _ = self.policy.evaluate_actions(share_obs_batch,
                                                                              obs_batch, 
                                                                              rnn_states_batch, 
                                                                              rnn_states_critic_batch, 
                                                                              actions_batch, 
                                                                              masks_batch, 
                                                                              available_actions_batch,
                                                                              active_masks_batch)

        # critic update
        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)

        self.policy.critic_optimizer.zero_grad()

        (value_loss * self.value_loss_coef).backward()

        if self._use_max_grad_norm:
            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
        else:
            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

        self.policy.critic_optimizer.step()

        # actor update
        ratio = torch.prod(torch.exp(action_log_probs - old_action_log_probs_batch),dim=-1,keepdim=True)
        if self._use_policy_active_masks:
            loss = (torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) *
                           active_masks_batch).sum() / active_masks_batch.sum()
        else:
            loss = torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True).mean()

        loss_grad = torch.autograd.grad(loss, self.policy.actor.parameters(), allow_unused=True)
        loss_grad = self.flat_grad(loss_grad)

        step_dir = self.conjugate_gradient(self.policy.actor, 
                                      obs_batch, 
                                      rnn_states_batch, 
                                      actions_batch, 
                                      masks_batch, 
                                      available_actions_batch, 
                                      active_masks_batch, 
                                      loss_grad.data, 
                                      nsteps=10)
        
        loss = loss.data.cpu().numpy()

        params = self.flat_params(self.policy.actor)
        fvp = self.fisher_vector_product(self.policy.actor,
                                    obs_batch, 
                                    rnn_states_batch, 
                                    actions_batch, 
                                    masks_batch, 
                                    available_actions_batch, 
                                    active_masks_batch, 
                                    step_dir)
        shs = 0.5 * (step_dir * fvp).sum(0, keepdim=True)
        step_size = 1 / torch.sqrt(shs / self.kl_threshold)[0]
        full_step = step_size * step_dir

        old_actor = Actor(self.policy.args, 
                            self.policy.obs_space,  
                            self.policy.act_space, 
                            self.device)
        self.update_model(old_actor, params)
        expected_improve = (loss_grad * full_step).sum(0, keepdim=True)
        expected_improve = expected_improve.data.cpu().numpy()
        

        # Backtracking line search
        flag = False
        fraction = 1
        for i in range(self.ls_step):
            new_params = params + fraction * full_step
            self.update_model(self.policy.actor, new_params)
            values, action_log_probs, dist_entropy, action_mu, action_std, _ = self.policy.evaluate_actions(share_obs_batch,
                                                                                obs_batch, 
                                                                                rnn_states_batch, 
                                                                                rnn_states_critic_batch, 
                                                                                actions_batch, 
                                                                                masks_batch, 
                                                                                available_actions_batch,
                                                                                active_masks_batch)

            ratio = torch.exp(action_log_probs - old_action_log_probs_batch)
            if self._use_policy_active_masks:
                new_loss = (torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True) *
                            active_masks_batch).sum() / active_masks_batch.sum()
            else:
                new_loss = torch.sum(ratio * factor_batch * adv_targ, dim=-1, keepdim=True).mean()

            new_loss = new_loss.data.cpu().numpy()
            loss_improve = new_loss - loss
            
            kl = self.kl_divergence(obs_batch, 
                               rnn_states_batch, 
                               actions_batch, 
                               masks_batch, 
                               available_actions_batch, 
                               active_masks_batch,
                               new_actor=self.policy.actor,
                               old_actor=old_actor)
            kl = kl.mean()

            if kl < self.kl_threshold and (loss_improve / expected_improve) > self.accept_ratio and loss_improve.item()>0:
                flag = True
                break
            expected_improve *= 0.5
            fraction *= 0.5

        if not flag:
            params = self.flat_params(old_actor)
            self.update_model(self.policy.actor, params)
            print('policy update does not impove the surrogate')

        return value_loss, critic_grad_norm, kl, loss_improve, expected_improve, dist_entropy, ratio

    def train(self, buffer, update_actor=True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        

        train_info = {}

        train_info['value_loss'] = 0
        train_info['kl'] = 0
        train_info['dist_entropy'] = 0
        train_info['loss_improve'] = 0
        train_info['expected_improve'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0


        if self._use_recurrent_policy:
            data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
        elif self._use_naive_recurrent:
            data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
        else:
            data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

        for sample in data_generator:

            value_loss, critic_grad_norm, kl, loss_improve, expected_improve, dist_entropy, imp_weights \
                = self.trpo_update(sample, update_actor)

            train_info['value_loss'] += value_loss.item()
            train_info['kl'] += kl
            train_info['loss_improve'] += loss_improve.item()
            train_info['expected_improve'] += expected_improve
            train_info['dist_entropy'] += dist_entropy.item()
            train_info['critic_grad_norm'] += critic_grad_norm
            train_info['ratio'] += imp_weights.mean()

        num_updates = self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates
 
        return train_info

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()


================================================
FILE: algorithms/utils/act.py
================================================
from .distributions import Bernoulli, Categorical, DiagGaussian
import torch
import torch.nn as nn

class ACTLayer(nn.Module):
    """
    MLP Module to compute actions.
    :param action_space: (gym.Space) action space.
    :param inputs_dim: (int) dimension of network input.
    :param use_orthogonal: (bool) whether to use orthogonal initialization.
    :param gain: (float) gain of the output layer of the network.
    """
    def __init__(self, action_space, inputs_dim, use_orthogonal, gain, args=None):
        super(ACTLayer, self).__init__()
        self.mixed_action = False
        self.multi_discrete = False
        self.action_type = action_space.__class__.__name__
        if action_space.__class__.__name__ == "Discrete":
            action_dim = action_space.n
            self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain)
        elif action_space.__class__.__name__ == "Box":
            action_dim = action_space.shape[0]
            self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain, args)
        elif action_space.__class__.__name__ == "MultiBinary":
            action_dim = action_space.shape[0]
            self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain)
        elif action_space.__class__.__name__ == "MultiDiscrete":
            self.multi_discrete = True
            action_dims = action_space.high - action_space.low + 1
            self.action_outs = []
            for action_dim in action_dims:
                self.action_outs.append(Categorical(inputs_dim, action_dim, use_orthogonal, gain))
            self.action_outs = nn.ModuleList(self.action_outs)
        else:  # discrete + continous
            self.mixed_action = True
            continous_dim = action_space[0].shape[0]
            discrete_dim = action_space[1].n
            self.action_outs = nn.ModuleList([DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain, args),
                                              Categorical(inputs_dim, discrete_dim, use_orthogonal, gain)])
    
    def forward(self, x, available_actions=None, deterministic=False):
        """
        Compute actions and action logprobs from given input.
        :param x: (torch.Tensor) input to network.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                  (if None, all actions available)
        :param deterministic: (bool) whether to sample from action distribution or return the mode.

        :return actions: (torch.Tensor) actions to take.
        :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
        """
        if self.mixed_action :
            actions = []
            action_log_probs = []
            for action_out in self.action_outs:
                action_logit = action_out(x)
                action = action_logit.mode() if deterministic else action_logit.sample()
                action_log_prob = action_logit.log_probs(action)
                actions.append(action.float())
                action_log_probs.append(action_log_prob)

            actions = torch.cat(actions, -1)
            action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)

        elif self.multi_discrete:
            actions = []
            action_log_probs = []
            for action_out in self.action_outs:
                action_logit = action_out(x)
                action = action_logit.mode() if deterministic else action_logit.sample()
                action_log_prob = action_logit.log_probs(action)
                actions.append(action)
                action_log_probs.append(action_log_prob)

            actions = torch.cat(actions, -1)
            action_log_probs = torch.cat(action_log_probs, -1)
        
        else:
            action_logits = self.action_out(x, available_actions)
            actions = action_logits.mode() if deterministic else action_logits.sample() 
            action_log_probs = action_logits.log_probs(actions)
        
        return actions, action_log_probs

    def get_probs(self, x, available_actions=None):
        """
        Compute action probabilities from inputs.
        :param x: (torch.Tensor) input to network.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                  (if None, all actions available)

        :return action_probs: (torch.Tensor)
        """
        if self.mixed_action or self.multi_discrete:
            action_probs = []
            for action_out in self.action_outs:
                action_logit = action_out(x)
                action_prob = action_logit.probs
                action_probs.append(action_prob)
            action_probs = torch.cat(action_probs, -1)
        else:
            action_logits = self.action_out(x, available_actions)
            action_probs = action_logits.probs
        
        return action_probs

    def evaluate_actions(self, x, action, available_actions=None, active_masks=None):
        """
        Compute log probability and entropy of given actions.
        :param x: (torch.Tensor) input to network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """
        if self.mixed_action:
            a, b = action.split((2, 1), -1)
            b = b.long()
            action = [a, b] 
            action_log_probs = [] 
            dist_entropy = []
            for action_out, act in zip(self.action_outs, action):
                action_logit = action_out(x)
                action_log_probs.append(action_logit.log_probs(act))
                if active_masks is not None:
                    if len(action_logit.entropy().shape) == len(active_masks.shape):
                        dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum()) 
                    else:
                        dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum())
                else:
                    dist_entropy.append(action_logit.entropy().mean())
                
            action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True)
            dist_entropy = dist_entropy[0] / 2.0 + dist_entropy[1] / 0.98 

        elif self.multi_discrete:
            action = torch.transpose(action, 0, 1)
            action_log_probs = []
            dist_entropy = []
            for action_out, act in zip(self.action_outs, action):
                action_logit = action_out(x)
                action_log_probs.append(action_logit.log_probs(act))
                if active_masks is not None:
                    dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())
                else:
                    dist_entropy.append(action_logit.entropy().mean())

            action_log_probs = torch.cat(action_log_probs, -1) 
            dist_entropy = torch.tensor(dist_entropy).mean()
        
        else:
            action_logits = self.action_out(x, available_actions)
            action_log_probs = action_logits.log_probs(action)
            if active_masks is not None:
                if self.action_type=="Discrete":
                    dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()
                else:
                    dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum()
            else:
                dist_entropy = action_logits.entropy().mean()
        
        return action_log_probs, dist_entropy

    def evaluate_actions_trpo(self, x, action, available_actions=None, active_masks=None):
        """
        Compute log probability and entropy of given actions.
        :param x: (torch.Tensor) input to network.
        :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
        :param available_actions: (torch.Tensor) denotes which actions are available to agent
                                                              (if None, all actions available)
        :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.

        :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
        :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
        """

        if self.multi_discrete:
            action = torch.transpose(action, 0, 1)
            action_log_probs = []
            dist_entropy = []
            mu_collector = []
            std_collector = []
            probs_collector = []
            for action_out, act in zip(self.action_outs, action):
                action_logit = action_out(x)
                mu = action_logit.mean
                std = action_logit.stddev
                action_log_probs.append(action_logit.log_probs(act))
                mu_collector.append(mu)
                std_collector.append(std)
                probs_collector.append(action_logit.logits)
                if active_masks is not None:
                    dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum())
                else:
                    dist_entropy.append(action_logit.entropy().mean())
            action_mu = torch.cat(mu_collector,-1)
            action_std = torch.cat(std_collector,-1)
            all_probs = torch.cat(probs_collector,-1)
            action_log_probs = torch.cat(action_log_probs, -1)
            dist_entropy = torch.tensor(dist_entropy).mean()
        
        else:
            action_logits = self.action_out(x, available_actions)
            action_mu = action_logits.mean
            action_std = action_logits.stddev
            action_log_probs = action_logits.log_probs(action)
            if self.action_type=="Discrete":
                all_probs = action_logits.logits
            else:
                all_probs = None
            if active_masks is not None:
                if self.action_type=="Discrete":
                    dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()
                else:
                    dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum()
            else:
                dist_entropy = action_logits.entropy().mean()
        
        return action_log_probs, dist_entropy, action_mu, action_std, all_probs


================================================
FILE: algorithms/utils/cnn.py
================================================
import torch.nn as nn
from .util import init

"""CNN Modules and utils."""

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class CNNLayer(nn.Module):
    def __init__(self, obs_shape, hidden_size, use_orthogonal, use_ReLU, kernel_size=3, stride=1):
        super(CNNLayer, self).__init__()

        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)

        input_channel = obs_shape[0]
        input_width = obs_shape[1]
        input_height = obs_shape[2]

        self.cnn = nn.Sequential(
            init_(nn.Conv2d(in_channels=input_channel,
                            out_channels=hidden_size // 2,
                            kernel_size=kernel_size,
                            stride=stride)
                  ),
            active_func,
            Flatten(),
            init_(nn.Linear(hidden_size // 2 * (input_width - kernel_size + stride) * (input_height - kernel_size + stride),
                            hidden_size)
                  ),
            active_func,
            init_(nn.Linear(hidden_size, hidden_size)), active_func)

    def forward(self, x):
        x = x / 255.0
        x = self.cnn(x)
        return x


class CNNBase(nn.Module):
    def __init__(self, args, obs_shape):
        super(CNNBase, self).__init__()

        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self.hidden_size = args.hidden_size

        self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._use_ReLU)

    def forward(self, x):
        x = self.cnn(x)
        return x


================================================
FILE: algorithms/utils/distributions.py
================================================
import torch
import torch.nn as nn
from .util import init

"""
Modify standard PyTorch distributions so they to make compatible with this codebase. 
"""

#
# Standardize distribution interfaces
#

# Categorical
class FixedCategorical(torch.distributions.Categorical):
    def sample(self):
        return super().sample().unsqueeze(-1)

    def log_probs(self, actions):
        return (
            super()
            .log_prob(actions.squeeze(-1))
            .view(actions.size(0), -1)
            .sum(-1)
            .unsqueeze(-1)
        )

    def mode(self):
        return self.probs.argmax(dim=-1, keepdim=True)


# Normal
class FixedNormal(torch.distributions.Normal):
    def log_probs(self, actions):
        return super().log_prob(actions)
        # return super().log_prob(actions).sum(-1, keepdim=True)

    def entrop(self):
        return super.entropy().sum(-1)

    def mode(self):
        return self.mean


# Bernoulli
class FixedBernoulli(torch.distributions.Bernoulli):
    def log_probs(self, actions):
        return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1)

    def entropy(self):
        return super().entropy().sum(-1)

    def mode(self):
        return torch.gt(self.probs, 0.5).float()


class Categorical(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
        super(Categorical, self).__init__()
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        def init_(m): 
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)

        self.linear = init_(nn.Linear(num_inputs, num_outputs))

    def forward(self, x, available_actions=None):
        x = self.linear(x)
        if available_actions is not None:
            x[available_actions == 0] = -1e10
        return FixedCategorical(logits=x)


# class DiagGaussian(nn.Module):
#     def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
#         super(DiagGaussian, self).__init__()
#
#         init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
#         def init_(m):
#             return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
#
#         self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
#         self.logstd = AddBias(torch.zeros(num_outputs))
#
#     def forward(self, x, available_actions=None):
#         action_mean = self.fc_mean(x)
#
#         #  An ugly hack for my KFAC implementation.
#         zeros = torch.zeros(action_mean.size())
#         if x.is_cuda:
#             zeros = zeros.cuda()
#
#         action_logstd = self.logstd(zeros)
#         return FixedNormal(action_mean, action_logstd.exp())

class DiagGaussian(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01, args=None):
        super(DiagGaussian, self).__init__()

        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)

        if args is not None:
            self.std_x_coef = args.std_x_coef
            self.std_y_coef = args.std_y_coef
        else:
            self.std_x_coef = 1.
            self.std_y_coef = 0.5
        self.fc_mean = init_(nn.Linear(num_inputs, num_outputs))
        log_std = torch.ones(num_outputs) * self.std_x_coef
        self.log_std = torch.nn.Parameter(log_std)

    def forward(self, x, available_actions=None):
        action_mean = self.fc_mean(x)
        action_std = torch.sigmoid(self.log_std / self.std_x_coef) * self.std_y_coef
        return FixedNormal(action_mean, action_std)

class Bernoulli(nn.Module):
    def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01):
        super(Bernoulli, self).__init__()
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        def init_(m): 
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain)
        
        self.linear = init_(nn.Linear(num_inputs, num_outputs))

    def forward(self, x):
        x = self.linear(x)
        return FixedBernoulli(logits=x)

class AddBias(nn.Module):
    def __init__(self, bias):
        super(AddBias, self).__init__()
        self._bias = nn.Parameter(bias.unsqueeze(1))

    def forward(self, x):
        if x.dim() == 2:
            bias = self._bias.t().view(1, -1)
        else:
            bias = self._bias.t().view(1, -1, 1, 1)

        return x + bias


================================================
FILE: algorithms/utils/mlp.py
================================================
import torch.nn as nn
from .util import init, get_clones

"""MLP modules."""

class MLPLayer(nn.Module):
    def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU):
        super(MLPLayer, self).__init__()
        self._layer_N = layer_N

        active_func = [nn.Tanh(), nn.ReLU()][use_ReLU]
        init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal]
        gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU])

        def init_(m):
            return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain)

        self.fc1 = nn.Sequential(
            init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        # self.fc_h = nn.Sequential(init_(
        #     nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size))
        # self.fc2 = get_clones(self.fc_h, self._layer_N)
        self.fc2 = nn.ModuleList([nn.Sequential(init_(
            nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) for i in range(self._layer_N)])

    def forward(self, x):
        x = self.fc1(x)
        for i in range(self._layer_N):
            x = self.fc2[i](x)
        return x


class MLPBase(nn.Module):
    def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
        super(MLPBase, self).__init__()

        self._use_feature_normalization = args.use_feature_normalization
        self._use_orthogonal = args.use_orthogonal
        self._use_ReLU = args.use_ReLU
        self._stacked_frames = args.stacked_frames
        self._layer_N = args.layer_N
        self.hidden_size = args.hidden_size

        obs_dim = obs_shape[0]

        if self._use_feature_normalization:
            self.feature_norm = nn.LayerNorm(obs_dim)

        self.mlp = MLPLayer(obs_dim, self.hidden_size,
                              self._layer_N, self._use_orthogonal, self._use_ReLU)

    def forward(self, x):
        if self._use_feature_normalization:
            x = self.feature_norm(x)

        x = self.mlp(x)

        return x

================================================
FILE: algorithms/utils/rnn.py
================================================
import torch
import torch.nn as nn

"""RNN modules."""


class RNNLayer(nn.Module):
    def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal):
        super(RNNLayer, self).__init__()
        self._recurrent_N = recurrent_N
        self._use_orthogonal = use_orthogonal

        self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N)
        for name, param in self.rnn.named_parameters():
            if 'bias' in name:
                nn.init.constant_(param, 0)
            elif 'weight' in name:
                if self._use_orthogonal:
                    nn.init.orthogonal_(param)
                else:
                    nn.init.xavier_uniform_(param)
        self.norm = nn.LayerNorm(outputs_dim)

    def forward(self, x, hxs, masks):
        if x.size(0) == hxs.size(0):
            x, hxs = self.rnn(x.unsqueeze(0),
                              (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous())
            x = x.squeeze(0)
            hxs = hxs.transpose(0, 1)
        else:
            # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
            N = hxs.size(0)
            T = int(x.size(0) / N)

            # unflatten
            x = x.view(T, N, x.size(1))

            # Same deal with masks
            masks = masks.view(T, N)

            # Let's figure out which steps in the sequence have a zero for any agent
            # We will always assume t=0 has a zero in it as that makes the logic cleaner
            has_zeros = ((masks[1:] == 0.0)
                         .any(dim=-1)
                         .nonzero()
                         .squeeze()
                         .cpu())

            # +1 to correct the masks[1:]
            if has_zeros.dim() == 0:
                # Deal with scalar
                has_zeros = [has_zeros.item() + 1]
            else:
                has_zeros = (has_zeros + 1).numpy().tolist()

            # add t=0 and t=T to the list
            has_zeros = [0] + has_zeros + [T]

            hxs = hxs.transpose(0, 1)

            outputs = []
            for i in range(len(has_zeros) - 1):
                # We can now process steps that don't have any zeros in masks together!
                # This is much faster
                start_idx = has_zeros[i]
                end_idx = has_zeros[i + 1]
                temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous()
                rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp)
                outputs.append(rnn_scores)

            # assert len(outputs) == T
            # x is a (T, N, -1) tensor
            x = torch.cat(outputs, dim=0)

            # flatten
            x = x.reshape(T * N, -1)
            hxs = hxs.transpose(0, 1)

        x = self.norm(x)
        return x, hxs


================================================
FILE: algorithms/utils/util.py
================================================
import copy
import numpy as np

import torch
import torch.nn as nn

def init(module, weight_init, bias_init, gain=1):
    weight_init(module.weight.data, gain=gain)
    bias_init(module.bias.data)
    return module

def get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

def check(input):
    output = torch.from_numpy(input) if type(input) == np.ndarray else input
    return output


================================================
FILE: configs/config.py
================================================
import argparse

def get_config():
    """
    The configuration parser for common hyperparameters of all environment. 
    Please reach each `scripts/train/<env>_runner.py` file to find private hyperparameters
    only used in <env>.

    Prepare parameters:
        --algorithm_name <algorithm_name>
            specifiy the algorithm, including `["happo", "hatrpo"]`
        --experiment_name <str>
            an identifier to distinguish different experiment.
        --seed <int>
            set seed for numpy and torch 
        --seed_specify
            by default True Random or specify seed for numpy/torch
        --running_id <int>
            the running index of experiment (default=1)
        --cuda
            by default True, will use GPU to train; or else will use CPU; 
        --cuda_deterministic
            by default, make sure random seed effective. if set, bypass such function.
        --n_training_threads <int>
            number of training threads working in parallel. by default 1
        --n_rollout_threads <int>
            number of parallel envs for training rollout. by default 32
        --n_eval_rollout_threads <int>
            number of parallel envs for evaluating rollout. by default 1
        --n_render_rollout_threads <int>
            number of parallel envs for rendering, could only be set as 1 for some environments.
        --num_env_steps <int>
            number of env steps to train (default: 10e6)

    
    Env parameters:
        --env_name <str>
            specify the name of environment
        --use_obs_instead_of_state
            [only for some env] by default False, will use global state; or else will use concatenated local obs.
    
    Replay Buffer parameters:
        --episode_length <int>
            the max length of episode in the buffer. 
    
    Network parameters:
        --share_policy
            by default True, all agents will share the same network; set to make training agents use different policies. 
        --use_centralized_V
            by default True, use centralized training mode; or else will decentralized training mode.
        --stacked_frames <int>
            Number of input frames which should be stack together.
        --hidden_size <int>
            Dimension of hidden layers for actor/critic networks
        --layer_N <int>
            Number of layers for actor/critic networks
        --use_ReLU
            by default True, will use ReLU. or else will use Tanh.
        --use_popart
            by default True, use running mean and std to normalize rewards. 
        --use_feature_normalization
            by default True, apply layernorm to normalize inputs. 
        --use_orthogonal
            by default True, use Orthogonal initialization for weights and 0 initialization for biases. or else, will use xavier uniform inilialization.
        --gain
            by default 0.01, use the gain # of last action layer
        --use_naive_recurrent_policy
            by default False, use the whole trajectory to calculate hidden states.
        --use_recurrent_policy
            by default, use Recurrent Policy. If set, do not use.
        --recurrent_N <int>
            The number of recurrent layers ( default 1).
        --data_chunk_length <int>
            Time length of chunks used to train a recurrent_policy, default 10.
    
    Optimizer parameters:
        --lr <float>
            learning rate parameter,  (default: 5e-4, fixed).
        --critic_lr <float>
            learning rate of critic  (default: 5e-4, fixed)
        --opti_eps <float>
            RMSprop optimizer epsilon (default: 1e-5)
        --weight_decay <float>
            coefficience of weight decay (default: 0)
    
    TRPO parameters:
        --kl_threshold <float>
            the threshold of kl-divergence (default: 0.01)
        --ls_step <int> 
            the step of line search (default: 10)
        --accept_ratio <float>
            accept ratio of loss improve (default: 0.5)
    
    PPO parameters:
        --ppo_epoch <int>
            number of ppo epochs (default: 15)
        --use_clipped_value_loss 
            by default, clip loss value. If set, do not clip loss value.
        --clip_param <float>
            ppo clip parameter (default: 0.2)
        --num_mini_batch <int>
            number of batches for ppo (default: 1)
        --entropy_coef <float>
            entropy term coefficient (default: 0.01)
        --use_max_grad_norm 
            by default, use max norm of gradients. If set, do not use.
        --max_grad_norm <float>
            max norm of gradients (default: 0.5)
        --use_gae
            by default, use generalized advantage estimation. If set, do not use gae.
        --gamma <float>
            discount factor for rewards (default: 0.99)
        --gae_lambda <float>
            gae lambda parameter (default: 0.95)
        --use_proper_time_limits
            by default, the return value does consider limits of time. If set, compute returns with considering time limits factor.
        --use_huber_loss
            by default, use huber loss. If set, do not use huber loss.
        --use_value_active_masks
            by default True, whether to mask useless data in value loss.  
        --huber_delta <float>
            coefficient of huber loss.  

    
    Run parameters:
        --use_linear_lr_decay
            by default, do not apply linear decay to learning rate. If set, use a linear schedule on the learning rate
        --save_interval <int>
            time duration between contiunous twice models saving.
        --log_interval <int>
            time duration between contiunous twice log printing.
        --model_dir <str>
            by default None. set the path to pretrained model.

    Eval parameters:
        --use_eval
            by default, do not start evaluation. If set`, start evaluation alongside with training.
        --eval_interval <int>
            time duration between contiunous twice evaluation progress.
        --eval_episodes <int>
            number of episodes of a single evaluation.
    
    Render parameters:
        --save_gifs
            by default, do not save render video. If set, save video.
        --use_render
            by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.
        --render_episodes <int>
            the number of episodes to render a given env
        --ifi <float>
            the play interval of each rendered image in saved video.
    
    Pretrained parameters:
        
    """
    parser = argparse.ArgumentParser(description='onpolicy_algorithm', formatter_class=argparse.RawDescriptionHelpFormatter)

    # prepare parameters
    parser.add_argument("--algorithm_name", type=str,
                        default=' ', choices=["happo","hatrpo"])
    parser.add_argument("--experiment_name", type=str, 
                        default="check", help="an identifier to distinguish different experiment.")
    parser.add_argument("--seed", type=int, 
                        default=1, help="Random seed for numpy/torch")
    parser.add_argument("--seed_specify", action="store_false",
                        default=True, help="Random or specify seed for numpy/torch")
    parser.add_argument("--running_id", type=int, 
                        default=1, help="the running index of experiment")
    parser.add_argument("--cuda", action='store_false', 
                        default=True, help="by default True, will use GPU to train; or else will use CPU;")
    parser.add_argument("--cuda_deterministic", action='store_false', 
                        default=True, help="by default, make sure random seed effective. if set, bypass such function.")
    parser.add_argument("--n_training_threads", type=int,
                        default=1, help="Number of torch threads for training")
    parser.add_argument("--n_rollout_threads", type=int, 
                        default=32, help="Number of parallel envs for training rollouts")
    parser.add_argument("--n_eval_rollout_threads", type=int, 
                        default=1, help="Number of parallel envs for evaluating rollouts")
    parser.add_argument("--n_render_rollout_threads", type=int, 
                        default=1, help="Number of parallel envs for rendering rollouts")
    parser.add_argument("--num_env_steps", type=int, 
                        default=10e6, help='Number of environment steps to train (default: 10e6)')
    parser.add_argument("--user_name", type=str, 
                        default='marl',help="[for wandb usage], to specify user's name for simply collecting training data.")
    # env parameters
    parser.add_argument("--env_name", type=str, 
                        default='StarCraft2', help="specify the name of environment")
    parser.add_argument("--use_obs_instead_of_state", action='store_true',
                        default=False, help="Whether to use global state or concatenated obs")

    # replay buffer parameters
    parser.add_argument("--episode_length", type=int,
                        default=200, help="Max length for any episode")

    # network parameters
    parser.add_argument("--share_policy", action='store_false',
                        default=True, help='Whether agent share the same policy')
    parser.add_argument("--use_centralized_V", action='store_false',
                        default=True, help="Whether to use centralized V function")
    parser.add_argument("--stacked_frames", type=int, 
                        default=1, help="Dimension of hidden layers for actor/critic networks")
    parser.add_argument("--use_stacked_frames", action='store_true',
                        default=False, help="Whether to use stacked_frames")
    parser.add_argument("--hidden_size", type=int, 
                        default=64, help="Dimension of hidden layers for actor/critic networks") 
    parser.add_argument("--layer_N", type=int, 
                        default=1, help="Number of layers for actor/critic networks")
    parser.add_argument("--use_ReLU", action='store_false',
                        default=True, help="Whether to use ReLU")
    parser.add_argument("--use_popart", action='store_false', 
                        default=True, help="by default True, use running mean and std to normalize rewards.")
    parser.add_argument("--use_valuenorm", action='store_false', 
                        default=True, help="by default True, use running mean and std to normalize rewards.")
    parser.add_argument("--use_feature_normalization", action='store_false',
                        default=True, help="Whether to apply layernorm to the inputs")
    parser.add_argument("--use_orthogonal", action='store_false', 
                        default=True, help="Whether to use Orthogonal initialization for weights and 0 initialization for biases")
    parser.add_argument("--gain", type=float, 
                        default=0.01, help="The gain # of last action layer")

    # recurrent parameters
    parser.add_argument("--use_naive_recurrent_policy", action='store_true',
                        default=False, help='Whether to use a naive recurrent policy')
    parser.add_argument("--use_recurrent_policy", action='store_true',
                        default=False, help='use a recurrent policy')
    parser.add_argument("--recurrent_N", type=int, 
                        default=1, help="The number of recurrent layers.")
    parser.add_argument("--data_chunk_length", type=int, 
                        default=10, help="Time length of chunks used to train a recurrent_policy")
    
    # optimizer parameters
    parser.add_argument("--lr", type=float, 
                        default=5e-4, help='learning rate (default: 5e-4)')
    parser.add_argument("--critic_lr", type=float, 
                        default=5e-4, help='critic learning rate (default: 5e-4)')
    parser.add_argument("--opti_eps", type=float, 
                        default=1e-5, help='RMSprop optimizer epsilon (default: 1e-5)')
    parser.add_argument("--weight_decay", type=float, default=0)
    parser.add_argument("--std_x_coef", type=float, default=1)
    parser.add_argument("--std_y_coef", type=float, default=0.5)


    # trpo parameters
    parser.add_argument("--kl_threshold", type=float, 
                        default=0.01, help='the threshold of kl-divergence (default: 0.01)')
    parser.add_argument("--ls_step", type=int, 
                        default=10, help='number of line search (default: 10)')
    parser.add_argument("--accept_ratio", type=float, 
                        default=0.5, help='accept ratio of loss improve (default: 0.5)')

    # ppo parameters
    parser.add_argument("--ppo_epoch", type=int, 
                        default=15, help='number of ppo epochs (default: 15)')
    parser.add_argument("--use_clipped_value_loss", action='store_false', 
                        default=True, help="by default, clip loss value. If set, do not clip loss value.")
    parser.add_argument("--clip_param", type=float, 
                        default=0.2, help='ppo clip parameter (default: 0.2)')
    parser.add_argument("--num_mini_batch", type=int, 
                        default=1, help='number of batches for ppo (default: 1)')
    parser.add_argument("--entropy_coef", type=float, 
                        default=0.01, help='entropy term coefficient (default: 0.01)')
    parser.add_argument("--value_loss_coef", type=float,
                        default=1, help='value loss coefficient (default: 0.5)')
    parser.add_argument("--use_max_grad_norm", action='store_false', 
                        default=True, help="by default, use max norm of gradients. If set, do not use.")
    parser.add_argument("--max_grad_norm", type=float, 
                        default=10.0, help='max norm of gradients (default: 0.5)')
    parser.add_argument("--use_gae", action='store_false',
                        default=True, help='use generalized advantage estimation')
    parser.add_argument("--gamma", type=float, default=0.99,
                        help='discount factor for rewards (default: 0.99)')
    parser.add_argument("--gae_lambda", type=float, default=0.95,
                        help='gae lambda parameter (default: 0.95)')
    parser.add_argument("--use_proper_time_limits", action='store_true',
                        default=False, help='compute returns taking into account time limits')
    parser.add_argument("--use_huber_loss", action='store_false', 
                        default=True, help="by default, use huber loss. If set, do not use huber loss.")
    parser.add_argument("--use_value_active_masks", action='store_false', 
                        default=True, help="by default True, whether to mask useless data in value loss.")
    parser.add_argument("--use_policy_active_masks", action='store_false', 
                        default=True, help="by default True, whether to mask useless data in policy loss.")
    parser.add_argument("--huber_delta", type=float, 
                        default=10.0, help=" coefficience of huber loss.")

    # run parameters
    parser.add_argument("--use_linear_lr_decay", action='store_true',
                        default=False, help='use a linear schedule on the learning rate')
    parser.add_argument("--save_interval", type=int, 
                        default=1, help="time duration between contiunous twice models saving.")
    parser.add_argument("--log_interval", type=int, 
                        default=5, help="time duration between contiunous twice log printing.")
    parser.add_argument("--model_dir", type=str, 
                        default=None, help="by default None. set the path to pretrained model.")

    # eval parameters
    parser.add_argument("--use_eval", action='store_true', 
                        default=False, help="by default, do not start evaluation. If set`, start evaluation alongside with training.")
    parser.add_argument("--eval_interval", type=int, 
                        default=25, help="time duration between contiunous twice evaluation progress.")
    parser.add_argument("--eval_episodes", type=int, 
                        default=32, help="number of episodes of a single evaluation.")

    # render parameters
    parser.add_argument("--save_gifs", action='store_true', 
                        default=False, help="by default, do not save render video. If set, save video.")
    parser.add_argument("--use_render", action='store_true', 
                        default=False, help="by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.")
    parser.add_argument("--render_episodes", type=int, 
                        default=5, help="the number of episodes to render a given env")
    parser.add_argument("--ifi", type=float, 
                        default=0.1, help="the play interval of each rendered image in saved video.")

    return parser


================================================
FILE: envs/__init__.py
================================================

import socket
from absl import flags
FLAGS = flags.FLAGS
FLAGS(['train_sc.py'])




================================================
FILE: envs/env_wrappers.py
================================================
"""
Modified from OpenAI Baselines code to work with multi-agent envs
"""
import numpy as np
import torch
from multiprocessing import Process, Pipe
from abc import ABC, abstractmethod
from utils.util import tile_images

class CloudpickleWrapper(object):
    """
    Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
    """

    def __init__(self, x):
        self.x = x

    def __getstate__(self):
        import cloudpickle
        return cloudpickle.dumps(self.x)

    def __setstate__(self, ob):
        import pickle
        self.x = pickle.loads(ob)


class ShareVecEnv(ABC):
    """
    An abstract asynchronous, vectorized environment.
    Used to batch data from multiple copies of an environment, so that
    each observation becomes an batch of observations, and expected action is a batch of actions to
    be applied per-environment.
    """
    closed = False
    viewer = None

    metadata = {
        'render.modes': ['human', 'rgb_array']
    }

    def __init__(self, num_envs, observation_space, share_observation_space, action_space):
        self.num_envs = num_envs
        self.observation_space = observation_space
        self.share_observation_space = share_observation_space
        self.action_space = action_space

    @abstractmethod
    def reset(self):
        """
        Reset all the environments and return an array of
        observations, or a dict of observation arrays.

        If step_async is still doing work, that work will
        be cancelled and step_wait() should not be called
        until step_async() is invoked again.
        """
        pass

    @abstractmethod
    def step_async(self, actions):
        """
        Tell all the environments to start taking a step
        with the given actions.
        Call step_wait() to get the results of the step.

        You should not call this if a step_async run is
        already pending.
        """
        pass

    @abstractmethod
    def step_wait(self):
        """
        Wait for the step taken with step_async().

        Returns (obs, rews, dones, infos):
         - obs: an array of observations, or a dict of
                arrays of observations.
         - rews: an array of rewards
         - dones: an array of "episode done" booleans
         - infos: a sequence of info objects
        """
        pass

    def close_extras(self):
        """
        Clean up the  extra resources, beyond what's in this base class.
        Only runs when not self.closed.
        """
        pass

    def close(self):
        if self.closed:
            return
        if self.viewer is not None:
            self.viewer.close()
        self.close_extras()
        self.closed = True

    def step(self, actions):
        """
        Step the environments synchronously.

        This is available for backwards compatibility.
        """
        self.step_async(actions)
        return self.step_wait()

    def render(self, mode='human'):
        imgs = self.get_images()
        bigimg = tile_images(imgs)
        if mode == 'human':
            self.get_viewer().imshow(bigimg)
            return self.get_viewer().isopen
        elif mode == 'rgb_array':
            return bigimg
        else:
            raise NotImplementedError

    def get_images(self):
        """
        Return RGB images from each environment
        """
        raise NotImplementedError

    @property
    def unwrapped(self):
        if isinstance(self, VecEnvWrapper):
            return self.venv.unwrapped
        else:
            return self

    def get_viewer(self):
        if self.viewer is None:
            from gym.envs.classic_control import rendering
            self.viewer = rendering.SimpleImageViewer()
        return self.viewer


def worker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            if 'bool' in done.__class__.__name__:
                if done:
                    ob = env.reset()
            else:
                if np.all(done):
                    ob = env.reset()

            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset()
            remote.send((ob))
        elif cmd == 'render':
            if data == "rgb_array":
                fr = env.render(mode=data)
                remote.send(fr)
            elif data == "human":
                env.render(mode=data)
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            env.close()
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send((env.observation_space, env.share_observation_space, env.action_space))
        else:
            raise NotImplementedError


class GuardSubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = False  # could cause zombie process
            p.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv()
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):

        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        obs = [remote.recv() for remote in self.remotes]
        return np.stack(obs)

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True


class SubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()

        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv()
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        obs = [remote.recv() for remote in self.remotes]
        return np.stack(obs)


    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True

    def render(self, mode="rgb_array"):
        for remote in self.remotes:
            remote.send(('render', mode))
        if mode == "rgb_array":   
            frame = [remote.recv() for remote in self.remotes]
            return np.stack(frame) 


def shareworker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, s_ob, reward, done, info, available_actions = env.step(data)
            if 'bool' in done.__class__.__name__:
                if done:
                    ob, s_ob, available_actions = env.reset()
            else:
                if np.all(done):
                    ob, s_ob, available_actions = env.reset()

            remote.send((ob, s_ob, reward, done, info, available_actions))
        elif cmd == 'reset':
            ob, s_ob, available_actions = env.reset()
            remote.send((ob, s_ob, available_actions))
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'render':
            if data == "rgb_array":
                fr = env.render(mode=data)
                remote.send(fr)
            elif data == "human":
                env.render(mode=data)
        elif cmd == 'close':
            env.close()
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send(
                (env.observation_space, env.share_observation_space, env.action_space))
        elif cmd == 'render_vulnerability':
            fr = env.render_vulnerability(data)
            remote.send((fr))
        elif cmd == 'get_num_agents':
            remote.send((env.n_agents))
        else:
            raise NotImplementedError


class ShareSubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=shareworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()
        self.remotes[0].send(('get_num_agents', None))
        self.n_agents = self.remotes[0].recv()
        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv(
        )
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, share_obs, rews, dones, infos, available_actions = zip(*results)
        return np.stack(obs), np.stack(share_obs), np.stack(rews), np.stack(dones), infos, np.stack(available_actions)

    def reset(self):
        for remote in self.remotes:
            remote.send(('reset', None))
        results = [remote.recv() for remote in self.remotes]
        obs, share_obs, available_actions = zip(*results)
        return np.stack(obs), np.stack(share_obs), np.stack(available_actions)

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True


def choosesimpleworker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset(data)
            remote.send((ob))
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            env.close()
            remote.close()
            break
        elif cmd == 'render':
            if data == "rgb_array":
                fr = env.render(mode=data)
                remote.send(fr)
            elif data == "human":
                env.render(mode=data)
        elif cmd == 'get_spaces':
            remote.send(
                (env.observation_space, env.share_observation_space, env.action_space))
        else:
            raise NotImplementedError


class ChooseSimpleSubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=choosesimpleworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()
        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv()
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self, reset_choose):
        for remote, choose in zip(self.remotes, reset_choose):
            remote.send(('reset', choose))
        obs = [remote.recv() for remote in self.remotes]
        return np.stack(obs)

    def render(self, mode="rgb_array"):
        for remote in self.remotes:
            remote.send(('render', mode))
        if mode == "rgb_array":   
            frame = [remote.recv() for remote in self.remotes]
            return np.stack(frame)

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True


def chooseworker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, s_ob, reward, done, info, available_actions = env.step(data)
            remote.send((ob, s_ob, reward, done, info, available_actions))
        elif cmd == 'reset':
            ob, s_ob, available_actions = env.reset(data)
            remote.send((ob, s_ob, available_actions))
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            env.close()
            remote.close()
            break
        elif cmd == 'render':
            remote.send(env.render(mode='rgb_array'))
        elif cmd == 'get_spaces':
            remote.send(
                (env.observation_space, env.share_observation_space, env.action_space))
        else:
            raise NotImplementedError


class ChooseSubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=chooseworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = True  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()
        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv(
        )
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, share_obs, rews, dones, infos, available_actions = zip(*results)
        return np.stack(obs), np.stack(share_obs), np.stack(rews), np.stack(dones), infos, np.stack(available_actions)

    def reset(self, reset_choose):
        for remote, choose in zip(self.remotes, reset_choose):
            remote.send(('reset', choose))
        results = [remote.recv() for remote in self.remotes]
        obs, share_obs, available_actions = zip(*results)
        return np.stack(obs), np.stack(share_obs), np.stack(available_actions)

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True


def chooseguardworker(remote, parent_remote, env_fn_wrapper):
    parent_remote.close()
    env = env_fn_wrapper.x()
    while True:
        cmd, data = remote.recv()
        if cmd == 'step':
            ob, reward, done, info = env.step(data)
            remote.send((ob, reward, done, info))
        elif cmd == 'reset':
            ob = env.reset(data)
            remote.send((ob))
        elif cmd == 'reset_task':
            ob = env.reset_task()
            remote.send(ob)
        elif cmd == 'close':
            env.close()
            remote.close()
            break
        elif cmd == 'get_spaces':
            remote.send(
                (env.observation_space, env.share_observation_space, env.action_space))
        else:
            raise NotImplementedError


class ChooseGuardSubprocVecEnv(ShareVecEnv):
    def __init__(self, env_fns, spaces=None):
        """
        envs: list of gym environments to run in subprocesses
        """
        self.waiting = False
        self.closed = False
        nenvs = len(env_fns)
        self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
        self.ps = [Process(target=chooseguardworker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
                   for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
        for p in self.ps:
            p.daemon = False  # if the main process crashes, we should not cause things to hang
            p.start()
        for remote in self.work_remotes:
            remote.close()
        self.remotes[0].send(('get_spaces', None))
        observation_space, share_observation_space, action_space = self.remotes[0].recv(
        )
        ShareVecEnv.__init__(self, len(env_fns), observation_space,
                             share_observation_space, action_space)

    def step_async(self, actions):
        for remote, action in zip(self.remotes, actions):
            remote.send(('step', action))
        self.waiting = True

    def step_wait(self):
        results = [remote.recv() for remote in self.remotes]
        self.waiting = False
        obs, rews, dones, infos = zip(*results)
        return np.stack(obs), np.stack(rews), np.stack(dones), infos

    def reset(self, reset_choose):
        for remote, choose in zip(self.remotes, reset_choose):
            remote.send(('reset', choose))
        obs = [remote.recv() for remote in self.remotes]
        return np.stack(obs)

    def reset_task(self):
        for remote in self.remotes:
            remote.send(('reset_task', None))
        return np.stack([remote.recv() for remote in self.remotes])

    def close(self):
        if self.closed:
            return
        if self.waiting:
            for remote in self.remotes:
                remote.recv()
        for remote in self.remotes:
            remote.send(('close', None))
        for p in self.ps:
            p.join()
        self.closed = True


# single env
class DummyVecEnv(ShareVecEnv):
    def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        ShareVecEnv.__init__(self, len(
            env_fns), env.observation_space, env.share_observation_space, env.action_space)
        self.actions = None

    def step_async(self, actions):
        self.actions = actions

    def step_wait(self):
        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]
        obs, rews, dones, infos = map(np.array, zip(*results))

        for (i, done) in enumerate(dones):
            if 'bool' in done.__class__.__name__:
                if done:
                    obs[i] = self.envs[i].reset()
            else:
                if np.all(done):
                    obs[i] = self.envs[i].reset()

        self.actions = None
        return obs, rews, dones, infos

    def reset(self):
        obs = [env.reset() for env in self.envs]
        return np.array(obs)

    def close(self):
        for env in self.envs:
            env.close()

    def render(self, mode="human"):
        if mode == "rgb_array":
            return np.array([env.render(mode=mode) for env in self.envs])
        elif mode == "human":
            for env in self.envs:
                env.render(mode=mode)
        else:
            raise NotImplementedError



class ShareDummyVecEnv(ShareVecEnv):
    def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        ShareVecEnv.__init__(self, len(
            env_fns), env.observation_space, env.share_observation_space, env.action_space)
        self.actions = None

    def step_async(self, actions):
        self.actions = actions

    def step_wait(self):
        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]
        obs, share_obs, rews, dones, infos, available_actions = map(
            np.array, zip(*results))

        for (i, done) in enumerate(dones):
            if 'bool' in done.__class__.__name__:
                if done:
                    obs[i], share_obs[i], available_actions[i] = self.envs[i].reset()
            else:
                if np.all(done):
                    obs[i], share_obs[i], available_actions[i] = self.envs[i].reset()
        self.actions = None

        return obs, share_obs, rews, dones, infos, available_actions

    def reset(self):
        results = [env.reset() for env in self.envs]
        obs, share_obs, available_actions = map(np.array, zip(*results))
        return obs, share_obs, available_actions

    def close(self):
        for env in self.envs:
            env.close()
    
    def render(self, mode="human"):
        if mode == "rgb_array":
            return np.array([env.render(mode=mode) for env in self.envs])
        elif mode == "human":
            for env in self.envs:
                env.render(mode=mode)
        else:
            raise NotImplementedError


class ChooseDummyVecEnv(ShareVecEnv):
    def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        ShareVecEnv.__init__(self, len(
            env_fns), env.observation_space, env.share_observation_space, env.action_space)
        self.actions = None

    def step_async(self, actions):
        self.actions = actions

    def step_wait(self):
        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]
        obs, share_obs, rews, dones, infos, available_actions = map(
            np.array, zip(*results))
        self.actions = None
        return obs, share_obs, rews, dones, infos, available_actions

    def reset(self, reset_choose):
        results = [env.reset(choose)
                   for (env, choose) in zip(self.envs, reset_choose)]
        obs, share_obs, available_actions = map(np.array, zip(*results))
        return obs, share_obs, available_actions

    def close(self):
        for env in self.envs:
            env.close()

    def render(self, mode="human"):
        if mode == "rgb_array":
            return np.array([env.render(mode=mode) for env in self.envs])
        elif mode == "human":
            for env in self.envs:
                env.render(mode=mode)
        else:
            raise NotImplementedError

class ChooseSimpleDummyVecEnv(ShareVecEnv):
    def __init__(self, env_fns):
        self.envs = [fn() for fn in env_fns]
        env = self.envs[0]
        ShareVecEnv.__init__(self, len(
            env_fns), env.observation_space, env.share_observation_space, env.action_space)
        self.actions = None

    def step_async(self, actions):
        self.actions = actions

    def step_wait(self):
        results = [env.step(a) for (a, env) in zip(self.actions, self.envs)]
        obs, rews, dones, infos = map(np.array, zip(*results))
        self.actions = None
        return obs, rews, dones, infos

    def reset(self, reset_choose):
        obs = [env.reset(choose)
                   for (env, choose) in zip(self.envs, reset_choose)]
        return np.array(obs)

    def close(self):
        for env in self.envs:
            env.close()

    def render(self, mode="human"):
        if mode == "rgb_array":
            return np.array([env.render(mode=mode) for env in self.envs])
        elif mode == "human":
            for env in self.envs:
                env.render(mode=mode)
        else:
            raise NotImplementedError


================================================
FILE: envs/ma_mujoco/__init__.py
================================================


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/__init__.py
================================================
from .mujoco_multi import MujocoMulti
from .coupled_half_cheetah import CoupledHalfCheetah
from .manyagent_swimmer import ManyAgentSwimmerEnv
from .manyagent_ant import ManyAgentAntEnv


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/.gitignore
================================================
*.auto.xml


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/__init__.py
================================================


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/coupled_half_cheetah.xml
================================================
<!-- Cheetah Model
    The state space is populated with joints in the order that they are
    defined in this file. The actuators also operate on joints.
    State-Space (name/joint/parameter):
        - rootx     slider      position (m)
        - rootz     slider      position (m)
        - rooty     hinge       angle (rad)
        - bthigh    hinge       angle (rad)
        - bshin     hinge       angle (rad)
        - bfoot     hinge       angle (rad)
        - fthigh    hinge       angle (rad)
        - fshin     hinge       angle (rad)
        - ffoot     hinge       angle (rad)
        - rootx     slider      velocity (m/s)
        - rootz     slider      velocity (m/s)
        - rooty     hinge       angular velocity (rad/s)
        - bthigh    hinge       angular velocity (rad/s)
        - bshin     hinge       angular velocity (rad/s)
        - bfoot     hinge       angular velocity (rad/s)
        - fthigh    hinge       angular velocity (rad/s)
        - fshin     hinge       angular velocity (rad/s)
        - ffoot     hinge       angular velocity (rad/s)
    Actuators (name/actuator/parameter):
        - bthigh    hinge       torque (N m)
        - bshin     hinge       torque (N m)
        - bfoot     hinge       torque (N m)
        - fthigh    hinge       torque (N m)
        - fshin     hinge       torque (N m)
        - ffoot     hinge       torque (N m)
-->
<mujoco model="cheetah">
  <compiler angle="radian" coordinate="local" inertiafromgeom="true" settotalmass="14"/>
  <default>
    <joint armature=".1" damping=".01" limited="true" solimplimit="0 .8 .03" solreflimit=".02 1" stiffness="8"/>
    <geom conaffinity="0" condim="3" contype="1" friction=".4 .1 .1" rgba="0.8 0.6 .4 1" solimp="0.0 0.8 0.01" solref="0.02 1"/>
    <motor ctrllimited="true" ctrlrange="-1 1"/>
  </default>
  <size nstack="300000" nuser_geom="1"/>
  <option gravity="0 0 -9.81" timestep="0.01"/>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 -1 .7">
      <site name="t1" pos="0.0 0 0" size="0.1"/>
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
    <!-- second cheetah definition -->
    <body name="torso2" pos="0 1 .7">
      <site name="t2" pos="0 0 0" size="0.1"/>
      <camera name="track2" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <joint armature="0" axis="1 0 0" damping="0" limited="false" name="rootx2" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 0 1" damping="0" limited="false" name="rootz2" pos="0 0 0" stiffness="0" type="slide"/>
      <joint armature="0" axis="0 1 0" damping="0" limited="false" name="rooty2" pos="0 0 0" stiffness="0" type="hinge"/>
      <geom fromto="-.5 0 0 .5 0 0" name="torso2" size="0.046" type="capsule"/>
      <geom axisangle="0 1 0 .87" name="head2" pos=".6 0 .1" size="0.046 .15" type="capsule"/>
      <!-- <site name='tip'  pos='.15 0 .11'/>-->
      <body name="bthigh2" pos="-.5 0 0">
        <joint axis="0 1 0" damping="6" name="bthigh2" pos="0 0 0" range="-.52 1.05" stiffness="240" type="hinge"/>
        <geom axisangle="0 1 0 -3.8" name="bthigh2" pos=".1 0 -.13" size="0.046 .145" type="capsule"/>
        <body name="bshin2" pos=".16 0 -.25">
          <joint axis="0 1 0" damping="4.5" name="bshin2" pos="0 0 0" range="-.785 .785" stiffness="180" type="hinge"/>
          <geom axisangle="0 1 0 -2.03" name="bshin2" pos="-.14 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .15" type="capsule"/>
          <body name="bfoot2" pos="-.28 0 -.14">
            <joint axis="0 1 0" damping="3" name="bfoot2" pos="0 0 0" range="-.4 .785" stiffness="120" type="hinge"/>
            <geom axisangle="0 1 0 -.27" name="bfoot2" pos=".03 0 -.097" rgba="0.9 0.6 0.6 1" size="0.046 .094" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="fthigh2" pos=".5 0 0">
        <joint axis="0 1 0" damping="4.5" name="fthigh2" pos="0 0 0" range="-1 .7" stiffness="180" type="hinge"/>
        <geom axisangle="0 1 0 .52" name="fthigh2" pos="-.07 0 -.12" size="0.046 .133" type="capsule"/>
        <body name="fshin2" pos="-.14 0 -.24">
          <joint axis="0 1 0" damping="3" name="fshin2" pos="0 0 0" range="-1.2 .87" stiffness="120" type="hinge"/>
          <geom axisangle="0 1 0 -.6" name="fshin2" pos=".065 0 -.09" rgba="0.9 0.6 0.6 1" size="0.046 .106" type="capsule"/>
          <body name="ffoot2" pos=".13 0 -.18">
            <joint axis="0 1 0" damping="1.5" name="ffoot2" pos="0 0 0" range="-.5 .5" stiffness="60" type="hinge"/>
            <geom axisangle="0 1 0 -.6" name="ffoot2" pos=".045 0 -.07" rgba="0.9 0.6 0.6 1" size="0.046 .07" type="capsule"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <tendon>
    <spatial name="tendon1" width="0.05" rgba=".95 .3 .3 1" limited="true" range="1.5 3.5" stiffness="0.1">
        <site site="t1"/>
        <site site="t2"/>
    </spatial>
  </tendon>-
  <actuator>
    <motor gear="120" joint="bthigh" name="bthigh"/>
    <motor gear="90" joint="bshin" name="bshin"/>
    <motor gear="60" joint="bfoot" name="bfoot"/>
    <motor gear="120" joint="fthigh" name="fthigh"/>
    <motor gear="60" joint="fshin" name="fshin"/>
    <motor gear="30" joint="ffoot" name="ffoot"/>
    <motor gear="120" joint="bthigh2" name="bthigh2"/>
    <motor gear="90" joint="bshin2" name="bshin2"/>
    <motor gear="60" joint="bfoot2" name="bfoot2"/>
    <motor gear="120" joint="fthigh2" name="fthigh2"/>
    <motor gear="60" joint="fshin2" name="fshin2"/>
    <motor gear="30" joint="ffoot2" name="ffoot2"/>
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml
================================================
<mujoco model="ant">
  <size nconmax="200"/>
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <!--<geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>-->
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="midx" pos="0.0 0 0">
        <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
        <!--<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>-->
        <body name="front_right_legx" pos="-1 0 0">
          <geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux_2_geomx" size="0.08" type="capsule"/>
          <body name="aux_2x" pos="0.0 0.2 0">
            <joint axis="0 0 1" name="hip_2x" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geomx" size="0.08" type="capsule"/>
            <body pos="-0.2 0.2 0">
              <joint axis="1 1 0" name="ankle_2x" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geomx" size="0.08" type="capsule"/>
            </body>
          </body>
        </body>
        <body name="back_legx" pos="-1 0 0">
          <geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux_3_geomx" size="0.08" type="capsule"/>
          <body name="aux_3x" pos="0.0 -0.2 0">
            <joint axis="0 0 1" name="hip_3x" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geomx" size="0.08" type="capsule"/>
            <body pos="-0.2 -0.2 0">
              <joint axis="-1 1 0" name="ankle_3x" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geomx" size="0.08" type="capsule"/>
            </body>
          </body>
        </body>
        <body name="mid" pos="-1 0 0">
          <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
          <!--<joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>-->
          <!--<body name="front_right_leg" pos="-1 0 0">
            <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
            <body name="aux_2" pos="-0.2 0.2 0">
              <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
              <body pos="-0.2 0.2 0">
                <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
              </body>
            </body>
          </body>
          <body name="back_leg" pos="-1 0 0">
            <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
            <body name="aux_3" pos="-0.2 -0.2 0">
              <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
              <body pos="-0.2 -0.2 0">
                <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
              </body>
            </body>
          </body>-->
          <body name="front_right_leg" pos="-1 0 0">
            <geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
            <body name="aux_2" pos="0.0 0.2 0">
              <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
              <body pos="-0.2 0.2 0">
                <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
              </body>
            </body>
          </body>
          <body name="back_leg" pos="-1 0 0">
            <geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
            <body name="aux_3" pos="0.0 -0.2 0">
              <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
              <body pos="-0.2 -0.2 0">
                <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
              </body>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml.template
================================================
<mujoco model="ant">
  <size nconmax="200"/>
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.005"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso_0" pos="0 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <!--<geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>-->
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg_0" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux1_geom_0" size="0.08" type="capsule"/>
        <body name="aux1_0" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip1_0" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom_0" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle1_0" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom_0" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg_0" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux2_geom_0" size="0.08" type="capsule"/>
        <body name="aux2_0" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip2_0" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom_0" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle2_0" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="second_ankle_geom_0" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      {{ body }}
    </body>
  </worldbody>
  <actuator>
    {{ actuators }}
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant__stage1.xml
================================================
<mujoco model="ant">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option integrator="RK4" timestep="0.01"/>
  <custom>
    <numeric data="0.0 0.0 0.55 1.0 0.0 0.0 0.0 0.0 1.0 0.0 -1.0 0.0 -1.0 0.0 1.0" name="init_qpos"/>
  </custom>
  <default>
    <joint armature="1" damping="1" limited="true"/>
    <geom conaffinity="0" condim="3" density="5.0" friction="1 0.5 0.5" margin="0.01" rgba="0.8 0.6 0.4 1"/>
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="60 60" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 0" rgba="0.8 0.9 0.8 1" size="40 40 40" type="plane"/>
    <body name="torso" pos=" 0 0.75">
      <camera name="track" mode="trackcom" pos="0 -3 0.3" xyaxes="1 0 0 0 0 1"/>
      <!--<geom name="torso_geom" pos="0 0 0" size="0.25" type="sphere"/>-->
      <joint armature="0" damping="0" limited="false" margin="0.01" name="root" pos="0 0 0" type="free"/>
      <body name="front_left_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="aux_1_geom" size="0.08" type="capsule"/>
        <body name="aux_1" pos="0.2 0.2 0">
          <joint axis="0 0 1" name="hip_1" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 0.2 0.0" name="left_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 0.2 0">
            <joint axis="-1 1 0" name="ankle_1" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 0.4 0.0" name="left_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="right_back_leg" pos="0 0 0">
        <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="aux_4_geom" size="0.08" type="capsule"/>
        <body name="aux_4" pos="0.2 -0.2 0">
          <joint axis="0 0 1" name="hip_4" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
          <geom fromto="0.0 0.0 0.0 0.2 -0.2 0.0" name="rightback_leg_geom" size="0.08" type="capsule"/>
          <body pos="0.2 -0.2 0">
            <joint axis="1 1 0" name="ankle_4" pos="0.0 0.0 0.0" range="30 70" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 0.4 -0.4 0.0" name="fourth_ankle_geom" size="0.08" type="capsule"/>
          </body>
        </body>
      </body>
      <body name="mid" pos="0.0 0 0">
        <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
        <joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
        <body name="front_right_leg" pos="-1 0 0">
          <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="aux_2_geom" size="0.08" type="capsule"/>
          <body name="aux_2" pos="-0.2 0.2 0">
            <joint axis="0 0 1" name="hip_2" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom" size="0.08" type="capsule"/>
            <body pos="-0.2 0.2 0">
              <joint axis="1 1 0" name="ankle_2" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom" size="0.08" type="capsule"/>
            </body>
          </body>
        </body>
        <body name="back_leg" pos="-1 0 0">
          <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="aux_3_geom" size="0.08" type="capsule"/>
          <body name="aux_3" pos="-0.2 -0.2 0">
            <joint axis="0 0 1" name="hip_3" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
            <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom" size="0.08" type="capsule"/>
            <body pos="-0.2 -0.2 0">
              <joint axis="-1 1 0" name="ankle_3" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
              <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom" size="0.08" type="capsule"/>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_4" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_1" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_2" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip_3" gear="150"/>
    <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle_3" gear="150"/>
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer.xml.template
================================================
<mujoco model="swimmer">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option collision="predefined" density="4000" integrator="RK4" timestep="0.005" viscosity="0.1"/>
  <default>
    <geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
    <joint armature='0.1'  />
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
    <!--  ================= SWIMMER ================= /-->
    <body name="torso" pos="0 0 0">
      <geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
      <joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
      <joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
      <joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
      <body name="mid0" pos="0.5 0 0">
        <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
        <joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
        {{ body }}
      </body>
    </body>
  </worldbody>
  <actuator>
{{ actuators }}
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer__bckp2.xml
================================================
<mujoco model="swimmer">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option collision="predefined" density="4000" integrator="RK4" timestep="0.01" viscosity="0.1"/>
  <default>
    <geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
    <joint armature='0.1'  />
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
    <!--  ================= SWIMMER ================= /-->
    <body name="torso" pos="0 0 0">
      <geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
      <joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
      <joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
      <joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
      <body name="mid1" pos="0.5 0 0">
        <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
        <joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
        <body name="mid2" pos="-1 0 0">
          <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
          <joint axis="0 0 -1" limited="true" name="rot1" pos="0 0 0" range="-100 100" type="hinge"/>
          <body name="mid3" pos="-1 0 0">
            <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
            <joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
            <body name="back" pos="-1 0 0">
              <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
              <joint axis="0 0 1" limited="true" name="rot3" pos="0 0 0" range="-100 100" type="hinge"/>
            </body>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot0"/>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot1"/>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot2"/>
     <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot3"/>
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer_bckp.xml
================================================
<mujoco model="swimmer">
  <compiler angle="degree" coordinate="local" inertiafromgeom="true"/>
  <option collision="predefined" density="4000" integrator="RK4" timestep="0.01" viscosity="0.1"/>
  <default>
    <geom conaffinity="1" condim="1" contype="1" material="geom" rgba="0.8 0.6 .4 1"/>
    <joint armature='0.1'  />
  </default>
  <asset>
    <texture builtin="gradient" height="100" rgb1="1 1 1" rgb2="0 0 0" type="skybox" width="100"/>
    <texture builtin="flat" height="1278" mark="cross" markrgb="1 1 1" name="texgeom" random="0.01" rgb1="0.8 0.6 0.4" rgb2="0.8 0.6 0.4" type="cube" width="127"/>
    <texture builtin="checker" height="100" name="texplane" rgb1="0 0 0" rgb2="0.8 0.8 0.8" type="2d" width="100"/>
    <material name="MatPlane" reflectance="0.5" shininess="1" specular="1" texrepeat="30 30" texture="texplane"/>
    <material name="geom" texture="texgeom" texuniform="true"/>
  </asset>
  <worldbody>
    <light cutoff="100" diffuse="1 1 1" dir="-0 0 -1.3" directional="true" exponent="1" pos="0 0 1.3" specular=".1 .1 .1"/>
    <geom conaffinity="1" condim="3" material="MatPlane" name="floor" pos="0 0 -0.1" rgba="0.8 0.9 0.8 1" size="40 40 0.1" type="plane"/>
    <!--  ================= SWIMMER ================= /-->
    <body name="torso" pos="0 0 0">
      <geom density="1000" fromto="1.5 0 0 0.5 0 0" size="0.1" type="capsule"/>
      <joint axis="1 0 0" name="slider1" pos="0 0 0" type="slide"/>
      <joint axis="0 1 0" name="slider2" pos="0 0 0" type="slide"/>
      <joint axis="0 0 1" name="rot" pos="0 0 0" type="hinge"/>
      <body name="mid1" pos="0.5 0 0">
        <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
        <joint axis="0 0 1" limited="true" name="rot0" pos="0 0 0" range="-100 100" type="hinge"/>
        <body name="mid2" pos="-1 0 0">
          <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
          <joint axis="0 0 -1" limited="true" name="rot1" pos="0 0 0" range="-100 100" type="hinge"/>
          <body name="back" pos="-1 0 0">
            <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
            <joint axis="0 0 1" limited="true" name="rot2" pos="0 0 0" range="-100 100" type="hinge"/>
          </body>
        </body>
      </body>
    </body>
  </worldbody>
  <actuator>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot0"/>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot1"/>
    <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot2"/>
  </actuator>
</mujoco>

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/coupled_half_cheetah.py
================================================
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
import os


class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, **kwargs):
        mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5)
        utils.EzPickle.__init__(self)

    def step(self, action):
        xposbefore1 = self.sim.data.qpos[0]
        xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
        self.do_simulation(action, self.frame_skip)
        xposafter1 = self.sim.data.qpos[0]
        xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos)//2]
        ob = self._get_obs()
        reward_ctrl1 = - 0.1 * np.square(action[0:len(action)//2]).sum()
        reward_ctrl2 = - 0.1 * np.square(action[len(action)//2:]).sum()
        reward_run1 = (xposafter1 - xposbefore1)/self.dt
        reward_run2 = (xposafter2 - xposbefore2) / self.dt
        reward = (reward_ctrl1 + reward_ctrl2)/2.0 + (reward_run1 + reward_run2)/2.0
        done = False
        return ob, reward, done, dict(reward_run1=reward_run1, reward_ctrl1=reward_ctrl1,
                                      reward_run2=reward_run2, reward_ctrl2=reward_ctrl2)

    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat[1:],
            self.sim.data.qvel.flat,
        ])

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5

    def get_env_info(self):
        return {"episode_limit": self.episode_limit}

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/manyagent_ant.py
================================================
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
from jinja2 import Template
import os

class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, **kwargs):
        agent_conf = kwargs.get("agent_conf")
        n_agents = int(agent_conf.split("x")[0])
        n_segs_per_agents = int(agent_conf.split("x")[1])
        n_segs = n_agents * n_segs_per_agents

        # Check whether asset file exists already, otherwise create it
        asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',
                                                  'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents,
                                                                                                                 n_segs_per_agents))
        #if not os.path.exists(asset_path):
        print("Auto-Generating Manyagent Ant asset with {} segments at {}.".format(n_segs, asset_path))
        self._generate_asset(n_segs=n_segs, asset_path=asset_path)

        #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
        #                          'manyagent_swimmer.xml')

        mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
        utils.EzPickle.__init__(self)

    def _generate_asset(self, n_segs, asset_path):
        template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',
                                                  'manyagent_ant.xml.template')
        with open(template_path, "r") as f:
            t = Template(f.read())
        body_str_template = """
        <body name="torso_{:d}" pos="-1 0 0">
           <!--<joint axis="0 1 0" name="nnn_{:d}" pos="0.0 0.0 0.0" range="-1 1" type="hinge"/>-->
            <geom density="100" fromto="1 0 0 0 0 0" size="0.1" type="capsule"/>
            <body name="front_right_leg_{:d}" pos="0 0 0">
              <geom fromto="0.0 0.0 0.0 0.0 0.2 0.0" name="aux1_geom_{:d}" size="0.08" type="capsule"/>
              <body name="aux_2_{:d}" pos="0.0 0.2 0">
                <joint axis="0 0 1" name="hip1_{:d}" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.2 0.2 0.0" name="right_leg_geom_{:d}" size="0.08" type="capsule"/>
                <body pos="-0.2 0.2 0">
                  <joint axis="1 1 0" name="ankle1_{:d}" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                  <geom fromto="0.0 0.0 0.0 -0.4 0.4 0.0" name="right_ankle_geom_{:d}" size="0.08" type="capsule"/>
                </body>
              </body>
            </body>
            <body name="back_leg_{:d}" pos="0 0 0">
              <geom fromto="0.0 0.0 0.0 0.0 -0.2 0.0" name="aux2_geom_{:d}" size="0.08" type="capsule"/>
              <body name="aux2_{:d}" pos="0.0 -0.2 0">
                <joint axis="0 0 1" name="hip2_{:d}" pos="0.0 0.0 0.0" range="-30 30" type="hinge"/>
                <geom fromto="0.0 0.0 0.0 -0.2 -0.2 0.0" name="back_leg_geom_{:d}" size="0.08" type="capsule"/>
                <body pos="-0.2 -0.2 0">
                  <joint axis="-1 1 0" name="ankle2_{:d}" pos="0.0 0.0 0.0" range="-70 -30" type="hinge"/>
                  <geom fromto="0.0 0.0 0.0 -0.4 -0.4 0.0" name="third_ankle_geom_{:d}" size="0.08" type="capsule"/>
                </body>
              </body>
            </body>
        """

        body_close_str_template ="</body>\n"
        actuator_str_template = """\t     <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip1_{:d}" gear="150"/>
                                          <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle1_{:d}" gear="150"/>
                                          <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="hip2_{:d}" gear="150"/>
                                          <motor ctrllimited="true" ctrlrange="-1.0 1.0" joint="ankle2_{:d}" gear="150"/>\n"""

        body_str = ""
        for i in range(1,n_segs):
            body_str += body_str_template.format(*([i]*16))
        body_str += body_close_str_template*(n_segs-1)

        actuator_str = ""
        for i in range(n_segs):
            actuator_str += actuator_str_template.format(*([i]*8))

        rt = t.render(body=body_str, actuators=actuator_str)
        with open(asset_path, "w") as f:
            f.write(rt)
        pass

    def step(self, a):
        xposbefore = self.get_body_com("torso_0")[0]
        self.do_simulation(a, self.frame_skip)
        xposafter = self.get_body_com("torso_0")[0]
        forward_reward = (xposafter - xposbefore)/self.dt
        ctrl_cost = .5 * np.square(a).sum()
        contact_cost = 0.5 * 1e-3 * np.sum(
            np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
        survive_reward = 1.0
        reward = forward_reward - ctrl_cost - contact_cost + survive_reward
        state = self.state_vector()
        notdone = np.isfinite(state).all() \
            and state[2] >= 0.2 and state[2] <= 1.0
        done = not notdone
        ob = self._get_obs()
        return ob, reward, done, dict(
            reward_forward=forward_reward,
            reward_ctrl=-ctrl_cost,
            reward_contact=-contact_cost,
            reward_survive=survive_reward)

    def _get_obs(self):
        return np.concatenate([
            self.sim.data.qpos.flat[2:],
            self.sim.data.qvel.flat,
            np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
        ])

    def reset_model(self):
        qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        self.set_state(qpos, qvel)
        return self._get_obs()

    def viewer_setup(self):
        self.viewer.cam.distance = self.model.stat.extent * 0.5

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/manyagent_swimmer.py
================================================
import numpy as np
from gym import utils
from gym.envs.mujoco import mujoco_env
import os
from jinja2 import Template

class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    def __init__(self, **kwargs):
        agent_conf = kwargs.get("agent_conf")
        n_agents = int(agent_conf.split("x")[0])
        n_segs_per_agents = int(agent_conf.split("x")[1])
        n_segs = n_agents * n_segs_per_agents

        # Check whether asset file exists already, otherwise create it
        asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',
                                                  'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents,
                                                                                                                 n_segs_per_agents))
        # if not os.path.exists(asset_path):
        print("Auto-Generating Manyagent Swimmer asset with {} segments at {}.".format(n_segs, asset_path))
        self._generate_asset(n_segs=n_segs, asset_path=asset_path)

        #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
        #                          'manyagent_swimmer.xml')

        mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
        utils.EzPickle.__init__(self)

    def _generate_asset(self, n_segs, asset_path):
        template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',
                                                  'manyagent_swimmer.xml.template')
        with open(template_path, "r") as f:
            t = Template(f.read())
        body_str_template = """
        <body name="mid{:d}" pos="-1 0 0">
          <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
          <joint axis="0 0 {:d}" limited="true" name="rot{:d}" pos="0 0 0" range="-100 100" type="hinge"/>
        """

        body_end_str_template = """
        <body name="back" pos="-1 0 0">
            <geom density="1000" fromto="0 0 0 -1 0 0" size="0.1" type="capsule"/>
            <joint axis="0 0 1" limited="true" name="rot{:d}" pos="0 0 0" range="-100 100" type="hinge"/>
          </body>
        """

        body_close_str_template ="</body>\n"
        actuator_str_template = """\t <motor ctrllimited="true" ctrlrange="-1 1" gear="150.0" joint="rot{:d}"/>\n"""

        body_str = ""
        for i in range(1,n_segs-1):
            body_str += body_str_template.format(i, (-1)**(i+1), i)
        body_str += body_end_str_template.format(n_segs-1)
        body_str += body_close_str_template*(n_segs-2)

        actuator_str = ""
        for i in range(n_segs):
            actuator_str += actuator_str_template.format(i)

        rt = t.render(body=body_str, actuators=actuator_str)
        with open(asset_path, "w") as f:
            f.write(rt)
        pass

    def step(self, a):
        ctrl_cost_coeff = 0.0001
        xposbefore = self.sim.data.qpos[0]
        self.do_simulation(a, self.frame_skip)
        xposafter = self.sim.data.qpos[0]
        reward_fwd = (xposafter - xposbefore) / self.dt
        reward_ctrl = - ctrl_cost_coeff * np.square(a).sum()
        reward = reward_fwd + reward_ctrl
        ob = self._get_obs()
        return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)

    def _get_obs(self):
        qpos = self.sim.data.qpos
        qvel = self.sim.data.qvel
        return np.concatenate([qpos.flat[2:], qvel.flat])

    def reset_model(self):
        self.set_state(
            self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
            self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
        )
        return self._get_obs()


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/mujoco_multi.py
================================================
from functools import partial
import gym
from gym.spaces import Box
from gym.wrappers import TimeLimit
import numpy as np

from .multiagentenv import MultiAgentEnv
from .manyagent_swimmer import ManyAgentSwimmerEnv
from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs


def env_fn(env, **kwargs) -> MultiAgentEnv: # TODO: this may be a more complex function
    # env_args = kwargs.get("env_args", {})
    return env(**kwargs)

env_REGISTRY = {}
env_REGISTRY["manyagent_swimmer"] = partial(env_fn, env=ManyAgentSwimmerEnv)


# using code from https://github.com/ikostrikov/pytorch-ddpg-naf
class NormalizedActions(gym.ActionWrapper):

    def _action(self, action):
        action = (action + 1) / 2
        action *= (self.action_space.high - self.action_space.low)
        action += self.action_space.low
        return action

    def action(self, action_):
        return self._action(action_)

    def _reverse_action(self, action):
        action -= self.action_space.low
        action /= (self.action_space.high - self.action_space.low)
        action = action * 2 - 1
        return action


class MujocoMulti(MultiAgentEnv):

    def __init__(self, batch_size=None, **kwargs):
        super().__init__(batch_size, **kwargs)
        self.scenario = kwargs["env_args"]["scenario"]  # e.g. Ant-v2
        self.agent_conf = kwargs["env_args"]["agent_conf"]  # e.g. '2x3'

        self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(self.scenario,
                                                                                            self.agent_conf)

        self.n_agents = len(self.agent_partitions)
        self.n_actions = max([len(l) for l in self.agent_partitions])
        self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False)

        self.agent_obsk = kwargs["env_args"].get("agent_obsk",
                                                 None)  # if None, fully observable else k>=0 implies observe nearest k agents or joints
        self.agent_obsk_agents = kwargs["env_args"].get("agent_obsk_agents",
                                                        False)  # observe full k nearest agents (True) or just single joints (False)

        if self.agent_obsk is not None:
            self.k_categories_label = kwargs["env_args"].get("k_categories")
            if self.k_categories_label is None:
                if self.scenario in ["Ant-v2", "manyagent_ant"]:
                    self.k_categories_label = "qpos,qvel,cfrc_ext|qpos"
                elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]:
                    self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos"
                elif self.scenario in ["Reacher-v2"]:
                    self.k_categories_label = "qpos,qvel,fingertip_dist|qpos"
                elif self.scenario in ["coupled_half_cheetah"]:
                    self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|"
                else:
                    self.k_categories_label = "qpos,qvel|qpos"

            k_split = self.k_categories_label.split("|")
            self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)]

            self.global_categories_label = kwargs["env_args"].get("global_categories")
            self.global_categories = self.global_categories_label.split(
                ",") if self.global_categories_label is not None else []

        if self.agent_obsk is not None:
            self.k_dicts = [get_joints_at_kdist(agent_id,
                                                self.agent_partitions,
                                                self.mujoco_edges,
                                                k=self.agent_obsk,
                                                kagents=False, ) for agent_id in range(self.n_agents)]

        # load scenario from script
        self.episode_limit = self.args.episode_limit

        self.env_version = kwargs["env_args"].get("env_version", 2)
        if self.env_version == 2:
            try:
                self.wrapped_env = NormalizedActions(gym.make(self.scenario))
            except gym.error.Error:
                self.wrapped_env = NormalizedActions(
                    TimeLimit(partial(env_REGISTRY[self.scenario], **kwargs["env_args"])(),
                              max_episode_steps=self.episode_limit))
        else:
            assert False, "not implemented!"
        self.timelimit_env = self.wrapped_env.env
        self.timelimit_env._max_episode_steps = self.episode_limit
        self.env = self.timelimit_env.env
        self.timelimit_env.reset()
        self.obs_size = self.get_obs_size()
        self.share_obs_size = self.get_state_size()

        # COMPATIBILITY
        self.n = self.n_agents
        # self.observation_space = [Box(low=np.array([-10]*self.n_agents), high=np.array([10]*self.n_agents)) for _ in range(self.n_agents)]
        self.observation_space = [Box(low=-10, high=10, shape=(self.obs_size,)) for _ in range(self.n_agents)]
        self.share_observation_space = [Box(low=-10, high=10, shape=(self.share_obs_size,)) for _ in
                                        range(self.n_agents)]

        acdims = [len(ap) for ap in self.agent_partitions]
        self.action_space = tuple([Box(self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],
                                       self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]) for a in
                                   range(self.n_agents)])

        pass

    def step(self, actions):

        # need to remove dummy actions that arise due to unequal action vector sizes across agents
        flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
        obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)
        self.steps += 1

        info = {}
        info.update(info_n)

        # if done_n:
        #     if self.steps < self.episode_limit:
        #         info["episode_limit"] = False   # the next state will be masked out
        #     else:
        #         info["episode_limit"] = True    # the next state will not be masked out
        if done_n:
            if self.steps < self.episode_limit:
                info["bad_transition"] = False  # the next state will be masked out
            else:
                info["bad_transition"] = True  # the next state will not be masked out

        # return reward_n, done_n, info
        rewards = [[reward_n]] * self.n_agents
        dones = [done_n] * self.n_agents
        infos = [info for _ in range(self.n_agents)]
        return self.get_obs(), self.get_state(), rewards, dones, infos, self.get_avail_actions()

    def get_obs(self):
        """ Returns all agent observat3ions in a list """
        state = self.env._get_obs()
        obs_n = []
        for a in range(self.n_agents):
            agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
            agent_id_feats[a] = 1.0
            # obs_n.append(self.get_obs_agent(a))
            # obs_n.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats]))
            # obs_n.append(np.concatenate([self.get_obs_agent(a), agent_id_feats]))
            obs_i = np.concatenate([state, agent_id_feats])
            obs_i = (obs_i - np.mean(obs_i)) / np.std(obs_i)
            obs_n.append(obs_i)
        return obs_n

    def get_obs_agent(self, agent_id):
        if self.agent_obsk is None:
            return self.env._get_obs()
        else:
            # return build_obs(self.env,
            #                       self.k_dicts[agent_id],
            #                       self.k_categories,
            #                       self.mujoco_globals,
            #                       self.global_categories,
            #                       vec_len=getattr(self, "obs_size", None))
            return build_obs(self.env,
                             self.k_dicts[agent_id],
                             self.k_categories,
                             self.mujoco_globals,
                             self.global_categories)

    def get_obs_size(self):
        """ Returns the shape of the observation """
        if self.agent_obsk is None:
            return self.get_obs_agent(0).size
        else:
            return len(self.get_obs()[0])
            # return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])

    def get_state(self, team=None):
        # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)
        state = self.env._get_obs()
        share_obs = []
        for a in range(self.n_agents):
            agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
            agent_id_feats[a] = 1.0
            # share_obs.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats]))
            state_i = np.concatenate([state, agent_id_feats])
            state_i = (state_i - np.mean(state_i)) / np.std(state_i)
            share_obs.append(state_i)
        return share_obs

    def get_state_size(self):
        """ Returns the shape of the state"""
        return len(self.get_state()[0])

    def get_avail_actions(self):  # all actions are always available
        return np.ones(shape=(self.n_agents, self.n_actions,))

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        return np.ones(shape=(self.n_actions,))

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        return self.n_actions  # CAREFUL! - for continuous dims, this is action space dim rather
        # return self.env.action_space.shape[0]

    def get_stats(self):
        return {}

    # TODO: Temp hack
    def get_agg_stats(self, stats):
        return {}

    def reset(self, **kwargs):
        """ Returns initial observations and states"""
        self.steps = 0
        self.timelimit_env.reset()
        return self.get_obs(), self.get_state(), self.get_avail_actions()

    def render(self, **kwargs):
        self.env.render(**kwargs)

    def close(self):
        pass

    def seed(self, args):
        pass

    def get_env_info(self):

        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit,
                    "action_spaces": self.action_space,
                    "actions_dtype": np.float32,
                    "normalise_actions": False
                    }
        return env_info


================================================
FILE: envs/ma_mujoco/multiagent_mujoco/multiagentenv.py
================================================
from collections import namedtuple
import numpy as np


def convert(dictionary):
    return namedtuple('GenericDict', dictionary.keys())(**dictionary)

class MultiAgentEnv(object):

    def __init__(self, batch_size=None, **kwargs):
        # Unpack arguments from sacred
        args = kwargs["env_args"]
        if isinstance(args, dict):
            args = convert(args)
        self.args = args

        if getattr(args, "seed", None) is not None:
            self.seed = args.seed
            self.rs = np.random.RandomState(self.seed) # initialise numpy random state

    def step(self, actions):
        """ Returns reward, terminated, info """
        raise NotImplementedError

    def get_obs(self):
        """ Returns all agent observations in a list """
        raise NotImplementedError

    def get_obs_agent(self, agent_id):
        """ Returns observation for agent_id """
        raise NotImplementedError

    def get_obs_size(self):
        """ Returns the shape of the observation """
        raise NotImplementedError

    def get_state(self):
        raise NotImplementedError

    def get_state_size(self):
        """ Returns the shape of the state"""
        raise NotImplementedError

    def get_avail_actions(self):
        raise NotImplementedError

    def get_avail_agent_actions(self, agent_id):
        """ Returns the available actions for agent_id """
        raise NotImplementedError

    def get_total_actions(self):
        """ Returns the total number of actions an agent could ever take """
        # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
        raise NotImplementedError

    def get_stats(self):
        raise NotImplementedError

    # TODO: Temp hack
    def get_agg_stats(self, stats):
        return {}

    def reset(self):
        """ Returns initial observations and states"""
        raise NotImplementedError

    def render(self):
        raise NotImplementedError

    def close(self):
        raise NotImplementedError

    def seed(self, seed):
        raise NotImplementedError

    def get_env_info(self):
        env_info = {"state_shape": self.get_state_size(),
                    "obs_shape": self.get_obs_size(),
                    "n_actions": self.get_total_actions(),
                    "n_agents": self.n_agents,
                    "episode_limit": self.episode_limit}
        return env_info

================================================
FILE: envs/ma_mujoco/multiagent_mujoco/obsk.py
================================================
import itertools
import numpy as np
from copy import deepcopy

class Node():
    def __init__(self, label, qpos_ids, qvel_ids, act_ids, body_fn=None, bodies=None, extra_obs=None, tendons=None):
        self.label = label
        self.qpos_ids = qpos_ids
        self.qvel_ids = qvel_ids
        self.act_ids = act_ids
        self.bodies = bodies
        self.extra_obs = {} if extra_obs is None else extra_obs
        self.body_fn = body_fn
        self.tendons = tendons
        pass

    def __str__(self):
        return self.label

    def __repr__(self):
        return self.label


class HyperEdge():
    def __init__(self, *edges):
        self.edges = set(edges)

    def __contains__(self, item):
        return item in self.edges

    def __str__(self):
        return "HyperEdge({})".format(self.edges)

    def __repr__(self):
        return "HyperEdge({})".format(self.edges)


def get_joints_at_kdist(agent_id, agent_partitions, hyperedges, k=0, kagents=False,):
    """ Identify all joints at distance <= k from agent agent_id

    :param agent_id: id of agent to be considered
    :param agent_partitions: list of joint tuples in order of agentids
    :param edges: list of tuples (joint1, joint2)
    :param k: kth degree
    :param kagents: True (observe all joints of an agent if a single one is) or False (individual joint granularity)
    :return:
        dict with k as key, and list of joints at that distance
    """
    assert not kagents, "kagents not implemented!"

    agent_joints = agent_partitions[agent_id]

    def _adjacent(lst, kagents=False):
        # return all sets adjacent to any element in lst
        ret = set([])
        for l in lst:
            ret = ret.union(set(itertools.chain(*[e.edges.difference({l}) for e in hyperedges if l in e])))
        return ret

    seen = set([])
    new = set([])
    k_dict = {}
    for _k in range(k+1):
        if not _k:
            new = set(agent_joints)
        else:
            print(hyperedges)
            new = _adjacent(new) - seen
        seen = seen.union(new)
        k_dict[_k] = sorted(list(new), key=lambda x:x.label)
    return k_dict


def build_obs(env, k_dict, k_categories, global_dict, global_categories, vec_len=None):
    """Given a k_dict from get_joints_at_kdist, extract observation vector.

    :param k_dict: k_dict
    :param qpos: qpos numpy array
    :param qvel: qvel numpy array
    :param vec_len: if None no padding, else zero-pad to vec_len
    :return:
    observation vector
    """

    # TODO: This needs to be fixed, it was designed for half-cheetah only!
    #if add_global_pos:
    #    obs_qpos_lst.append(global_qpos)
    #    obs_qvel_lst.append(global_qvel)


    body_set_dict = {}
    obs_lst = []
    # Add parts attributes
    for k in sorted(list(k_dict.keys())):
        cats = k_categories[k]
        for _t in k_dict[k]:
            for c in cats:
                if c in _t.extra_obs:
                    items = _t.extra_obs[c](env).tolist()
                    obs_lst.extend(items if isinstance(items, list) else [items])
                else:
                    if c in ["qvel","qpos"]: # this is a "joint position/velocity" item
                        items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format(c))]
                        obs_lst.extend(items if isinstance(items, list) else [items])
                    elif c in ["qfrc_actuator"]: # this is a "vel position" item
                        items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format("qvel"))]
                        obs_lst.extend(items if isinstance(items, list) else [items])
                    elif c in ["cvel", "cinert", "cfrc_ext"]:  # this is a "body position" item
                        if _t.bodies is not None:
                            for b in _t.bodies:
                                if c not in body_set_dict:
                                    body_set_dict[c] = set()
                                if b not in body_set_dict[c]:
                                    items = getattr(env.sim.data, c)[b].tolist()
                                    items = getattr(_t, "body_fn", lambda _id,x:x)(b, items)
                                    obs_lst.extend(items if isinstance(items, list) else [items])
                                    body_set_dict[c].add(b)

    # Add global attributes
    body_set_dict = {}
    for c in global_categories:
        if c in ["qvel", "qpos"]:  # this is a "joint position" item
            for j in global_dict.get("joints", []):
                items = getattr(env.sim.data, c)[getattr(j, "{}_ids".format(c))]
                obs_lst.extend(items if isinstance(items, list) else [items])
        else:
            for b in global_dict.get("bodies", []):
                if c not in body_set_dict:
                    body_set_dict[c] = set()
                if b not in body_set_dict[c]:
                    obs_lst.extend(getattr(env.sim.data, c)[b].tolist())
                    body_set_dict[c].add(b)

    if vec_len is not None:
        pad = np.array((vec_len - len(obs_lst))*[0])
        if len(pad):
            return np.concatenate([np.array(obs_lst), pad])
    return np.array(obs_lst)


def build_actions(agent_partitions, k_dict):
    # Composes agent actions output from networks
    # into coherent joint action vector to be sent to the env.
    pass

def get_parts_and_edges(label, partitioning):
    if label in ["half_cheetah", "HalfCheetah-v2"]:

        # define Mujoco graph
        bthigh = Node("bthigh", -6, -6, 0)
        bshin = Node("bshin", -5, -5, 1)
        bfoot = Node("bfoot", -4, -4, 2)
        fthigh = Node("fthigh", -3, -3, 3)
        fshin = Node("fshin", -2, -2, 4)
        ffoot = Node("ffoot", -1, -1, 5)

        edges = [HyperEdge(bfoot, bshin),
                 HyperEdge(bshin, bthigh),
                 HyperEdge(bthigh, fthigh),
                 HyperEdge(fthigh, fshin),
                 HyperEdge(fshin, ffoot)]

        root_x = Node("root_x", 0, 0, -1,
                      extra_obs={"qpos": lambda env: np.array([])})
        root_z = Node("root_z", 1, 1, -1)
        root_y = Node("root_y", 2, 2, -1)
        globals = {"joints":[root_x, root_y, root_z]}

        if partitioning == "2x3":
            parts = [(bfoot, bshin, bthigh),
                     (ffoot, fshin, fthigh)]
        elif partitioning == "6x1":
            parts = [(bfoot,), (bshin,), (bthigh,), (ffoot,), (fshin,), (fthigh,)]
        elif partitioning == "3x2":
            parts = [(bfoot, bshin,), (bthigh, ffoot,), (fshin, fthigh,)]
        else:
            raise Exception("UNKNOWN partitioning config: {}".format(partitioning))

        return parts, edges, globals

    elif label in ["Ant-v2"]:

        # define Mujoco graph
        torso = 1
        front_left_leg = 2
        aux_1 = 3
        ankle_1 = 4
        front_right_leg = 5
        aux_2 = 6
        ankle_2 = 7
        back_leg = 8
        aux_3 = 9
        ankle_3 = 10
        right_back_leg = 11
        aux_4 = 12
        ankle_4 = 13

        hip1 = Node("hip1", -8, -8, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist()) #
        ankle1 = Node("ankle1", -7, -7, 3, bodies=[front_left_leg, aux_1, ankle_1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        hip2 = Node("hip2", -6, -6, 4, bodies=[torso, front_right_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        ankle2 = Node("ankle2", -5, -5, 5, bodies=[front_right_leg, aux_2, ankle_2], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        hip3 = Node("hip3", -4, -4, 6, bodies=[torso, back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        ankle3 = Node("ankle3", -3, -3, 7, bodies=[back_leg, aux_3, ankle_3], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        hip4 = Node("hip4", -2, -2, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
        ankle4 = Node("ankle4", -1, -1, 1, bodies=[right_back_leg, aux_4, ankle_4], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,

        edges = [HyperEdge(ankle4, hip4),
                 HyperEdge(ankle1, hip1),
                 HyperEdge(ankle2, hip2),
                 HyperEdge(ankle3, hip3),
                 HyperEdge(hip4, hip1, hip2, hip3),
                 ]

        free_joint = Node("free", 0, 0, -1, extra_obs={"qpos": lambda env: env.sim.data.qpos[:7],
                                                       "qvel": lambda env: env.sim.data.qvel[:6],
                                                       "cfrc_ext": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)})
        globals = {"joints": [free_joint]}

        if partitioning == "2x4": # neighbouring legs together
            parts = [(hip1, ankle1, hip2, ankle2),
                     (hip3, ankle3, hip4, ankle4)]
        elif partitioning == "2x4d": # diagonal legs together
            parts = [(hip1, ankle1, hip3, ankle3),
                     (hip2, ankle2, hip4, ankle4)]
        elif partitioning == "4x2":
            parts = [(hip1, ankle1),
                     (hip2, ankle2),
                     (hip3, ankle3),
                     (hip4, ankle4)]
        elif partitioning == "8x1":
            parts = [(hip1,), (ankle1,),
                     (hip2,), (ankle2,),
                     (hip3,), (ankle3,),
                     (hip4,), (ankle4,)]
        else:
            raise Exception("UNKNOWN partitioning config: {}".format(partitioning))

        return parts, edges, globals

    elif label in ["Hopper-v2"]:

        # define Mujoco-Graph
        thigh_joint = Node("thigh_joint", -3, -3, 0,
                           extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-3]]), -10, 10)})
        leg_joint = Node("leg_joint", -2, -2, 1,
                         extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-2]]), -10, 10)})
        foot_joint = Node("foot_joint", -1, -1, 2,
                          extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-1]]), -10, 10)})

        edges = [HyperEdge(foot_joint, leg_joint),
                 HyperEdge(leg_joint, thigh_joint)]

        root_x = Node("root_x", 0, 0, -1, extra_obs={"qpos": lambda env: np.array([]),
                                                     "qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)})
        root_z = Node("root_z", 1, 1, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)})
        root_y = Node("root_y", 2, 2, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[2]]), -10, 10)})
        globals = {"joints":[root_x, root_y, root_z]}

        if partitioning == "3x1":
            parts = [(thigh_joint,),
                     (leg_joint,),
                     (foot_joint,)]

        else:
            raise Exception("UNKNOWN partitioning config: {}".format(partitioning))

        return parts, edges, globals

    elif label in ["Humanoid-v2", "HumanoidStandup-v2"]:

        # define Mujoco-Graph
        abdomen_y = Node("abdomen_y", -16, -16, 0) # act ordering bug in env -- double check!
        abdomen_z = Node("abdomen_z", -17, -17, 1)
        abdomen_x = Node("abdomen_x", -15, -15, 2)
        right_hip_x = Node("right_hip_x", -14, -14, 3)
        right_hip_z = Node("right_hip_z", -13, -13, 4)
        right_hip_y = Node("right_hip_y", -12, -12, 5)
        right_knee = Node("right_knee", -11, -11, 6)
        left_hip_x = Node("left_hip_x", -10, -10, 7)
        left_hip_z = Node("left_hip_z", -9, -9, 8)
        left_hip_y = Node("left_hip_y", -8, -8, 9)
        left_knee = Node("left_knee", -7, -7, 10)
        right_shoulder1 = Node("right_shoulder1", -6, -6, 11)
        right_shoulder2 = Node("right_shoulder2", -5, -5, 12)
        right_elbow = Node("right_elbow", -4, -4, 13)
        left_shoulder1 = Node("left_shoulder1", -3, -3, 14)
        left_shoulder2 = Node("left_shoulder2", -2, -2, 15)
        left_elbow = Node("left_elbow", -1, -1, 16)

        edges = [HyperEdge(abdomen_x, abdomen_y, abdomen_z),
                 HyperEdge(right_hip_x, right_hip_y, right_hip_z),
                 HyperEdge(left_hip_x, left_hip_y, left_hip_z),
                 HyperEdge(left_elbow, left_shoulder1, left_shoulder2),
                 HyperEdge(right_elbow, right_shoulder1, right_shoulder2),
                 HyperEdge(left_knee, left_hip_x, left_hip_y, left_hip_z),
                 HyperEdge(right_knee, right_hip_x, right_hip_y, right_hip_z),
                 HyperEdge(left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z),
                 HyperEdge(right_shoulder1, right_shoulder2, abdomen_x, abdomen_y, abdomen_z),
                 HyperEdge(abdomen_x, abdomen_y, abdomen_z, left_hip_x, left_hip_y, left_hip_z),
                 HyperEdge(abdomen_x, abdomen_y, abdomen_z, right_hip_x, right_hip_y, right_hip_z),
                 ]

        globals = {}

        if partitioning == "9|8": # 17 in total, so one action is a dummy (to be handled by pymarl)
            # isolate upper and lower body
            parts = [(left_shoulder1, left_sho
Download .txt
gitextract_0j5bd_hz/

├── .gitignore
├── LICENSE
├── README.md
├── algorithms/
│   ├── __init__.py
│   ├── actor_critic.py
│   ├── happo_policy.py
│   ├── happo_trainer.py
│   ├── hatrpo_policy.py
│   ├── hatrpo_trainer.py
│   └── utils/
│       ├── act.py
│       ├── cnn.py
│       ├── distributions.py
│       ├── mlp.py
│       ├── rnn.py
│       └── util.py
├── configs/
│   └── config.py
├── envs/
│   ├── __init__.py
│   ├── env_wrappers.py
│   ├── ma_mujoco/
│   │   ├── __init__.py
│   │   └── multiagent_mujoco/
│   │       ├── __init__.py
│   │       ├── assets/
│   │       │   ├── .gitignore
│   │       │   ├── __init__.py
│   │       │   ├── coupled_half_cheetah.xml
│   │       │   ├── manyagent_ant.xml
│   │       │   ├── manyagent_ant.xml.template
│   │       │   ├── manyagent_ant__stage1.xml
│   │       │   ├── manyagent_swimmer.xml.template
│   │       │   ├── manyagent_swimmer__bckp2.xml
│   │       │   └── manyagent_swimmer_bckp.xml
│   │       ├── coupled_half_cheetah.py
│   │       ├── manyagent_ant.py
│   │       ├── manyagent_swimmer.py
│   │       ├── mujoco_multi.py
│   │       ├── multiagentenv.py
│   │       └── obsk.py
│   └── starcraft2/
│       ├── StarCraft2_Env.py
│       ├── multiagentenv.py
│       └── smac_maps.py
├── install_sc2.sh
├── requirements.txt
├── runners/
│   ├── __init__.py
│   └── separated/
│       ├── __init__.py
│       ├── base_runner.py
│       ├── mujoco_runner.py
│       └── smac_runner.py
├── scripts/
│   ├── __init__.py
│   ├── train/
│   │   ├── __init__.py
│   │   ├── train_mujoco.py
│   │   └── train_smac.py
│   ├── train_mujoco.sh
│   └── train_smac.sh
└── utils/
    ├── __init__.py
    ├── multi_discrete.py
    ├── popart.py
    ├── separated_buffer.py
    └── util.py
Download .txt
SYMBOL INDEX (405 symbols across 31 files)

FILE: algorithms/actor_critic.py
  class Actor (line 11) | class Actor(nn.Module):
    method __init__ (line 19) | def __init__(self, args, obs_space, action_space, device=torch.device(...
    method forward (line 42) | def forward(self, obs, rnn_states, masks, available_actions=None, dete...
    method evaluate_actions (line 71) | def evaluate_actions(self, obs, rnn_states, action, masks, available_a...
  class Critic (line 118) | class Critic(nn.Module):
    method __init__ (line 125) | def __init__(self, args, cent_obs_space, device=torch.device("cpu")):
    method forward (line 149) | def forward(self, cent_obs, rnn_states, masks):

FILE: algorithms/happo_policy.py
  class HAPPO_Policy (line 6) | class HAPPO_Policy:
    method __init__ (line 17) | def __init__(self, args, obs_space, cent_obs_space, act_space, device=...
    method lr_decay (line 46) | def lr_decay(self, episode, episodes):
    method get_actions (line 55) | def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_crit...
    method get_values (line 83) | def get_values(self, cent_obs, rnn_states_critic, masks):
    method evaluate_actions (line 95) | def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states...
    method act (line 125) | def act(self, obs, rnn_states_actor, masks, available_actions=None, de...

FILE: algorithms/happo_trainer.py
  class HAPPO (line 8) | class HAPPO():
    method __init__ (line 15) | def __init__(self,
    method cal_value_loss (line 48) | def cal_value_loss(self, values, value_preds_batch, return_batch, acti...
    method ppo_update (line 88) | def ppo_update(self, sample, update_actor=True):
    method train (line 169) | def train(self, buffer, update_actor=True):
    method prep_training (line 222) | def prep_training(self):
    method prep_rollout (line 226) | def prep_rollout(self):

FILE: algorithms/hatrpo_policy.py
  class HATRPO_Policy (line 6) | class HATRPO_Policy:
    method __init__ (line 17) | def __init__(self, args, obs_space, cent_obs_space, act_space, device=...
    method lr_decay (line 46) | def lr_decay(self, episode, episodes):
    method get_actions (line 55) | def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_crit...
    method get_values (line 83) | def get_values(self, cent_obs, rnn_states_critic, masks):
    method evaluate_actions (line 95) | def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states...
    method act (line 125) | def act(self, obs, rnn_states_actor, masks, available_actions=None, de...

FILE: algorithms/hatrpo_trainer.py
  class HATRPO (line 9) | class HATRPO():
    method __init__ (line 16) | def __init__(self,
    method cal_value_loss (line 51) | def cal_value_loss(self, values, value_preds_batch, return_batch, acti...
    method flat_grad (line 91) | def flat_grad(self, grads):
    method flat_hessian (line 100) | def flat_hessian(self, hessians):
    method flat_params (line 109) | def flat_params(self, model):
    method update_model (line 116) | def update_model(self, model, new_params):
    method kl_approx (line 125) | def kl_approx(self, q, p):
    method kl_divergence (line 130) | def kl_divergence(self, obs, rnn_states, action, masks, available_acti...
    method conjugate_gradient (line 152) | def conjugate_gradient(self, actor, obs, rnn_states, action, masks, av...
    method fisher_vector_product (line 170) | def fisher_vector_product(self, actor, obs, rnn_states, action, masks,...
    method trpo_update (line 181) | def trpo_update(self, sample, update_actor=True):
    method train (line 321) | def train(self, buffer, update_actor=True):
    method prep_training (line 378) | def prep_training(self):
    method prep_rollout (line 382) | def prep_rollout(self):

FILE: algorithms/utils/act.py
  class ACTLayer (line 5) | class ACTLayer(nn.Module):
    method __init__ (line 13) | def __init__(self, action_space, inputs_dim, use_orthogonal, gain, arg...
    method forward (line 41) | def forward(self, x, available_actions=None, deterministic=False):
    method get_probs (line 85) | def get_probs(self, x, available_actions=None):
    method evaluate_actions (line 107) | def evaluate_actions(self, x, action, available_actions=None, active_m...
    method evaluate_actions_trpo (line 167) | def evaluate_actions_trpo(self, x, action, available_actions=None, act...

FILE: algorithms/utils/cnn.py
  class Flatten (line 6) | class Flatten(nn.Module):
    method forward (line 7) | def forward(self, x):
  class CNNLayer (line 11) | class CNNLayer(nn.Module):
    method __init__ (line 12) | def __init__(self, obs_shape, hidden_size, use_orthogonal, use_ReLU, k...
    method forward (line 40) | def forward(self, x):
  class CNNBase (line 46) | class CNNBase(nn.Module):
    method __init__ (line 47) | def __init__(self, args, obs_shape):
    method forward (line 56) | def forward(self, x):

FILE: algorithms/utils/distributions.py
  class FixedCategorical (line 14) | class FixedCategorical(torch.distributions.Categorical):
    method sample (line 15) | def sample(self):
    method log_probs (line 18) | def log_probs(self, actions):
    method mode (line 27) | def mode(self):
  class FixedNormal (line 32) | class FixedNormal(torch.distributions.Normal):
    method log_probs (line 33) | def log_probs(self, actions):
    method entrop (line 37) | def entrop(self):
    method mode (line 40) | def mode(self):
  class FixedBernoulli (line 45) | class FixedBernoulli(torch.distributions.Bernoulli):
    method log_probs (line 46) | def log_probs(self, actions):
    method entropy (line 49) | def entropy(self):
    method mode (line 52) | def mode(self):
  class Categorical (line 56) | class Categorical(nn.Module):
    method __init__ (line 57) | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=...
    method forward (line 65) | def forward(self, x, available_actions=None):
  class DiagGaussian (line 94) | class DiagGaussian(nn.Module):
    method __init__ (line 95) | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=...
    method forward (line 113) | def forward(self, x, available_actions=None):
  class Bernoulli (line 118) | class Bernoulli(nn.Module):
    method __init__ (line 119) | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=...
    method forward (line 127) | def forward(self, x):
  class AddBias (line 131) | class AddBias(nn.Module):
    method __init__ (line 132) | def __init__(self, bias):
    method forward (line 136) | def forward(self, x):

FILE: algorithms/utils/mlp.py
  class MLPLayer (line 6) | class MLPLayer(nn.Module):
    method __init__ (line 7) | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, us...
    method forward (line 26) | def forward(self, x):
  class MLPBase (line 33) | class MLPBase(nn.Module):
    method __init__ (line 34) | def __init__(self, args, obs_shape, cat_self=True, attn_internal=False):
    method forward (line 52) | def forward(self, x):

FILE: algorithms/utils/rnn.py
  class RNNLayer (line 7) | class RNNLayer(nn.Module):
    method __init__ (line 8) | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal):
    method forward (line 24) | def forward(self, x, hxs, masks):

FILE: algorithms/utils/util.py
  function init (line 7) | def init(module, weight_init, bias_init, gain=1):
  function get_clones (line 12) | def get_clones(module, N):
  function check (line 15) | def check(input):

FILE: configs/config.py
  function get_config (line 3) | def get_config():

FILE: envs/env_wrappers.py
  class CloudpickleWrapper (line 10) | class CloudpickleWrapper(object):
    method __init__ (line 15) | def __init__(self, x):
    method __getstate__ (line 18) | def __getstate__(self):
    method __setstate__ (line 22) | def __setstate__(self, ob):
  class ShareVecEnv (line 27) | class ShareVecEnv(ABC):
    method __init__ (line 41) | def __init__(self, num_envs, observation_space, share_observation_spac...
    method reset (line 48) | def reset(self):
    method step_async (line 60) | def step_async(self, actions):
    method step_wait (line 72) | def step_wait(self):
    method close_extras (line 85) | def close_extras(self):
    method close (line 92) | def close(self):
    method step (line 100) | def step(self, actions):
    method render (line 109) | def render(self, mode='human'):
    method get_images (line 120) | def get_images(self):
    method unwrapped (line 127) | def unwrapped(self):
    method get_viewer (line 133) | def get_viewer(self):
  function worker (line 140) | def worker(remote, parent_remote, env_fn_wrapper):
  class GuardSubprocVecEnv (line 177) | class GuardSubprocVecEnv(ShareVecEnv):
    method __init__ (line 178) | def __init__(self, env_fns, spaces=None):
    method step_async (line 199) | def step_async(self, actions):
    method step_wait (line 205) | def step_wait(self):
    method reset (line 211) | def reset(self):
    method reset_task (line 217) | def reset_task(self):
    method close (line 222) | def close(self):
  class SubprocVecEnv (line 235) | class SubprocVecEnv(ShareVecEnv):
    method __init__ (line 236) | def __init__(self, env_fns, spaces=None):
    method step_async (line 257) | def step_async(self, actions):
    method step_wait (line 262) | def step_wait(self):
    method reset (line 268) | def reset(self):
    method reset_task (line 275) | def reset_task(self):
    method close (line 280) | def close(self):
    method render (line 292) | def render(self, mode="rgb_array"):
  function shareworker (line 300) | def shareworker(remote, parent_remote, env_fn_wrapper):
  class ShareSubprocVecEnv (line 343) | class ShareSubprocVecEnv(ShareVecEnv):
    method __init__ (line 344) | def __init__(self, env_fns, spaces=None):
    method step_async (line 367) | def step_async(self, actions):
    method step_wait (line 372) | def step_wait(self):
    method reset (line 378) | def reset(self):
    method reset_task (line 385) | def reset_task(self):
    method close (line 390) | def close(self):
  function choosesimpleworker (line 403) | def choosesimpleworker(remote, parent_remote, env_fn_wrapper):
  class ChooseSimpleSubprocVecEnv (line 434) | class ChooseSimpleSubprocVecEnv(ShareVecEnv):
    method __init__ (line 435) | def __init__(self, env_fns, spaces=None):
    method step_async (line 455) | def step_async(self, actions):
    method step_wait (line 460) | def step_wait(self):
    method reset (line 466) | def reset(self, reset_choose):
    method render (line 472) | def render(self, mode="rgb_array"):
    method reset_task (line 479) | def reset_task(self):
    method close (line 484) | def close(self):
  function chooseworker (line 497) | def chooseworker(remote, parent_remote, env_fn_wrapper):
  class ChooseSubprocVecEnv (line 524) | class ChooseSubprocVecEnv(ShareVecEnv):
    method __init__ (line 525) | def __init__(self, env_fns, spaces=None):
    method step_async (line 546) | def step_async(self, actions):
    method step_wait (line 551) | def step_wait(self):
    method reset (line 557) | def reset(self, reset_choose):
    method reset_task (line 564) | def reset_task(self):
    method close (line 569) | def close(self):
  function chooseguardworker (line 582) | def chooseguardworker(remote, parent_remote, env_fn_wrapper):
  class ChooseGuardSubprocVecEnv (line 607) | class ChooseGuardSubprocVecEnv(ShareVecEnv):
    method __init__ (line 608) | def __init__(self, env_fns, spaces=None):
    method step_async (line 629) | def step_async(self, actions):
    method step_wait (line 634) | def step_wait(self):
    method reset (line 640) | def reset(self, reset_choose):
    method reset_task (line 646) | def reset_task(self):
    method close (line 651) | def close(self):
  class DummyVecEnv (line 665) | class DummyVecEnv(ShareVecEnv):
    method __init__ (line 666) | def __init__(self, env_fns):
    method step_async (line 673) | def step_async(self, actions):
    method step_wait (line 676) | def step_wait(self):
    method reset (line 691) | def reset(self):
    method close (line 695) | def close(self):
    method render (line 699) | def render(self, mode="human"):
  class ShareDummyVecEnv (line 710) | class ShareDummyVecEnv(ShareVecEnv):
    method __init__ (line 711) | def __init__(self, env_fns):
    method step_async (line 718) | def step_async(self, actions):
    method step_wait (line 721) | def step_wait(self):
    method reset (line 737) | def reset(self):
    method close (line 742) | def close(self):
    method render (line 746) | def render(self, mode="human"):
  class ChooseDummyVecEnv (line 756) | class ChooseDummyVecEnv(ShareVecEnv):
    method __init__ (line 757) | def __init__(self, env_fns):
    method step_async (line 764) | def step_async(self, actions):
    method step_wait (line 767) | def step_wait(self):
    method reset (line 774) | def reset(self, reset_choose):
    method close (line 780) | def close(self):
    method render (line 784) | def render(self, mode="human"):
  class ChooseSimpleDummyVecEnv (line 793) | class ChooseSimpleDummyVecEnv(ShareVecEnv):
    method __init__ (line 794) | def __init__(self, env_fns):
    method step_async (line 801) | def step_async(self, actions):
    method step_wait (line 804) | def step_wait(self):
    method reset (line 810) | def reset(self, reset_choose):
    method close (line 815) | def close(self):
    method render (line 819) | def render(self, mode="human"):

FILE: envs/ma_mujoco/multiagent_mujoco/coupled_half_cheetah.py
  class CoupledHalfCheetah (line 7) | class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 8) | def __init__(self, **kwargs):
    method step (line 12) | def step(self, action):
    method _get_obs (line 28) | def _get_obs(self):
    method reset_model (line 34) | def reset_model(self):
    method viewer_setup (line 40) | def viewer_setup(self):
    method get_env_info (line 43) | def get_env_info(self):

FILE: envs/ma_mujoco/multiagent_mujoco/manyagent_ant.py
  class ManyAgentAntEnv (line 7) | class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 8) | def __init__(self, **kwargs):
    method _generate_asset (line 28) | def _generate_asset(self, n_segs, asset_path):
    method step (line 81) | def step(self, a):
    method _get_obs (line 102) | def _get_obs(self):
    method reset_model (line 109) | def reset_model(self):
    method viewer_setup (line 115) | def viewer_setup(self):

FILE: envs/ma_mujoco/multiagent_mujoco/manyagent_swimmer.py
  class ManyAgentSwimmerEnv (line 7) | class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
    method __init__ (line 8) | def __init__(self, **kwargs):
    method _generate_asset (line 28) | def _generate_asset(self, n_segs, asset_path):
    method step (line 64) | def step(self, a):
    method _get_obs (line 75) | def _get_obs(self):
    method reset_model (line 80) | def reset_model(self):

FILE: envs/ma_mujoco/multiagent_mujoco/mujoco_multi.py
  function env_fn (line 12) | def env_fn(env, **kwargs) -> MultiAgentEnv: # TODO: this may be a more c...
  class NormalizedActions (line 21) | class NormalizedActions(gym.ActionWrapper):
    method _action (line 23) | def _action(self, action):
    method action (line 29) | def action(self, action_):
    method _reverse_action (line 32) | def _reverse_action(self, action):
  class MujocoMulti (line 39) | class MujocoMulti(MultiAgentEnv):
    method __init__ (line 41) | def __init__(self, batch_size=None, **kwargs):
    method step (line 120) | def step(self, actions):
    method get_obs (line 147) | def get_obs(self):
    method get_obs_agent (line 162) | def get_obs_agent(self, agent_id):
    method get_obs_size (line 178) | def get_obs_size(self):
    method get_state (line 186) | def get_state(self, team=None):
    method get_state_size (line 199) | def get_state_size(self):
    method get_avail_actions (line 203) | def get_avail_actions(self):  # all actions are always available
    method get_avail_agent_actions (line 206) | def get_avail_agent_actions(self, agent_id):
    method get_total_actions (line 210) | def get_total_actions(self):
    method get_stats (line 215) | def get_stats(self):
    method get_agg_stats (line 219) | def get_agg_stats(self, stats):
    method reset (line 222) | def reset(self, **kwargs):
    method render (line 228) | def render(self, **kwargs):
    method close (line 231) | def close(self):
    method seed (line 234) | def seed(self, args):
    method get_env_info (line 237) | def get_env_info(self):

FILE: envs/ma_mujoco/multiagent_mujoco/multiagentenv.py
  function convert (line 5) | def convert(dictionary):
  class MultiAgentEnv (line 8) | class MultiAgentEnv(object):
    method __init__ (line 10) | def __init__(self, batch_size=None, **kwargs):
    method step (line 21) | def step(self, actions):
    method get_obs (line 25) | def get_obs(self):
    method get_obs_agent (line 29) | def get_obs_agent(self, agent_id):
    method get_obs_size (line 33) | def get_obs_size(self):
    method get_state (line 37) | def get_state(self):
    method get_state_size (line 40) | def get_state_size(self):
    method get_avail_actions (line 44) | def get_avail_actions(self):
    method get_avail_agent_actions (line 47) | def get_avail_agent_actions(self, agent_id):
    method get_total_actions (line 51) | def get_total_actions(self):
    method get_stats (line 56) | def get_stats(self):
    method get_agg_stats (line 60) | def get_agg_stats(self, stats):
    method reset (line 63) | def reset(self):
    method render (line 67) | def render(self):
    method close (line 70) | def close(self):
    method seed (line 73) | def seed(self, seed):
    method get_env_info (line 76) | def get_env_info(self):

FILE: envs/ma_mujoco/multiagent_mujoco/obsk.py
  class Node (line 5) | class Node():
    method __init__ (line 6) | def __init__(self, label, qpos_ids, qvel_ids, act_ids, body_fn=None, b...
    method __str__ (line 17) | def __str__(self):
    method __repr__ (line 20) | def __repr__(self):
  class HyperEdge (line 24) | class HyperEdge():
    method __init__ (line 25) | def __init__(self, *edges):
    method __contains__ (line 28) | def __contains__(self, item):
    method __str__ (line 31) | def __str__(self):
    method __repr__ (line 34) | def __repr__(self):
  function get_joints_at_kdist (line 38) | def get_joints_at_kdist(agent_id, agent_partitions, hyperedges, k=0, kag...
  function build_obs (line 74) | def build_obs(env, k_dict, k_categories, global_dict, global_categories,...
  function build_actions (line 141) | def build_actions(agent_partitions, k_dict):
  function get_parts_and_edges (line 146) | def get_parts_and_edges(label, partitioning):

FILE: envs/starcraft2/StarCraft2_Env.py
  class Direction (line 56) | class Direction(enum.IntEnum):
  class StarCraft2Env (line 63) | class StarCraft2Env(MultiAgentEnv):
    method __init__ (line 68) | def __init__(
    method _launch (line 330) | def _launch(self):
    method reset (line 381) | def reset(self):
    method _restart (line 438) | def _restart(self):
    method full_restart (line 449) | def full_restart(self):
    method step (line 455) | def step(self, actions):
    method get_agent_action (line 617) | def get_agent_action(self, a_id, action):
    method get_agent_action_heuristic (line 712) | def get_agent_action_heuristic(self, a_id, action):
    method reward_battle (line 809) | def reward_battle(self):
    method get_total_actions (line 866) | def get_total_actions(self):
    method distance (line 871) | def distance(x1, y1, x2, y2):
    method unit_shoot_range (line 875) | def unit_shoot_range(self, agent_id):
    method unit_sight_range (line 879) | def unit_sight_range(self, agent_id):
    method unit_max_cooldown (line 883) | def unit_max_cooldown(self, unit):
    method save_replay (line 898) | def save_replay(self):
    method unit_max_shield (line 906) | def unit_max_shield(self, unit):
    method can_move (line 915) | def can_move(self, unit, direction):
    method get_surrounding_points (line 933) | def get_surrounding_points(self, unit, include_self=False):
    method check_bounds (line 956) | def check_bounds(self, x, y):
    method get_surrounding_pathing (line 960) | def get_surrounding_pathing(self, unit):
    method get_surrounding_height (line 969) | def get_surrounding_height(self, unit):
    method get_obs_agent (line 978) | def get_obs_agent(self, agent_id):
    method get_obs (line 1144) | def get_obs(self):
    method get_state (line 1152) | def get_state(self, agent_id=-1):
    method get_state_agent (line 1327) | def get_state_agent(self, agent_id):
    method get_obs_enemy_feats_size (line 1522) | def get_obs_enemy_feats_size(self):
    method get_state_enemy_feats_size (line 1533) | def get_state_enemy_feats_size(self):
    method get_obs_ally_feats_size (line 1547) | def get_obs_ally_feats_size(self):
    method get_state_ally_feats_size (line 1561) | def get_state_ally_feats_size(self):
    method get_obs_own_feats_size (line 1578) | def get_obs_own_feats_size(self):
    method get_state_own_feats_size (line 1590) | def get_state_own_feats_size(self):
    method get_obs_move_feats_size (line 1605) | def get_obs_move_feats_size(self):
    method get_state_move_feats_size (line 1615) | def get_state_move_feats_size(self):
    method get_obs_size (line 1625) | def get_obs_size(self):
    method get_state_size (line 1651) | def get_state_size(self):
    method get_visibility_matrix (line 1737) | def get_visibility_matrix(self):
    method get_unit_type_id (line 1778) | def get_unit_type_id(self, unit, ally):
    method get_avail_agent_actions (line 1809) | def get_avail_agent_actions(self, agent_id):
    method get_avail_actions (line 1855) | def get_avail_actions(self):
    method close (line 1863) | def close(self):
    method seed (line 1868) | def seed(self, seed):
    method render (line 1872) | def render(self):
    method _kill_all_units (line 1876) | def _kill_all_units(self):
    method init_units (line 1886) | def init_units(self):
    method update_units (line 1941) | def update_units(self):
    method _init_ally_unit_types (line 1987) | def _init_ally_unit_types(self, min_unit_type):
    method only_medivac_left (line 2017) | def only_medivac_left(self, ally):
    method get_unit_by_id (line 2041) | def get_unit_by_id(self, a_id):
    method get_stats (line 2045) | def get_stats(self):

FILE: envs/starcraft2/multiagentenv.py
  class MultiAgentEnv (line 6) | class MultiAgentEnv(object):
    method step (line 8) | def step(self, actions):
    method get_obs (line 12) | def get_obs(self):
    method get_obs_agent (line 16) | def get_obs_agent(self, agent_id):
    method get_obs_size (line 20) | def get_obs_size(self):
    method get_state (line 24) | def get_state(self):
    method get_state_size (line 28) | def get_state_size(self):
    method get_avail_actions (line 32) | def get_avail_actions(self):
    method get_avail_agent_actions (line 36) | def get_avail_agent_actions(self, agent_id):
    method get_total_actions (line 40) | def get_total_actions(self):
    method reset (line 44) | def reset(self):
    method render (line 48) | def render(self):
    method close (line 51) | def close(self):
    method seed (line 54) | def seed(self):
    method save_replay (line 57) | def save_replay(self):
    method get_env_info (line 61) | def get_env_info(self):

FILE: envs/starcraft2/smac_maps.py
  class SMACMap (line 8) | class SMACMap(lib.Map):
  function get_smac_map_registry (line 448) | def get_smac_map_registry():
  function get_map_params (line 456) | def get_map_params(map_name):

FILE: runners/separated/base_runner.py
  function _t2n (line 11) | def _t2n(x):
  class Runner (line 14) | class Runner(object):
    method __init__ (line 15) | def __init__(self, config):
    method run (line 105) | def run(self):
    method warmup (line 108) | def warmup(self):
    method collect (line 111) | def collect(self, step):
    method insert (line 114) | def insert(self, data):
    method compute (line 118) | def compute(self):
    method train (line 127) | def train(self):
    method save (line 177) | def save(self):
    method restore (line 188) | def restore(self):
    method log_train (line 199) | def log_train(self, train_infos, total_num_steps):
    method log_env (line 205) | def log_env(self, env_infos, total_num_steps):

FILE: runners/separated/mujoco_runner.py
  function _t2n (line 8) | def _t2n(x):
  class MujocoRunner (line 12) | class MujocoRunner(Runner):
    method __init__ (line 15) | def __init__(self, config):
    method run (line 18) | def run(self):
    method warmup (line 89) | def warmup(self):
    method collect (line 101) | def collect(self, step):
    method insert (line 129) | def insert(self, data):
    method log_train (line 157) | def log_train(self, train_infos, total_num_steps):
    method eval (line 166) | def eval(self, total_num_steps):

FILE: runners/separated/smac_runner.py
  function _t2n (line 7) | def _t2n(x):
  class SMACRunner (line 10) | class SMACRunner(Runner):
    method __init__ (line 12) | def __init__(self, config):
    method run (line 15) | def run(self):
    method warmup (line 95) | def warmup(self):
    method collect (line 107) | def collect(self, step):
    method insert (line 136) | def insert(self, data):
    method log_train (line 162) | def log_train(self, train_infos, total_num_steps):
    method eval (line 170) | def eval(self, total_num_steps):

FILE: scripts/train/train_mujoco.py
  function make_train_env (line 17) | def make_train_env(all_args):
  function make_eval_env (line 40) | def make_eval_env(all_args):
  function parse_args (line 63) | def parse_args(args, parser):
  function main (line 86) | def main(args):

FILE: scripts/train/train_smac.py
  function make_train_env (line 17) | def make_train_env(all_args):
  function make_eval_env (line 36) | def make_eval_env(all_args):
  function parse_args (line 55) | def parse_args(args, parser):
  function main (line 73) | def main(args):

FILE: utils/multi_discrete.py
  class MultiDiscrete (line 6) | class MultiDiscrete(gym.Space):
    method __init__ (line 22) | def __init__(self, array_of_param_array):
    method sample (line 28) | def sample(self):
    method contains (line 34) | def contains(self, x):
    method shape (line 38) | def shape(self):
    method __repr__ (line 41) | def __repr__(self):
    method __eq__ (line 44) | def __eq__(self, other):

FILE: utils/popart.py
  class PopArt (line 8) | class PopArt(nn.Module):
    method __init__ (line 11) | def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element...
    method reset_parameters (line 25) | def reset_parameters(self):
    method running_mean_var (line 30) | def running_mean_var(self):
    method forward (line 36) | def forward(self, input_vector, train=True):
    method denormalize (line 64) | def denormalize(self, input_vector):

FILE: utils/separated_buffer.py
  function _flatten (line 6) | def _flatten(T, N, x):
  function _cast (line 9) | def _cast(x):
  class SeparatedReplayBuffer (line 12) | class SeparatedReplayBuffer(object):
    method __init__ (line 13) | def __init__(self, args, obs_space, share_obs_space, act_space):
    method update_factor (line 64) | def update_factor(self, factor):
    method insert (line 67) | def insert(self, share_obs, obs, rnn_states, rnn_states_critic, action...
    method chooseinsert (line 87) | def chooseinsert(self, share_obs, obs, rnn_states, rnn_states_critic, ...
    method after_update (line 107) | def after_update(self):
    method chooseafter_update (line 118) | def chooseafter_update(self):
    method compute_returns (line 124) | def compute_returns(self, next_value, value_normalizer=None):
    method feed_forward_generator (line 171) | def feed_forward_generator(self, advantages, num_mini_batch=None, mini...
    method naive_recurrent_generator (line 231) | def naive_recurrent_generator(self, advantages, num_mini_batch):
    method recurrent_generator (line 313) | def recurrent_generator(self, advantages, num_mini_batch, data_chunk_l...

FILE: utils/util.py
  function check (line 5) | def check(input):
  function get_gard_norm (line 9) | def get_gard_norm(it):
  function update_linear_schedule (line 17) | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr):
  function huber_loss (line 23) | def huber_loss(e, d):
  function mse_loss (line 28) | def mse_loss(e):
  function get_shape_from_obs_space (line 31) | def get_shape_from_obs_space(obs_space):
  function get_shape_from_act_space (line 40) | def get_shape_from_act_space(act_space):
  function tile_images (line 54) | def tile_images(img_nhwc):
Condensed preview — 56 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (400K chars).
[
  {
    "path": ".gitignore",
    "chars": 102,
    "preview": "*.*~\n__pycache__/\n*.pkl\ndata/\n**/*.egg-info\n.python-version\n.idea/\n.vscode/\n.DS_Store\n_build/\nresults/"
  },
  {
    "path": "LICENSE",
    "chars": 1077,
    "preview": "MIT License\n\nCopyright (c) 2020 Tianshou contributors\n\nPermission is hereby granted, free of charge, to any person obtai"
  },
  {
    "path": "README.md",
    "chars": 2638,
    "preview": "# Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning\nDescribed in the paper \"[Trust Region Policy Op"
  },
  {
    "path": "algorithms/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "algorithms/actor_critic.py",
    "chars": 8575,
    "preview": "import torch\nimport torch.nn as nn\nfrom algorithms.utils.util import init, check\nfrom algorithms.utils.cnn import CNNBas"
  },
  {
    "path": "algorithms/happo_policy.py",
    "chars": 7719,
    "preview": "import torch\nfrom algorithms.actor_critic import Actor, Critic\nfrom utils.util import update_linear_schedule\n\n\nclass HAP"
  },
  {
    "path": "algorithms/happo_trainer.py",
    "chars": 10310,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom utils.util import get_gard_norm, huber_loss, mse_loss\nfrom ut"
  },
  {
    "path": "algorithms/hatrpo_policy.py",
    "chars": 7756,
    "preview": "import torch\nfrom algorithms.actor_critic import Actor, Critic\nfrom utils.util import update_linear_schedule\n\n\nclass HAT"
  },
  {
    "path": "algorithms/hatrpo_trainer.py",
    "chars": 17736,
    "preview": "import numpy as np\nimport torch\nimport torch.nn as nn\nfrom utils.util import get_gard_norm, huber_loss, mse_loss\nfrom ut"
  },
  {
    "path": "algorithms/utils/act.py",
    "chars": 10957,
    "preview": "from .distributions import Bernoulli, Categorical, DiagGaussian\nimport torch\nimport torch.nn as nn\n\nclass ACTLayer(nn.Mo"
  },
  {
    "path": "algorithms/utils/cnn.py",
    "chars": 1852,
    "preview": "import torch.nn as nn\nfrom .util import init\n\n\"\"\"CNN Modules and utils.\"\"\"\n\nclass Flatten(nn.Module):\n    def forward(se"
  },
  {
    "path": "algorithms/utils/distributions.py",
    "chars": 4540,
    "preview": "import torch\nimport torch.nn as nn\nfrom .util import init\n\n\"\"\"\nModify standard PyTorch distributions so they to make com"
  },
  {
    "path": "algorithms/utils/mlp.py",
    "chars": 2075,
    "preview": "import torch.nn as nn\nfrom .util import init, get_clones\n\n\"\"\"MLP modules.\"\"\"\n\nclass MLPLayer(nn.Module):\n    def __init_"
  },
  {
    "path": "algorithms/utils/rnn.py",
    "chars": 2849,
    "preview": "import torch\nimport torch.nn as nn\n\n\"\"\"RNN modules.\"\"\"\n\n\nclass RNNLayer(nn.Module):\n    def __init__(self, inputs_dim, o"
  },
  {
    "path": "algorithms/utils/util.py",
    "chars": 425,
    "preview": "import copy\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\ndef init(module, weight_init, bias_init, gain=1):\n  "
  },
  {
    "path": "configs/config.py",
    "chars": 17183,
    "preview": "import argparse\n\ndef get_config():\n    \"\"\"\n    The configuration parser for common hyperparameters of all environment. \n"
  },
  {
    "path": "envs/__init__.py",
    "chars": 83,
    "preview": "\nimport socket\nfrom absl import flags\nFLAGS = flags.FLAGS\nFLAGS(['train_sc.py'])\n\n\n"
  },
  {
    "path": "envs/env_wrappers.py",
    "chars": 28380,
    "preview": "\"\"\"\nModified from OpenAI Baselines code to work with multi-agent envs\n\"\"\"\nimport numpy as np\nimport torch\nfrom multiproc"
  },
  {
    "path": "envs/ma_mujoco/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/__init__.py",
    "chars": 185,
    "preview": "from .mujoco_multi import MujocoMulti\nfrom .coupled_half_cheetah import CoupledHalfCheetah\nfrom .manyagent_swimmer impor"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/.gitignore",
    "chars": 11,
    "preview": "*.auto.xml\n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/coupled_half_cheetah.xml",
    "chars": 8773,
    "preview": "<!-- Cheetah Model\n    The state space is populated with joints in the order that they are\n    defined in this file. The"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml",
    "chars": 8333,
    "preview": "<mujoco model=\"ant\">\n  <size nconmax=\"200\"/>\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <o"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml.template",
    "chars": 3075,
    "preview": "<mujoco model=\"ant\">\n  <size nconmax=\"200\"/>\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <o"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant__stage1.xml",
    "chars": 5215,
    "preview": "<mujoco model=\"ant\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option integrator=\"RK4\" t"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer.xml.template",
    "chars": 1896,
    "preview": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"pre"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer__bckp2.xml",
    "chars": 2900,
    "preview": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"pre"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer_bckp.xml",
    "chars": 2570,
    "preview": "<mujoco model=\"swimmer\">\n  <compiler angle=\"degree\" coordinate=\"local\" inertiafromgeom=\"true\"/>\n  <option collision=\"pre"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/coupled_half_cheetah.py",
    "chars": 1834,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nimport os\n\n\nclass CoupledHalfCheetah(muj"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/manyagent_ant.py",
    "chars": 5791,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nfrom jinja2 import Template\nimport os\n\nc"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/manyagent_swimmer.py",
    "chars": 3735,
    "preview": "import numpy as np\nfrom gym import utils\nfrom gym.envs.mujoco import mujoco_env\nimport os\nfrom jinja2 import Template\n\nc"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/mujoco_multi.py",
    "chars": 10823,
    "preview": "from functools import partial\nimport gym\nfrom gym.spaces import Box\nfrom gym.wrappers import TimeLimit\nimport numpy as n"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/multiagentenv.py",
    "chars": 2411,
    "preview": "from collections import namedtuple\nimport numpy as np\n\n\ndef convert(dictionary):\n    return namedtuple('GenericDict', di"
  },
  {
    "path": "envs/ma_mujoco/multiagent_mujoco/obsk.py",
    "chars": 24148,
    "preview": "import itertools\nimport numpy as np\nfrom copy import deepcopy\n\nclass Node():\n    def __init__(self, label, qpos_ids, qve"
  },
  {
    "path": "envs/starcraft2/StarCraft2_Env.py",
    "chars": 80827,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom .mult"
  },
  {
    "path": "envs/starcraft2/multiagentenv.py",
    "chars": 2006,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\n\nclass Mul"
  },
  {
    "path": "envs/starcraft2/smac_maps.py",
    "chars": 10539,
    "preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\n\nfrom pysc2"
  },
  {
    "path": "install_sc2.sh",
    "chars": 863,
    "preview": "#!/bin/bash\n# Install SC2 and add the custom maps\n\nif [ -z \"$EXP_DIR\" ]\nthen\n    EXP_DIR=~\nfi\n\necho \"EXP_DIR: $EXP_DIR\"\n"
  },
  {
    "path": "requirements.txt",
    "chars": 2683,
    "preview": "absl-py==0.9.0\naiohttp==3.6.2\naioredis==1.3.1\nastor==0.8.0\nastunparse==1.6.3\nasync-timeout==3.0.1\natari-py==0.2.6\natomic"
  },
  {
    "path": "runners/__init__.py",
    "chars": 59,
    "preview": "from runners import separated\n\n__all__=[\n\n    \"separated\"\n]"
  },
  {
    "path": "runners/separated/__init__.py",
    "chars": 103,
    "preview": "from runners.separated import base_runner,smac_runner\n\n__all__=[\n    \"base_runner\",\n    \"smac_runner\"\n]"
  },
  {
    "path": "runners/separated/base_runner.py",
    "chars": 11706,
    "preview": "    \nimport time\nimport os\nimport numpy as np\nfrom itertools import chain\nimport torch\nfrom tensorboardX import SummaryW"
  },
  {
    "path": "runners/separated/mujoco_runner.py",
    "chars": 10573,
    "preview": "import time\nimport numpy as np\nfrom functools import reduce\nimport torch\nfrom runners.separated.base_runner import Runne"
  },
  {
    "path": "runners/separated/smac_runner.py",
    "chars": 11702,
    "preview": "import time\nimport numpy as np\nfrom functools import reduce\nimport torch\nfrom runners.separated.base_runner import Runne"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/train/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/train/train_mujoco.py",
    "chars": 5971,
    "preview": "#!/usr/bin/env python\nimport sys\nimport os\nsys.path.append(\"../\")\nimport socket\nimport setproctitle\nimport numpy as np\nf"
  },
  {
    "path": "scripts/train/train_smac.py",
    "chars": 5371,
    "preview": "#!/usr/bin/env python\nimport sys\nimport os\nsys.path.append(\"../\")\nimport socket\nimport setproctitle\nimport numpy as np\nf"
  },
  {
    "path": "scripts/train_mujoco.sh",
    "chars": 837,
    "preview": "#!/bin/sh\nenv=\"mujoco\"\nscenario=\"Ant-v2\"\nagent_conf=\"2x4\"\nagent_obsk=2\nalgo=\"happo\"\nexp=\"mlp\"\nrunning_max=20\nkl_threshol"
  },
  {
    "path": "scripts/train_smac.sh",
    "chars": 705,
    "preview": "#!/bin/sh\nenv=\"StarCraft2\"\nmap=\"3s5z\"\nalgo=\"happo\"\nexp=\"mlp\"\nrunning_max=20\nkl_threshold=0.06\necho \"env is ${env}, map i"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/multi_discrete.py",
    "chars": 2346,
    "preview": "import gym\nimport numpy as np\n\n# An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates)"
  },
  {
    "path": "utils/popart.py",
    "chars": 3106,
    "preview": "\nimport numpy as np\n\nimport torch\nimport torch.nn as nn\n\n\nclass PopArt(nn.Module):\n    \"\"\" Normalize a vector of observa"
  },
  {
    "path": "utils/separated_buffer.py",
    "chars": 23539,
    "preview": "import torch\nimport numpy as np\nfrom collections import defaultdict\nfrom utils.util import check, get_shape_from_obs_spa"
  },
  {
    "path": "utils/util.py",
    "chars": 2233,
    "preview": "import numpy as np\nimport math\nimport torch\n\ndef check(input):\n    if type(input) == np.ndarray:\n        return torch.fr"
  }
]

About this extraction

This page contains the full source code of the cyanrain7/TRPO-in-MARL GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 56 files (372.2 KB), approximately 96.4k tokens, and a symbol index with 405 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!