master 6daedd29e4b4 cached
40 files
361.1 KB
89.4k tokens
311 symbols
1 requests
Download .txt
Showing preview only (375K chars total). Download the full file or copy to clipboard to get everything.
Repository: vwxyzjn/invalid-action-masking
Branch: master
Commit: 6daedd29e4b4
Files: 40
Total size: 361.1 KB

Directory structure:
gitextract_uyz0i984/

├── .gitignore
├── .python-version
├── LICENSE
├── README.MD
├── build.sh
├── gym_vec_api/
│   ├── ppo_multidiscrete.py
│   └── ppo_multidiscrete_mask.py
├── invalid_action_masking/
│   ├── ppo_10x10.py
│   ├── ppo_16x16.py
│   ├── ppo_24x24.py
│   ├── ppo_4x4.py
│   ├── ppo_no_adj_10x10.py
│   ├── ppo_no_adj_16x16.py
│   ├── ppo_no_adj_24x24.py
│   ├── ppo_no_adj_4x4.py
│   ├── ppo_no_mask_10x10.py
│   ├── ppo_no_mask_16x16.py
│   ├── ppo_no_mask_24x24.py
│   └── ppo_no_mask_4x4.py
├── plots/
│   ├── analysis.py
│   ├── approx_kl.py
│   ├── charts_episode_reward/
│   │   ├── all_df_cache.pkl
│   │   ├── data/
│   │   │   ├── MicrortsMining10x10F9-v0.pkl
│   │   │   ├── MicrortsMining16x16F9-v0.pkl
│   │   │   ├── MicrortsMining24x24F9-v0.pkl
│   │   │   └── MicrortsMining4x4F9-v0.pkl
│   │   ├── envs_cache.pkl
│   │   └── exp_names_cache.pkl
│   ├── episode_reward.py
│   └── losses_approx_kl/
│       ├── all_df_cache.pkl
│       ├── data/
│       │   ├── MicrortsMining10x10F9-v0.pkl
│       │   ├── MicrortsMining16x16F9-v0.pkl
│       │   ├── MicrortsMining24x24F9-v0.pkl
│       │   └── MicrortsMining4x4F9-v0.pkl
│       ├── envs_cache.pkl
│       └── exp_names_cache.pkl
├── ppo.py
├── pyproject.toml
├── requirements.txt
└── test.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
**.tfevents.**

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/


================================================
FILE: .python-version
================================================
3.9.5/envs/invalid-action-masking

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2020 neurips2020submission

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.MD
================================================
# A Closer Look at Invalid Action Masking in Policy Gradient Algorithms

This repo contains the source code to reproduce the results in the paper [*A Closer Look at Invalid Action Masking in Policy Gradient Algorithms*](https://arxiv.org/abs/2006.14171). 

## Get started

If you have pyenv or poetry:
```bash
poetry install

rm -rf ~/microrts && mkdir ~/microrts && \
    wget -O ~/microrts/microrts.zip http://microrts.s3.amazonaws.com/microrts/artifacts/202004222224.microrts.zip && \
    unzip ~/microrts/microrts.zip -d ~/microrts/ && \
    rm ~/microrts/microrts.zip
```

Else, you can also install dependencies via `pip install -r requirements.txt`.

## 10x10 Experiments
```
poetry run python invalid_action_masking/ppo_10x10.py
poetry run python invalid_action_masking/ppo_no_adj_10x10.py
poetry run python invalid_action_masking/ppo_no_mask_10x10.py
poetry run python ppo.py # newer & recommended PPO implementation that matches implementation details in `openai/baselines`
```

## Citation

```bibtex
@inproceedings{huang2020closer,
  author    = {Shengyi Huang and
               Santiago Onta{\~{n}}{\'{o}}n},
  editor    = {Roman Bart{\'{a}}k and
               Fazel Keshtkar and
               Michael Franklin},
  title     = {A Closer Look at Invalid Action Masking in Policy Gradient Algorithms},
  booktitle = {Proceedings of the Thirty-Fifth International Florida Artificial Intelligence
               Research Society Conference, {FLAIRS} 2022, Hutchinson Island, Jensen
               Beach, Florida, USA, May 15-18, 2022},
  year      = {2022},
  url       = {https://doi.org/10.32473/flairs.v35i.130584},
  doi       = {10.32473/flairs.v35i.130584},
  timestamp = {Thu, 09 Jun 2022 16:44:11 +0200},
  biburl    = {https://dblp.org/rec/conf/flairs/HuangO22.bib},
  bibsource = {dblp computer science bibliography, https://dblp.org}
}
```


================================================
FILE: build.sh
================================================
docker build -t invalid_action_masking:latest -f sharedmemory.Dockerfile .


================================================
FILE: gym_vec_api/ppo_multidiscrete.py
================================================
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import gym_microrts # fmt: off
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
        help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
        help='the id of the gym environment')
    parser.add_argument('--learning-rate', type=float, default=2.5e-4,
        help='the learning rate of the optimizer')
    parser.add_argument('--seed', type=int, default=1,
        help='seed of the experiment')
    parser.add_argument('--total-timesteps', type=int, default=10000000,
        help='total timesteps of the experiments')
    parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, cuda will be enabled by default')
    parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='if toggled, this experiment will be tracked with Weights and Biases')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
        help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='weather to capture videos of the agent performances (check out `videos` folder)')

    # Algorithm specific arguments
    parser.add_argument('--num-envs', type=int, default=4,
        help='the number of parallel game environments')
    parser.add_argument('--num-steps', type=int, default=128,
        help='the number of steps to run in each environment per policy rollout')
    parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Use GAE for advantage computation')
    parser.add_argument('--gamma', type=float, default=0.99,
        help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
        help='the lambda for the general advantage estimation')
    parser.add_argument('--num-minibatches', type=int, default=4,
        help='the number of mini-batches')
    parser.add_argument('--update-epochs', type=int, default=4,
        help="the K epochs to update the policy")
    parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggles advantages normalization")
    parser.add_argument('--clip-coef', type=float, default=0.1,
        help="the surrogate clipping coefficient")
    parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--ent-coef', type=float, default=0.01,
        help="coefficient of the entropy")
    parser.add_argument('--vf-coef', type=float, default=0.5,
        help="coefficient of the value function")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
        help='the maximum norm for the gradient clipping')
    parser.add_argument('--target-kl', type=float, default=None,
        help='the target KL divergence threshold')
    args = parser.parse_args()
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    # fmt: on
    return args


def make_env(gym_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.network = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(27, 16, kernel_size=3, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(16, 32, kernel_size=2)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(32*3*3, 128)),
            nn.ReLU(),
        )
        self.nvec = envs.single_action_space.nvec
        self.actor = layer_init(nn.Linear(128, self.nvec.sum()), std=0.01)
        self.critic = layer_init(nn.Linear(128, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x))

    def get_action_and_value(self, x, action=None):
        hidden = self.network(x)
        logits = self.actor(hidden)
        split_logits = torch.split(logits, self.nvec.tolist(), dim=1)
        multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        return action.T, logprob.sum(0), entropy.sum(0), self.critic(hidden)


