Repository: kandouss/marlgrid
Branch: master
Commit: e88c40bad076
Files: 19
Total size: 87.4 KB
Directory structure:
gitextract_hq47jw9t/
├── .gitignore
├── LICENSE
├── README.md
├── examples/
│ ├── human_player.py
│ └── video_test.py
├── marlgrid/
│ ├── __init__.py
│ ├── agents.py
│ ├── base.py
│ ├── envs/
│ │ ├── __init__.py
│ │ ├── cluttered.py
│ │ ├── doorkey.py
│ │ ├── empty.py
│ │ ├── goalcycle.py
│ │ └── viz_test.py
│ ├── objects.py
│ ├── rendering.py
│ └── utils/
│ ├── __init__.py
│ └── video.py
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Weights and biases temp dir
wandb/
# Video files
*.mp4
# Editor files
.vscode/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# MarlGrid
Gridworld for MARL experiments, based on [MiniGrid](https://github.com/maximecb/gym-minigrid).
[](https://youtube.com/watch?v=e0xL6KB6RBA)
<video src="https://kam.al/images/extra/cluttered_multigrid_example.mp4" id="spinning-video" controls preload loop style="width:400px; max-width:100%; display:block; margin-left:auto; margin-right:auto; margin-bottom:20px;"></video>
## Training multiple independent learners
### Pre-built environment
MarlGrid comes with a few pre-built environments (see marlgrid/envs):
- `MarlGrid-3AgentCluttered11x11-v0`
- `MarlGrid-3AgentCluttered15x15-v0`
- `MarlGrid-2AgentEmpty9x9-v0`
- `MarlGrid-3AgentEmpty9x9-v0`
- `MarlGrid-4AgentEmpty9x9-v0`
(as of v0.0.2)
### Custom environment
Create an RL agent (e.g. `TestRLAgent` subclassing `marlgrid.agents.LearningAgent`) that implements:
- `action_step(self, obs)`,
- `save_step(self, *transition_values)`,
- `start_episode(self)` (optional),
- `end_episode(self)` (optional),
Then multiple such agents can be trained in a MARLGrid environment like `ClutteredMultiGrid`:
```
agents = marlgrid.agents.IndependentLearners(
TestRLAgent(),
TestRLAgent(),
TestRLAgent()
)
env = ClutteredMultiGrid(agents, grid_size=15, n_clutter=10)
for i_episode in range(N_episodes):
obs_array = env.reset()
with agents.episode():
episode_over = False
while not episode_over:
# env.render()
# Get an array with actions for each agent.
action_array = agents.action_step(obs_array)
# Step the multi-agent environment
next_obs_array, reward_array, done, _ = env.step(action_array)
# Save the transition data to replay buffers, if necessary
agents.save_step(obs_array, action_array, next_obs_array, reward_array, done)
obs_array = next_obs_array
episode_over = done
# or if "done" is per-agent:
episode_over = all(done) # or any(done)
```
================================================
FILE: examples/human_player.py
================================================
import numpy as np
import marlgrid
from marlgrid.rendering import InteractivePlayerWindow
from marlgrid.agents import GridAgentInterface
from marlgrid.envs import env_from_config
class HumanPlayer:
def __init__(self):
self.player_window = InteractivePlayerWindow(
caption="interactive marlgrid"
)
self.episode_count = 0
def action_step(self, obs):
return self.player_window.get_action(obs.astype(np.uint8))
def save_step(self, obs, act, rew, done):
print(f" step {self.step_count:<4d}: reward {rew} (episode total {self.cumulative_reward})")
self.cumulative_reward += rew
self.step_count += 1
def start_episode(self):
self.cumulative_reward = 0
self.step_count = 0
def end_episode(self):
print(
f"Finished episode {self.episode_count} after {self.step_count} steps."
f" Episode return was {self.cumulative_reward}."
)
self.episode_count += 1
env_config = {
"env_class": "ClutteredGoalCycleEnv",
"grid_size": 13,
"max_steps": 250,
"clutter_density": 0.15,
"respawn": True,
"ghost_mode": True,
"reward_decay": False,
"n_bonus_tiles": 3,
"initial_reward": True,
"penalty": -1.5
}
player_interface_config = {
"view_size": 7,
"view_offset": 1,
"view_tile_size": 11,
"observation_style": "rich",
"see_through_walls": False,
"color": "prestige"
}
# Add the player/agent config to the environment config (as expected by "env_from_config" below)
env_config['agents'] = [player_interface_config]
# Create the environment based on the combined env/player config
env = env_from_config(env_config)
# Create a human player interface per the class defined above
human = HumanPlayer()
# Start an episode!
# Each observation from the environment contains a list of observaitons for each agent.
# In this case there's only one agent so the list will be of length one.
obs_list = env.reset()
human.start_episode()
done = False
while not done:
env.render() # OPTIONAL: render the whole scene + birds eye view
player_action = human.action_step(obs_list[0]['pov'])
# The environment expects a list of actions, so put the player action into a list
agent_actions = [player_action]
next_obs_list, rew_list, done, _ = env.step(agent_actions)
human.save_step(
obs_list[0], player_action, rew_list[0], done
)
obs_list = next_obs_list
human.end_episode()
================================================
FILE: examples/video_test.py
================================================
from marlgrid.utils.video import GridRecorder
import gym_minigrid
env = gym_minigrid.envs.empty.EmptyEnv(size=10)
env.max_steps = 200
env = GridRecorder(env, render_kwargs={"tile_size": 11})
obs = env.reset()
env.recording = True
count = 0
done = False
while not done:
act = env.action_space.sample()
obs, rew, done, _ = env.step(act)
count += 1
env.export_video("test_minigrid.mp4")
================================================
FILE: marlgrid/__init__.py
================================================
================================================
FILE: marlgrid/agents.py
================================================
import gym
import numpy as np
from enum import IntEnum
import warnings
import numba
from .objects import GridAgent, BonusTile
class GridAgentInterface(GridAgent):
class actions(IntEnum):
left = 0 # Rotate left
right = 1 # Rotate right
forward = 2 # Move forward
pickup = 3 # Pick up an object
drop = 4 # Drop an object
toggle = 5 # Toggle/activate an object
done = 6 # Done completing task
def __init__(
self,
view_size=7,
view_tile_size=5,
view_offset=0,
observation_style='image',
observe_rewards=False,
observe_position=False,
observe_orientation=False,
restrict_actions=False,
see_through_walls=False,
hide_item_types=[],
prestige_beta=0.95,
prestige_scale=2,
allow_negative_prestige=False,
spawn_delay=0,
**kwargs):
super().__init__(**kwargs)
self.view_size = view_size
self.view_tile_size = view_tile_size
self.view_offset = view_offset
self.observation_style = observation_style
self.observe_rewards = observe_rewards
self.observe_position = observe_position
self.observe_orientation = observe_orientation
self.hide_item_types = hide_item_types
self.see_through_walls = see_through_walls
self.init_kwargs = kwargs
self.restrict_actions = restrict_actions
self.prestige_beta = prestige_beta
self.prestige_scale = prestige_scale
self.allow_negative_prestige = allow_negative_prestige
self.spawn_delay = spawn_delay
if self.prestige_beta > 1:
# warnings.warn("prestige_beta must be between 0 and 1. Using default 0.99")
self.prestige_beta = 0.95
image_space = gym.spaces.Box(
low=0,
high=255,
shape=(view_tile_size * view_size, view_tile_size * view_size, 3),
dtype="uint8",
)
if observation_style == 'image':
self.observation_space = image_space
elif observation_style == 'rich':
obs_space = {
'pov': image_space,
}
if self.observe_rewards:
obs_space['reward'] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32)
if self.observe_position:
obs_space['position'] = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)
if self.observe_orientation:
obs_space['orientation'] = gym.spaces.Discrete(n=4)
self.observation_space = gym.spaces.Dict(obs_space)
else:
raise ValueError(f"{self.__class__.__name__} kwarg 'observation_style' must be one of 'image', 'rich'.")
if self.restrict_actions:
self.action_space = gym.spaces.Discrete(3)
else:
self.action_space = gym.spaces.Discrete(len(self.actions))
self.metadata = {
**self.metadata,
'view_size': view_size,
'view_tile_size': view_tile_size,
}
self.reset(new_episode=True)
def render_post(self, tile):
if not self.active:
return tile
blue = np.array([0,0,255])
red = np.array([255,0,0])
if self.color == 'prestige':
# Compute a scaled prestige value between 0 and 1 that will be used to
# interpolate between the low-prestige (red) and high-prestige (blue)
# colors.
if self.allow_negative_prestige:
prestige_scaled = 1/(1 + np.exp(-self.prestige/self.prestige_scale))
else:
prestige_scaled = np.tanh(self.prestige/self.prestige_scale)
new_color = (
prestige_scaled * blue +
(1.-prestige_scaled) * red
).astype(np.int)
grey_pixels = (np.diff(tile, axis=-1)==0).all(axis=-1)
alpha = tile[...,0].astype(np.uint16)[...,None]
tile = np.right_shift(alpha * new_color, 8).astype(np.uint8)
return tile
else:
return tile
def clone(self):
ret = self.__class__(
view_size = self.view_size,
view_offset=self.view_offset,
view_tile_size = self.view_tile_size,
observation_style = self.observation_style,
observe_rewards = self.observe_rewards,
observe_position = self.observe_position,
observe_orientation = self.observe_orientation,
hide_item_types = self.hide_item_types,
restrict_actions = self.restrict_actions,
see_through_walls=self.see_through_walls,
prestige_beta = self.prestige_beta,
prestige_scale = self.prestige_scale,
allow_negative_prestige = self.allow_negative_prestige,
spawn_delay = self.spawn_delay,
**self.init_kwargs
)
return ret
def on_step(self, obj):
if isinstance(obj, BonusTile):
self.bonuses.append((obj.bonus_id, self.prestige))
self.prestige *= self.prestige_beta
def reward(self, rew):
if self.allow_negative_prestige:
self.rew += rew
else:
if rew >= 0:
self.prestige += rew
else: # rew < 0
self.prestige = 0
def activate(self):
self.active = True
def deactivate(self):
self.active = False
def reset(self, new_episode=False):
self.done = False
self.active = False
self.pos = None
self.carrying = None
self.mission = ""
if new_episode:
self.prestige = 0
self.bonus_state = None
self.bonuses = []
def render(self, img):
if self.active:
super().render(img)
@property
def dir_vec(self):
"""
Get the direction vector for the agent, pointing in the direction
of forward movement.
"""
assert self.dir >= 0 and self.dir < 4
return np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])[self.dir]
@property
def right_vec(self):
"""
Get the vector pointing to the right of the agent.
"""
dx, dy = self.dir_vec
return np.array((-dy, dx))
@property
def front_pos(self):
"""
Get the position of the cell that is right in front of the agent
"""
return np.add(self.pos, self.dir_vec)
def get_view_coords(self, i, j):
"""
Translate and rotate absolute grid coordinates (i, j) into the
agent's partially observable view (sub-grid). Note that the resulting
coordinates may be negative or outside of the agent's view size.
"""
ax, ay = self.pos
dx, dy = self.dir_vec
rx, ry = self.right_vec
ax -= 2*self.view_offset*dx
ay -= 2*self.view_offset*dy
# Compute the absolute coordinates of the top-left view corner
sz = self.view_size
hs = self.view_size // 2
tx = ax + (dx * (sz - 1)) - (rx * hs)
ty = ay + (dy * (sz - 1)) - (ry * hs)
lx = i - tx
ly = j - ty
# Project the coordinates of the object relative to the top-left
# corner onto the agent's own coordinate system
vx = rx * lx + ry * ly
vy = -(dx * lx + dy * ly)
return vx, vy
def get_view_pos(self):
return (self.view_size // 2, self.view_size - 1 - self.view_offset)
def get_view_exts(self):
"""
Get the extents of the square set of tiles visible to the agent
Note: the bottom extent indices are not included in the set
"""
dir = self.dir
# Facing right
if dir == 0: # 1
topX = self.pos[0] - self.view_offset
topY = self.pos[1] - self.view_size // 2
# Facing down
elif dir == 1: # 0
topX = self.pos[0] - self.view_size // 2
topY = self.pos[1] - self.view_offset
# Facing left
elif dir == 2: # 3
topX = self.pos[0] - self.view_size + 1 + self.view_offset
topY = self.pos[1] - self.view_size // 2
# Facing up
elif dir == 3: # 2
topX = self.pos[0] - self.view_size // 2
topY = self.pos[1] - self.view_size + 1 + self.view_offset
else:
assert False, "invalid agent direction"
botX = topX + self.view_size
botY = topY + self.view_size
return (topX, topY, botX, botY)
def relative_coords(self, x, y):
"""
Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates
"""
vx, vy = self.get_view_coords(x, y)
if vx < 0 or vy < 0 or vx >= self.view_size or vy >= self.view_size:
return None
return vx, vy
def in_view(self, x, y):
"""
check if a grid position is visible to the agent
"""
return self.relative_coords(x, y) is not None
def sees(self, x, y):
raise NotImplementedError
def process_vis(self, opacity_grid):
assert len(opacity_grid.shape) == 2
if not self.see_through_walls:
return occlude_mask(~opacity_grid, self.get_view_pos())
else:
return np.full(opacity_grid.shape, 1, dtype=np.bool)
@numba.njit
def occlude_mask(grid, agent_pos):
mask = np.zeros(grid.shape[:2]).astype(numba.boolean)
mask[agent_pos[0], agent_pos[1]] = True
width, height = grid.shape[:2]
for j in range(agent_pos[1]+1,0,-1):
for i in range(agent_pos[0], width):
if mask[i,j] and grid[i,j]:
if i < width - 1:
mask[i + 1, j] = True
if j > 0:
mask[i, j - 1] = True
if i < width - 1:
mask[i + 1, j - 1] = True
for i in range(agent_pos[0]+1,0,-1):
if mask[i,j] and grid[i,j]:
if i > 0:
mask[i - 1, j] = True
if j > 0:
mask[i, j - 1] = True
if i > 0:
mask[i - 1, j - 1] = True
for j in range(agent_pos[1], height):
for i in range(agent_pos[0], width):
if mask[i,j] and grid[i,j]:
if i < width - 1:
mask[i + 1, j] = True
if j < height-1:
mask[i, j + 1] = True
if i < width - 1:
mask[i + 1, j + 1] = True
for i in range(agent_pos[0]+1,0,-1):
if mask[i,j] and grid[i,j]:
if i > 0:
mask[i - 1, j] = True
if j < height-1:
mask[i, j + 1] = True
if i > 0:
mask[i - 1, j + 1] = True
return mask
================================================
FILE: marlgrid/base.py
================================================
# Multi-agent gridworld.
# Based on MiniGrid: https://github.com/maximecb/gym-minigrid.
import gym
import numpy as np
import gym_minigrid
from enum import IntEnum
import math
import warnings
from .objects import WorldObj, Wall, Goal, Lava, GridAgent, BonusTile, BulkObj, COLORS
from .agents import GridAgentInterface
from .rendering import SimpleImageViewer
from gym_minigrid.rendering import fill_coords, point_in_rect, downsample, highlight_img
TILE_PIXELS = 32
class ObjectRegistry:
'''
This class contains dicts that map objects to numeric keys and vise versa.
Used so that grid worlds can represent objects using numerical arrays rather
than lists of lists of generic objects.
'''
def __init__(self, objs=[], max_num_objects=1000):
self.key_to_obj_map = {}
self.obj_to_key_map = {}
self.max_num_objects = max_num_objects
for obj in objs:
self.add_object(obj)
def get_next_key(self):
for k in range(self.max_num_objects):
if k not in self.key_to_obj_map:
break
else:
raise ValueError("Object registry full.")
return k
def __len__(self):
return len(self.id_to_obj_map)
def add_object(self, obj):
new_key = self.get_next_key()
self.key_to_obj_map[new_key] = obj
self.obj_to_key_map[obj] = new_key
return new_key
def contains_object(self, obj):
return obj in self.obj_to_key_map
def contains_key(self, key):
return key in self.key_to_obj_map
def get_key(self, obj):
if obj in self.obj_to_key_map:
return self.obj_to_key_map[obj]
else:
return self.add_object(obj)
# 5/4/2020 This gets called A LOT. Replaced calls to this function with direct dict gets
# in an attempt to speed things up. Probably didn't make a big difference.
def obj_of_key(self, key):
return self.key_to_obj_map[key]
def rotate_grid(grid, rot_k):
'''
This function basically replicates np.rot90 (with the correct args for rotating images).
But it's faster.
'''
rot_k = rot_k % 4
if rot_k==3:
return np.moveaxis(grid[:,::-1], 0, 1)
elif rot_k==1:
return np.moveaxis(grid[::-1,:], 0, 1)
elif rot_k==2:
return grid[::-1,::-1]
else:
return grid
class MultiGrid:
tile_cache = {}
def __init__(self, shape, obj_reg=None, orientation=0):
self.orientation = orientation
if isinstance(shape, tuple):
self.width, self.height = shape
self.grid = np.zeros((self.width, self.height), dtype=np.uint8) # w,h
elif isinstance(shape, np.ndarray):
self.width, self.height = shape.shape
self.grid = shape
else:
raise ValueError("Must create grid from shape tuple or array.")
if self.width < 3 or self.height < 3:
raise ValueError("Grid needs width, height >= 3")
self.obj_reg = ObjectRegistry(objs=[None]) if obj_reg is None else obj_reg
@property
def opacity(self):
transparent_fun = np.vectorize(lambda k: (self.obj_reg.key_to_obj_map[k].see_behind() if hasattr(self.obj_reg.key_to_obj_map[k], 'see_behind') else True))
return ~transparent_fun(self.grid)
def __getitem__(self, *args, **kwargs):
return self.__class__(
np.ndarray.__getitem__(self.grid, *args, **kwargs),
obj_reg=self.obj_reg,
orientation=self.orientation,
)
def rotate_left(self, k=1):
return self.__class__(
rotate_grid(self.grid, rot_k=k), # np.rot90(self.grid, k=k),
obj_reg=self.obj_reg,
orientation=(self.orientation - k) % 4,
)
def slice(self, topX, topY, width, height, rot_k=0):
"""
Get a subset of the grid
"""
sub_grid = self.__class__(
(width, height),
obj_reg=self.obj_reg,
orientation=(self.orientation - rot_k) % 4,
)
x_min = max(0, topX)
x_max = min(topX + width, self.width)
y_min = max(0, topY)
y_max = min(topY + height, self.height)
x_offset = x_min - topX
y_offset = y_min - topY
sub_grid.grid[
x_offset : x_max - x_min + x_offset, y_offset : y_max - y_min + y_offset
] = self.grid[x_min:x_max, y_min:y_max]
sub_grid.grid = rotate_grid(sub_grid.grid, rot_k)
sub_grid.width, sub_grid.height = sub_grid.grid.shape
return sub_grid
def set(self, i, j, obj):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
self.grid[i, j] = self.obj_reg.get_key(obj)
def get(self, i, j):
assert i >= 0 and i < self.width
assert j >= 0 and j < self.height
return self.obj_reg.key_to_obj_map[self.grid[i, j]]
def horz_wall(self, x, y, length=None, obj_type=Wall):
if length is None:
length = self.width - x
for i in range(0, length):
self.set(x + i, y, obj_type())
def vert_wall(self, x, y, length=None, obj_type=Wall):
if length is None:
length = self.height - y
for j in range(0, length):
self.set(x, y + j, obj_type())
def wall_rect(self, x, y, w, h, obj_type=Wall):
self.horz_wall(x, y, w, obj_type=obj_type)
self.horz_wall(x, y + h - 1, w, obj_type=obj_type)
self.vert_wall(x, y, h, obj_type=obj_type)
self.vert_wall(x + w - 1, y, h, obj_type=obj_type)
def __str__(self):
render = (
lambda x: " "
if x is None or not hasattr(x, "str_render")
else x.str_render(dir=self.orientation)
)
hstars = "*" * (2 * self.width + 2)
return (
hstars
+ "\n"
+ "\n".join(
"*" + "".join(render(self.get(i, j)) for i in range(self.width)) + "*"
for j in range(self.height)
)
+ "\n"
+ hstars
)
def encode(self, vis_mask=None):
"""
Produce a compact numpy encoding of the grid
"""
if vis_mask is None:
vis_mask = np.ones((self.width, self.height), dtype=bool)
array = np.zeros((self.width, self.height, 3), dtype="uint8")
for i in range(self.width):
for j in range(self.height):
if vis_mask[i, j]:
v = self.get(i, j)
if v is None:
array[i, j, :] = 0
else:
array[i, j, :] = v.encode()
return array
@classmethod
def decode(cls, array):
raise NotImplementedError
width, height, channels = array.shape
assert channels == 3
vis_mask[i, j] = np.ones(shape=(width, height), dtype=np.bool)
grid = cls((width, height))
@classmethod
def cache_render_fun(cls, key, f, *args, **kwargs):
if key not in cls.tile_cache:
cls.tile_cache[key] = f(*args, **kwargs)
return np.copy(cls.tile_cache[key])
@classmethod
def cache_render_obj(cls, obj, tile_size, subdivs):
if obj is None:
return cls.cache_render_fun((tile_size, None), cls.empty_tile, tile_size, subdivs)
else:
img = cls.cache_render_fun(
(tile_size, obj.__class__.__name__, *obj.encode()),
cls.render_object, obj, tile_size, subdivs
)
if hasattr(obj, 'render_post'):
return obj.render_post(img)
else:
return img
@classmethod
def empty_tile(cls, tile_size, subdivs):
alpha = max(0, min(20, tile_size-10))
img = np.full((tile_size, tile_size, 3), alpha, dtype=np.uint8)
img[1:,:-1] = 0
return img
@classmethod
def render_object(cls, obj, tile_size, subdivs):
img = np.zeros((tile_size*subdivs,tile_size*subdivs, 3), dtype=np.uint8)
obj.render(img)
# if 'Agent' not in obj.type and len(obj.agents) > 0:
# obj.agents[0].render(img)
return downsample(img, subdivs).astype(np.uint8)
@classmethod
def blend_tiles(cls, img1, img2):
'''
This function renders one "tile" on top of another. Kinda janky, works surprisingly well.
Assumes img2 is a downscaled monochromatic with a black (0,0,0) background.
'''
alpha = img2.sum(2, keepdims=True)
max_alpha = alpha.max()
if max_alpha == 0:
return img1
return (
((img1 * (max_alpha-alpha))+(img2*alpha)
)/max_alpha
).astype(img1.dtype)
@classmethod
def render_tile(cls, obj, tile_size=TILE_PIXELS, subdivs=3, top_agent=None):
subdivs = 3
if obj is None:
img = cls.cache_render_obj(obj, tile_size, subdivs)
else:
if ('Agent' in obj.type) and (top_agent in obj.agents):
# If the tile is a stack of agents that includes the top agent, then just render the top agent.
img = cls.cache_render_obj(top_agent, tile_size, subdivs)
else:
# Otherwise, render (+ downsize) the item in the tile.
img = cls.cache_render_obj(obj, tile_size, subdivs)
# If the base obj isn't an agent but has agents on top, render an agent and blend it in.
if len(obj.agents)>0 and 'Agent' not in obj.type:
if top_agent in obj.agents:
img_agent = cls.cache_render_obj(top_agent, tile_size, subdivs)
else:
img_agent = cls.cache_render_obj(obj.agents[0], tile_size, subdivs)
img = cls.blend_tiles(img, img_agent)
# Render the tile border if any of the corners are black.
if (img[([0,0,-1,-1],[0,-1,0,-1])]==0).all(axis=-1).any():
img = img + cls.cache_render_fun((tile_size, None), cls.empty_tile, tile_size, subdivs)
return img
def render(self, tile_size, highlight_mask=None, visible_mask=None, top_agent=None):
width_px = self.width * tile_size
height_px = self.height * tile_size
img = np.zeros(shape=(height_px, width_px), dtype=np.uint8)[...,None]+COLORS['shadow']
for j in range(0, self.height):
for i in range(0, self.width):
if visible_mask is not None and not visible_mask[i,j]:
continue
obj = self.get(i, j)
tile_img = MultiGrid.render_tile(
obj,
tile_size=tile_size,
top_agent=top_agent
)
ymin = j * tile_size
ymax = (j + 1) * tile_size
xmin = i * tile_size
xmax = (i + 1) * tile_size
img[ymin:ymax, xmin:xmax, :] = rotate_grid(tile_img, self.orientation)
if highlight_mask is not None:
hm = np.kron(highlight_mask.T, np.full((tile_size, tile_size), 255, dtype=np.uint16)
)[...,None] # arcane magic.
img = np.right_shift(img.astype(np.uint16)*8+hm*2, 3).clip(0,255).astype(np.uint8)
return img
class MultiGridEnv(gym.Env):
def __init__(
self,
agents = [],
grid_size=None,
width=None,
height=None,
max_steps=100,
reward_decay=True,
seed=1337,
respawn=False,
ghost_mode=True,
agent_spawn_kwargs = {}
):
if grid_size is not None:
assert width == None and height == None
width, height = grid_size, grid_size
self.respawn = respawn
self.window = None
self.width = width
self.height = height
self.max_steps = max_steps
self.reward_decay = reward_decay
self.seed(seed=seed)
self.agent_spawn_kwargs = agent_spawn_kwargs
self.ghost_mode = ghost_mode
self.agents = []
for agent in agents:
self.add_agent(agent)
self.reset()
def seed(self, seed=1337):
# Seed the random number generator
self.np_random, _ = gym.utils.seeding.np_random(seed)
return [seed]
@property
def action_space(self):
return gym.spaces.Tuple(
[agent.action_space for agent in self.agents]
)
@property
def observation_space(self):
return gym.spaces.Tuple(
[agent.observation_space for agent in self.agents]
)
@property
def num_agents(self):
return len(self.agents)
def add_agent(self, agent_interface):
if isinstance(agent_interface, dict):
self.agents.append(GridAgentInterface(**agent_interface))
elif isinstance(agent_interface, GridAgentInterface):
self.agents.append(agent_interface)
else:
raise ValueError(
"To add an agent to a marlgrid environment, call add_agent with either a GridAgentInterface object "
" or a dictionary that can be used to initialize one.")
def reset(self, **kwargs):
for agent in self.agents:
agent.agents = []
agent.reset(new_episode=True)
self._gen_grid(self.width, self.height)
for agent in self.agents:
if agent.spawn_delay == 0:
self.place_obj(agent, **self.agent_spawn_kwargs)
agent.activate()
self.step_count = 0
obs = self.gen_obs()
return obs
def gen_obs_grid(self, agent):
# If the agent is inactive, return an empty grid and a visibility mask that hides everything.
if not agent.active:
# below, not sure orientation is correct but as of 6/27/2020 that doesn't matter because
# agent views are usually square and this grid won't be used for anything.
grid = MultiGrid((agent.view_size, agent.view_size), orientation=agent.dir+1)
vis_mask = np.zeros((agent.view_size, agent.view_size), dtype=np.bool)
return grid, vis_mask
topX, topY, botX, botY = agent.get_view_exts()
grid = self.grid.slice(
topX, topY, agent.view_size, agent.view_size, rot_k=agent.dir + 1
)
# Process occluders and visibility
# Note that this incurs some slight performance cost
vis_mask = agent.process_vis(grid.opacity)
# Warning about the rest of the function:
# Allows masking away objects that the agent isn't supposed to see.
# But breaks consistency between the states of the grid objects in the parial views
# and the grid objects overall.
if len(getattr(agent, 'hide_item_types', []))>0:
for i in range(grid.width):
for j in range(grid.height):
item = grid.get(i,j)
if (item is not None) and (item is not agent) and (item.type in agent.hide_item_types):
if len(item.agents) > 0:
grid.set(i,j,item.agents[0])
else:
grid.set(i,j,None)
return grid, vis_mask
def gen_agent_obs(self, agent):
"""
Generate the agent's view (partially observable, low-resolution encoding)
"""
grid, vis_mask = self.gen_obs_grid(agent)
grid_image = grid.render(tile_size=agent.view_tile_size, visible_mask=vis_mask, top_agent=agent)
if agent.observation_style=='image':
return grid_image
else:
ret = {'pov': grid_image}
if agent.observe_rewards:
ret['reward'] = getattr(agent, 'step_reward', 0)
if agent.observe_position:
agent_pos = agent.pos if agent.pos is not None else (0,0)
ret['position'] = np.array(agent_pos)/np.array([self.width, self.height], dtype=np.float)
if agent.observe_orientation:
agent_dir = agent.dir if agent.dir is not None else 0
ret['orientation'] = agent_dir
return ret
def gen_obs(self):
return [self.gen_agent_obs(agent) for agent in self.agents]
def __str__(self):
return self.grid.__str__()
def check_agent_position_integrity(self, title=''):
'''
This function checks whether each agent is present in the grid in exactly one place.
This is particularly helpful for validating the world state when ghost_mode=False and
agents can stack, since the logic for moving them around gets a bit messy.
Prints a message and drops into pdb if there's an inconsistency.
'''
agent_locs = [[] for _ in range(len(self.agents))]
for i in range(self.grid.width):
for j in range(self.grid.height):
x = self.grid.get(i,j)
for k,agent in enumerate(self.agents):
if x==agent:
agent_locs[k].append(('top', (i,j)))
if hasattr(x, 'agents') and agent in x.agents:
agent_locs[k].append(('stacked', (i,j)))
if not all([len(x)==1 for x in agent_locs]):
print(f"{title} > Failed integrity test!")
for a, al in zip(self.agents, agent_locs):
print(" > ", a.color,'-', al)
import pdb; pdb.set_trace()
def step(self, actions):
# Spawn agents if it's time.
for agent in self.agents:
if not agent.active and not agent.done and self.step_count >= agent.spawn_delay:
self.place_obj(agent, **self.agent_spawn_kwargs)
agent.activate()
assert len(actions) == len(self.agents)
step_rewards = np.zeros((len(self.agents,)), dtype=np.float)
self.step_count += 1
iter_agents = list(enumerate(zip(self.agents, actions)))
iter_order = np.arange(len(iter_agents))
self.np_random.shuffle(iter_order)
for shuffled_ix in iter_order:
agent_no, (agent, action) = iter_agents[shuffled_ix]
agent.step_reward = 0
if agent.active:
cur_pos = agent.pos[:]
cur_cell = self.grid.get(*cur_pos)
fwd_pos = agent.front_pos[:]
fwd_cell = self.grid.get(*fwd_pos)
agent_moved = False
# Rotate left
if action == agent.actions.left:
agent.dir = (agent.dir - 1) % 4
# Rotate right
elif action == agent.actions.right:
agent.dir = (agent.dir + 1) % 4
# Move forward
elif action == agent.actions.forward:
# Under the follow conditions, the agent can move forward.
can_move = fwd_cell is None or fwd_cell.can_overlap()
if self.ghost_mode is False and isinstance(fwd_cell, GridAgent):
can_move = False
if can_move:
agent_moved = True
# Add agent to new cell
if fwd_cell is None:
self.grid.set(*fwd_pos, agent)
agent.pos = fwd_pos
else:
fwd_cell.agents.append(agent)
agent.pos = fwd_pos
# Remove agent from old cell
if cur_cell == agent:
self.grid.set(*cur_pos, None)
else:
assert cur_cell.can_overlap()
cur_cell.agents.remove(agent)
# Add agent's agents to old cell
for left_behind in agent.agents:
cur_obj = self.grid.get(*cur_pos)
if cur_obj is None:
self.grid.set(*cur_pos, left_behind)
elif cur_obj.can_overlap():
cur_obj.agents.append(left_behind)
else: # How was "agent" there in teh first place?
raise ValueError("?!?!?!")
# After moving, the agent shouldn't contain any other agents.
agent.agents = []
# test_integrity(f"After moving {agent.color} fellow")
# Rewards can be got iff. fwd_cell has a "get_reward" method
if hasattr(fwd_cell, 'get_reward'):
rwd = fwd_cell.get_reward(agent)
if bool(self.reward_decay):
rwd *= (1.0-0.9*(self.step_count/self.max_steps))
step_rewards[agent_no] += rwd
agent.reward(rwd)
if isinstance(fwd_cell, (Lava, Goal)):
agent.done = True
# TODO: verify pickup/drop/toggle logic in an environment that
# supports the relevant interactions.
# Pick up an object
elif action == agent.actions.pickup:
if fwd_cell and fwd_cell.can_pickup():
if agent.carrying is None:
agent.carrying = fwd_cell
agent.carrying.cur_pos = np.array([-1, -1])
self.grid.set(*fwd_pos, None)
else:
pass
# Drop an object
elif action == agent.actions.drop:
if not fwd_cell and agent.carrying:
self.grid.set(*fwd_pos, agent.carrying)
agent.carrying.cur_pos = fwd_pos
agent.carrying = None
else:
pass
# Toggle/activate an object
elif action == agent.actions.toggle:
if fwd_cell:
wasted = bool(fwd_cell.toggle(agent, fwd_pos))
else:
pass
# Done action (not used by default)
elif action == agent.actions.done:
pass
else:
raise ValueError(f"Environment can't handle action {action}.")
agent.on_step(fwd_cell if agent_moved else None)
# If any of the agents individually are "done" (hit lava or in some cases a goal)
# but the env requires respawning, then respawn those agents.
for agent in self.agents:
if agent.done:
if self.respawn:
resting_place_obj = self.grid.get(*agent.pos)
if resting_place_obj == agent:
if agent.agents:
self.grid.set(*agent.pos, agent.agents[0])
agent.agents[0].agents += agent.agents[1:]
else:
self.grid.set(*agent.pos, None)
else:
resting_place_obj.agents.remove(agent)
resting_place_obj.agents += agent.agents[:]
agent.agents = []
agent.reset(new_episode=False)
self.place_obj(agent, **self.agent_spawn_kwargs)
agent.activate()
else: # if the agent shouldn't be respawned, then deactivate it.
agent.deactivate()
# The episode overall is done if all the agents are done, or if it exceeds the step limit.
done = (self.step_count >= self.max_steps) or all([agent.done for agent in self.agents])
obs = [self.gen_agent_obs(agent) for agent in self.agents]
return obs, step_rewards, done, {}
def put_obj(self, obj, i, j):
"""
Put an object at a specific position in the grid. Replace anything that is already there.
"""
self.grid.set(i, j, obj)
if obj is not None:
obj.set_position((i,j))
return True
def try_place_obj(self,obj, pos):
''' Try to place an object at a certain position in the grid.
If it is possible, then do so and return True.
Otherwise do nothing and return False. '''
# grid_obj: whatever object is already at pos.
grid_obj = self.grid.get(*pos)
# If the target position is empty, then the object can always be placed.
if grid_obj is None:
self.grid.set(*pos, obj)
obj.set_position(pos)
return True
# Otherwise only agents can be placed, and only if the target position can_overlap.
if not (grid_obj.can_overlap() and obj.is_agent):
return False
# If ghost mode is off and there's already an agent at the target cell, the agent can't
# be placed there.
if (not self.ghost_mode) and (grid_obj.is_agent or (len(grid_obj.agents)>0)):
return False
grid_obj.agents.append(obj)
obj.set_position(pos)
return True
def place_obj(self, obj, top=(0,0), size=None, reject_fn=None, max_tries=1e5):
max_tries = int(max(1, min(max_tries, 1e5)))
top = (max(top[0], 0), max(top[1], 0))
if size is None:
size = (self.grid.width, self.grid.height)
bottom = (min(top[0] + size[0], self.grid.width), min(top[1] + size[1], self.grid.height))
# agent_positions = [tuple(agent.pos) if agent.pos is not None else None for agent in self.agents]
for try_no in range(max_tries):
pos = self.np_random.randint(top, bottom)
if (reject_fn is not None) and reject_fn(pos):
continue
else:
if self.try_place_obj(obj, pos):
break
else:
raise RecursionError("Rejection sampling failed in place_obj.")
return pos
def place_agents(self, top=None, size=None, rand_dir=True, max_tries=1000):
# warnings.warn("Placing agents with the function place_agents is deprecated.")
pass
def render(
self,
mode="human",
close=False,
highlight=True,
tile_size=TILE_PIXELS,
show_agent_views=True,
max_agents_per_col=3,
agent_col_width_frac = 0.3,
agent_col_padding_px = 2,
pad_grey = 100
):
"""
Render the whole-grid human view
"""
if close:
if self.window:
self.window.close()
return
if mode == "human" and not self.window:
# from gym.envs.classic_control.rendering import SimpleImageViewer
self.window = SimpleImageViewer(caption="Marlgrid")
# Compute which cells are visible to the agent
highlight_mask = np.full((self.width, self.height), False, dtype=np.bool)
for agent in self.agents:
if agent.active:
xlow, ylow, xhigh, yhigh = agent.get_view_exts()
dxlow, dylow = max(0, 0-xlow), max(0, 0-ylow)
dxhigh, dyhigh = max(0, xhigh-self.grid.width), max(0, yhigh-self.grid.height)
if agent.see_through_walls:
highlight_mask[xlow+dxlow:xhigh-dxhigh, ylow+dylow:yhigh-dyhigh] = True
else:
a,b = self.gen_obs_grid(agent)
highlight_mask[xlow+dxlow:xhigh-dxhigh, ylow+dylow:yhigh-dyhigh] |= (
rotate_grid(b, a.orientation)[dxlow:(xhigh-xlow)-dxhigh, dylow:(yhigh-ylow)-dyhigh]
)
# Render the whole grid
img = self.grid.render(
tile_size, highlight_mask=highlight_mask if highlight else None
)
rescale = lambda X, rescale_factor=2: np.kron(
X, np.ones((int(rescale_factor), int(rescale_factor), 1))
)
if show_agent_views:
target_partial_width = int(img.shape[0]*agent_col_width_frac-2*agent_col_padding_px)
target_partial_height = (img.shape[1]-2*agent_col_padding_px)//max_agents_per_col
agent_views = [self.gen_agent_obs(agent) for agent in self.agents]
agent_views = [view['pov'] if isinstance(view, dict) else view for view in agent_views]
agent_views = [rescale(view, min(target_partial_width/view.shape[0], target_partial_height/view.shape[1])) for view in agent_views]
# import pdb; pdb.set_trace()
agent_views = [agent_views[pos:pos+max_agents_per_col] for pos in range(0, len(agent_views), max_agents_per_col)]
f_offset = lambda view: np.array([target_partial_height - view.shape[1], target_partial_width - view.shape[0]])//2
cols = []
for col_views in agent_views:
col = np.full(( img.shape[0],target_partial_width+2*agent_col_padding_px,3), pad_grey, dtype=np.uint8)
for k, view in enumerate(col_views):
offset = f_offset(view) + agent_col_padding_px
offset[0] += k*target_partial_height
col[offset[0]:offset[0]+view.shape[0], offset[1]:offset[1]+view.shape[1],:] = view
cols.append(col)
img = np.concatenate((img, *cols), axis=1)
if mode == "human":
if not self.window.isopen:
self.window.imshow(img)
self.window.window.set_caption("Marlgrid")
else:
self.window.imshow(img)
return img
================================================
FILE: marlgrid/envs/__init__.py
================================================
from ..base import MultiGridEnv
from .empty import EmptyMultiGrid
from .doorkey import DoorKeyEnv
from .cluttered import ClutteredMultiGrid
from .goalcycle import ClutteredGoalCycleEnv
from .viz_test import VisibilityTestEnv
from ..agents import GridAgentInterface
from gym.envs.registration import register as gym_register
import sys
import inspect
import random
this_module = sys.modules[__name__]
registered_envs = []
def register_marl_env(
env_name,
env_class,
n_agents,
grid_size,
view_size,
view_tile_size=8,
view_offset=0,
agent_color=None,
env_kwargs={},
):
colors = ["red", "blue", "purple", "orange", "olive", "pink"]
assert n_agents <= len(colors)
class RegEnv(env_class):
def __new__(cls):
instance = super(env_class, RegEnv).__new__(env_class)
instance.__init__(
agents=[
GridAgentInterface(
color=c if agent_color is None else agent_color,
view_size=view_size,
view_tile_size=8,
view_offset=view_offset,
)
for c in colors[:n_agents]
],
grid_size=grid_size,
**env_kwargs,
)
return instance
env_class_name = f"env_{len(registered_envs)}"
setattr(this_module, env_class_name, RegEnv)
registered_envs.append(env_name)
gym_register(env_name, entry_point=f"marlgrid.envs:{env_class_name}")
def env_from_config(env_config, randomize_seed=True):
possible_envs = {k:v for k,v in globals().items() if inspect.isclass(v) and issubclass(v, MultiGridEnv)}
env_class = possible_envs[env_config['env_class']]
env_kwargs = {k:v for k,v in env_config.items() if k != 'env_class'}
if randomize_seed:
env_kwargs['seed'] = env_kwargs.get('seed', 0) + random.randint(0, 1337*1337)
return env_class(**env_kwargs)
register_marl_env(
"MarlGrid-1AgentCluttered15x15-v0",
ClutteredMultiGrid,
n_agents=1,
grid_size=11,
view_size=5,
env_kwargs={'n_clutter':30}
)
register_marl_env(
"MarlGrid-3AgentCluttered11x11-v0",
ClutteredMultiGrid,
n_agents=3,
grid_size=11,
view_size=7,
env_kwargs={'clutter_density':0.15}
)
register_marl_env(
"MarlGrid-3AgentCluttered15x15-v0",
ClutteredMultiGrid,
n_agents=3,
grid_size=15,
view_size=7,
env_kwargs={'clutter_density':0.15}
)
register_marl_env(
"MarlGrid-2AgentEmpty9x9-v0", EmptyMultiGrid, n_agents=2, grid_size=9, view_size=7
)
register_marl_env(
"MarlGrid-3AgentEmpty9x9-v0", EmptyMultiGrid, n_agents=3, grid_size=9, view_size=7
)
register_marl_env(
"MarlGrid-4AgentEmpty9x9-v0", EmptyMultiGrid, n_agents=4, grid_size=9, view_size=7
)
register_marl_env(
"Goalcycle-demo-solo-v0",
ClutteredGoalCycleEnv,
n_agents=1,
grid_size=13,
view_size=7,
view_tile_size=5,
view_offset=1,
env_kwargs={
'clutter_density':0.1,
'n_bonus_tiles': 3
}
)
================================================
FILE: marlgrid/envs/cluttered.py
================================================
from ..base import MultiGridEnv, MultiGrid
from ..objects import *
class ClutteredMultiGrid(MultiGridEnv):
mission = "get to the green square"
metadata = {}
def __init__(self, *args, n_clutter=None, clutter_density=None, randomize_goal=False, **kwargs):
if (n_clutter is None) == (clutter_density is None):
raise ValueError("Must provide n_clutter xor clutter_density in environment config.")
super().__init__(*args, **kwargs)
if clutter_density is not None:
self.n_clutter = int(clutter_density * (self.width-2)*(self.height-2))
else:
self.n_clutter = n_clutter
self.randomize_goal = randomize_goal
# self.reset()
def _gen_grid(self, width, height):
self.grid = MultiGrid((width, height))
self.grid.wall_rect(0, 0, width, height)
if getattr(self, 'randomize_goal', True):
self.place_obj(Goal(color="green", reward=1), max_tries=100)
else:
self.put_obj(Goal(color="green", reward=1), width - 2, height - 2)
for _ in range(getattr(self, 'n_clutter', 0)):
self.place_obj(Wall(), max_tries=100)
self.agent_spawn_kwargs = {}
self.place_agents(**self.agent_spawn_kwargs)
================================================
FILE: marlgrid/envs/doorkey.py
================================================
from ..base import MultiGridEnv, MultiGrid
from ..objects import *
class DoorKeyEnv(MultiGridEnv):
"""
Environment with a door and key, sparse reward.
Similar to DoorKeyEnv in
https://github.com/maximecb/gym-minigrid/blob/master/gym_minigrid/envs/doorkey.py
"""
mission = "use the key to open the door and then get to the goal"
metadata = {}
def _gen_grid(self, width, height):
# Create an empty grid
self.grid = MultiGrid((width, height))
# Generate the surrounding walls
self.grid.wall_rect(0, 0, width, height)
# Place a goal in the bottom-right corner
self.put_obj(Goal(color="green", reward=1), width - 2, height - 2)
# Create a vertical splitting wall
splitIdx = self._rand_int(2, width - 2)
self.grid.vert_wall(splitIdx, 0)
# Place the agent at a random position and orientation
# on the left side of the splitting wall
# self.place_agent(size=(splitIdx, height))
# Place a door in the wall
doorIdx = self._rand_int(1, width - 2)
self.put_obj(Door(color="yellow", state=Door.states.locked), splitIdx, doorIdx)
# Place a yellow key on the left side
self.place_obj(obj=Key("yellow"), top=(0, 0), size=(splitIdx, height))
self.agent_spawn_kwargs = {}
self.place_agents(**self.agent_spawn_kwargs)
================================================
FILE: marlgrid/envs/empty.py
================================================
from ..base import MultiGridEnv, MultiGrid
from ..objects import *
class EmptyMultiGrid(MultiGridEnv):
mission = "get to the green square"
metadata = {}
def _gen_grid(self, width, height):
self.grid = MultiGrid((width, height))
self.grid.wall_rect(0, 0, width, height)
self.put_obj(Goal(color="green", reward=1), width - 2, height - 2)
self.agent_spawn_kwargs = {}
self.place_agents(**self.agent_spawn_kwargs)
================================================
FILE: marlgrid/envs/goalcycle.py
================================================
from ..base import MultiGridEnv, MultiGrid
from ..objects import *
class ClutteredGoalCycleEnv(MultiGridEnv):
mission = "Cycle between yellow goal tiles."
metadata = {}
def __init__(self, *args, reward=1, penalty=0.0, n_clutter=None, clutter_density=None, n_bonus_tiles=3, initial_reward=True, cycle_reset=False, reset_on_mistake=False, reward_decay=False, **kwargs):
if (n_clutter is None) == (clutter_density is None):
raise ValueError("Must provide n_clutter xor clutter_density in environment config.")
# Overwrite the default reward_decay for goal cycle environments.
super().__init__(*args, **{**kwargs, 'reward_decay': reward_decay})
if clutter_density is not None:
self.n_clutter = int(clutter_density * (self.width-2)*(self.height-2))
else:
self.n_clutter = n_clutter
self.reward = reward
self.penalty = penalty
self.initial_reward = initial_reward
self.n_bonus_tiles = n_bonus_tiles
self.reset_on_mistake = reset_on_mistake
self.bonus_tiles = []
def _gen_grid(self, width, height):
self.grid = MultiGrid((width, height))
self.grid.wall_rect(0, 0, width, height)
for bonus_id in range(getattr(self, 'n_bonus_tiles', 0)):
self.place_obj(
BonusTile(
color="yellow",
reward=self.reward,
penalty=self.penalty,
bonus_id=bonus_id,
n_bonus=self.n_bonus_tiles,
initial_reward=self.initial_reward,
reset_on_mistake=self.reset_on_mistake,
),
max_tries=100
)
for _ in range(getattr(self, 'n_clutter', 0)):
self.place_obj(Wall(), max_tries=100)
self.agent_spawn_kwargs = {}
self.place_agents(**self.agent_spawn_kwargs)
================================================
FILE: marlgrid/envs/viz_test.py
================================================
from ..base import MultiGridEnv, MultiGrid
from ..objects import *
class VisibilityTestEnv(MultiGridEnv):
mission = ""
metadata = {}
def _gen_grid(self, width, height):
self.grid = MultiGrid((width, height))
self.grid.wall_rect(0, 0, width, height)
self.grid.horz_wall(0, height // 2, width - 3, obj_type=Wall)
self.agent_spawn_kwargs = {}
self.place_agents(**self.agent_spawn_kwargs)
================================================
FILE: marlgrid/objects.py
================================================
import numpy as np
from enum import IntEnum
from gym_minigrid.rendering import (
fill_coords,
point_in_rect,
point_in_triangle,
rotate_fn,
)
# Map of color names to RGB values
COLORS = {
"red": np.array([255, 0, 0]),
"orange": np.array([255, 165, 0]),
"green": np.array([0, 255, 0]),
"blue": np.array([0, 0, 255]),
"cyan": np.array([0, 139, 139]),
"purple": np.array([112, 39, 195]),
"yellow": np.array([255, 255, 0]),
"olive": np.array([128, 128, 0]),
"grey": np.array([100, 100, 100]),
"worst": np.array([74, 65, 42]), # https://en.wikipedia.org/wiki/Pantone_448_C
"pink": np.array([255, 0, 189]),
"white": np.array([255,255,255]),
"prestige": np.array([255,255,255]),
"shadow": np.array([35,25,30]), # nice dark purpley color for cells agents can't see.
}
# Used to map colors to integers
COLOR_TO_IDX = dict({v: k for k, v in enumerate(COLORS.keys())})
OBJECT_TYPES = []
class RegisteredObjectType(type):
def __new__(meta, name, bases, class_dict):
cls = type.__new__(meta, name, bases, class_dict)
if name not in OBJECT_TYPES:
OBJECT_TYPES.append(cls)
def get_recursive_subclasses():
return OBJECT_TYPES
cls.recursive_subclasses = staticmethod(get_recursive_subclasses)
return cls
class WorldObj(metaclass=RegisteredObjectType):
def __init__(self, color="worst", state=0):
self.color = color
self.state = state
self.contains = None
self.agents = [] # Some objects can have agents on top (e.g. floor, open doors, etc).
self.pos_init = None
self.pos = None
self.is_agent = False
@property
def dir(self):
return None
def set_position(self, pos):
if self.pos_init is None:
self.pos_init = pos
self.pos = pos
@property
def numeric_color(self):
return COLORS[self.color]
@property
def type(self):
return self.__class__.__name__
def can_overlap(self):
return False
def can_pickup(self):
return False
def can_contain(self):
return False
def see_behind(self):
return True
def toggle(self, env, pos):
return False
def encode(self, str_class=False):
# Note 5/29/20: Commented out the condition below; was causing agents to
# render incorrectly in partial views. In particular, if there were N red agents,
# agents {i != k} would render as blue (rather than red) in agent k's partial view.
# # if len(self.agents)>0:
# # return self.agents[0].encode(str_class=str_class)
# # else:
enc_class = self.type if bool(str_class) else self.recursive_subclasses().index(self.__class__)
enc_color = self.color if isinstance(self.color, int) else COLOR_TO_IDX[self.color]
return (enc_class, enc_color, self.state)
def describe(self):
return f"Obj: {self.type}({self.color}, {self.state})"
@classmethod
def decode(cls, type, color, state):
if isinstance(type, str):
cls_subclasses = {c.__name__: c for c in cls.recursive_subclasses()}
if type not in cls_subclasses:
raise ValueError(
f"Not sure how to construct a {cls} of (sub)type {type}"
)
return cls_subclasses[type](color, state)
elif isinstance(type, int):
subclass = cls.recursive_subclasses()[type]
return subclass(color, state)
def render(self, img):
raise NotImplementedError
def str_render(self, dir=0):
return "??"
class GridAgent(WorldObj):
def __init__(self, *args, color='red', **kwargs):
super().__init__(*args, **{'color':color, **kwargs})
self.metadata = {
'color': color,
}
self.is_agent = True
@property
def dir(self):
return self.state % 4
@property
def type(self):
return 'Agent'
@dir.setter
def dir(self, dir):
self.state = self.state // 4 + dir % 4
def str_render(self, dir=0):
return [">>", "VV", "<<", "^^"][(self.dir + dir) % 4]
def can_overlap(self):
return True
def render(self, img):
tri_fn = point_in_triangle((0.12, 0.19), (0.87, 0.50), (0.12, 0.81),)
tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * np.pi * (self.dir))
fill_coords(img, tri_fn, COLORS[self.color])
class BulkObj(WorldObj, metaclass=RegisteredObjectType):
# Todo: special behavior for hash, eq if the object has an agent.
def __hash__(self):
return hash((self.__class__, self.color, self.state, tuple(self.agents)))
def __eq__(self, other):
return hash(self) == hash(other)
class BonusTile(WorldObj):
def __init__(self, reward, penalty=-0.1, bonus_id=0, n_bonus=1, initial_reward=True, reset_on_mistake=False, color='yellow', *args, **kwargs):
super().__init__(*args, **{'color': color, **kwargs, 'state': bonus_id})
self.reward = reward
self.penalty = penalty
self.n_bonus = n_bonus
self.bonus_id = bonus_id
self.initial_reward = initial_reward
self.reset_on_mistake = reset_on_mistake
def can_overlap(self):
return True
def str_render(self, dir=0):
return "BB"
def get_reward(self, agent):
# If the agent hasn't hit any bonus tiles, set its bonus state so that
# it'll get a reward from hitting this tile.
first_bonus = False
if agent.bonus_state is None:
agent.bonus_state = (self.bonus_id - 1) % self.n_bonus
first_bonus = True
if agent.bonus_state == self.bonus_id:
# This is the last bonus tile the agent hit
rew = -np.abs(self.penalty)
elif (agent.bonus_state + 1)%self.n_bonus == self.bonus_id:
# The agent hit the previous bonus tile before this one
agent.bonus_state = self.bonus_id
# rew = agent.bonus_value
rew = self.reward
else:
# The agent hit any other bonus tile before this one
rew = -np.abs(self.penalty)
if self.reset_on_mistake:
agent.bonus_state = self.bonus_id
if first_bonus and not bool(self.initial_reward):
return 0
else:
return rew
def render(self, img):
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
class Goal(WorldObj):
def __init__(self, reward, *args, **kwargs):
super().__init__(*args, **kwargs)
self.reward = reward
def can_overlap(self):
return True
def get_reward(self, agent):
return self.reward
def str_render(self, dir=0):
return "GG"
def render(self, img):
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
class Floor(WorldObj):
def can_overlap(self):
return True# and self.agent is None
def str_render(self, dir=0):
return "FF"
def render(self, img):
# Give the floor a pale color
c = COLORS[self.color]
img.setLineColor(100, 100, 100, 0)
img.setColor(*c / 2)
# img.drawPolygon([
# (1 , TILE_PIXELS),
# (TILE_PIXELS, TILE_PIXELS),
# (TILE_PIXELS, 1),
# (1 , 1)
# ])
class EmptySpace(WorldObj):
def can_verlap(self):
return True
def str_render(self, dir=0):
return " "
class Lava(WorldObj):
def can_overlap(self):
return True# and self.agent is None
def str_render(self, dir=0):
return "VV"
def render(self, img):
c = (255, 128, 0)
# Background color
fill_coords(img, point_in_rect(0, 1, 0, 1), c)
# Little waves
for i in range(3):
ylo = 0.3 + 0.2 * i
yhi = 0.4 + 0.2 * i
fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))
fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))
fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))
fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))
class Wall(BulkObj):
def see_behind(self):
return False
def str_render(self, dir=0):
return "WW"
def render(self, img):
fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])
class Key(WorldObj):
def can_pickup(self):
return True
def str_render(self, dir=0):
return "KK"
def render(self, img):
c = COLORS[self.color]
# Vertical quad
fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)
# Teeth
fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)
fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)
# Ring
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)
fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))
class Ball(WorldObj):
def can_pickup(self):
return True
def str_render(self, dir=0):
return "AA"
def render(self, img):
fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
class Door(WorldObj):
states = IntEnum("door_state", "open closed locked")
def can_overlap(self):
return self.state == self.states.open# and self.agent is None # is open
def see_behind(self):
return self.state == self.states.open # is open
def toggle(self, agent, pos):
if self.state == self.states.locked: # is locked
# If the agent is carrying a key of matching color
if (
agent.carrying is not None
and isinstance(agent.carrying, Key)
and agent.carrying.color == self.color
):
self.state = self.states.closed
elif self.state == self.states.closed: # is unlocked but closed
self.state = self.states.open
elif self.state == self.states.open: # is open
self.state = self.states.closed
return True
def render(self, img):
c = COLORS[self.color]
if self.state == self.states.open:
fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))
return
# Door frame and door
if self.state == self.states.locked:
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))
# Draw key slot
fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)
else:
fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)
fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))
fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))
# Draw door handle
fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)
class Box(WorldObj):
def __init__(self, color=0, state=0, contains=None):
super().__init__(color, state)
self.contains = contains
def can_pickup(self):
return True
def toggle(self):
raise NotImplementedError
def str_render(self, dir=0):
return "BB"
def render(self, img):
c = COLORS[self.color]
# Outline
fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)
fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))
# Horizontal slit
fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)
================================================
FILE: marlgrid/rendering.py
================================================
import pyglet
from pyglet.gl import *
import sys
class SimpleImageViewer(object):
def __init__(self, display=None, caption=None, maxwidth=500):
self.window = None
self.isopen = False
self.display = display
self.maxwidth = maxwidth
self.caption = caption
def imshow(self, arr):
if self.window is None:
height, width, _channels = arr.shape
if width > self.maxwidth:
scale = self.maxwidth / width
width = int(scale * width)
height = int(scale * height)
self.window = pyglet.window.Window(width=width, height=height,
display=self.display, vsync=False, resizable=True, caption=self.caption)
self.width = width
self.height = height
self.isopen = True
@self.window.event
def on_resize(width, height):
self.width = width
self.height = height
@self.window.event
def on_close():
self.isopen = False
assert len(arr.shape) == 3, "You passed in an image with the wrong number shape"
image = pyglet.image.ImageData(arr.shape[1], arr.shape[0],
'RGB', arr.tobytes(), pitch=arr.shape[1]*-3)
gl.glTexParameteri(gl.GL_TEXTURE_2D,
gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)
texture = image.get_texture()
aspect_ratio = arr.shape[1]/arr.shape[0]
forced_width = min(self.width, self.height * aspect_ratio)
texture.height = int(forced_width / aspect_ratio)
texture.width = int(forced_width)
self.window.clear()
self.window.switch_to()
self.window.dispatch_events()
texture.blit(0, 0) # draw
self.window.flip()
def close(self):
if self.isopen and sys.meta_path:
# ^^^ check sys.meta_path to avoid 'ImportError: sys.meta_path is None, Python is likely shutting down'
self.window.close()
self.isopen = False
def __del__(self):
self.close()
class InteractivePlayerWindow(SimpleImageViewer):
def __init__(self, display=None, caption=None, maxwidth=500):
super().__init__(display=display, caption=caption, maxwidth=maxwidth)
self.key = None
self.action_count = 0
self.action_map = {
pyglet.window.key._0:0,
pyglet.window.key._1:1,
pyglet.window.key._2:2,
pyglet.window.key._3:3,
pyglet.window.key._4:4,
pyglet.window.key._5:5,
pyglet.window.key._6:6,
pyglet.window.key.LEFT:0,
pyglet.window.key.RIGHT:1,
pyglet.window.key.UP:2,
# pyglet.window.key.Q:-1,
}
def get_action(self, obs):
if self.window is None:
self.imshow(obs)
@self.window.event
def on_key_press(symbol, modifiers):
self.key = symbol
return self.get_action(obs)
self.imshow(obs)
self.key = None
while self.key not in self.action_map:
self.window.dispatch_events()
pyglet.clock.tick()
return self.action_map[self.key]
================================================
FILE: marlgrid/utils/__init__.py
================================================
================================================
FILE: marlgrid/utils/video.py
================================================
import gym
import numpy as np
import os
import tqdm
def export_video(X, outfile, fps=30, rescale_factor=2):
try:
import moviepy.editor as mpy
except:
raise ImportError(
"GridRecorder requires moviepy library. Try installing:\n $ pip install moviepy"
)
if isinstance(X, list):
X = np.stack(X)
if isinstance(X, np.float) and X.max() < 1:
X = (X * 255).astype(np.uint8).clip(0, 255)
if rescale_factor is not None and rescale_factor != 1:
X = np.kron(X, np.ones((1, rescale_factor, rescale_factor, 1)))
def make_frame(i):
out = X[i]
return out
getframe = lambda t: make_frame(min(int(t * fps), len(X) - 1))
clip = mpy.VideoClip(getframe, duration=len(X) / fps)
outfile = os.path.abspath(os.path.expanduser(outfile))
if not os.path.isdir(os.path.dirname(outfile)):
os.makedirs(os.path.dirname(outfile))
clip.write_videofile(outfile, fps=fps)
def render_frames(X, path, ext="png"):
try:
from PIL import Image
except ImportError as e:
raise ImportError(
"Error importing from PIL in export_frames. Try installing PIL:\n $ pip install Pillow"
)
# If the path has a file extension, dump frames in a new directory with = path minus extension
if "." in os.path.basename(path):
path = os.path.splitext(path)[0]
if not os.path.isdir(path):
os.makedirs(path)
for k, frame in tqdm.tqdm(enumerate(X), total=len(X)):
Image.fromarray(frame, "RGB").save(os.path.join(path, f"frame_{k}.{ext}"))
class GridRecorder(gym.core.Wrapper):
default_max_len = 1000
default_video_kwargs = {
'fps': 20,
'rescale_factor': 1,
}
def __init__(
self,
env,
save_root,
max_steps=1000,
auto_save_images=True,
auto_save_videos=True,
auto_save_interval=None,
render_kwargs={},
video_kwargs={}
):
super().__init__(env)
self.frames = None
self.ptr = 0
self.reset_count = 0
self.last_save = -10000
self.recording = False
self.save_root = self.fix_path(save_root)
self.auto_save_videos = auto_save_videos
self.auto_save_images = auto_save_images
self.auto_save_interval = auto_save_interval
self.render_kwargs = render_kwargs
self.video_kwargs = {**self.default_video_kwargs, **video_kwargs}
self.n_parallel = getattr(env, 'num_envs', 1)
if max_steps is None:
if hasattr(env, "max_steps") and env.max_steps != 0:
self.max_steps = env.max_steps + 1
else:
self.max_steps = self.default_max_steps + 1
else:
self.max_steps = max_steps + 1
@staticmethod
def fix_path(path):
return os.path.abspath(os.path.expanduser(path))
@property
def should_record(self):
if self.recording:
return True
if self.auto_save_interval is None:
return False
return (self.reset_count - self.last_save) >= self.auto_save_interval
def export_frames(self, episode_id=None, save_root=None):
if save_root is None:
save_root = self.save_root
if episode_id is None:
episode_id = f'frames_{self.reset_count}'
render_frames(self.frames[:self.ptr], os.path.join(self.fix_path(save_root), episode_id))
def export_video(self, episode_id=None, save_root=None):
if save_root is None:
save_root = self.save_root
if episode_id is None:
episode_id = f'video_{self.reset_count}.mp4'
export_video(self.frames[:self.ptr], os.path.join(self.fix_path(save_root), episode_id), **self.video_kwargs)
def export_both(self, episode_id, save_root=None):
self.export_frames(f'{episode_id}_frames', save_root=save_root)
self.export_video(f'{episode_id}.mp4', save_root=save_root)
def reset(self, **kwargs):
if self.should_record and self.ptr>0:
self.append_current_frame()
if self.auto_save_images:
self.export_frames()
if self.auto_save_videos:
self.export_video()
self.last_save = self.reset_count
del self.frames
self.frames = None
self.ptr = 0
self.reset_count += self.n_parallel
return self.env.reset(**kwargs)
def append_current_frame(self):
if self.should_record:
new_frame = self.env.render(mode="rgb_array", **self.render_kwargs)
if isinstance(new_frame, list) or len(new_frame.shape)>3:
new_frame = new_frame[0]
if self.frames is None:
self.frames = np.zeros(
(self.max_steps, *new_frame.shape), dtype=new_frame.dtype
)
self.frames[self.ptr] = new_frame
self.ptr += 1
def step(self, action):
self.append_current_frame()
obs, rew, done, info = self.env.step(action)
return obs, rew, done, info
# def export_video(
# self,
# output_path,
# fps=20,
# rescale_factor=1,
# render_last=True,
# render_frame_images=True,
# **kwargs,
# ):
# if self.should_record:
# if render_last:
# self.frames[self.ptr] = self.env.render(
# mode="rgb_array", **self.render_kwargs
# )
# if render_frame_images:
# render_frames(self.frames[: self.ptr + 1], output_path)
# return export_video(
# self.frames[: self.ptr + 1],
# output_path,
# fps=fps,
# rescale_factor=rescale_factor,
# **kwargs,
# )
================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages
setup(
name="marlgrid",
version="0.0.5",
packages=find_packages(),
install_requires=["numpy", "tqdm", "gym", "gym-minigrid", "numba"],
)
gitextract_hq47jw9t/ ├── .gitignore ├── LICENSE ├── README.md ├── examples/ │ ├── human_player.py │ └── video_test.py ├── marlgrid/ │ ├── __init__.py │ ├── agents.py │ ├── base.py │ ├── envs/ │ │ ├── __init__.py │ │ ├── cluttered.py │ │ ├── doorkey.py │ │ ├── empty.py │ │ ├── goalcycle.py │ │ └── viz_test.py │ ├── objects.py │ ├── rendering.py │ └── utils/ │ ├── __init__.py │ └── video.py └── setup.py
SYMBOL INDEX (187 symbols across 12 files)
FILE: examples/human_player.py
class HumanPlayer (line 8) | class HumanPlayer:
method __init__ (line 9) | def __init__(self):
method action_step (line 15) | def action_step(self, obs):
method save_step (line 18) | def save_step(self, obs, act, rew, done):
method start_episode (line 23) | def start_episode(self):
method end_episode (line 27) | def end_episode(self):
FILE: marlgrid/agents.py
class GridAgentInterface (line 9) | class GridAgentInterface(GridAgent):
class actions (line 10) | class actions(IntEnum):
method __init__ (line 19) | def __init__(
method render_post (line 92) | def render_post(self, tile):
method clone (line 121) | def clone(self):
method on_step (line 141) | def on_step(self, obj):
method reward (line 146) | def reward(self, rew):
method activate (line 155) | def activate(self):
method deactivate (line 158) | def deactivate(self):
method reset (line 161) | def reset(self, new_episode=False):
method render (line 172) | def render(self, img):
method dir_vec (line 177) | def dir_vec(self):
method right_vec (line 186) | def right_vec(self):
method front_pos (line 194) | def front_pos(self):
method get_view_coords (line 200) | def get_view_coords(self, i, j):
method get_view_pos (line 233) | def get_view_pos(self):
method get_view_exts (line 237) | def get_view_exts(self):
method relative_coords (line 268) | def relative_coords(self, x, y):
method in_view (line 280) | def in_view(self, x, y):
method sees (line 287) | def sees(self, x, y):
method process_vis (line 290) | def process_vis(self, opacity_grid):
function occlude_mask (line 299) | def occlude_mask(grid, agent_pos):
FILE: marlgrid/base.py
class ObjectRegistry (line 19) | class ObjectRegistry:
method __init__ (line 25) | def __init__(self, objs=[], max_num_objects=1000):
method get_next_key (line 32) | def get_next_key(self):
method __len__ (line 40) | def __len__(self):
method add_object (line 43) | def add_object(self, obj):
method contains_object (line 49) | def contains_object(self, obj):
method contains_key (line 52) | def contains_key(self, key):
method get_key (line 55) | def get_key(self, obj):
method obj_of_key (line 63) | def obj_of_key(self, key):
function rotate_grid (line 67) | def rotate_grid(grid, rot_k):
class MultiGrid (line 83) | class MultiGrid:
method __init__ (line 87) | def __init__(self, shape, obj_reg=None, orientation=0):
method opacity (line 104) | def opacity(self):
method __getitem__ (line 108) | def __getitem__(self, *args, **kwargs):
method rotate_left (line 115) | def rotate_left(self, k=1):
method slice (line 123) | def slice(self, topX, topY, width, height, rot_k=0):
method set (line 149) | def set(self, i, j, obj):
method get (line 154) | def get(self, i, j):
method horz_wall (line 160) | def horz_wall(self, x, y, length=None, obj_type=Wall):
method vert_wall (line 166) | def vert_wall(self, x, y, length=None, obj_type=Wall):
method wall_rect (line 172) | def wall_rect(self, x, y, w, h, obj_type=Wall):
method __str__ (line 178) | def __str__(self):
method encode (line 196) | def encode(self, vis_mask=None):
method decode (line 217) | def decode(cls, array):
method cache_render_fun (line 226) | def cache_render_fun(cls, key, f, *args, **kwargs):
method cache_render_obj (line 232) | def cache_render_obj(cls, obj, tile_size, subdivs):
method empty_tile (line 246) | def empty_tile(cls, tile_size, subdivs):
method render_object (line 253) | def render_object(cls, obj, tile_size, subdivs):
method blend_tiles (line 261) | def blend_tiles(cls, img1, img2):
method render_tile (line 276) | def render_tile(cls, obj, tile_size=TILE_PIXELS, subdivs=3, top_agent=...
method render (line 301) | def render(self, tile_size, highlight_mask=None, visible_mask=None, to...
class MultiGridEnv (line 334) | class MultiGridEnv(gym.Env):
method __init__ (line 335) | def __init__(
method seed (line 371) | def seed(self, seed=1337):
method action_space (line 377) | def action_space(self):
method observation_space (line 383) | def observation_space(self):
method num_agents (line 389) | def num_agents(self):
method add_agent (line 392) | def add_agent(self, agent_interface):
method reset (line 402) | def reset(self, **kwargs):
method gen_obs_grid (line 418) | def gen_obs_grid(self, agent):
method gen_agent_obs (line 453) | def gen_agent_obs(self, agent):
method gen_obs (line 473) | def gen_obs(self):
method __str__ (line 476) | def __str__(self):
method check_agent_position_integrity (line 479) | def check_agent_position_integrity(self, title=''):
method step (line 501) | def step(self, actions):
method put_obj (line 655) | def put_obj(self, obj, i, j):
method try_place_obj (line 664) | def try_place_obj(self,obj, pos):
method place_obj (line 690) | def place_obj(self, obj, top=(0,0), size=None, reject_fn=None, max_tri...
method place_agents (line 710) | def place_agents(self, top=None, size=None, rand_dir=True, max_tries=1...
method render (line 714) | def render(
FILE: marlgrid/envs/__init__.py
function register_marl_env (line 20) | def register_marl_env(
function env_from_config (line 58) | def env_from_config(env_config, randomize_seed=True):
FILE: marlgrid/envs/cluttered.py
class ClutteredMultiGrid (line 5) | class ClutteredMultiGrid(MultiGridEnv):
method __init__ (line 9) | def __init__(self, *args, n_clutter=None, clutter_density=None, random...
method _gen_grid (line 25) | def _gen_grid(self, width, height):
FILE: marlgrid/envs/doorkey.py
class DoorKeyEnv (line 5) | class DoorKeyEnv(MultiGridEnv):
method _gen_grid (line 15) | def _gen_grid(self, width, height):
FILE: marlgrid/envs/empty.py
class EmptyMultiGrid (line 5) | class EmptyMultiGrid(MultiGridEnv):
method _gen_grid (line 9) | def _gen_grid(self, width, height):
FILE: marlgrid/envs/goalcycle.py
class ClutteredGoalCycleEnv (line 5) | class ClutteredGoalCycleEnv(MultiGridEnv):
method __init__ (line 9) | def __init__(self, *args, reward=1, penalty=0.0, n_clutter=None, clutt...
method _gen_grid (line 30) | def _gen_grid(self, width, height):
FILE: marlgrid/envs/viz_test.py
class VisibilityTestEnv (line 5) | class VisibilityTestEnv(MultiGridEnv):
method _gen_grid (line 9) | def _gen_grid(self, width, height):
FILE: marlgrid/objects.py
class RegisteredObjectType (line 33) | class RegisteredObjectType(type):
method __new__ (line 34) | def __new__(meta, name, bases, class_dict):
class WorldObj (line 46) | class WorldObj(metaclass=RegisteredObjectType):
method __init__ (line 47) | def __init__(self, color="worst", state=0):
method dir (line 59) | def dir(self):
method set_position (line 62) | def set_position(self, pos):
method numeric_color (line 68) | def numeric_color(self):
method type (line 72) | def type(self):
method can_overlap (line 75) | def can_overlap(self):
method can_pickup (line 78) | def can_pickup(self):
method can_contain (line 81) | def can_contain(self):
method see_behind (line 84) | def see_behind(self):
method toggle (line 87) | def toggle(self, env, pos):
method encode (line 90) | def encode(self, str_class=False):
method describe (line 101) | def describe(self):
method decode (line 105) | def decode(cls, type, color, state):
method render (line 117) | def render(self, img):
method str_render (line 120) | def str_render(self, dir=0):
class GridAgent (line 124) | class GridAgent(WorldObj):
method __init__ (line 125) | def __init__(self, *args, color='red', **kwargs):
method dir (line 133) | def dir(self):
method type (line 137) | def type(self):
method dir (line 141) | def dir(self, dir):
method str_render (line 144) | def str_render(self, dir=0):
method can_overlap (line 147) | def can_overlap(self):
method render (line 150) | def render(self, img):
class BulkObj (line 156) | class BulkObj(WorldObj, metaclass=RegisteredObjectType):
method __hash__ (line 158) | def __hash__(self):
method __eq__ (line 161) | def __eq__(self, other):
class BonusTile (line 164) | class BonusTile(WorldObj):
method __init__ (line 165) | def __init__(self, reward, penalty=-0.1, bonus_id=0, n_bonus=1, initia...
method can_overlap (line 174) | def can_overlap(self):
method str_render (line 177) | def str_render(self, dir=0):
method get_reward (line 180) | def get_reward(self, agent):
method render (line 208) | def render(self, img):
class Goal (line 211) | class Goal(WorldObj):
method __init__ (line 212) | def __init__(self, reward, *args, **kwargs):
method can_overlap (line 216) | def can_overlap(self):
method get_reward (line 219) | def get_reward(self, agent):
method str_render (line 222) | def str_render(self, dir=0):
method render (line 225) | def render(self, img):
class Floor (line 229) | class Floor(WorldObj):
method can_overlap (line 230) | def can_overlap(self):
method str_render (line 233) | def str_render(self, dir=0):
method render (line 236) | def render(self, img):
class EmptySpace (line 249) | class EmptySpace(WorldObj):
method can_verlap (line 250) | def can_verlap(self):
method str_render (line 253) | def str_render(self, dir=0):
class Lava (line 257) | class Lava(WorldObj):
method can_overlap (line 258) | def can_overlap(self):
method str_render (line 261) | def str_render(self, dir=0):
method render (line 264) | def render(self, img):
class Wall (line 280) | class Wall(BulkObj):
method see_behind (line 281) | def see_behind(self):
method str_render (line 284) | def str_render(self, dir=0):
method render (line 287) | def render(self, img):
class Key (line 291) | class Key(WorldObj):
method can_pickup (line 292) | def can_pickup(self):
method str_render (line 295) | def str_render(self, dir=0):
method render (line 298) | def render(self, img):
class Ball (line 313) | class Ball(WorldObj):
method can_pickup (line 314) | def can_pickup(self):
method str_render (line 317) | def str_render(self, dir=0):
method render (line 320) | def render(self, img):
class Door (line 324) | class Door(WorldObj):
method can_overlap (line 327) | def can_overlap(self):
method see_behind (line 330) | def see_behind(self):
method toggle (line 333) | def toggle(self, agent, pos):
method render (line 348) | def render(self, img):
class Box (line 373) | class Box(WorldObj):
method __init__ (line 374) | def __init__(self, color=0, state=0, contains=None):
method can_pickup (line 378) | def can_pickup(self):
method toggle (line 381) | def toggle(self):
method str_render (line 384) | def str_render(self, dir=0):
method render (line 387) | def render(self, img):
FILE: marlgrid/rendering.py
class SimpleImageViewer (line 5) | class SimpleImageViewer(object):
method __init__ (line 6) | def __init__(self, display=None, caption=None, maxwidth=500):
method imshow (line 13) | def imshow(self, arr):
method close (line 54) | def close(self):
method __del__ (line 60) | def __del__(self):
class InteractivePlayerWindow (line 64) | class InteractivePlayerWindow(SimpleImageViewer):
method __init__ (line 65) | def __init__(self, display=None, caption=None, maxwidth=500):
method get_action (line 84) | def get_action(self, obs):
FILE: marlgrid/utils/video.py
function export_video (line 7) | def export_video(X, outfile, fps=30, rescale_factor=2):
function render_frames (line 38) | def render_frames(X, path, ext="png"):
class GridRecorder (line 55) | class GridRecorder(gym.core.Wrapper):
method __init__ (line 61) | def __init__(
method fix_path (line 96) | def fix_path(path):
method should_record (line 100) | def should_record(self):
method export_frames (line 107) | def export_frames(self, episode_id=None, save_root=None):
method export_video (line 114) | def export_video(self, episode_id=None, save_root=None):
method export_both (line 121) | def export_both(self, episode_id, save_root=None):
method reset (line 125) | def reset(self, **kwargs):
method append_current_frame (line 139) | def append_current_frame(self):
method step (line 151) | def step(self, action):
Condensed preview — 19 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (93K chars).
[
{
"path": ".gitignore",
"chars": 1878,
"preview": "# Weights and biases temp dir\nwandb/\n\n# Video files\n*.mp4\n\n# Editor files\n.vscode/\n\n# Byte-compiled / optimized / DLL fi"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 2115,
"preview": "# MarlGrid\nGridworld for MARL experiments, based on [MiniGrid](https://github.com/maximecb/gym-minigrid).\n\n[![Three agen"
},
{
"path": "examples/human_player.py",
"chars": 2518,
"preview": "import numpy as np\nimport marlgrid\n\nfrom marlgrid.rendering import InteractivePlayerWindow\nfrom marlgrid.agents import G"
},
{
"path": "examples/video_test.py",
"chars": 402,
"preview": "from marlgrid.utils.video import GridRecorder\nimport gym_minigrid\n\nenv = gym_minigrid.envs.empty.EmptyEnv(size=10)\nenv.m"
},
{
"path": "marlgrid/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "marlgrid/agents.py",
"chars": 11169,
"preview": "import gym\nimport numpy as np\nfrom enum import IntEnum\nimport warnings\nimport numba\n\nfrom .objects import GridAgent, Bon"
},
{
"path": "marlgrid/base.py",
"chars": 30049,
"preview": "# Multi-agent gridworld.\n# Based on MiniGrid: https://github.com/maximecb/gym-minigrid.\n\nimport gym\nimport numpy as np\ni"
},
{
"path": "marlgrid/envs/__init__.py",
"chars": 3107,
"preview": "from ..base import MultiGridEnv\n\nfrom .empty import EmptyMultiGrid\nfrom .doorkey import DoorKeyEnv\nfrom .cluttered impor"
},
{
"path": "marlgrid/envs/cluttered.py",
"chars": 1268,
"preview": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass ClutteredMultiGrid(MultiGridEnv):\n mission"
},
{
"path": "marlgrid/envs/doorkey.py",
"chars": 1400,
"preview": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass DoorKeyEnv(MultiGridEnv):\n \"\"\"\n Environ"
},
{
"path": "marlgrid/envs/empty.py",
"chars": 467,
"preview": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass EmptyMultiGrid(MultiGridEnv):\n mission = \""
},
{
"path": "marlgrid/envs/goalcycle.py",
"chars": 1947,
"preview": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass ClutteredGoalCycleEnv(MultiGridEnv):\n miss"
},
{
"path": "marlgrid/envs/viz_test.py",
"chars": 441,
"preview": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass VisibilityTestEnv(MultiGridEnv):\n mission "
},
{
"path": "marlgrid/objects.py",
"chars": 11937,
"preview": "import numpy as np\nfrom enum import IntEnum\nfrom gym_minigrid.rendering import (\n fill_coords,\n point_in_rect,\n "
},
{
"path": "marlgrid/rendering.py",
"chars": 3263,
"preview": "import pyglet\nfrom pyglet.gl import *\nimport sys\n\nclass SimpleImageViewer(object):\n def __init__(self, display=None, "
},
{
"path": "marlgrid/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "marlgrid/utils/video.py",
"chars": 5931,
"preview": "import gym\nimport numpy as np\nimport os\nimport tqdm\n\n\ndef export_video(X, outfile, fps=30, rescale_factor=2):\n\n try:\n"
},
{
"path": "setup.py",
"chars": 198,
"preview": "from setuptools import setup, find_packages\n\nsetup(\n name=\"marlgrid\",\n version=\"0.0.5\",\n packages=find_packages"
}
]
About this extraction
This page contains the full source code of the kandouss/marlgrid GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 19 files (87.4 KB), approximately 21.7k tokens, and a symbol index with 187 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.