[
  {
    "path": ".gitignore",
    "content": "# Weights and biases temp dir\nwandb/\n\n# Video files\n*.mp4\n\n# Editor files\n.vscode/\n\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n*.py,cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# PEP 582; used by e.g. github.com/David-OConnor/pyflow\n__pypackages__/\n\n# Celery stuff\ncelerybeat-schedule\ncelerybeat.pid\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License.\n"
  },
  {
    "path": "README.md",
    "content": "# MarlGrid\nGridworld for MARL experiments, based on [MiniGrid](https://github.com/maximecb/gym-minigrid).\n\n[![Three agents navigating a cluttered MarlGrid environment.](https://img.youtube.com/vi/e0xL6KB6RBA/0.jpg)](https://youtube.com/watch?v=e0xL6KB6RBA)\n<video src=\"https://kam.al/images/extra/cluttered_multigrid_example.mp4\" id=\"spinning-video\" controls preload loop style=\"width:400px; max-width:100%; display:block; margin-left:auto; margin-right:auto; margin-bottom:20px;\"></video>\n\n## Training multiple independent learners\n\n### Pre-built environment\n\nMarlGrid comes with a few pre-built environments (see marlgrid/envs):\n- `MarlGrid-3AgentCluttered11x11-v0`\n- `MarlGrid-3AgentCluttered15x15-v0`\n- `MarlGrid-2AgentEmpty9x9-v0`\n- `MarlGrid-3AgentEmpty9x9-v0`\n- `MarlGrid-4AgentEmpty9x9-v0`\n(as of v0.0.2)\n\n### Custom environment\n\nCreate an RL agent (e.g. `TestRLAgent` subclassing `marlgrid.agents.LearningAgent`) that implements:\n - `action_step(self, obs)`,\n - `save_step(self, *transition_values)`,\n - `start_episode(self)` (optional),\n - `end_episode(self)` (optional),\n \nThen multiple such agents can be trained in a MARLGrid environment like `ClutteredMultiGrid`:\n\n```\nagents = marlgrid.agents.IndependentLearners(\n    TestRLAgent(),\n    TestRLAgent(),\n    TestRLAgent()\n)\n\nenv = ClutteredMultiGrid(agents, grid_size=15, n_clutter=10)\n\n\nfor i_episode in range(N_episodes):\n\n    obs_array = env.reset()\n\n    with agents.episode():\n\n        episode_over = False\n\n        while not episode_over:\n            # env.render()\n\n            # Get an array with actions for each agent.\n            action_array = agents.action_step(obs_array)\n\n            # Step the multi-agent environment\n            next_obs_array, reward_array, done, _ = env.step(action_array)\n\n            # Save the transition data to replay buffers, if necessary\n            agents.save_step(obs_array, action_array, next_obs_array, reward_array, done)\n\n            obs_array = next_obs_array\n\n            episode_over = done\n            # or if \"done\" is per-agent:\n            episode_over = all(done) # or any(done)\n            \n```"
  },
  {
    "path": "examples/human_player.py",
    "content": "import numpy as np\nimport marlgrid\n\nfrom marlgrid.rendering import InteractivePlayerWindow\nfrom marlgrid.agents import GridAgentInterface\nfrom marlgrid.envs import env_from_config\n\nclass HumanPlayer:\n    def __init__(self):\n        self.player_window = InteractivePlayerWindow(\n            caption=\"interactive marlgrid\"\n        )\n        self.episode_count = 0\n\n    def action_step(self, obs):\n        return self.player_window.get_action(obs.astype(np.uint8))\n\n    def save_step(self, obs, act, rew, done):\n        print(f\"   step {self.step_count:<4d}: reward {rew} (episode total {self.cumulative_reward})\")\n        self.cumulative_reward += rew\n        self.step_count += 1\n\n    def start_episode(self):\n        self.cumulative_reward = 0\n        self.step_count = 0\n    \n    def end_episode(self):\n        print(\n            f\"Finished episode {self.episode_count} after {self.step_count} steps.\"\n            f\"  Episode return was {self.cumulative_reward}.\"\n        )\n        self.episode_count += 1\n\n\nenv_config =  {\n    \"env_class\": \"ClutteredGoalCycleEnv\",\n    \"grid_size\": 13,\n    \"max_steps\": 250,\n    \"clutter_density\": 0.15,\n    \"respawn\": True,\n    \"ghost_mode\": True,\n    \"reward_decay\": False,\n    \"n_bonus_tiles\": 3,\n    \"initial_reward\": True,\n    \"penalty\": -1.5\n}\n\nplayer_interface_config = {\n    \"view_size\": 7,\n    \"view_offset\": 1,\n    \"view_tile_size\": 11,\n    \"observation_style\": \"rich\",\n    \"see_through_walls\": False,\n    \"color\": \"prestige\"\n}\n\n# Add the player/agent config to the environment config (as expected by \"env_from_config\" below)\nenv_config['agents'] = [player_interface_config]\n\n# Create the environment based on the combined env/player config\nenv = env_from_config(env_config)\n\n# Create a human player interface per the class defined above\nhuman = HumanPlayer()\n\n# Start an episode!\n# Each observation from the environment contains a list of observaitons for each agent.\n# In this case there's only one agent so the list will be of length one.\nobs_list = env.reset()\n\nhuman.start_episode()\ndone = False\nwhile not done:\n\n    env.render() # OPTIONAL: render the whole scene + birds eye view\n    \n    player_action = human.action_step(obs_list[0]['pov'])\n    # The environment expects a list of actions, so put the player action into a list\n    agent_actions = [player_action]\n\n    next_obs_list, rew_list, done, _ = env.step(agent_actions)\n    \n    human.save_step(\n        obs_list[0], player_action, rew_list[0], done\n    )\n\n    obs_list = next_obs_list\n\nhuman.end_episode()\n"
  },
  {
    "path": "examples/video_test.py",
    "content": "from marlgrid.utils.video import GridRecorder\nimport gym_minigrid\n\nenv = gym_minigrid.envs.empty.EmptyEnv(size=10)\nenv.max_steps = 200\n\nenv = GridRecorder(env, render_kwargs={\"tile_size\": 11})\n\nobs = env.reset()\nenv.recording = True\n\ncount = 0\ndone = False\n\nwhile not done:\n    act = env.action_space.sample()\n    obs, rew, done, _ = env.step(act)\n    count += 1\n\nenv.export_video(\"test_minigrid.mp4\")\n"
  },
  {
    "path": "marlgrid/__init__.py",
    "content": ""
  },
  {
    "path": "marlgrid/agents.py",
    "content": "import gym\nimport numpy as np\nfrom enum import IntEnum\nimport warnings\nimport numba\n\nfrom .objects import GridAgent, BonusTile\n\nclass GridAgentInterface(GridAgent):\n    class actions(IntEnum):\n        left = 0  # Rotate left\n        right = 1  # Rotate right\n        forward = 2  # Move forward\n        pickup = 3  # Pick up an object\n        drop = 4  # Drop an object\n        toggle = 5  # Toggle/activate an object\n        done = 6  # Done completing task\n\n    def __init__(\n            self,\n            view_size=7,\n            view_tile_size=5,\n            view_offset=0,\n            observation_style='image',\n            observe_rewards=False,\n            observe_position=False,\n            observe_orientation=False,\n            restrict_actions=False,\n            see_through_walls=False,\n            hide_item_types=[],\n            prestige_beta=0.95,\n            prestige_scale=2,\n            allow_negative_prestige=False,\n            spawn_delay=0,\n            **kwargs):\n        super().__init__(**kwargs)\n\n        self.view_size = view_size\n        self.view_tile_size = view_tile_size\n        self.view_offset = view_offset\n        self.observation_style = observation_style\n        self.observe_rewards = observe_rewards\n        self.observe_position = observe_position\n        self.observe_orientation = observe_orientation\n        self.hide_item_types = hide_item_types\n        self.see_through_walls = see_through_walls\n        self.init_kwargs = kwargs\n        self.restrict_actions = restrict_actions\n        self.prestige_beta = prestige_beta\n        self.prestige_scale = prestige_scale\n        self.allow_negative_prestige = allow_negative_prestige\n        self.spawn_delay = spawn_delay\n\n        if self.prestige_beta > 1:\n            # warnings.warn(\"prestige_beta must be between 0 and 1. Using default 0.99\")\n            self.prestige_beta = 0.95\n            \n        image_space = gym.spaces.Box(\n            low=0,\n            high=255,\n            shape=(view_tile_size * view_size, view_tile_size * view_size, 3),\n            dtype=\"uint8\",\n        )\n        if observation_style == 'image':\n            self.observation_space = image_space\n        elif observation_style == 'rich':\n            obs_space = {\n                'pov': image_space,\n            }\n            if self.observe_rewards:\n                obs_space['reward'] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.float32)\n            if self.observe_position:\n                obs_space['position'] = gym.spaces.Box(low=0, high=1, shape=(2,), dtype=np.float32)\n            if self.observe_orientation:\n                obs_space['orientation'] = gym.spaces.Discrete(n=4)\n            self.observation_space = gym.spaces.Dict(obs_space)\n        else:\n            raise ValueError(f\"{self.__class__.__name__} kwarg 'observation_style' must be one of 'image', 'rich'.\")\n\n        if self.restrict_actions:\n            self.action_space = gym.spaces.Discrete(3)\n        else:\n            self.action_space = gym.spaces.Discrete(len(self.actions))\n\n        self.metadata = {\n            **self.metadata,\n            'view_size': view_size,\n            'view_tile_size': view_tile_size,\n        }\n        self.reset(new_episode=True)\n\n    def render_post(self, tile):\n        if not self.active:\n            return tile\n\n        blue = np.array([0,0,255])\n        red = np.array([255,0,0])\n\n        if self.color == 'prestige':\n            # Compute a scaled prestige value between 0 and 1 that will be used to \n            #   interpolate between the low-prestige (red) and high-prestige (blue)\n            #   colors.\n            if self.allow_negative_prestige:\n                prestige_scaled = 1/(1 + np.exp(-self.prestige/self.prestige_scale))\n            else:\n                prestige_scaled = np.tanh(self.prestige/self.prestige_scale)\n\n            new_color = (\n                    prestige_scaled * blue +\n                    (1.-prestige_scaled) * red\n                ).astype(np.int)\n\n            grey_pixels = (np.diff(tile, axis=-1)==0).all(axis=-1)\n\n            alpha = tile[...,0].astype(np.uint16)[...,None]\n            tile = np.right_shift(alpha * new_color, 8).astype(np.uint8)\n            return tile\n        else:\n            return tile\n\n    def clone(self):\n        ret =  self.__class__(\n            view_size = self.view_size,\n            view_offset=self.view_offset,\n            view_tile_size = self.view_tile_size,\n            observation_style = self.observation_style,\n            observe_rewards = self.observe_rewards,\n            observe_position = self.observe_position,\n            observe_orientation = self.observe_orientation,\n            hide_item_types = self.hide_item_types,\n            restrict_actions = self.restrict_actions,\n            see_through_walls=self.see_through_walls,\n            prestige_beta = self.prestige_beta,\n            prestige_scale = self.prestige_scale,\n            allow_negative_prestige = self.allow_negative_prestige,\n            spawn_delay = self.spawn_delay,\n            **self.init_kwargs\n        )\n        return ret\n\n    def on_step(self, obj):\n        if isinstance(obj, BonusTile):\n            self.bonuses.append((obj.bonus_id, self.prestige))\n        self.prestige *= self.prestige_beta\n\n    def reward(self, rew):\n        if self.allow_negative_prestige:\n            self.rew += rew\n        else:\n            if rew >= 0:\n                self.prestige += rew\n            else: # rew < 0\n                self.prestige = 0\n\n    def activate(self):\n        self.active = True\n\n    def deactivate(self):\n        self.active = False\n\n    def reset(self, new_episode=False):\n        self.done = False\n        self.active = False\n        self.pos = None\n        self.carrying = None\n        self.mission = \"\"\n        if new_episode:\n            self.prestige = 0\n            self.bonus_state = None\n            self.bonuses = []\n\n    def render(self, img):\n        if self.active:\n            super().render(img)\n\n    @property\n    def dir_vec(self):\n        \"\"\"\n        Get the direction vector for the agent, pointing in the direction\n        of forward movement.\n        \"\"\"\n        assert self.dir >= 0 and self.dir < 4\n        return np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])[self.dir]\n\n    @property\n    def right_vec(self):\n        \"\"\"\n        Get the vector pointing to the right of the agent.\n        \"\"\"\n        dx, dy = self.dir_vec\n        return np.array((-dy, dx))\n\n    @property\n    def front_pos(self):\n        \"\"\"\n        Get the position of the cell that is right in front of the agent\n        \"\"\"\n        return np.add(self.pos, self.dir_vec)\n\n    def get_view_coords(self, i, j):\n        \"\"\"\n        Translate and rotate absolute grid coordinates (i, j) into the\n        agent's partially observable view (sub-grid). Note that the resulting\n        coordinates may be negative or outside of the agent's view size.\n        \"\"\"\n\n        ax, ay = self.pos\n        dx, dy = self.dir_vec\n        rx, ry = self.right_vec\n\n        \n        ax -= 2*self.view_offset*dx\n        ay -= 2*self.view_offset*dy\n\n\n        # Compute the absolute coordinates of the top-left view corner\n        sz = self.view_size\n        hs = self.view_size // 2\n        tx = ax + (dx * (sz - 1)) - (rx * hs)\n        ty = ay + (dy * (sz - 1)) - (ry * hs)\n\n        lx = i - tx\n        ly = j - ty\n\n        # Project the coordinates of the object relative to the top-left\n        # corner onto the agent's own coordinate system\n        vx = rx * lx + ry * ly\n        vy = -(dx * lx + dy * ly)\n\n        return vx, vy\n\n        \n    def get_view_pos(self):\n        return (self.view_size // 2, self.view_size - 1 - self.view_offset)\n\n\n    def get_view_exts(self):\n        \"\"\"\n        Get the extents of the square set of tiles visible to the agent\n        Note: the bottom extent indices are not included in the set\n        \"\"\"\n\n        dir = self.dir\n        # Facing right\n        if dir == 0:  # 1\n            topX = self.pos[0] - self.view_offset\n            topY = self.pos[1] - self.view_size // 2\n        # Facing down\n        elif dir == 1:  # 0\n            topX = self.pos[0] - self.view_size // 2\n            topY = self.pos[1] - self.view_offset\n        # Facing left\n        elif dir == 2:  # 3\n            topX = self.pos[0] - self.view_size + 1 + self.view_offset\n            topY = self.pos[1] - self.view_size // 2\n        # Facing up\n        elif dir == 3:  # 2\n            topX = self.pos[0] - self.view_size // 2\n            topY = self.pos[1] - self.view_size + 1 + self.view_offset\n        else:\n            assert False, \"invalid agent direction\"\n\n        botX = topX + self.view_size\n        botY = topY + self.view_size\n\n        return (topX, topY, botX, botY)\n\n    def relative_coords(self, x, y):\n        \"\"\"\n        Check if a grid position belongs to the agent's field of view, and returns the corresponding coordinates\n        \"\"\"\n\n        vx, vy = self.get_view_coords(x, y)\n\n        if vx < 0 or vy < 0 or vx >= self.view_size or vy >= self.view_size:\n            return None\n\n        return vx, vy\n\n    def in_view(self, x, y):\n        \"\"\"\n        check if a grid position is visible to the agent\n        \"\"\"\n\n        return self.relative_coords(x, y) is not None\n\n    def sees(self, x, y):\n        raise NotImplementedError\n\n    def process_vis(self, opacity_grid):\n        assert len(opacity_grid.shape) == 2\n        if not self.see_through_walls:\n            return occlude_mask(~opacity_grid, self.get_view_pos())\n        else:\n            return np.full(opacity_grid.shape, 1, dtype=np.bool)\n    \n\n@numba.njit\ndef occlude_mask(grid, agent_pos):\n    mask = np.zeros(grid.shape[:2]).astype(numba.boolean)\n    mask[agent_pos[0], agent_pos[1]] = True\n    width, height = grid.shape[:2]\n\n    for j in range(agent_pos[1]+1,0,-1):\n        for i in range(agent_pos[0], width):\n            if mask[i,j] and grid[i,j]:\n                if i < width - 1:\n                    mask[i + 1, j] = True\n                if j > 0:\n                    mask[i, j - 1] = True\n                    if i < width - 1:\n                        mask[i + 1, j - 1] = True\n\n        for i in range(agent_pos[0]+1,0,-1):\n            if mask[i,j] and grid[i,j]:    \n                if i > 0:\n                    mask[i - 1, j] = True\n                if j > 0:\n                    mask[i, j - 1] = True\n                    if i > 0:\n                        mask[i - 1, j - 1] = True\n\n\n    for j in range(agent_pos[1], height):\n        for i in range(agent_pos[0], width):\n            if mask[i,j] and grid[i,j]:\n                if i < width - 1:\n                    mask[i + 1, j] = True\n                if j < height-1:\n                    mask[i, j + 1] = True\n                    if i < width - 1:\n                        mask[i + 1, j + 1] = True\n\n        for i in range(agent_pos[0]+1,0,-1):\n            if mask[i,j] and grid[i,j]:\n                if i > 0:\n                    mask[i - 1, j] = True\n                if j < height-1:\n                    mask[i, j + 1] = True\n                    if i > 0:\n                        mask[i - 1, j + 1] = True\n                    \n    return mask"
  },
  {
    "path": "marlgrid/base.py",
    "content": "# Multi-agent gridworld.\n# Based on MiniGrid: https://github.com/maximecb/gym-minigrid.\n\nimport gym\nimport numpy as np\nimport gym_minigrid\nfrom enum import IntEnum\nimport math\nimport warnings\n\nfrom .objects import WorldObj, Wall, Goal, Lava, GridAgent, BonusTile, BulkObj, COLORS\nfrom .agents import GridAgentInterface\nfrom .rendering import SimpleImageViewer\nfrom gym_minigrid.rendering import fill_coords, point_in_rect, downsample, highlight_img\n\nTILE_PIXELS = 32\n\n\nclass ObjectRegistry:\n    '''\n    This class contains dicts that map objects to numeric keys and vise versa.\n    Used so that grid worlds can represent objects using numerical arrays rather \n        than lists of lists of generic objects.\n    '''\n    def __init__(self, objs=[], max_num_objects=1000):\n        self.key_to_obj_map = {}\n        self.obj_to_key_map = {}\n        self.max_num_objects = max_num_objects\n        for obj in objs:\n            self.add_object(obj)\n\n    def get_next_key(self):\n        for k in range(self.max_num_objects):\n            if k not in self.key_to_obj_map:\n                break\n        else:\n            raise ValueError(\"Object registry full.\")\n        return k\n\n    def __len__(self):\n        return len(self.id_to_obj_map)\n\n    def add_object(self, obj):\n        new_key = self.get_next_key()\n        self.key_to_obj_map[new_key] = obj\n        self.obj_to_key_map[obj] = new_key\n        return new_key\n\n    def contains_object(self, obj):\n        return obj in self.obj_to_key_map\n\n    def contains_key(self, key):\n        return key in self.key_to_obj_map\n\n    def get_key(self, obj):\n        if obj in self.obj_to_key_map:\n            return self.obj_to_key_map[obj]\n        else:\n            return self.add_object(obj)\n\n    # 5/4/2020 This gets called A LOT. Replaced calls to this function with direct dict gets\n    #           in an attempt to speed things up. Probably didn't make a big difference.\n    def obj_of_key(self, key):\n        return self.key_to_obj_map[key]\n\n\ndef rotate_grid(grid, rot_k):\n    '''\n    This function basically replicates np.rot90 (with the correct args for rotating images).\n    But it's faster.\n    '''\n    rot_k = rot_k % 4\n    if rot_k==3:\n        return np.moveaxis(grid[:,::-1], 0, 1)\n    elif rot_k==1:\n        return np.moveaxis(grid[::-1,:], 0, 1)\n    elif rot_k==2:\n        return grid[::-1,::-1]\n    else:\n        return grid\n\n\nclass MultiGrid:\n\n    tile_cache = {}\n\n    def __init__(self, shape, obj_reg=None, orientation=0):\n        self.orientation = orientation\n        if isinstance(shape, tuple):\n            self.width, self.height = shape\n            self.grid = np.zeros((self.width, self.height), dtype=np.uint8)  # w,h\n        elif isinstance(shape, np.ndarray):\n            self.width, self.height = shape.shape\n            self.grid = shape\n        else:\n            raise ValueError(\"Must create grid from shape tuple or array.\")\n\n        if self.width < 3 or self.height < 3:\n            raise ValueError(\"Grid needs width, height >= 3\")\n\n        self.obj_reg = ObjectRegistry(objs=[None]) if obj_reg is None else obj_reg\n\n    @property\n    def opacity(self):\n        transparent_fun = np.vectorize(lambda k: (self.obj_reg.key_to_obj_map[k].see_behind() if hasattr(self.obj_reg.key_to_obj_map[k], 'see_behind') else True))\n        return ~transparent_fun(self.grid)\n\n    def __getitem__(self, *args, **kwargs):\n        return self.__class__(\n            np.ndarray.__getitem__(self.grid, *args, **kwargs),\n            obj_reg=self.obj_reg,\n            orientation=self.orientation,\n        )\n\n    def rotate_left(self, k=1):\n        return self.__class__(\n            rotate_grid(self.grid, rot_k=k), # np.rot90(self.grid, k=k),\n            obj_reg=self.obj_reg,\n            orientation=(self.orientation - k) % 4,\n        )\n\n\n    def slice(self, topX, topY, width, height, rot_k=0):\n        \"\"\"\n        Get a subset of the grid\n        \"\"\"\n        sub_grid = self.__class__(\n            (width, height),\n            obj_reg=self.obj_reg,\n            orientation=(self.orientation - rot_k) % 4,\n        )\n        x_min = max(0, topX)\n        x_max = min(topX + width, self.width)\n        y_min = max(0, topY)\n        y_max = min(topY + height, self.height)\n\n        x_offset = x_min - topX\n        y_offset = y_min - topY\n        sub_grid.grid[\n            x_offset : x_max - x_min + x_offset, y_offset : y_max - y_min + y_offset\n        ] = self.grid[x_min:x_max, y_min:y_max]\n\n        sub_grid.grid = rotate_grid(sub_grid.grid, rot_k)\n\n        sub_grid.width, sub_grid.height = sub_grid.grid.shape\n\n        return sub_grid\n\n    def set(self, i, j, obj):\n        assert i >= 0 and i < self.width\n        assert j >= 0 and j < self.height\n        self.grid[i, j] = self.obj_reg.get_key(obj)\n\n    def get(self, i, j):\n        assert i >= 0 and i < self.width\n        assert j >= 0 and j < self.height\n\n        return self.obj_reg.key_to_obj_map[self.grid[i, j]]\n\n    def horz_wall(self, x, y, length=None, obj_type=Wall):\n        if length is None:\n            length = self.width - x\n        for i in range(0, length):\n            self.set(x + i, y, obj_type())\n\n    def vert_wall(self, x, y, length=None, obj_type=Wall):\n        if length is None:\n            length = self.height - y\n        for j in range(0, length):\n            self.set(x, y + j, obj_type())\n\n    def wall_rect(self, x, y, w, h, obj_type=Wall):\n        self.horz_wall(x, y, w, obj_type=obj_type)\n        self.horz_wall(x, y + h - 1, w, obj_type=obj_type)\n        self.vert_wall(x, y, h, obj_type=obj_type)\n        self.vert_wall(x + w - 1, y, h, obj_type=obj_type)\n\n    def __str__(self):\n        render = (\n            lambda x: \"  \"\n            if x is None or not hasattr(x, \"str_render\")\n            else x.str_render(dir=self.orientation)\n        )\n        hstars = \"*\" * (2 * self.width + 2)\n        return (\n            hstars\n            + \"\\n\"\n            + \"\\n\".join(\n                \"*\" + \"\".join(render(self.get(i, j)) for i in range(self.width)) + \"*\"\n                for j in range(self.height)\n            )\n            + \"\\n\"\n            + hstars\n        )\n\n    def encode(self, vis_mask=None):\n        \"\"\"\n        Produce a compact numpy encoding of the grid\n        \"\"\"\n\n        if vis_mask is None:\n            vis_mask = np.ones((self.width, self.height), dtype=bool)\n\n        array = np.zeros((self.width, self.height, 3), dtype=\"uint8\")\n\n        for i in range(self.width):\n            for j in range(self.height):\n                if vis_mask[i, j]:\n                    v = self.get(i, j)\n                    if v is None:\n                        array[i, j, :] = 0\n                    else:\n                        array[i, j, :] = v.encode()\n        return array\n\n    @classmethod\n    def decode(cls, array):\n        raise NotImplementedError\n        width, height, channels = array.shape\n        assert channels == 3\n        vis_mask[i, j] = np.ones(shape=(width, height), dtype=np.bool)\n        grid = cls((width, height))\n\n    \n    @classmethod\n    def cache_render_fun(cls, key, f, *args, **kwargs):\n        if key not in cls.tile_cache:\n            cls.tile_cache[key] = f(*args, **kwargs)\n        return np.copy(cls.tile_cache[key])\n\n    @classmethod\n    def cache_render_obj(cls, obj, tile_size, subdivs):\n        if obj is None:\n            return cls.cache_render_fun((tile_size, None), cls.empty_tile, tile_size, subdivs)\n        else:\n            img = cls.cache_render_fun(\n                (tile_size, obj.__class__.__name__, *obj.encode()),\n                cls.render_object, obj, tile_size, subdivs\n            )\n            if hasattr(obj, 'render_post'):\n                return obj.render_post(img)\n            else:\n                return img\n\n    @classmethod\n    def empty_tile(cls, tile_size, subdivs):\n        alpha = max(0, min(20, tile_size-10))\n        img = np.full((tile_size, tile_size, 3), alpha, dtype=np.uint8)\n        img[1:,:-1] = 0\n        return img\n\n    @classmethod\n    def render_object(cls, obj, tile_size, subdivs):\n        img = np.zeros((tile_size*subdivs,tile_size*subdivs, 3), dtype=np.uint8)\n        obj.render(img)\n        # if 'Agent' not in obj.type and len(obj.agents) > 0:\n        #     obj.agents[0].render(img)\n        return downsample(img, subdivs).astype(np.uint8)\n\n    @classmethod\n    def blend_tiles(cls, img1, img2):\n        '''\n        This function renders one \"tile\" on top of another. Kinda janky, works surprisingly well.\n        Assumes img2 is a downscaled monochromatic with a black (0,0,0) background.\n        '''\n        alpha = img2.sum(2, keepdims=True)\n        max_alpha = alpha.max()\n        if max_alpha == 0:\n            return img1\n        return (\n            ((img1 * (max_alpha-alpha))+(img2*alpha)\n            )/max_alpha\n        ).astype(img1.dtype)\n\n    @classmethod\n    def render_tile(cls, obj, tile_size=TILE_PIXELS, subdivs=3, top_agent=None):\n        subdivs = 3\n\n        if obj is None:\n            img = cls.cache_render_obj(obj, tile_size, subdivs)\n        else:\n            if ('Agent' in obj.type) and (top_agent in obj.agents):\n                # If the tile is a stack of agents that includes the top agent, then just render the top agent.\n                img = cls.cache_render_obj(top_agent, tile_size, subdivs)\n            else: \n                # Otherwise, render (+ downsize) the item in the tile.\n                img = cls.cache_render_obj(obj, tile_size, subdivs)\n                # If the base obj isn't an agent but has agents on top, render an agent and blend it in.\n                if len(obj.agents)>0 and 'Agent' not in obj.type:\n                    if top_agent in obj.agents:\n                        img_agent = cls.cache_render_obj(top_agent, tile_size, subdivs)\n                    else:\n                        img_agent = cls.cache_render_obj(obj.agents[0], tile_size, subdivs)\n                    img = cls.blend_tiles(img, img_agent)\n\n            # Render the tile border if any of the corners are black.\n            if (img[([0,0,-1,-1],[0,-1,0,-1])]==0).all(axis=-1).any():\n                img = img + cls.cache_render_fun((tile_size, None), cls.empty_tile, tile_size, subdivs)\n        return img\n\n    def render(self, tile_size, highlight_mask=None, visible_mask=None, top_agent=None):\n        width_px = self.width * tile_size\n        height_px = self.height * tile_size\n\n        img = np.zeros(shape=(height_px, width_px), dtype=np.uint8)[...,None]+COLORS['shadow']\n\n        for j in range(0, self.height):\n            for i in range(0, self.width):\n                if visible_mask is not None and not visible_mask[i,j]:\n                    continue\n                obj = self.get(i, j)\n\n                tile_img = MultiGrid.render_tile(\n                    obj,\n                    tile_size=tile_size,\n                    top_agent=top_agent\n                )\n\n                ymin = j * tile_size\n                ymax = (j + 1) * tile_size\n                xmin = i * tile_size\n                xmax = (i + 1) * tile_size\n\n                img[ymin:ymax, xmin:xmax, :] = rotate_grid(tile_img, self.orientation)\n        \n        if highlight_mask is not None:\n            hm = np.kron(highlight_mask.T, np.full((tile_size, tile_size), 255, dtype=np.uint16)\n                )[...,None] # arcane magic.\n            img = np.right_shift(img.astype(np.uint16)*8+hm*2, 3).clip(0,255).astype(np.uint8)\n\n        return img\n\n\nclass MultiGridEnv(gym.Env):\n    def __init__(\n        self,\n        agents = [],\n        grid_size=None,\n        width=None,\n        height=None,\n        max_steps=100,\n        reward_decay=True,\n        seed=1337,\n        respawn=False,\n        ghost_mode=True,\n        agent_spawn_kwargs = {}\n    ):\n\n        if grid_size is not None:\n            assert width == None and height == None\n            width, height = grid_size, grid_size\n\n        self.respawn = respawn\n\n        self.window = None\n\n        self.width = width\n        self.height = height\n        self.max_steps = max_steps\n        self.reward_decay = reward_decay\n        self.seed(seed=seed)\n        self.agent_spawn_kwargs = agent_spawn_kwargs\n        self.ghost_mode = ghost_mode\n\n        self.agents = []\n        for agent in agents:\n            self.add_agent(agent)\n\n        self.reset()\n\n    def seed(self, seed=1337):\n        # Seed the random number generator\n        self.np_random, _ = gym.utils.seeding.np_random(seed)\n        return [seed]\n\n    @property\n    def action_space(self):\n        return gym.spaces.Tuple(\n            [agent.action_space for agent in self.agents]\n        )\n\n    @property\n    def observation_space(self):\n        return gym.spaces.Tuple(\n            [agent.observation_space for agent in self.agents]\n        )\n\n    @property\n    def num_agents(self):\n        return len(self.agents)\n    \n    def add_agent(self, agent_interface):\n        if isinstance(agent_interface, dict):\n            self.agents.append(GridAgentInterface(**agent_interface))\n        elif isinstance(agent_interface, GridAgentInterface):\n            self.agents.append(agent_interface)\n        else:\n            raise ValueError(\n                \"To add an agent to a marlgrid environment, call add_agent with either a GridAgentInterface object \"\n                \" or a dictionary that can be used to initialize one.\")\n\n    def reset(self, **kwargs):\n        for agent in self.agents:\n            agent.agents = []\n            agent.reset(new_episode=True)\n\n        self._gen_grid(self.width, self.height)\n\n        for agent in self.agents:\n            if agent.spawn_delay == 0:\n                self.place_obj(agent, **self.agent_spawn_kwargs)\n                agent.activate()\n\n        self.step_count = 0\n        obs = self.gen_obs()\n        return obs\n\n    def gen_obs_grid(self, agent):\n        # If the agent is inactive, return an empty grid and a visibility mask that hides everything.\n        if not agent.active:\n            # below, not sure orientation is correct but as of 6/27/2020 that doesn't matter because\n            # agent views are usually square and this grid won't be used for anything.\n            grid = MultiGrid((agent.view_size, agent.view_size), orientation=agent.dir+1)\n            vis_mask = np.zeros((agent.view_size, agent.view_size), dtype=np.bool)\n            return grid, vis_mask\n\n        topX, topY, botX, botY = agent.get_view_exts()\n\n        grid = self.grid.slice(\n            topX, topY, agent.view_size, agent.view_size, rot_k=agent.dir + 1\n        )\n\n        # Process occluders and visibility\n        # Note that this incurs some slight performance cost\n        vis_mask = agent.process_vis(grid.opacity)\n\n        # Warning about the rest of the function:\n        #  Allows masking away objects that the agent isn't supposed to see.\n        #  But breaks consistency between the states of the grid objects in the parial views\n        #   and the grid objects overall.\n        if len(getattr(agent, 'hide_item_types', []))>0:\n            for i in range(grid.width):\n                for j in range(grid.height):\n                    item = grid.get(i,j)\n                    if (item is not None) and (item is not agent) and (item.type in agent.hide_item_types):\n                        if len(item.agents) > 0:\n                            grid.set(i,j,item.agents[0])\n                        else:\n                            grid.set(i,j,None)\n\n        return grid, vis_mask\n\n    def gen_agent_obs(self, agent):\n        \"\"\"\n        Generate the agent's view (partially observable, low-resolution encoding)\n        \"\"\"\n        grid, vis_mask = self.gen_obs_grid(agent)\n        grid_image = grid.render(tile_size=agent.view_tile_size, visible_mask=vis_mask, top_agent=agent)\n        if agent.observation_style=='image':\n            return grid_image\n        else:\n            ret = {'pov': grid_image}\n            if agent.observe_rewards:\n                ret['reward'] = getattr(agent, 'step_reward', 0)\n            if agent.observe_position:\n                agent_pos = agent.pos if agent.pos is not None else (0,0)\n                ret['position'] = np.array(agent_pos)/np.array([self.width, self.height], dtype=np.float)\n            if agent.observe_orientation:\n                agent_dir = agent.dir if agent.dir is not None else 0\n                ret['orientation'] = agent_dir\n            return ret\n\n    def gen_obs(self):\n        return [self.gen_agent_obs(agent) for agent in self.agents]\n\n    def __str__(self):\n        return self.grid.__str__()\n\n    def check_agent_position_integrity(self, title=''):\n        '''\n        This function checks whether each agent is present in the grid in exactly one place.\n        This is particularly helpful for validating the world state when ghost_mode=False and\n        agents can stack, since the logic for moving them around gets a bit messy.\n        Prints a message and drops into pdb if there's an inconsistency.\n        '''\n        agent_locs = [[] for _ in range(len(self.agents))]\n        for i in range(self.grid.width):\n            for j in range(self.grid.height):\n                x = self.grid.get(i,j)\n                for k,agent in enumerate(self.agents):\n                    if x==agent:\n                        agent_locs[k].append(('top', (i,j)))\n                    if hasattr(x, 'agents') and agent in x.agents:\n                        agent_locs[k].append(('stacked', (i,j)))\n        if not all([len(x)==1 for x in agent_locs]):\n            print(f\"{title} > Failed integrity test!\")\n            for a, al in zip(self.agents, agent_locs):\n                print(\" > \", a.color,'-', al)\n            import pdb; pdb.set_trace()\n\n    def step(self, actions):\n        # Spawn agents if it's time.\n        for agent in self.agents:\n            if not agent.active and not agent.done and self.step_count >= agent.spawn_delay:\n                self.place_obj(agent, **self.agent_spawn_kwargs)\n                agent.activate()\n                \n        assert len(actions) == len(self.agents)\n\n        step_rewards = np.zeros((len(self.agents,)), dtype=np.float)\n\n        self.step_count += 1\n\n        iter_agents = list(enumerate(zip(self.agents, actions)))\n        iter_order = np.arange(len(iter_agents))\n        self.np_random.shuffle(iter_order)\n        for shuffled_ix in iter_order:\n            agent_no, (agent, action) = iter_agents[shuffled_ix]\n            agent.step_reward = 0\n\n            if agent.active:\n\n                cur_pos = agent.pos[:]\n                cur_cell = self.grid.get(*cur_pos)\n                fwd_pos = agent.front_pos[:]\n                fwd_cell = self.grid.get(*fwd_pos)\n                agent_moved = False\n\n                # Rotate left\n                if action == agent.actions.left:\n                    agent.dir = (agent.dir - 1) % 4\n\n                # Rotate right\n                elif action == agent.actions.right:\n                    agent.dir = (agent.dir + 1) % 4\n\n                # Move forward\n                elif action == agent.actions.forward:\n                    # Under the follow conditions, the agent can move forward.\n                    can_move = fwd_cell is None or fwd_cell.can_overlap()\n                    if self.ghost_mode is False and isinstance(fwd_cell, GridAgent):\n                        can_move = False\n\n                    if can_move:\n                        agent_moved = True\n                        # Add agent to new cell\n                        if fwd_cell is None:\n                            self.grid.set(*fwd_pos, agent)\n                            agent.pos = fwd_pos\n                        else:\n                            fwd_cell.agents.append(agent)\n                            agent.pos = fwd_pos\n\n                        # Remove agent from old cell\n                        if cur_cell == agent:\n                            self.grid.set(*cur_pos, None)\n                        else:\n                            assert cur_cell.can_overlap()\n                            cur_cell.agents.remove(agent)\n\n                        # Add agent's agents to old cell\n                        for left_behind in agent.agents:\n                            cur_obj = self.grid.get(*cur_pos)\n                            if cur_obj is None:\n                                self.grid.set(*cur_pos, left_behind)\n                            elif cur_obj.can_overlap():\n                                cur_obj.agents.append(left_behind)\n                            else: # How was \"agent\" there in teh first place?\n                                raise ValueError(\"?!?!?!\")\n\n                        # After moving, the agent shouldn't contain any other agents.\n                        agent.agents = [] \n                        # test_integrity(f\"After moving {agent.color} fellow\")\n\n                        # Rewards can be got iff. fwd_cell has a \"get_reward\" method\n                        if hasattr(fwd_cell, 'get_reward'):\n                            rwd = fwd_cell.get_reward(agent)\n                            if bool(self.reward_decay):\n                                rwd *= (1.0-0.9*(self.step_count/self.max_steps))\n                            step_rewards[agent_no] += rwd\n                            agent.reward(rwd)\n                            \n\n                        if isinstance(fwd_cell, (Lava, Goal)):\n                            agent.done = True\n\n                # TODO: verify pickup/drop/toggle logic in an environment that \n                #  supports the relevant interactions.\n                # Pick up an object\n                elif action == agent.actions.pickup:\n                    if fwd_cell and fwd_cell.can_pickup():\n                        if agent.carrying is None:\n                            agent.carrying = fwd_cell\n                            agent.carrying.cur_pos = np.array([-1, -1])\n                            self.grid.set(*fwd_pos, None)\n                    else:\n                        pass\n\n                # Drop an object\n                elif action == agent.actions.drop:\n                    if not fwd_cell and agent.carrying:\n                        self.grid.set(*fwd_pos, agent.carrying)\n                        agent.carrying.cur_pos = fwd_pos\n                        agent.carrying = None\n                    else:\n                        pass\n\n                # Toggle/activate an object\n                elif action == agent.actions.toggle:\n                    if fwd_cell:\n                        wasted = bool(fwd_cell.toggle(agent, fwd_pos))\n                    else:\n                        pass\n\n                # Done action (not used by default)\n                elif action == agent.actions.done:\n                    pass\n\n                else:\n                    raise ValueError(f\"Environment can't handle action {action}.\")\n\n                agent.on_step(fwd_cell if agent_moved else None)\n\n        \n        # If any of the agents individually are \"done\" (hit lava or in some cases a goal) \n        #   but the env requires respawning, then respawn those agents.\n        for agent in self.agents:\n            if agent.done:\n                if self.respawn:\n                    resting_place_obj = self.grid.get(*agent.pos)\n                    if resting_place_obj == agent:\n                        if agent.agents:\n                            self.grid.set(*agent.pos, agent.agents[0])\n                            agent.agents[0].agents += agent.agents[1:]\n                        else:\n                            self.grid.set(*agent.pos, None)\n                    else:\n                        resting_place_obj.agents.remove(agent)\n                        resting_place_obj.agents += agent.agents[:]\n                        agent.agents = []\n                        \n                    agent.reset(new_episode=False)\n                    self.place_obj(agent, **self.agent_spawn_kwargs)\n                    agent.activate()\n                else: # if the agent shouldn't be respawned, then deactivate it.\n                    agent.deactivate()\n\n        # The episode overall is done if all the agents are done, or if it exceeds the step limit.\n        done = (self.step_count >= self.max_steps) or all([agent.done for agent in self.agents])\n\n        obs = [self.gen_agent_obs(agent) for agent in self.agents]\n\n        return obs, step_rewards, done, {}\n\n    def put_obj(self, obj, i, j):\n        \"\"\"\n        Put an object at a specific position in the grid. Replace anything that is already there.\n        \"\"\"\n        self.grid.set(i, j, obj)\n        if obj is not None:\n            obj.set_position((i,j))\n        return True\n\n    def try_place_obj(self,obj, pos):\n        ''' Try to place an object at a certain position in the grid.\n        If it is possible, then do so and return True.\n        Otherwise do nothing and return False. '''\n        # grid_obj: whatever object is already at pos.\n        grid_obj = self.grid.get(*pos)\n\n        # If the target position is empty, then the object can always be placed.\n        if grid_obj is None:\n            self.grid.set(*pos, obj)\n            obj.set_position(pos)\n            return True\n\n        # Otherwise only agents can be placed, and only if the target position can_overlap.\n        if not (grid_obj.can_overlap() and obj.is_agent):\n            return False\n\n        # If ghost mode is off and there's already an agent at the target cell, the agent can't\n        #   be placed there.\n        if (not self.ghost_mode) and (grid_obj.is_agent or (len(grid_obj.agents)>0)):\n            return False\n\n        grid_obj.agents.append(obj)\n        obj.set_position(pos)\n        return True\n\n    def place_obj(self, obj, top=(0,0), size=None, reject_fn=None, max_tries=1e5):\n        max_tries = int(max(1, min(max_tries, 1e5)))\n        top = (max(top[0], 0), max(top[1], 0))\n        if size is None:\n            size = (self.grid.width, self.grid.height)\n        bottom = (min(top[0] + size[0], self.grid.width), min(top[1] + size[1], self.grid.height))\n\n        # agent_positions = [tuple(agent.pos) if agent.pos is not None else None for agent in self.agents]\n        for try_no in range(max_tries):\n            pos = self.np_random.randint(top, bottom)\n            if (reject_fn is not None) and reject_fn(pos):\n                continue\n            else:\n                if self.try_place_obj(obj, pos):\n                    break\n        else:\n            raise RecursionError(\"Rejection sampling failed in place_obj.\")\n\n        return pos\n\n    def place_agents(self, top=None, size=None, rand_dir=True, max_tries=1000):\n        # warnings.warn(\"Placing agents with the function place_agents is deprecated.\")\n        pass\n\n    def render(\n        self,\n        mode=\"human\",\n        close=False,\n        highlight=True,\n        tile_size=TILE_PIXELS,\n        show_agent_views=True,\n        max_agents_per_col=3,\n        agent_col_width_frac = 0.3,\n        agent_col_padding_px = 2,\n        pad_grey = 100\n    ):\n        \"\"\"\n        Render the whole-grid human view\n        \"\"\"\n\n        if close:\n            if self.window:\n                self.window.close()\n            return\n\n        if mode == \"human\" and not self.window:\n            # from gym.envs.classic_control.rendering import SimpleImageViewer\n\n            self.window = SimpleImageViewer(caption=\"Marlgrid\")\n\n        # Compute which cells are visible to the agent\n        highlight_mask = np.full((self.width, self.height), False, dtype=np.bool)\n        for agent in self.agents:\n            if agent.active:\n                xlow, ylow, xhigh, yhigh = agent.get_view_exts()\n                dxlow, dylow = max(0, 0-xlow), max(0, 0-ylow)\n                dxhigh, dyhigh = max(0, xhigh-self.grid.width), max(0, yhigh-self.grid.height)\n                if agent.see_through_walls:\n                    highlight_mask[xlow+dxlow:xhigh-dxhigh, ylow+dylow:yhigh-dyhigh] = True\n                else:\n                    a,b = self.gen_obs_grid(agent)\n                    highlight_mask[xlow+dxlow:xhigh-dxhigh, ylow+dylow:yhigh-dyhigh] |= (\n                        rotate_grid(b, a.orientation)[dxlow:(xhigh-xlow)-dxhigh, dylow:(yhigh-ylow)-dyhigh]\n                    )\n\n\n        # Render the whole grid\n        img = self.grid.render(\n            tile_size, highlight_mask=highlight_mask if highlight else None\n        )\n        rescale = lambda X, rescale_factor=2: np.kron(\n            X, np.ones((int(rescale_factor), int(rescale_factor), 1))\n        )\n\n        if show_agent_views:\n\n            target_partial_width = int(img.shape[0]*agent_col_width_frac-2*agent_col_padding_px)\n            target_partial_height = (img.shape[1]-2*agent_col_padding_px)//max_agents_per_col\n\n            agent_views = [self.gen_agent_obs(agent) for agent in self.agents]\n            agent_views = [view['pov'] if isinstance(view, dict) else view for view in agent_views]\n            agent_views = [rescale(view, min(target_partial_width/view.shape[0], target_partial_height/view.shape[1])) for view in agent_views]\n            # import pdb; pdb.set_trace()\n            agent_views = [agent_views[pos:pos+max_agents_per_col] for pos in range(0, len(agent_views), max_agents_per_col)]\n\n            f_offset = lambda view: np.array([target_partial_height - view.shape[1], target_partial_width - view.shape[0]])//2\n            \n            cols = []\n            for col_views in agent_views:\n                col = np.full(( img.shape[0],target_partial_width+2*agent_col_padding_px,3), pad_grey, dtype=np.uint8)\n                for k, view in enumerate(col_views):\n                    offset = f_offset(view) + agent_col_padding_px\n                    offset[0] += k*target_partial_height\n                    col[offset[0]:offset[0]+view.shape[0], offset[1]:offset[1]+view.shape[1],:] = view\n                cols.append(col)\n\n            img = np.concatenate((img, *cols), axis=1)\n\n        if mode == \"human\":\n            if not self.window.isopen:\n                self.window.imshow(img)\n                self.window.window.set_caption(\"Marlgrid\")\n            else:\n                self.window.imshow(img)\n\n        return img\n"
  },
  {
    "path": "marlgrid/envs/__init__.py",
    "content": "from ..base import MultiGridEnv\n\nfrom .empty import EmptyMultiGrid\nfrom .doorkey import DoorKeyEnv\nfrom .cluttered import ClutteredMultiGrid\nfrom .goalcycle import ClutteredGoalCycleEnv\nfrom .viz_test import VisibilityTestEnv\n\nfrom ..agents import GridAgentInterface\nfrom gym.envs.registration import register as gym_register\n\nimport sys\nimport inspect\nimport random\n\nthis_module = sys.modules[__name__]\nregistered_envs = []\n\n\ndef register_marl_env(\n    env_name,\n    env_class,\n    n_agents,\n    grid_size,\n    view_size,\n    view_tile_size=8,\n    view_offset=0,\n    agent_color=None,\n    env_kwargs={},\n):\n    colors = [\"red\", \"blue\", \"purple\", \"orange\", \"olive\", \"pink\"]\n    assert n_agents <= len(colors)\n\n    class RegEnv(env_class):\n        def __new__(cls):\n            instance = super(env_class, RegEnv).__new__(env_class)\n            instance.__init__(\n                agents=[\n                    GridAgentInterface(\n                        color=c if agent_color is None else agent_color,\n                        view_size=view_size,\n                        view_tile_size=8,\n                        view_offset=view_offset,\n                        )\n                    for c in colors[:n_agents]\n                ],\n                grid_size=grid_size,\n                **env_kwargs,\n            )\n            return instance\n\n    env_class_name = f\"env_{len(registered_envs)}\"\n    setattr(this_module, env_class_name, RegEnv)\n    registered_envs.append(env_name)\n    gym_register(env_name, entry_point=f\"marlgrid.envs:{env_class_name}\")\n\n\ndef env_from_config(env_config, randomize_seed=True):\n    possible_envs = {k:v for k,v in globals().items() if inspect.isclass(v) and issubclass(v, MultiGridEnv)}\n    \n    env_class = possible_envs[env_config['env_class']]\n    \n    env_kwargs = {k:v for k,v in env_config.items() if k != 'env_class'}\n    if randomize_seed:\n        env_kwargs['seed'] = env_kwargs.get('seed', 0) + random.randint(0, 1337*1337)\n    \n    return env_class(**env_kwargs)\n\n\nregister_marl_env(\n    \"MarlGrid-1AgentCluttered15x15-v0\",\n    ClutteredMultiGrid,\n    n_agents=1,\n    grid_size=11,\n    view_size=5,\n    env_kwargs={'n_clutter':30}\n)\n\nregister_marl_env(\n    \"MarlGrid-3AgentCluttered11x11-v0\",\n    ClutteredMultiGrid,\n    n_agents=3,\n    grid_size=11,\n    view_size=7,\n    env_kwargs={'clutter_density':0.15}\n)\n\nregister_marl_env(\n    \"MarlGrid-3AgentCluttered15x15-v0\",\n    ClutteredMultiGrid,\n    n_agents=3,\n    grid_size=15,\n    view_size=7,\n    env_kwargs={'clutter_density':0.15}\n)\n\nregister_marl_env(\n    \"MarlGrid-2AgentEmpty9x9-v0\", EmptyMultiGrid, n_agents=2, grid_size=9, view_size=7\n)\n\nregister_marl_env(\n    \"MarlGrid-3AgentEmpty9x9-v0\", EmptyMultiGrid, n_agents=3, grid_size=9, view_size=7\n)\n\nregister_marl_env(\n    \"MarlGrid-4AgentEmpty9x9-v0\", EmptyMultiGrid, n_agents=4, grid_size=9, view_size=7\n)\n\nregister_marl_env(\n    \"Goalcycle-demo-solo-v0\", \n    ClutteredGoalCycleEnv, \n    n_agents=1, \n    grid_size=13,\n    view_size=7,\n    view_tile_size=5,\n    view_offset=1,\n    env_kwargs={\n        'clutter_density':0.1,\n        'n_bonus_tiles': 3\n    }\n)"
  },
  {
    "path": "marlgrid/envs/cluttered.py",
    "content": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass ClutteredMultiGrid(MultiGridEnv):\n    mission = \"get to the green square\"\n    metadata = {}\n\n    def __init__(self, *args, n_clutter=None, clutter_density=None, randomize_goal=False, **kwargs):\n        if (n_clutter is None) == (clutter_density is None):\n            raise ValueError(\"Must provide n_clutter xor clutter_density in environment config.\")\n\n        super().__init__(*args, **kwargs)\n\n        if clutter_density is not None:\n            self.n_clutter = int(clutter_density * (self.width-2)*(self.height-2))\n        else:\n            self.n_clutter = n_clutter\n\n        self.randomize_goal = randomize_goal\n\n        # self.reset()\n\n\n    def _gen_grid(self, width, height):\n        self.grid = MultiGrid((width, height))\n        self.grid.wall_rect(0, 0, width, height)\n        if getattr(self, 'randomize_goal', True):\n            self.place_obj(Goal(color=\"green\", reward=1), max_tries=100)\n        else:\n            self.put_obj(Goal(color=\"green\", reward=1), width - 2, height - 2)\n        for _ in range(getattr(self, 'n_clutter', 0)):\n            self.place_obj(Wall(), max_tries=100)\n\n        self.agent_spawn_kwargs = {}\n        self.place_agents(**self.agent_spawn_kwargs)\n"
  },
  {
    "path": "marlgrid/envs/doorkey.py",
    "content": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass DoorKeyEnv(MultiGridEnv):\n    \"\"\"\n    Environment with a door and key, sparse reward.\n    Similar to DoorKeyEnv in \n        https://github.com/maximecb/gym-minigrid/blob/master/gym_minigrid/envs/doorkey.py\n    \"\"\"\n\n    mission = \"use the key to open the door and then get to the goal\"\n    metadata = {}\n\n    def _gen_grid(self, width, height):\n        # Create an empty grid\n        self.grid = MultiGrid((width, height))\n\n        # Generate the surrounding walls\n        self.grid.wall_rect(0, 0, width, height)\n\n        # Place a goal in the bottom-right corner\n        self.put_obj(Goal(color=\"green\", reward=1), width - 2, height - 2)\n\n        # Create a vertical splitting wall\n        splitIdx = self._rand_int(2, width - 2)\n        self.grid.vert_wall(splitIdx, 0)\n\n        # Place the agent at a random position and orientation\n        # on the left side of the splitting wall\n        # self.place_agent(size=(splitIdx, height))\n\n        # Place a door in the wall\n        doorIdx = self._rand_int(1, width - 2)\n        self.put_obj(Door(color=\"yellow\", state=Door.states.locked), splitIdx, doorIdx)\n\n        # Place a yellow key on the left side\n        self.place_obj(obj=Key(\"yellow\"), top=(0, 0), size=(splitIdx, height))\n\n        self.agent_spawn_kwargs = {}\n        self.place_agents(**self.agent_spawn_kwargs)\n"
  },
  {
    "path": "marlgrid/envs/empty.py",
    "content": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass EmptyMultiGrid(MultiGridEnv):\n    mission = \"get to the green square\"\n    metadata = {}\n\n    def _gen_grid(self, width, height):\n        self.grid = MultiGrid((width, height))\n        self.grid.wall_rect(0, 0, width, height)\n        self.put_obj(Goal(color=\"green\", reward=1), width - 2, height - 2)\n\n\n        self.agent_spawn_kwargs = {}\n        self.place_agents(**self.agent_spawn_kwargs)\n"
  },
  {
    "path": "marlgrid/envs/goalcycle.py",
    "content": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass ClutteredGoalCycleEnv(MultiGridEnv):\n    mission = \"Cycle between yellow goal tiles.\"\n    metadata = {}\n\n    def __init__(self, *args, reward=1, penalty=0.0, n_clutter=None, clutter_density=None, n_bonus_tiles=3, initial_reward=True, cycle_reset=False, reset_on_mistake=False, reward_decay=False, **kwargs):\n        if (n_clutter is None) == (clutter_density is None):\n            raise ValueError(\"Must provide n_clutter xor clutter_density in environment config.\")\n\n        # Overwrite the default reward_decay for goal cycle environments.\n        super().__init__(*args, **{**kwargs, 'reward_decay': reward_decay})\n\n        if clutter_density is not None:\n            self.n_clutter = int(clutter_density * (self.width-2)*(self.height-2))\n        else:\n            self.n_clutter = n_clutter\n        \n        self.reward = reward\n        self.penalty = penalty\n\n        self.initial_reward = initial_reward\n        self.n_bonus_tiles = n_bonus_tiles\n        self.reset_on_mistake = reset_on_mistake\n\n        self.bonus_tiles = []\n\n    def _gen_grid(self, width, height):\n        self.grid = MultiGrid((width, height))\n        self.grid.wall_rect(0, 0, width, height)\n\n        for bonus_id in range(getattr(self, 'n_bonus_tiles', 0)):\n            self.place_obj(\n                BonusTile(\n                    color=\"yellow\",\n                    reward=self.reward,\n                    penalty=self.penalty,\n                    bonus_id=bonus_id,\n                    n_bonus=self.n_bonus_tiles,\n                    initial_reward=self.initial_reward,\n                    reset_on_mistake=self.reset_on_mistake,\n                ),\n                max_tries=100\n            )\n        for _ in range(getattr(self, 'n_clutter', 0)):\n            self.place_obj(Wall(), max_tries=100)\n\n        self.agent_spawn_kwargs = {}\n        self.place_agents(**self.agent_spawn_kwargs)\n"
  },
  {
    "path": "marlgrid/envs/viz_test.py",
    "content": "from ..base import MultiGridEnv, MultiGrid\nfrom ..objects import *\n\n\nclass VisibilityTestEnv(MultiGridEnv):\n    mission = \"\"\n    metadata = {}\n\n    def _gen_grid(self, width, height):\n        self.grid = MultiGrid((width, height))\n        self.grid.wall_rect(0, 0, width, height)\n        self.grid.horz_wall(0, height // 2, width - 3, obj_type=Wall)\n\n        self.agent_spawn_kwargs = {}\n        self.place_agents(**self.agent_spawn_kwargs)\n"
  },
  {
    "path": "marlgrid/objects.py",
    "content": "import numpy as np\nfrom enum import IntEnum\nfrom gym_minigrid.rendering import (\n    fill_coords,\n    point_in_rect,\n    point_in_triangle,\n    rotate_fn,\n)\n\n# Map of color names to RGB values\nCOLORS = {\n    \"red\": np.array([255, 0, 0]),\n    \"orange\": np.array([255, 165, 0]),\n    \"green\": np.array([0, 255, 0]),\n    \"blue\": np.array([0, 0, 255]),\n    \"cyan\": np.array([0, 139, 139]),\n    \"purple\": np.array([112, 39, 195]),\n    \"yellow\": np.array([255, 255, 0]),\n    \"olive\": np.array([128, 128, 0]),\n    \"grey\": np.array([100, 100, 100]),\n    \"worst\": np.array([74, 65, 42]),  # https://en.wikipedia.org/wiki/Pantone_448_C\n    \"pink\": np.array([255, 0, 189]),\n    \"white\": np.array([255,255,255]),\n    \"prestige\": np.array([255,255,255]),\n    \"shadow\": np.array([35,25,30]), # nice dark purpley color for cells agents can't see.\n}\n\n# Used to map colors to integers\nCOLOR_TO_IDX = dict({v: k for k, v in enumerate(COLORS.keys())})\n\nOBJECT_TYPES = []\n\nclass RegisteredObjectType(type):\n    def __new__(meta, name, bases, class_dict):\n        cls = type.__new__(meta, name, bases, class_dict)\n        if name not in OBJECT_TYPES:\n            OBJECT_TYPES.append(cls)\n\n        def get_recursive_subclasses():\n            return OBJECT_TYPES\n\n        cls.recursive_subclasses = staticmethod(get_recursive_subclasses)\n        return cls\n\n\nclass WorldObj(metaclass=RegisteredObjectType):\n    def __init__(self, color=\"worst\", state=0):\n        self.color = color\n        self.state = state\n        self.contains = None\n\n        self.agents = [] # Some objects can have agents on top (e.g. floor, open doors, etc).\n        \n        self.pos_init = None\n        self.pos = None\n        self.is_agent = False\n\n    @property\n    def dir(self):\n        return None\n\n    def set_position(self, pos):\n        if self.pos_init is None:\n            self.pos_init = pos\n        self.pos = pos\n\n    @property\n    def numeric_color(self):\n        return COLORS[self.color]\n    \n    @property\n    def type(self):\n        return self.__class__.__name__\n\n    def can_overlap(self):\n        return False\n\n    def can_pickup(self):\n        return False\n\n    def can_contain(self):\n        return False\n\n    def see_behind(self):\n        return True\n\n    def toggle(self, env, pos):\n        return False\n\n    def encode(self, str_class=False):\n        # Note 5/29/20: Commented out the condition below; was causing agents to \n        #  render incorrectly in partial views. In particular, if there were N red agents,\n        #  agents {i != k} would render as blue (rather than red) in agent k's partial view.\n        # # if len(self.agents)>0:\n        # #     return self.agents[0].encode(str_class=str_class)\n        # # else:\n        enc_class = self.type if bool(str_class) else self.recursive_subclasses().index(self.__class__)\n        enc_color = self.color if isinstance(self.color, int) else COLOR_TO_IDX[self.color]\n        return (enc_class, enc_color, self.state)\n\n    def describe(self):\n        return f\"Obj: {self.type}({self.color}, {self.state})\"\n\n    @classmethod\n    def decode(cls, type, color, state):\n        if isinstance(type, str):\n            cls_subclasses = {c.__name__: c for c in cls.recursive_subclasses()}\n            if type not in cls_subclasses:\n                raise ValueError(\n                    f\"Not sure how to construct a {cls} of (sub)type {type}\"\n                )\n            return cls_subclasses[type](color, state)\n        elif isinstance(type, int):\n            subclass = cls.recursive_subclasses()[type]\n            return subclass(color, state)\n\n    def render(self, img):\n        raise NotImplementedError\n\n    def str_render(self, dir=0):\n        return \"??\"\n\n\nclass GridAgent(WorldObj):\n    def __init__(self, *args, color='red', **kwargs):\n        super().__init__(*args, **{'color':color, **kwargs})\n        self.metadata = {\n            'color': color,\n        }\n        self.is_agent = True\n\n    @property\n    def dir(self):\n        return self.state % 4\n\n    @property\n    def type(self):\n        return 'Agent'\n\n    @dir.setter\n    def dir(self, dir):\n        self.state = self.state // 4 + dir % 4\n\n    def str_render(self, dir=0):\n        return [\">>\", \"VV\", \"<<\", \"^^\"][(self.dir + dir) % 4]\n\n    def can_overlap(self):\n        return True\n\n    def render(self, img):\n        tri_fn = point_in_triangle((0.12, 0.19), (0.87, 0.50), (0.12, 0.81),)\n        tri_fn = rotate_fn(tri_fn, cx=0.5, cy=0.5, theta=0.5 * np.pi * (self.dir))\n        fill_coords(img, tri_fn, COLORS[self.color])\n\n\nclass BulkObj(WorldObj, metaclass=RegisteredObjectType):\n    # Todo: special behavior for hash, eq if the object has an agent.\n    def __hash__(self):\n        return hash((self.__class__, self.color, self.state, tuple(self.agents)))\n\n    def __eq__(self, other):\n        return hash(self) == hash(other)\n\nclass BonusTile(WorldObj):\n    def __init__(self, reward, penalty=-0.1, bonus_id=0, n_bonus=1, initial_reward=True, reset_on_mistake=False, color='yellow', *args, **kwargs):\n        super().__init__(*args, **{'color': color, **kwargs, 'state': bonus_id})\n        self.reward = reward\n        self.penalty = penalty\n        self.n_bonus = n_bonus\n        self.bonus_id = bonus_id\n        self.initial_reward = initial_reward\n        self.reset_on_mistake = reset_on_mistake\n\n    def can_overlap(self):\n        return True\n\n    def str_render(self, dir=0):\n        return \"BB\"\n\n    def get_reward(self, agent):\n        # If the agent hasn't hit any bonus tiles, set its bonus state so that\n        #  it'll get a reward from hitting this tile.\n        first_bonus = False\n        if agent.bonus_state is None:\n            agent.bonus_state = (self.bonus_id - 1) % self.n_bonus\n            first_bonus = True\n\n        if agent.bonus_state == self.bonus_id:\n            # This is the last bonus tile the agent hit\n            rew = -np.abs(self.penalty)\n        elif (agent.bonus_state + 1)%self.n_bonus == self.bonus_id:\n            # The agent hit the previous bonus tile before this one\n            agent.bonus_state = self.bonus_id\n            # rew = agent.bonus_value\n            rew = self.reward\n        else:\n            # The agent hit any other bonus tile before this one\n            rew = -np.abs(self.penalty)\n\n        if self.reset_on_mistake:\n            agent.bonus_state = self.bonus_id\n\n        if first_bonus and not bool(self.initial_reward):\n            return 0\n        else:\n            return rew\n\n    def render(self, img):\n        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])\n\nclass Goal(WorldObj):\n    def __init__(self, reward, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.reward = reward\n\n    def can_overlap(self):\n        return True\n\n    def get_reward(self, agent):\n        return self.reward\n\n    def str_render(self, dir=0):\n        return \"GG\"\n\n    def render(self, img):\n        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])\n\n\nclass Floor(WorldObj):\n    def can_overlap(self):\n        return True# and self.agent is None\n\n    def str_render(self, dir=0):\n        return \"FF\"\n\n    def render(self, img):\n        # Give the floor a pale color\n        c = COLORS[self.color]\n        img.setLineColor(100, 100, 100, 0)\n        img.setColor(*c / 2)\n        # img.drawPolygon([\n        #     (1          , TILE_PIXELS),\n        #     (TILE_PIXELS, TILE_PIXELS),\n        #     (TILE_PIXELS,           1),\n        #     (1          ,           1)\n        # ])\n\n\nclass EmptySpace(WorldObj):\n    def can_verlap(self):\n        return True\n\n    def str_render(self, dir=0):\n        return \"  \"\n\n\nclass Lava(WorldObj):\n    def can_overlap(self):\n        return True# and self.agent is None\n\n    def str_render(self, dir=0):\n        return \"VV\"\n\n    def render(self, img):\n        c = (255, 128, 0)\n\n        # Background color\n        fill_coords(img, point_in_rect(0, 1, 0, 1), c)\n\n        # Little waves\n        for i in range(3):\n            ylo = 0.3 + 0.2 * i\n            yhi = 0.4 + 0.2 * i\n            fill_coords(img, point_in_line(0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0))\n            fill_coords(img, point_in_line(0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0))\n            fill_coords(img, point_in_line(0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0))\n            fill_coords(img, point_in_line(0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0))\n\n\nclass Wall(BulkObj):\n    def see_behind(self):\n        return False\n\n    def str_render(self, dir=0):\n        return \"WW\"\n\n    def render(self, img):\n        fill_coords(img, point_in_rect(0, 1, 0, 1), COLORS[self.color])\n\n\nclass Key(WorldObj):\n    def can_pickup(self):\n        return True\n\n    def str_render(self, dir=0):\n        return \"KK\"\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        # Vertical quad\n        fill_coords(img, point_in_rect(0.50, 0.63, 0.31, 0.88), c)\n\n        # Teeth\n        fill_coords(img, point_in_rect(0.38, 0.50, 0.59, 0.66), c)\n        fill_coords(img, point_in_rect(0.38, 0.50, 0.81, 0.88), c)\n\n        # Ring\n        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.190), c)\n        fill_coords(img, point_in_circle(cx=0.56, cy=0.28, r=0.064), (0, 0, 0))\n\n\nclass Ball(WorldObj):\n    def can_pickup(self):\n        return True\n\n    def str_render(self, dir=0):\n        return \"AA\"\n\n    def render(self, img):\n        fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])\n\n\nclass Door(WorldObj):\n    states = IntEnum(\"door_state\", \"open closed locked\")\n\n    def can_overlap(self):\n        return self.state == self.states.open# and self.agent is None  # is open\n\n    def see_behind(self):\n        return self.state == self.states.open  # is open\n\n    def toggle(self, agent, pos):\n        if self.state == self.states.locked:  # is locked\n            # If the agent is carrying a key of matching color\n            if (\n                agent.carrying is not None\n                and isinstance(agent.carrying, Key)\n                and agent.carrying.color == self.color\n            ):\n                self.state = self.states.closed\n        elif self.state == self.states.closed:  # is unlocked but closed\n            self.state = self.states.open\n        elif self.state == self.states.open:  # is open\n            self.state = self.states.closed\n        return True\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        if self.state == self.states.open:\n            fill_coords(img, point_in_rect(0.88, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.92, 0.96, 0.04, 0.96), (0, 0, 0))\n            return\n\n        # Door frame and door\n        if self.state == self.states.locked:\n            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.06, 0.94, 0.06, 0.94), 0.45 * np.array(c))\n\n            # Draw key slot\n            fill_coords(img, point_in_rect(0.52, 0.75, 0.50, 0.56), c)\n        else:\n            fill_coords(img, point_in_rect(0.00, 1.00, 0.00, 1.00), c)\n            fill_coords(img, point_in_rect(0.04, 0.96, 0.04, 0.96), (0, 0, 0))\n            fill_coords(img, point_in_rect(0.08, 0.92, 0.08, 0.92), c)\n            fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), (0, 0, 0))\n\n            # Draw door handle\n            fill_coords(img, point_in_circle(cx=0.75, cy=0.50, r=0.08), c)\n\n\nclass Box(WorldObj):\n    def __init__(self, color=0, state=0, contains=None):\n        super().__init__(color, state)\n        self.contains = contains\n\n    def can_pickup(self):\n        return True\n\n    def toggle(self):\n        raise NotImplementedError\n\n    def str_render(self, dir=0):\n        return \"BB\"\n\n    def render(self, img):\n        c = COLORS[self.color]\n\n        # Outline\n        fill_coords(img, point_in_rect(0.12, 0.88, 0.12, 0.88), c)\n        fill_coords(img, point_in_rect(0.18, 0.82, 0.18, 0.82), (0, 0, 0))\n\n        # Horizontal slit\n        fill_coords(img, point_in_rect(0.16, 0.84, 0.47, 0.53), c)\n"
  },
  {
    "path": "marlgrid/rendering.py",
    "content": "import pyglet\nfrom pyglet.gl import *\nimport sys\n\nclass SimpleImageViewer(object):\n    def __init__(self, display=None, caption=None, maxwidth=500):\n        self.window = None\n        self.isopen = False\n        self.display = display\n        self.maxwidth = maxwidth\n        self.caption = caption\n\n    def imshow(self, arr):\n        if self.window is None:\n            height, width, _channels = arr.shape\n            if width > self.maxwidth:\n                scale = self.maxwidth / width\n                width = int(scale * width)\n                height = int(scale * height)\n            self.window = pyglet.window.Window(width=width, height=height,\n                display=self.display, vsync=False, resizable=True, caption=self.caption)\n            self.width = width\n            self.height = height\n            self.isopen = True\n\n            @self.window.event\n            def on_resize(width, height):\n                self.width = width\n                self.height = height\n\n            @self.window.event\n            def on_close():\n                self.isopen = False\n\n        assert len(arr.shape) == 3, \"You passed in an image with the wrong number shape\"\n\n        image = pyglet.image.ImageData(arr.shape[1], arr.shape[0],\n            'RGB', arr.tobytes(), pitch=arr.shape[1]*-3)\n        gl.glTexParameteri(gl.GL_TEXTURE_2D,\n            gl.GL_TEXTURE_MAG_FILTER, gl.GL_NEAREST)\n        texture = image.get_texture()\n\n        aspect_ratio = arr.shape[1]/arr.shape[0]\n        forced_width = min(self.width, self.height * aspect_ratio)\n        texture.height = int(forced_width / aspect_ratio)\n        texture.width = int(forced_width)\n\n        self.window.clear()\n        self.window.switch_to()\n        self.window.dispatch_events()\n        texture.blit(0, 0) # draw\n        self.window.flip()\n        \n    def close(self):\n        if self.isopen and sys.meta_path:\n            # ^^^ check sys.meta_path to avoid 'ImportError: sys.meta_path is None, Python is likely shutting down'\n            self.window.close()\n            self.isopen = False\n\n    def __del__(self):\n        self.close()\n\n\nclass InteractivePlayerWindow(SimpleImageViewer):\n    def __init__(self, display=None, caption=None, maxwidth=500):\n        super().__init__(display=display, caption=caption, maxwidth=maxwidth)\n        self.key = None\n        self.action_count = 0\n\n        self.action_map = {\n            pyglet.window.key._0:0,\n            pyglet.window.key._1:1,\n            pyglet.window.key._2:2,\n            pyglet.window.key._3:3,\n            pyglet.window.key._4:4,\n            pyglet.window.key._5:5,\n            pyglet.window.key._6:6,\n            pyglet.window.key.LEFT:0,\n            pyglet.window.key.RIGHT:1,\n            pyglet.window.key.UP:2,\n            # pyglet.window.key.Q:-1,\n        }\n\n    def get_action(self, obs):\n        if self.window is None:\n            self.imshow(obs)\n\n            @self.window.event\n            def on_key_press(symbol, modifiers):\n                self.key = symbol\n\n            return self.get_action(obs)\n    \n        self.imshow(obs)\n        self.key = None\n        while self.key not in self.action_map:\n            self.window.dispatch_events()\n            pyglet.clock.tick()\n\n        return self.action_map[self.key]"
  },
  {
    "path": "marlgrid/utils/__init__.py",
    "content": ""
  },
  {
    "path": "marlgrid/utils/video.py",
    "content": "import gym\nimport numpy as np\nimport os\nimport tqdm\n\n\ndef export_video(X, outfile, fps=30, rescale_factor=2):\n\n    try:\n        import moviepy.editor as mpy\n    except:\n        raise ImportError(\n            \"GridRecorder requires moviepy library. Try installing:\\n $ pip install moviepy\"\n        )\n\n    if isinstance(X, list):\n        X = np.stack(X)\n\n    if isinstance(X, np.float) and X.max() < 1:\n        X = (X * 255).astype(np.uint8).clip(0, 255)\n\n    if rescale_factor is not None and rescale_factor != 1:\n        X = np.kron(X, np.ones((1, rescale_factor, rescale_factor, 1)))\n\n    def make_frame(i):\n        out = X[i]\n        return out\n\n    getframe = lambda t: make_frame(min(int(t * fps), len(X) - 1))\n    clip = mpy.VideoClip(getframe, duration=len(X) / fps)\n\n    outfile = os.path.abspath(os.path.expanduser(outfile))\n    if not os.path.isdir(os.path.dirname(outfile)):\n        os.makedirs(os.path.dirname(outfile))\n    clip.write_videofile(outfile, fps=fps)\n\n\ndef render_frames(X, path, ext=\"png\"):\n    try:\n        from PIL import Image\n    except ImportError as e:\n        raise ImportError(\n            \"Error importing from PIL in export_frames. Try installing PIL:\\n $ pip install Pillow\"\n        )\n\n    # If the path has a file extension, dump frames in a new directory with = path minus extension\n    if \".\" in os.path.basename(path):\n        path = os.path.splitext(path)[0]\n    if not os.path.isdir(path):\n        os.makedirs(path)\n\n    for k, frame in tqdm.tqdm(enumerate(X), total=len(X)):\n        Image.fromarray(frame, \"RGB\").save(os.path.join(path, f\"frame_{k}.{ext}\"))\n\nclass GridRecorder(gym.core.Wrapper):\n    default_max_len = 1000\n    default_video_kwargs = {\n        'fps': 20,\n        'rescale_factor': 1,\n    }\n    def __init__(\n            self,\n            env,\n            save_root,\n            max_steps=1000,\n            auto_save_images=True,\n            auto_save_videos=True,\n            auto_save_interval=None,\n            render_kwargs={},\n            video_kwargs={}\n            ):\n        super().__init__(env)\n\n        self.frames = None\n        self.ptr = 0\n        self.reset_count = 0\n        self.last_save = -10000\n        self.recording = False\n        self.save_root = self.fix_path(save_root)\n        self.auto_save_videos = auto_save_videos\n        self.auto_save_images = auto_save_images\n        self.auto_save_interval = auto_save_interval\n        self.render_kwargs = render_kwargs\n        self.video_kwargs = {**self.default_video_kwargs, **video_kwargs}\n        self.n_parallel = getattr(env, 'num_envs', 1)\n\n        if max_steps is None:\n            if hasattr(env, \"max_steps\") and env.max_steps != 0:\n                self.max_steps = env.max_steps + 1\n            else:\n                self.max_steps = self.default_max_steps + 1\n        else:\n            self.max_steps = max_steps + 1\n    \n    @staticmethod\n    def fix_path(path):\n        return os.path.abspath(os.path.expanduser(path))\n\n    @property\n    def should_record(self):\n        if self.recording:\n            return True\n        if self.auto_save_interval is None:\n            return False\n        return (self.reset_count - self.last_save) >= self.auto_save_interval\n\n    def export_frames(self,  episode_id=None, save_root=None):\n        if save_root is None:\n            save_root = self.save_root\n        if episode_id is None:\n            episode_id = f'frames_{self.reset_count}'\n        render_frames(self.frames[:self.ptr], os.path.join(self.fix_path(save_root), episode_id))\n\n    def export_video(self, episode_id=None, save_root=None):\n        if save_root is None:\n            save_root = self.save_root\n        if episode_id is None:\n            episode_id = f'video_{self.reset_count}.mp4'\n        export_video(self.frames[:self.ptr],  os.path.join(self.fix_path(save_root), episode_id), **self.video_kwargs)\n\n    def export_both(self, episode_id, save_root=None):\n        self.export_frames(f'{episode_id}_frames', save_root=save_root)\n        self.export_video(f'{episode_id}.mp4', save_root=save_root)\n\n    def reset(self, **kwargs):\n        if self.should_record and self.ptr>0:\n            self.append_current_frame()\n            if self.auto_save_images:\n                self.export_frames()\n            if self.auto_save_videos:\n                self.export_video()\n            self.last_save = self.reset_count\n        del self.frames\n        self.frames = None\n        self.ptr = 0\n        self.reset_count += self.n_parallel\n        return self.env.reset(**kwargs)\n\n    def append_current_frame(self):\n        if self.should_record:\n            new_frame = self.env.render(mode=\"rgb_array\", **self.render_kwargs)\n            if isinstance(new_frame, list) or len(new_frame.shape)>3:\n                new_frame = new_frame[0]\n            if self.frames is None:\n                self.frames = np.zeros(\n                    (self.max_steps, *new_frame.shape), dtype=new_frame.dtype\n                )\n            self.frames[self.ptr] = new_frame\n            self.ptr += 1\n\n    def step(self, action):\n        self.append_current_frame()\n        obs, rew, done, info = self.env.step(action)\n        return obs, rew, done, info\n\n    # def export_video(\n    #     self,\n    #     output_path,\n    #     fps=20,\n    #     rescale_factor=1,\n    #     render_last=True,\n    #     render_frame_images=True,\n    #     **kwargs,\n    # ):\n    #     if self.should_record:\n    #         if render_last:\n    #             self.frames[self.ptr] = self.env.render(\n    #                 mode=\"rgb_array\", **self.render_kwargs\n    #             )\n    #         if render_frame_images:\n    #             render_frames(self.frames[: self.ptr + 1], output_path)\n    #         return export_video(\n    #             self.frames[: self.ptr + 1],\n    #             output_path,\n    #             fps=fps,\n    #             rescale_factor=rescale_factor,\n    #             **kwargs,\n    #         )\n"
  },
  {
    "path": "setup.py",
    "content": "from setuptools import setup, find_packages\n\nsetup(\n    name=\"marlgrid\",\n    version=\"0.0.5\",\n    packages=find_packages(),\n    install_requires=[\"numpy\", \"tqdm\", \"gym\", \"gym-minigrid\", \"numba\"],\n)\n"
  }
]