if __name__ == "__main__":
    args = parse_args()
    run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.MultiDiscrete), "only MultiDiscrete action space is supported"

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs = torch.Tensor(envs.reset()).to(device)
    next_done = torch.zeros(args.num_envs).to(device)
    num_updates = args.total_timesteps // args.batch_size

    for update in range(1, num_updates + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            obs[step] = next_obs
            dones[step] = next_done

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs)
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

            for item in info:
                if "episode" in item.keys():
                    print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                    writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step)
                    break

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            if args.gae:
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        nextvalues = values[t + 1]
                    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(device)
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
                advantages = returns - values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)

        # Optimizaing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds].T)
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    # old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()


================================================
FILE: gym_vec_api/ppo_multidiscrete_mask.py
================================================
import argparse
import os
import random
import time
from distutils.util import strtobool

import gym
import gym_microrts # fmt: off
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter


def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
        help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
        help='the id of the gym environment')
    parser.add_argument('--learning-rate', type=float, default=2.5e-4,
        help='the learning rate of the optimizer')
    parser.add_argument('--seed', type=int, default=1,
        help='seed of the experiment')
    parser.add_argument('--total-timesteps', type=int, default=10000000,
        help='total timesteps of the experiments')
    parser.add_argument('--torch-deterministic', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--cuda', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='if toggled, cuda will be enabled by default')
    parser.add_argument('--track', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='if toggled, this experiment will be tracked with Weights and Biases')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
        help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
        help="the entity (team) of wandb's project")
    parser.add_argument('--capture-video', type=lambda x:bool(strtobool(x)), default=False, nargs='?', const=True,
        help='weather to capture videos of the agent performances (check out `videos` folder)')

    # Algorithm specific arguments
    parser.add_argument('--num-envs', type=int, default=4,
        help='the number of parallel game environments')
    parser.add_argument('--num-steps', type=int, default=128,
        help='the number of steps to run in each environment per policy rollout')
    parser.add_argument('--anneal-lr', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--gae', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Use GAE for advantage computation')
    parser.add_argument('--gamma', type=float, default=0.99,
        help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.95,
        help='the lambda for the general advantage estimation')
    parser.add_argument('--num-minibatches', type=int, default=4,
        help='the number of mini-batches')
    parser.add_argument('--update-epochs', type=int, default=4,
        help="the K epochs to update the policy")
    parser.add_argument('--norm-adv', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help="Toggles advantages normalization")
    parser.add_argument('--clip-coef', type=float, default=0.1,
        help="the surrogate clipping coefficient")
    parser.add_argument('--clip-vloss', type=lambda x:bool(strtobool(x)), default=True, nargs='?', const=True,
        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--ent-coef', type=float, default=0.01,
        help="coefficient of the entropy")
    parser.add_argument('--vf-coef', type=float, default=0.5,
        help="coefficient of the value function")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
        help='the maximum norm for the gradient clipping')
    parser.add_argument('--target-kl', type=float, default=None,
        help='the target KL divergence threshold')
    args = parser.parse_args()
    args.batch_size = int(args.num_envs * args.num_steps)
    args.minibatch_size = int(args.batch_size // args.num_minibatches)
    # fmt: on
    return args


def make_env(gym_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(gym_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk


def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class Transpose(nn.Module):
    def __init__(self, permutation):
        super().__init__()
        self.permutation = permutation

    def forward(self, x):
        return x.permute(self.permutation)


class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.network = nn.Sequential(
            Transpose((0, 3, 1, 2)),
            layer_init(nn.Conv2d(27, 16, kernel_size=3, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(16, 32, kernel_size=2)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(32*3*3, 128)),
            nn.ReLU(),
        )
        self.nvec = envs.single_action_space.nvec
        self.actor = layer_init(nn.Linear(128, self.nvec.sum()), std=0.01)
        self.critic = layer_init(nn.Linear(128, 1), std=1)

    def get_value(self, x):
        return self.critic(self.network(x))

    def get_action_and_value(self, x, action_mask, action=None):
        hidden = self.network(x)
        logits = self.actor(hidden)
        split_logits = torch.split(logits, self.nvec.tolist(), dim=1)
        split_action_masks = torch.split(action_mask, self.nvec.tolist(), dim=1)
        multi_categoricals = [
            CategoricalMasked(logits=logits, masks=iam)
            for (logits, iam) in zip(split_logits, split_action_masks)
        ]
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        entropy = torch.stack([categorical.entropy() for categorical in multi_categoricals])
        return action.T, logprob.sum(0), entropy.sum(0), self.critic(hidden)


if __name__ == "__main__":
    args = parse_args()
    run_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    if args.track:
        import wandb

        wandb.init(
            project=args.wandb_project_name,
            entity=args.wandb_entity,
            sync_tensorboard=True,
            config=vars(args),
            name=run_name,
            monitor_gym=True,
            save_code=True,
        )
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # env setup
    envs = gym.vector.SyncVectorEnv(
        [make_env(args.gym_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)]
    )
    assert isinstance(envs.single_action_space, gym.spaces.MultiDiscrete), "only MultiDiscrete action space is supported"

    agent = Agent(envs).to(device)
    optimizer = optim.Adam(agent.parameters(), lr=args.learning_rate, eps=1e-5)

    # ALGO Logic: Storage setup
    obs = torch.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape).to(device)
    actions = torch.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape).to(device)
    logprobs = torch.zeros((args.num_steps, args.num_envs)).to(device)
    rewards = torch.zeros((args.num_steps, args.num_envs)).to(device)
    dones = torch.zeros((args.num_steps, args.num_envs)).to(device)
    values = torch.zeros((args.num_steps, args.num_envs)).to(device)
    action_masks = torch.zeros((args.num_steps, args.num_envs) + (envs.single_action_space.nvec.sum(),)).to(device)

    # TRY NOT TO MODIFY: start the game
    global_step = 0
    start_time = time.time()
    next_obs = torch.Tensor(envs.reset()).to(device)
    next_done = torch.zeros(args.num_envs).to(device)
    num_updates = args.total_timesteps // args.batch_size

    for update in range(1, num_updates + 1):
        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            frac = 1.0 - (update - 1.0) / num_updates
            lrnow = frac * args.learning_rate
            optimizer.param_groups[0]["lr"] = lrnow

        for step in range(0, args.num_steps):
            global_step += 1 * args.num_envs
            obs[step] = next_obs
            dones[step] = next_done
            action_masks[step] = torch.Tensor(
                np.array([env.action_mask for env in envs.envs])
            )

            # ALGO LOGIC: action logic
            with torch.no_grad():
                action, logprob, _, value = agent.get_action_and_value(next_obs, action_masks[step])
                values[step] = value.flatten()
            actions[step] = action
            logprobs[step] = logprob

            # TRY NOT TO MODIFY: execute the game and log data.
            next_obs, reward, done, info = envs.step(action.cpu().numpy())
            rewards[step] = torch.tensor(reward).to(device).view(-1)
            next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device)

            for item in info:
                if "episode" in item.keys():
                    print(f"global_step={global_step}, episodic_return={item['episode']['r']}")
                    writer.add_scalar("charts/episodic_return", item["episode"]["r"], global_step)
                    writer.add_scalar("charts/episodic_length", item["episode"]["l"], global_step)
                    break

        # bootstrap value if not done
        with torch.no_grad():
            next_value = agent.get_value(next_obs).reshape(1, -1)
            if args.gae:
                advantages = torch.zeros_like(rewards).to(device)
                lastgaelam = 0
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        nextvalues = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        nextvalues = values[t + 1]
                    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
                    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
                returns = advantages + values
            else:
                returns = torch.zeros_like(rewards).to(device)
                for t in reversed(range(args.num_steps)):
                    if t == args.num_steps - 1:
                        nextnonterminal = 1.0 - next_done
                        next_return = next_value
                    else:
                        nextnonterminal = 1.0 - dones[t + 1]
                        next_return = returns[t + 1]
                    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
                advantages = returns - values

        # flatten the batch
        b_obs = obs.reshape((-1,) + envs.single_observation_space.shape)
        b_logprobs = logprobs.reshape(-1)
        b_actions = actions.reshape((-1,) + envs.single_action_space.shape)
        b_advantages = advantages.reshape(-1)
        b_returns = returns.reshape(-1)
        b_values = values.reshape(-1)
        b_action_masks = action_masks.reshape((-1, action_masks.shape[-1]))

        # Optimizaing the policy and value network
        b_inds = np.arange(args.batch_size)
        clipfracs = []
        for epoch in range(args.update_epochs):
            np.random.shuffle(b_inds)
            for start in range(0, args.batch_size, args.minibatch_size):
                end = start + args.minibatch_size
                mb_inds = b_inds[start:end]

                _, newlogprob, entropy, newvalue = agent.get_action_and_value(
                    b_obs[mb_inds],
                    b_action_masks[mb_inds],
                    b_actions.long()[mb_inds].T,
                )
                logratio = newlogprob - b_logprobs[mb_inds]
                ratio = logratio.exp()

                with torch.no_grad():
                    # calculate approx_kl http://joschu.net/blog/kl-approx.html
                    # old_approx_kl = (-logratio).mean()
                    approx_kl = ((ratio - 1) - logratio).mean()
                    clipfracs += [((ratio - 1.0).abs() > args.clip_coef).float().mean().item()]

                mb_advantages = b_advantages[mb_inds]
                if args.norm_adv:
                    mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

                # Policy loss
                pg_loss1 = -mb_advantages * ratio
                pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
                pg_loss = torch.max(pg_loss1, pg_loss2).mean()

                # Value loss
                newvalue = newvalue.view(-1)
                if args.clip_vloss:
                    v_loss_unclipped = (newvalue - b_returns[mb_inds]) ** 2
                    v_clipped = b_values[mb_inds] + torch.clamp(
                        newvalue - b_values[mb_inds],
                        -args.clip_coef,
                        args.clip_coef,
                    )
                    v_loss_clipped = (v_clipped - b_returns[mb_inds]) ** 2
                    v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                    v_loss = 0.5 * v_loss_max.mean()
                else:
                    v_loss = 0.5 * ((newvalue - b_returns[mb_inds]) ** 2).mean()

                entropy_loss = entropy.mean()
                loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
                optimizer.step()

            if args.target_kl is not None:
                if approx_kl > args.target_kl:
                    break

        y_pred, y_true = b_values.cpu().numpy(), b_returns.cpu().numpy()
        var_y = np.var(y_true)
        explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

        # TRY NOT TO MODIFY: record rewards for plotting purposes
        writer.add_scalar("charts/learning_rate", optimizer.param_groups[0]["lr"], global_step)
        writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
        writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step)
        writer.add_scalar("losses/entropy", entropy_loss.item(), global_step)
        writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
        writer.add_scalar("losses/clipfrac", np.mean(clipfracs), global_step)
        writer.add_scalar("losses/explained_variance", explained_var, global_step)
        print("SPS:", int(global_step / (time.time() - start_time)))
        writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

    envs.close()
    writer.close()


================================================
FILE: invalid_action_masking/ppo_10x10.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=False,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=False,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=False,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=False,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3,),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*6*6, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3,),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*6*6, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

def evaluate_with_no_mask():
    evaluate_rewards = []
    evaluate_invalid_action_stats = []
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            evaluate_rewards += [np.sum(real_rewards)]
            real_rewards = []
            evaluate_invalid_action_stats += [pd.DataFrame(invalid_action_stats).sum(0)]
            invalid_action_stats = []
            next_obs = np.array(env.reset())
    
    return np.average(evaluate_rewards), pd.DataFrame(evaluate_invalid_action_stats).mean(0)

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                invalid_action_masks[minibatch_ind])
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
        
    # evaluate no mask
    average_reward, average_invalid_action_stats = evaluate_with_no_mask()
    writer.add_scalar("evals/charts/episode_reward", average_reward, global_step)
    print(f"global_step={global_step}, eval_reward={average_reward}")
    for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
        writer.add_scalar(f"evals/stats/{key}", average_invalid_action_stats[idx], global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_16x16.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining16x16F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*12*12, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*12*12, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

def evaluate_with_no_mask():
    evaluate_rewards = []
    evaluate_invalid_action_stats = []
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            evaluate_rewards += [np.sum(real_rewards)]
            real_rewards = []
            evaluate_invalid_action_stats += [pd.DataFrame(invalid_action_stats).sum(0)]
            invalid_action_stats = []
            next_obs = np.array(env.reset())
    
    return np.average(evaluate_rewards), pd.DataFrame(evaluate_invalid_action_stats).mean(0)

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                invalid_action_masks[minibatch_ind])
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
        
    # evaluate no mask
    average_reward, average_invalid_action_stats = evaluate_with_no_mask()
    writer.add_scalar("evals/charts/episode_reward", average_reward, global_step)
    print(f"global_step={global_step}, eval_reward={average_reward}")
    for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
        writer.add_scalar(f"evals/stats/{key}", average_invalid_action_stats[idx], global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_24x24.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining24x24F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

def evaluate_with_no_mask():
    evaluate_rewards = []
    evaluate_invalid_action_stats = []
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            evaluate_rewards += [np.sum(real_rewards)]
            real_rewards = []
            evaluate_invalid_action_stats += [pd.DataFrame(invalid_action_stats).sum(0)]
            invalid_action_stats = []
            next_obs = np.array(env.reset())
    
    return np.average(evaluate_rewards), pd.DataFrame(evaluate_invalid_action_stats).mean(0)

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                invalid_action_masks[minibatch_ind])
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
        
    # evaluate no mask
    average_reward, average_invalid_action_stats = evaluate_with_no_mask()
    writer.add_scalar("evals/charts/episode_reward", average_reward, global_step)
    print(f"global_step={global_step}, eval_reward={average_reward}")
    for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
        writer.add_scalar(f"evals/stats/{key}", average_invalid_action_stats[idx], global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_4x4.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining4x4F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=2,),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(16*3*3, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )
    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=2,),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(16*3*3, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

def evaluate_with_no_mask():
    evaluate_rewards = []
    evaluate_invalid_action_stats = []
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            evaluate_rewards += [np.sum(real_rewards)]
            real_rewards = []
            evaluate_invalid_action_stats += [pd.DataFrame(invalid_action_stats).sum(0)]
            invalid_action_stats = []
            next_obs = np.array(env.reset())
    
    return np.average(evaluate_rewards), pd.DataFrame(evaluate_invalid_action_stats).mean(0)

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                invalid_action_masks[minibatch_ind])
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)
        
    # evaluate no mask
    average_reward, average_invalid_action_stats = evaluate_with_no_mask()
    writer.add_scalar("evals/charts/episode_reward", average_reward, global_step)
    print(f"global_step={global_step}, eval_reward={average_reward}")
    for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
        writer.add_scalar(f"evals/stats/{key}", average_invalid_action_stats[idx], global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_no_adj_10x10.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining10x10F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3,),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*6*6, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3,),
            nn.MaxPool2d(1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*6*6, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
            
            # CORE LOGIC:
            # use the action generated by CategoricalMasked, but 
            # don't adjust the logprobability accordingly. Instead, calculate the log
            # probability using Categorical
            action, logproba, _, probs = pg.get_action(obs[step:step+1], action=action)
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,)
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_no_adj_16x16.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining16x16F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3),
            nn.MaxPool2d(1),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
            
            # CORE LOGIC:
            # use the action generated by CategoricalMasked, but 
            # don't adjust the logprobability accordingly. Instead, calculate the log
            # probability using Categorical
            action, logproba, _, probs = pg.get_action(obs[step:step+1], action=action)
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_reward", np.sum(real_rewards), global_step)
            print(f"global_step={global_step}, episode_reward={np.sum(real_rewards)}")
            for i in range(len(env.rfs)):
                writer.add_scalar(f"charts/episode_reward/{str(env.rfs[i])}", raw_rewards.sum(1)[i], global_step)
            real_rewards = []
            for key, idx in zip(info['invalid_action_stats'], range(len(info['invalid_action_stats']))):
                writer.add_scalar(f"stats/{key}", pd.DataFrame(invalid_action_stats).sum(0)[idx], global_step)
            invalid_action_stats = []
            next_obs = np.array(env.reset())

    # bootstrap reward if not done. reached the batch limit
    last_value = 0
    if not dones[step]:
        last_value = vf.forward(next_obs.reshape((1,)+next_obs.shape))[0].detach().cpu().numpy()[0]
    bootstrapped_rewards = np.append(rewards, last_value)

    # calculate the returns and advantages
    if args.gae:
        bootstrapped_values = np.append(values.detach().cpu().numpy(), last_value)
        deltas = bootstrapped_rewards[:-1] + args.gamma * bootstrapped_values[1:] * (1-dones) - bootstrapped_values[:-1]
        advantages = discount_cumsum(deltas, dones, args.gamma * args.gae_lambda)
        advantages = torch.Tensor(advantages).to(device)
        returns = advantages + values
    else:
        returns = discount_cumsum(bootstrapped_rewards, dones, args.gamma)[:-1]
        advantages = returns - values.detach().cpu().numpy()
        advantages = torch.Tensor(advantages).to(device)
        returns = torch.Tensor(returns).to(device)

    # Advantage normalization
    if args.norm_adv:
        EPS = 1e-10
        advantages = (advantages - advantages.mean()) / (advantages.std() + EPS)

    # Optimizaing policy network
    entropys = []
    target_pg = Policy().to(device)
    inds = np.arange(args.batch_size,)
    for i_epoch_pi in range(args.update_epochs):
        np.random.shuffle(inds)
        for start in range(0, args.batch_size, args.minibatch_size):
            end = start + args.minibatch_size
            minibatch_ind = inds[start:end]
            target_pg.load_state_dict(pg.state_dict())
            
            _, newlogproba, _, _ = pg.get_action(
                obs[minibatch_ind],
                torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,)
            ratio = (newlogproba - logprobs[:,minibatch_ind]).exp()

            # Policy loss as in OpenAI SpinUp
            clip_adv = torch.where(advantages[minibatch_ind] > 0,
                                    (1.+args.clip_coef) * advantages[minibatch_ind],
                                    (1.-args.clip_coef) * advantages[minibatch_ind]).to(device)

            # Entropy computation with resampled actions
            entropy = -(newlogproba.exp() * newlogproba).mean()
            entropys.append(entropy.item())

            policy_loss = -torch.min(ratio * advantages[minibatch_ind], clip_adv) + args.ent_coef * entropy
            policy_loss = policy_loss.mean()
            
            pg_optimizer.zero_grad()
            policy_loss.backward()
            nn.utils.clip_grad_norm_(pg.parameters(), args.max_grad_norm)
            pg_optimizer.step()

            approx_kl = (logprobs[:,minibatch_ind] - newlogproba).mean()
            # Optimizing value network
            new_values = vf.forward(obs[minibatch_ind]).view(-1)

            # Value loss clipping
            if args.clip_vloss:
                v_loss_unclipped = ((new_values - returns[minibatch_ind]) ** 2)
                v_clipped = values[minibatch_ind] + torch.clamp(new_values - values[minibatch_ind], -args.clip_coef, args.clip_coef)
                v_loss_clipped = (v_clipped - returns[minibatch_ind])**2
                v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
                v_loss = 0.5 * v_loss_max.mean()
            else:
                v_loss = torch.mean((returns[minibatch_ind]- new_values).pow(2))

            v_optimizer.zero_grad()
            v_loss.backward()
            nn.utils.clip_grad_norm_(vf.parameters(), args.max_grad_norm)
            v_optimizer.step()

        if args.kle_stop:
            if approx_kl > args.target_kl:
                break
        if args.kle_rollback:
            if (logprobs[:,minibatch_ind] - 
                pg.get_action(
                    obs[minibatch_ind],
                    torch.LongTensor(actions[minibatch_ind].astype(np.int)).to(device).T,
                    invalid_action_masks[minibatch_ind])[1]).mean() > args.target_kl:
                pg.load_state_dict(target_pg.state_dict())
                break

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    writer.add_scalar("losses/value_loss", v_loss.item(), global_step)
    writer.add_scalar("charts/policy_learning_rate", pg_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("charts/value_learning_rate", v_optimizer.param_groups[0]['lr'], global_step)
    writer.add_scalar("losses/policy_loss", policy_loss.item(), global_step)
    writer.add_scalar("losses/entropy", np.mean(entropys), global_step)
    writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step)
    if args.kle_stop or args.kle_rollback:
        writer.add_scalar("debug/pg_stop_iter", i_epoch_pi, global_step)

env.close()
writer.close()


================================================
FILE: invalid_action_masking/ppo_no_adj_24x24.py
================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.utils.tensorboard import SummaryWriter

from cleanrl.common import preprocess_obs_space, preprocess_ac_space
import argparse
import numpy as np
import gym
import gym_microrts
from gym.wrappers import TimeLimit, Monitor
from gym.spaces import Discrete, Box, MultiBinary, MultiDiscrete, Space
import time
import random
import os
import pandas as pd

# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
    def __init__(self, epsilon=1e-4, shape=()):
        self.mean = np.zeros(shape, 'float64')
        self.var = np.ones(shape, 'float64')
        self.count = epsilon

    def update(self, x):
        batch_mean = np.mean([x], axis=0)
        batch_var = np.var([x], axis=0)
        batch_count = 1
        self.update_from_moments(batch_mean, batch_var, batch_count)

    def update_from_moments(self, batch_mean, batch_var, batch_count):
        self.mean, self.var, self.count = update_mean_var_count_from_moments(
            self.mean, self.var, self.count, batch_mean, batch_var, batch_count)

def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
    delta = batch_mean - mean
    tot_count = count + batch_count

    new_mean = mean + delta * batch_count / tot_count
    m_a = var * count
    m_b = batch_var * batch_count
    M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
    new_var = M2 / tot_count
    new_count = tot_count

    return new_mean, new_var, new_count
class NormalizedEnv(gym.core.Wrapper):
    def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
        super(NormalizedEnv, self).__init__(env)
        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=(1,)) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(())
        self.gamma = gamma
        self.epsilon = epsilon

    def step(self, action):
        obs, rews, news, infos = self.env.step(action)
        infos['real_reward'] = rews
        # print("before", self.ret)
        self.ret = self.ret * self.gamma + rews
        # print("after", self.ret)
        obs = self._obfilt(obs)
        if self.ret_rms:
            self.ret_rms.update(np.array([self.ret].copy()))
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret = self.ret * (1-float(news))
        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        self.ret = np.zeros(())
        obs = self.env.reset()
        return self._obfilt(obs)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PPO agent')
    # Common arguments
    parser.add_argument('--exp-name', type=str, default=os.path.basename(__file__).rstrip(".py"),
                       help='the name of this experiment')
    parser.add_argument('--gym-id', type=str, default="MicrortsMining24x24F9-v0",
                       help='the id of the gym environment')
    parser.add_argument('--seed', type=int, default=1,
                       help='seed of the experiment')
    parser.add_argument('--episode-length', type=int, default=0,
                       help='the maximum length of each episode')
    parser.add_argument('--total-timesteps', type=int, default=100000,
                       help='total timesteps of the experiments')
    parser.add_argument('--no-torch-deterministic', action='store_false', dest="torch_deterministic", default=True,
                       help='if toggled, `torch.backends.cudnn.deterministic=False`')
    parser.add_argument('--no-cuda', action='store_false', dest="cuda", default=True,
                       help='if toggled, cuda will not be enabled by default')
    parser.add_argument('--prod-mode', action='store_true', default=False,
                       help='run the script in production mode and use wandb to log outputs')
    parser.add_argument('--capture-video', action='store_true', default=False,
                       help='weather to capture videos of the agent performances (check out `videos` folder)')
    parser.add_argument('--wandb-project-name', type=str, default="cleanRL",
                       help="the wandb's project name")
    parser.add_argument('--wandb-entity', type=str, default=None,
                       help="the entity (team) of wandb's project")

    # Algorithm specific arguments
    parser.add_argument('--batch-size', type=int, default=2048,
                       help='the batch size of ppo')
    parser.add_argument('--minibatch-size', type=int, default=256,
                       help='the mini batch size of ppo')
    parser.add_argument('--gamma', type=float, default=0.99,
                       help='the discount factor gamma')
    parser.add_argument('--gae-lambda', type=float, default=0.97,
                       help='the lambda for the general advantage estimation')
    parser.add_argument('--ent-coef', type=float, default=0.01,
                       help="coefficient of the entropy")
    parser.add_argument('--max-grad-norm', type=float, default=0.5,
                       help='the maximum norm for the gradient clipping')
    parser.add_argument('--clip-coef', type=float, default=0.2,
                       help="the surrogate clipping coefficient")
    parser.add_argument('--update-epochs', type=int, default=10,
                        help="the K epochs to update the policy")
    parser.add_argument('--kle-stop', action='store_true', default=False,
                        help='If toggled, the policy updates will be early stopped w.r.t target-kl')
    parser.add_argument('--kle-rollback', action='store_true', default=False,
                        help='If toggled, the policy updates will roll back to previous policy if KL exceeds target-kl')
    parser.add_argument('--target-kl', type=float, default=0.015,
                        help='the target-kl variable that is referred by --kl')
    parser.add_argument('--gae', action='store_true', default=True,
                        help='Use GAE for advantage computation')
    parser.add_argument('--policy-lr', type=float, default=3e-4,
                        help="the learning rate of the policy optimizer")
    parser.add_argument('--value-lr', type=float, default=3e-4,
                        help="the learning rate of the critic optimizer")
    parser.add_argument('--norm-obs', action='store_true', default=True,
                        help="Toggles observation normalization")
    parser.add_argument('--norm-returns', action='store_true', default=False,
                        help="Toggles returns normalization")
    parser.add_argument('--norm-adv', action='store_true', default=True,
                        help="Toggles advantages normalization")
    parser.add_argument('--obs-clip', type=float, default=10.0,
                        help="Value for reward clipping, as per the paper")
    parser.add_argument('--rew-clip', type=float, default=10.0,
                        help="Value for observation clipping, as per the paper")
    parser.add_argument('--anneal-lr', action='store_true', default=True,
                        help="Toggle learning rate annealing for policy and value networks")
    parser.add_argument('--weights-init', default="orthogonal", choices=["xavier", 'orthogonal'],
                        help='Selects the scheme to be used for weights initialization'),
    parser.add_argument('--clip-vloss', action="store_true", default=True,
                        help='Toggles wheter or not to use a clipped loss for the value function, as per the paper.')
    parser.add_argument('--pol-layer-norm', action='store_true', default=False,
                       help='Enables layer normalization in the policy network')
    args = parser.parse_args()
    if not args.seed:
        args.seed = int(time.time())

args.features_turned_on = sum([args.kle_stop, args.kle_rollback, args.gae, args.norm_obs, args.norm_returns, args.norm_adv, args.anneal_lr, args.clip_vloss, args.pol_layer_norm])
# TRY NOT TO MODIFY: setup the environment
experiment_name = f"{args.gym_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
writer = SummaryWriter(f"runs/{experiment_name}")
writer.add_text('hyperparameters', "|param|value|\n|-|-|\n%s" % (
        '\n'.join([f"|{key}|{value}|" for key, value in vars(args).items()])))

if args.prod_mode:
    import wandb
    wandb.init(project=args.wandb_project_name, entity=args.wandb_entity, tensorboard=True, config=vars(args), name=experiment_name, monitor_gym=True)
    writer = SummaryWriter(f"/tmp/{experiment_name}")
    wandb.save(os.path.abspath(__file__))

# TRY NOT TO MODIFY: seeding
device = torch.device('cuda' if torch.cuda.is_available() and args.cuda else 'cpu')
env = gym.make(args.gym_id)
# respect the default timelimit
assert isinstance(env.action_space, MultiDiscrete), "only MultiDiscrete action space is supported"
assert isinstance(env, TimeLimit) or int(args.episode_length), "the gym env does not have a built in TimeLimit, please specify by using --episode-length"
if isinstance(env, TimeLimit):
    if int(args.episode_length):
        env._max_episode_steps = int(args.episode_length)
    args.episode_length = env._max_episode_steps
else:
    env = TimeLimit(env, int(args.episode_length))
env = NormalizedEnv(env.env, ob=args.norm_obs, ret=args.norm_returns, clipob=args.obs_clip, cliprew=args.rew_clip, gamma=args.gamma)
env = TimeLimit(env, int(args.episode_length))
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic
env.seed(args.seed)
env.action_space.seed(args.seed)
env.observation_space.seed(args.seed)
if args.capture_video:
    env = Monitor(env, f'videos/{experiment_name}')

# ALGO LOGIC: initialize agent here:
class CategoricalMasked(Categorical):
    def __init__(self, probs=None, logits=None, validate_args=None, masks=[]):
        self.masks = masks
        if len(self.masks) == 0:
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
        else:
            self.masks = masks.type(torch.BoolTensor).to(device)
            logits = torch.where(self.masks, logits, torch.tensor(-1e+8).to(device))
            super(CategoricalMasked, self).__init__(probs, logits, validate_args)
    
    def entropy(self):
        if len(self.masks) == 0:
            return super(CategoricalMasked, self).entropy()
        p_log_p = self.logits * self.probs
        p_log_p = torch.where(self.masks, p_log_p, torch.tensor(0.).to(device))
        return -p_log_p.sum(-1)

class Policy(nn.Module):
    def __init__(self):
        super(Policy, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, env.action_space.nvec.sum())
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

    def get_action(self, x, action=None, invalid_action_masks=None):
        logits = self.forward(x)
        split_logits = torch.split(logits, env.action_space.nvec.tolist(), dim=1)
        
        if invalid_action_masks is not None:
            split_invalid_action_masks = torch.split(invalid_action_masks, env.action_space.nvec.tolist(), dim=1)
            multi_categoricals = [CategoricalMasked(logits=logits, masks=iam) for (logits, iam) in zip(split_logits, split_invalid_action_masks)]
        else:
            multi_categoricals = [Categorical(logits=logits) for logits in split_logits]
        
        if action is None:
            action = torch.stack([categorical.sample() for categorical in multi_categoricals])
        logprob = torch.stack([categorical.log_prob(a) for a, categorical in zip(action, multi_categoricals)])
        return action, logprob, [], multi_categoricals

class Value(nn.Module):
    def __init__(self):
        super(Value, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(27, 16, kernel_size=3, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=1),
            nn.MaxPool2d(2),
            nn.ReLU())
        self.fc = nn.Sequential(
            nn.Linear(32*5*5, 128),
            nn.ReLU(),
            nn.Linear(128, 1)
        )

    def forward(self, x):
        x = torch.Tensor(np.moveaxis(x, -1, 1)).to(device)
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x

def discount_cumsum(x, dones, gamma):
    """
    computing discounted cumulative sums of vectors that resets with dones
    input:
        vector x,  vector dones,
        [x0,       [0,
         x1,        0,
         x2         1,
         x3         0, 
         x4]        0]
    output:
        [x0 + discount * x1 + discount^2 * x2,
         x1 + discount * x2,
         x2,
         x3 + discount * x4,
         x4]
    """
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1] * (1-dones[t])
    return discount_cumsum

pg = Policy().to(device)
vf = Value().to(device)

# MODIFIED: Separate optimizer and learning rates
pg_optimizer = optim.Adam(list(pg.parameters()), lr=args.policy_lr)
v_optimizer = optim.Adam(list(vf.parameters()), lr=args.value_lr)

# MODIFIED: Initializing learning rate anneal scheduler when need
if args.anneal_lr:
    anneal_fn = lambda f: max(0, 1-f / args.total_timesteps)
    pg_lr_scheduler = optim.lr_scheduler.LambdaLR(pg_optimizer, lr_lambda=anneal_fn)
    vf_lr_scheduler = optim.lr_scheduler.LambdaLR(v_optimizer, lr_lambda=anneal_fn)

loss_fn = nn.MSELoss()

# TRY NOT TO MODIFY: start the game
global_step = 0
while global_step < args.total_timesteps:
    if args.capture_video:
        env.stats_recorder.done=True
    next_obs = np.array(env.reset())

    # ALGO Logic: Storage for epoch data
    obs = np.empty((args.batch_size,) + env.observation_space.shape)

    actions = np.empty((args.batch_size,) + env.action_space.shape)
    logprobs = torch.zeros((env.action_space.nvec.shape[0], args.batch_size,)).to(device)

    rewards = np.zeros((args.batch_size,))
    raw_rewards = np.zeros((len(env.rfs),args.batch_size,))
    
    real_rewards = []
    invalid_action_stats = []

    dones = np.zeros((args.batch_size,))
    values = torch.zeros((args.batch_size,)).to(device)

    invalid_action_masks = torch.zeros((args.batch_size, env.action_space.nvec.sum()))
    # TRY NOT TO MODIFY: prepare the execution of the game.
    for step in range(args.batch_size):
        env.render()
        global_step += 1
        obs[step] = next_obs.copy()

        # ALGO LOGIC: put action logic here
        invalid_action_mask = torch.ones(env.action_space.nvec.sum())
        invalid_action_mask[0:env.action_space.nvec[0]] = torch.tensor(env.unit_location_mask)
        invalid_action_mask[-env.action_space.nvec[-1]:] = torch.tensor(env.target_unit_location_mask)
        invalid_action_masks[step] = invalid_action_mask
        with torch.no_grad():
            values[step] = vf.forward(obs[step:step+1])
            action, logproba, _, probs = pg.get_action(obs[step:step+1], invalid_action_masks=invalid_action_masks[step:step+1])
            
            # CORE LOGIC:
            # use the action generated by CategoricalMasked, but 
            # don't adjust the logprobability accordingly. Instead, calculate the log
            # probability using Categorical
            action, logproba, _, probs = pg.get_action(obs[step:step+1], action=action)
        
        actions[step] = action[:,0].data.cpu().numpy()
        logprobs[:,[step]] = logproba

        # TRY NOT TO MODIFY: execute the game and log data.
        next_obs, rewards[step], dones[step], info = env.step(action[:,0].data.cpu().numpy())
        raw_rewards[:,step] = info["rewards"]
        real_rewards += [info['real_reward']]
        invalid_action_stats += [info['invalid_action_stats']]
        next_obs = np.array(next_obs)

        # Annealing the rate if instructed to do so.
        if args.anneal_lr:
            pg_lr_scheduler.step()
            vf_lr_scheduler.step()

        if dones[step]:
            # Computing the discounted returns:
            writer.add_scalar("charts/episode_r
Download .txt
gitextract_uyz0i984/

├── .gitignore
├── .python-version
├── LICENSE
├── README.MD
├── build.sh
├── gym_vec_api/
│   ├── ppo_multidiscrete.py
│   └── ppo_multidiscrete_mask.py
├── invalid_action_masking/
│   ├── ppo_10x10.py
│   ├── ppo_16x16.py
│   ├── ppo_24x24.py
│   ├── ppo_4x4.py
│   ├── ppo_no_adj_10x10.py
│   ├── ppo_no_adj_16x16.py
│   ├── ppo_no_adj_24x24.py
│   ├── ppo_no_adj_4x4.py
│   ├── ppo_no_mask_10x10.py
│   ├── ppo_no_mask_16x16.py
│   ├── ppo_no_mask_24x24.py
│   └── ppo_no_mask_4x4.py
├── plots/
│   ├── analysis.py
│   ├── approx_kl.py
│   ├── charts_episode_reward/
│   │   ├── all_df_cache.pkl
│   │   ├── data/
│   │   │   ├── MicrortsMining10x10F9-v0.pkl
│   │   │   ├── MicrortsMining16x16F9-v0.pkl
│   │   │   ├── MicrortsMining24x24F9-v0.pkl
│   │   │   └── MicrortsMining4x4F9-v0.pkl
│   │   ├── envs_cache.pkl
│   │   └── exp_names_cache.pkl
│   ├── episode_reward.py
│   └── losses_approx_kl/
│       ├── all_df_cache.pkl
│       ├── data/
│       │   ├── MicrortsMining10x10F9-v0.pkl
│       │   ├── MicrortsMining16x16F9-v0.pkl
│       │   ├── MicrortsMining24x24F9-v0.pkl
│       │   └── MicrortsMining4x4F9-v0.pkl
│       ├── envs_cache.pkl
│       └── exp_names_cache.pkl
├── ppo.py
├── pyproject.toml
├── requirements.txt
└── test.py
Download .txt
SYMBOL INDEX (311 symbols across 17 files)

FILE: gym_vec_api/ppo_multidiscrete.py
  function parse_args (line 17) | def parse_args():
  function make_env (line 81) | def make_env(gym_id, seed, idx, capture_video, run_name):
  function layer_init (line 96) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  class Transpose (line 102) | class Transpose(nn.Module):
    method __init__ (line 103) | def __init__(self, permutation):
    method forward (line 107) | def forward(self, x):
  class Agent (line 111) | class Agent(nn.Module):
    method __init__ (line 112) | def __init__(self, envs):
    method get_value (line 128) | def get_value(self, x):
    method get_action_and_value (line 131) | def get_action_and_value(self, x, action=None):

FILE: gym_vec_api/ppo_multidiscrete_mask.py
  function parse_args (line 17) | def parse_args():
  function make_env (line 81) | def make_env(gym_id, seed, idx, capture_video, run_name):
  function layer_init (line 96) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  class Transpose (line 102) | class Transpose(nn.Module):
    method __init__ (line 103) | def __init__(self, permutation):
    method forward (line 107) | def forward(self, x):
  class CategoricalMasked (line 111) | class CategoricalMasked(Categorical):
    method __init__ (line 112) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 121) | def entropy(self):
  class Agent (line 129) | class Agent(nn.Module):
    method __init__ (line 130) | def __init__(self, envs):
    method get_value (line 146) | def get_value(self, x):
    method get_action_and_value (line 149) | def get_action_and_value(self, x, action_mask, action=None):

FILE: invalid_action_masking/ppo_10x10.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):
  function evaluate_with_no_mask (line 317) | def evaluate_with_no_mask():

FILE: invalid_action_masking/ppo_16x16.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):
  function evaluate_with_no_mask (line 317) | def evaluate_with_no_mask():

FILE: invalid_action_masking/ppo_24x24.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):
  function evaluate_with_no_mask (line 317) | def evaluate_with_no_mask():

FILE: invalid_action_masking/ppo_4x4.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 230) | def forward(self, x):
    method get_action (line 237) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 252) | class Value(nn.Module):
    method __init__ (line 253) | def __init__(self):
    method forward (line 265) | def forward(self, x):
  function discount_cumsum (line 272) | def discount_cumsum(x, dones, gamma):
  function evaluate_with_no_mask (line 310) | def evaluate_with_no_mask():

FILE: invalid_action_masking/ppo_no_adj_10x10.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_adj_16x16.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_adj_24x24.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 234) | def forward(self, x):
    method get_action (line 241) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 256) | class Value(nn.Module):
    method __init__ (line 257) | def __init__(self):
    method forward (line 272) | def forward(self, x):
  function discount_cumsum (line 279) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_adj_4x4.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 201) | class CategoricalMasked(Categorical):
    method __init__ (line 202) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 211) | def entropy(self):
  class Policy (line 218) | class Policy(nn.Module):
    method __init__ (line 219) | def __init__(self):
    method forward (line 231) | def forward(self, x):
    method get_action (line 238) | def get_action(self, x, action=None, invalid_action_masks=None):
  class Value (line 253) | class Value(nn.Module):
    method __init__ (line 254) | def __init__(self):
    method forward (line 266) | def forward(self, x):
  function discount_cumsum (line 273) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_mask_10x10.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 203) | class CategoricalMasked(Categorical):
    method __init__ (line 204) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 213) | def entropy(self):
  class Policy (line 220) | class Policy(nn.Module):
    method __init__ (line 221) | def __init__(self):
    method forward (line 236) | def forward(self, x):
    method get_action (line 243) | def get_action(self, x, action=None):
  class Value (line 255) | class Value(nn.Module):
    method __init__ (line 256) | def __init__(self):
    method forward (line 271) | def forward(self, x):
  function discount_cumsum (line 278) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_mask_16x16.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 203) | class CategoricalMasked(Categorical):
    method __init__ (line 204) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 213) | def entropy(self):
  class Policy (line 220) | class Policy(nn.Module):
    method __init__ (line 221) | def __init__(self):
    method forward (line 236) | def forward(self, x):
    method get_action (line 243) | def get_action(self, x, action=None):
  class Value (line 255) | class Value(nn.Module):
    method __init__ (line 256) | def __init__(self):
    method forward (line 271) | def forward(self, x):
  function discount_cumsum (line 278) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_mask_24x24.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 203) | class CategoricalMasked(Categorical):
    method __init__ (line 204) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 213) | def entropy(self):
  class Policy (line 220) | class Policy(nn.Module):
    method __init__ (line 221) | def __init__(self):
    method forward (line 236) | def forward(self, x):
    method get_action (line 243) | def get_action(self, x, action=None):
  class Value (line 255) | class Value(nn.Module):
    method __init__ (line 256) | def __init__(self):
    method forward (line 271) | def forward(self, x):
  function discount_cumsum (line 278) | def discount_cumsum(x, dones, gamma):

FILE: invalid_action_masking/ppo_no_mask_4x4.py
  class RunningMeanStd (line 21) | class RunningMeanStd(object):
    method __init__ (line 22) | def __init__(self, epsilon=1e-4, shape=()):
    method update (line 27) | def update(self, x):
    method update_from_moments (line 33) | def update_from_moments(self, batch_mean, batch_var, batch_count):
  function update_mean_var_count_from_moments (line 37) | def update_mean_var_count_from_moments(mean, var, count, batch_mean, bat...
  class NormalizedEnv (line 49) | class NormalizedEnv(gym.core.Wrapper):
    method __init__ (line 50) | def __init__(self, env, ob=True, ret=True, clipob=10., cliprew=10., ga...
    method step (line 60) | def step(self, action):
    method _obfilt (line 73) | def _obfilt(self, obs):
    method reset (line 81) | def reset(self):
  class CategoricalMasked (line 203) | class CategoricalMasked(Categorical):
    method __init__ (line 204) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 213) | def entropy(self):
  class Policy (line 220) | class Policy(nn.Module):
    method __init__ (line 221) | def __init__(self):
    method forward (line 233) | def forward(self, x):
    method get_action (line 240) | def get_action(self, x, action=None):
  class Value (line 252) | class Value(nn.Module):
    method __init__ (line 253) | def __init__(self):
    method forward (line 265) | def forward(self, x):
  function discount_cumsum (line 272) | def discount_cumsum(x, dones, gamma):

FILE: plots/approx_kl.py
  function smooth (line 140) | def smooth(scalars, weight):  # Weight between 0 and 1
  function get_df_for_env (line 158) | def get_df_for_env(gym_id):
  function export_legend (line 180) | def export_legend(ax, filename="legend.pdf"):
  function _smooth (line 237) | def _smooth(df):

FILE: plots/episode_reward.py
  function smooth (line 156) | def smooth(scalars, weight):  # Weight between 0 and 1
  function get_df_for_env (line 174) | def get_df_for_env(gym_id):
  function export_legend (line 196) | def export_legend(ax, filename="legend.pdf"):
  function _smooth (line 253) | def _smooth(df):

FILE: ppo.py
  class ImageToPyTorch (line 90) | class ImageToPyTorch(gym.ObservationWrapper):
    method __init__ (line 91) | def __init__(self, env):
    method observation (line 101) | def observation(self, observation):
  class VecPyTorch (line 104) | class VecPyTorch(VecEnvWrapper):
    method __init__ (line 105) | def __init__(self, venv, device):
    method reset (line 109) | def reset(self):
    method step_async (line 114) | def step_async(self, actions):
    method step_wait (line 118) | def step_wait(self):
  class MicroRTSStatsRecorder (line 124) | class MicroRTSStatsRecorder(gym.Wrapper):
    method reset (line 126) | def reset(self, **kwargs):
    method step (line 131) | def step(self, action):
  function make_env (line 158) | def make_env(gym_id, seed, idx):
  class CategoricalMasked (line 181) | class CategoricalMasked(Categorical):
    method __init__ (line 182) | def __init__(self, probs=None, logits=None, validate_args=None, masks=...
    method entropy (line 191) | def entropy(self):
  class Scale (line 198) | class Scale(nn.Module):
    method __init__ (line 199) | def __init__(self, scale):
    method forward (line 203) | def forward(self, x):
  function layer_init (line 206) | def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
  class Agent (line 211) | class Agent(nn.Module):
    method __init__ (line 212) | def __init__(self, frames=4):
    method forward (line 225) | def forward(self, x):
    method get_action (line 228) | def get_action(self, x, action=None, invalid_action_masks=None):
    method get_value (line 244) | def get_value(self, x):
Condensed preview — 40 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (381K chars).
[
  {
    "path": ".gitignore",
    "chars": 1790,
    "preview": "**.tfevents.**\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distri"
  },
  {
    "path": ".python-version",
    "chars": 33,
    "preview": "3.9.5/envs/invalid-action-masking"
  },
  {
    "path": "LICENSE",
    "chars": 1078,
    "preview": "MIT License\n\nCopyright (c) 2020 neurips2020submission\n\nPermission is hereby granted, free of charge, to any person obtai"
  },
  {
    "path": "README.MD",
    "chars": 1863,
    "preview": "# A Closer Look at Invalid Action Masking in Policy Gradient Algorithms\n\nThis repo contains the source code to reproduce"
  },
  {
    "path": "build.sh",
    "chars": 75,
    "preview": "docker build -t invalid_action_masking:latest -f sharedmemory.Dockerfile .\n"
  },
  {
    "path": "gym_vec_api/ppo_multidiscrete.py",
    "chars": 15277,
    "preview": "import argparse\nimport os\nimport random\nimport time\nfrom distutils.util import strtobool\n\nimport gym\nimport gym_microrts"
  },
  {
    "path": "gym_vec_api/ppo_multidiscrete_mask.py",
    "chars": 16679,
    "preview": "import argparse\nimport os\nimport random\nimport time\nfrom distutils.util import strtobool\n\nimport gym\nimport gym_microrts"
  },
  {
    "path": "invalid_action_masking/ppo_10x10.py",
    "chars": 24788,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_16x16.py",
    "chars": 24786,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_24x24.py",
    "chars": 24822,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_4x4.py",
    "chars": 24585,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_adj_10x10.py",
    "chars": 22548,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_adj_16x16.py",
    "chars": 22546,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_adj_24x24.py",
    "chars": 22586,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_adj_4x4.py",
    "chars": 22350,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_mask_10x10.py",
    "chars": 21558,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_mask_16x16.py",
    "chars": 21556,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_mask_24x24.py",
    "chars": 21596,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "invalid_action_masking/ppo_no_mask_4x4.py",
    "chars": 21360,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "plots/analysis.py",
    "chars": 6732,
    "preview": "import wandb\nimport numpy as np\nimport pandas as pd \napi = wandb.Api()\nwandb_entity = os.environ['WANDB_ENTITY']\n\n# Proj"
  },
  {
    "path": "plots/approx_kl.py",
    "chars": 12982,
    "preview": "from os import path\nimport pickle\nimport wandb\nimport pandas as pd\nimport numpy as np\nimport seaborn as sns\nimport matpl"
  },
  {
    "path": "plots/episode_reward.py",
    "chars": 13917,
    "preview": "from os import path\nimport pickle\nimport wandb\nimport pandas as pd\nimport numpy as np\nimport seaborn as sns\nimport matpl"
  },
  {
    "path": "ppo.py",
    "chars": 19408,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional as F\nfrom torch.distributions."
  },
  {
    "path": "pyproject.toml",
    "chars": 675,
    "preview": "[tool.poetry]\nname = \"invalid-action-masking\"\nversion = \"0.1.0\"\ndescription = \"\"\nauthors = [\"Costa Huang <costa.huang@ou"
  },
  {
    "path": "requirements.txt",
    "chars": 887,
    "preview": "absl-py==0.13.0\ncachetools==4.2.2\ncertifi==2021.5.30\ncharset-normalizer==2.0.1; python_version >= \"3\"\ncleanrl @ git+http"
  },
  {
    "path": "test.py",
    "chars": 3296,
    "preview": "# suppose action 1 is invalid\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torch.nn.functional"
  }
]

// ... and 14 more files (download for full content)

About this extraction

This page contains the full source code of the vwxyzjn/invalid-action-masking GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 40 files (361.1 KB), approximately 89.4k tokens, and a symbol index with 311 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!