[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Pythond\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\npip-wheel-metadata/\nshare/python-wheels/\n*.egg-info/\n.installed.cfg\n*.egg\nMANIFEST\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.nox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n.pytest_cache/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\ndb.sqlite3\ndb.sqlite3-journal\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# IPython\nprofile_default/\nipython_config.py\n\n# pyenv\n.python-version\n\n# pipenv\n#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.\n#   However, in case of collaboration, if having platform-specific dependencies or dependencies\n#   having no cross-platform support, pipenv may install dependencies that don't work, or not\n#   install all needed dependencies.\n#Pipfile.lock\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# Environments\n.env\n.venv\nenv/\nvenv/\nENV/\nenv.bak/\nvenv.bak/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n.dmypy.json\ndmypy.json\n\n# Pyre type checker\n.pyre/\n\n# Pycharm\n.idea/\n"
  },
  {
    "path": "README.rst",
    "content": ".. |copy| unicode:: 0xA9\n.. |---| unicode:: U+02014\n\n======\nMuZero\n======\n\nThis repository is a Python implementation of the MuZero algorithm.\nIt is based upon the `pre-print paper`__ and the `pseudocode`__ describing the Muzero framework.\nNeural computations are implemented with Tensorflow.\n\nYou can easily train your own MuZero, more specifically for one player and non-image based environments (such as `CartPole`__).\nIf you wish to train Muzero on other kinds of environments, this codebase can be used with slight modifications.\n\n__ https://arxiv.org/abs/1911.08265\n__ https://arxiv.org/src/1911.08265v1/anc/pseudocode.py\n__ https://gym.openai.com/envs/CartPole-v1/\n\n\n**DISCLAIMER**: this code is early research code. What this means is:\n\n- Silent bugs may exist.\n- It may not work reliably on other environments or with other hyper-parameters.\n- The code quality and documentation are quite lacking, and much of the code might still feel \"in-progress\".\n- The training and testing pipeline is not very advanced.\n\nDependencies\n============\n\nWe run this code using:\n\n- Conda **4.7.12**\n- Python **3.7**\n- Tensorflow **2.0.0**\n- Numpy **1.17.3**\n\nTraining your MuZero\n====================\n\nThis code must be run from the main function in ``muzero.py`` (don't forget to first configure your conda environment).\n\nTraining a Cartpole-v1 bot\n--------------------------\n\nTo train a model, please follow these steps:\n\n1) Create or modify an existing configuration of Muzero in ``config.py``.\n\n2) Call the right configuration inside the main of ``muzero.py``.\n\n3) Run the main function: ``python muzero.py``.\n\nTraining on an other environment\n--------------------------------\n\nTo train on a different environment than Cartpole-v1, please follow these additional steps:\n\n1) Create a class that extends ``AbstractGame``, this class should implement the behavior of your environment.\nFor instance, the ``CartPole`` class extends ``AbstractGame`` and works as a wrapper upon `gym CartPole-v1`__.\nYou can use the ``CartPole`` class as a template for any gym environment.\n\n__ https://gym.openai.com/envs/CartPole-v1/\n\n2) **This step is optional** (only if you want to use a different kind of network architecture or value/reward transform).\nCreate a class that extends ``BaseNetwork``, this class should implement the different networks (representation, value, policy, reward and dynamic) and value/reward transforms.\nFor instance, the ``CartPoleNetwork`` class extends ``BaseNetwork`` and implements fully connected networks.\n\n3) **This step is optional** (only if you use a different value/reward transform).\nYou should implement the corresponding inverse value/reward transform by modifying the ``loss_value`` and ``loss_reward`` function inside ``training.py``.\n\nDifferences from the paper\n==========================\n\nThis implementation differ from the original paper in the following manners:\n\n- We use fully connected layers instead of convolutional ones. This is due to the nature of our environment (Cartpole-v1) which as no spatial correlation in the observation vector.\n- We don't scale the hidden state between 0 and 1 using min-max normalization. Instead we use a tanh function that maps any values in a range between -1 and 1.\n- We do use a slightly simple invertible transform for the value prediction by removing the linear term.\n- During training, samples are drawn from a uniform distribution instead of using prioritized replay.\n- We also scale the loss of each head by 1/K (with K the number of unrolled steps). But, instead we consider that K is always constant (even if it is not always true).\n"
  },
  {
    "path": "muzero/__init__.py",
    "content": ""
  },
  {
    "path": "muzero/config.py",
    "content": "import collections\nfrom typing import Optional, Dict\n\nimport tensorflow_core as tf\n\nfrom game.cartpole import CartPole\nfrom game.game import AbstractGame\nfrom networks.cartpole_network import CartPoleNetwork\nfrom networks.network import BaseNetwork, UniformNetwork\n\nKnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])\n\n\nclass MuZeroConfig(object):\n\n    def __init__(self,\n                 game,\n                 nb_training_loop: int,\n                 nb_episodes: int,\n                 nb_epochs: int,\n                 network_args: Dict,\n                 network,\n                 action_space_size: int,\n                 max_moves: int,\n                 discount: float,\n                 dirichlet_alpha: float,\n                 num_simulations: int,\n                 batch_size: int,\n                 td_steps: int,\n                 visit_softmax_temperature_fn,\n                 lr: float,\n                 known_bounds: Optional[KnownBounds] = None):\n        ### Environment\n        self.game = game\n\n        ### Self-Play\n        self.action_space_size = action_space_size\n        # self.num_actors = num_actors\n\n        self.visit_softmax_temperature_fn = visit_softmax_temperature_fn\n        self.max_moves = max_moves\n        self.num_simulations = num_simulations\n        self.discount = discount\n\n        # Root prior exploration noise.\n        self.root_dirichlet_alpha = dirichlet_alpha\n        self.root_exploration_fraction = 0.25\n\n        # UCB formula\n        self.pb_c_base = 19652\n        self.pb_c_init = 1.25\n\n        # If we already have some information about which values occur in the\n        # environment, we can use them to initialize the rescaling.\n        # This is not strictly necessary, but establishes identical behaviour to\n        # AlphaZero in board games.\n        self.known_bounds = known_bounds\n\n        ### Training\n        self.nb_training_loop = nb_training_loop\n        self.nb_episodes = nb_episodes  # Nb of episodes per training loop\n        self.nb_epochs = nb_epochs  # Nb of epochs per training loop\n\n        # self.training_steps = int(1000e3)\n        # self.checkpoint_interval = int(1e3)\n        self.window_size = int(1e6)\n        self.batch_size = batch_size\n        self.num_unroll_steps = 5\n        self.td_steps = td_steps\n\n        self.weight_decay = 1e-4\n        self.momentum = 0.9\n\n        self.network_args = network_args\n        self.network = network\n        self.lr = lr\n        # Exponential learning rate schedule\n        # self.lr_init = lr_init\n        # self.lr_decay_rate = 0.1\n        # self.lr_decay_steps = lr_decay_steps\n\n    def new_game(self) -> AbstractGame:\n        return self.game(self.discount)\n\n    def new_network(self) -> BaseNetwork:\n        return self.network(**self.network_args)\n\n    def uniform_network(self) -> UniformNetwork:\n        return UniformNetwork(self.action_space_size)\n\n    def new_optimizer(self) -> tf.keras.optimizers:\n        return tf.keras.optimizers.SGD(learning_rate=self.lr, momentum=self.momentum)\n\n\ndef make_cartpole_config() -> MuZeroConfig:\n    def visit_softmax_temperature(num_moves, training_steps):\n        return 1.0\n\n    return MuZeroConfig(\n        game=CartPole,\n        nb_training_loop=50,\n        nb_episodes=20,\n        nb_epochs=20,\n        network_args={'action_size': 2,\n                      'state_size': 4,\n                      'representation_size': 4,\n                      'max_value': 500},\n        network=CartPoleNetwork,\n        action_space_size=2,\n        max_moves=1000,\n        discount=0.99,\n        dirichlet_alpha=0.25,\n        num_simulations=11,  # Odd number perform better in eval mode\n        batch_size=512,\n        td_steps=10,\n        visit_softmax_temperature_fn=visit_softmax_temperature,\n        lr=0.05)\n\n\n\"\"\"\nLegacy configs from the DeepMind's pseudocode.\n\ndef make_board_game_config(action_space_size: int, max_moves: int,\n                           dirichlet_alpha: float,\n                           lr_init: float) -> MuZeroConfig:\n    def visit_softmax_temperature(num_moves, training_steps):\n        if num_moves < 30:\n            return 1.0\n        else:\n            return 0.0  # Play according to the max.\n\n    return MuZeroConfig(\n        action_space_size=action_space_size,\n        max_moves=max_moves,\n        discount=1.0,\n        dirichlet_alpha=dirichlet_alpha,\n        num_simulations=800,\n        batch_size=2048,\n        td_steps=max_moves,  # Always use Monte Carlo return.\n        num_actors=3000,\n        lr_init=lr_init,\n        lr_decay_steps=400e3,\n        visit_softmax_temperature_fn=visit_softmax_temperature,\n        known_bounds=KnownBounds(-1, 1))\n\n\ndef make_go_config() -> MuZeroConfig:\n    return make_board_game_config(\n        action_space_size=362, max_moves=722, dirichlet_alpha=0.03, lr_init=0.01)\n\n\ndef make_chess_config() -> MuZeroConfig:\n    return make_board_game_config(\n        action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1)\n\n\ndef make_shogi_config() -> MuZeroConfig:\n    return make_board_game_config(\n        action_space_size=11259, max_moves=512, dirichlet_alpha=0.15, lr_init=0.1)\n\n\ndef make_atari_config() -> MuZeroConfig:\n    def visit_softmax_temperature(num_moves, training_steps):\n        if training_steps < 500e3:\n            return 1.0\n        elif training_steps < 750e3:\n            return 0.5\n        else:\n            return 0.25\n\n    return MuZeroConfig(\n        action_space_size=18,\n        max_moves=27000,  # Half an hour at action repeat 4.\n        discount=0.997,\n        dirichlet_alpha=0.25,\n        num_simulations=50,\n        batch_size=1024,\n        td_steps=10,\n        num_actors=350,\n        lr_init=0.05,\n        lr_decay_steps=350e3,\n        visit_softmax_temperature_fn=visit_softmax_temperature)\n\"\"\"\n"
  },
  {
    "path": "muzero/game/__init__.py",
    "content": ""
  },
  {
    "path": "muzero/game/cartpole.py",
    "content": "from typing import List\n\nimport gym\n\nfrom game.game import Action, AbstractGame\nfrom game.gym_wrappers import ScalingObservationWrapper\n\n\nclass CartPole(AbstractGame):\n    \"\"\"The Gym CartPole environment\"\"\"\n\n    def __init__(self, discount: float):\n        super().__init__(discount)\n        self.env = gym.make('CartPole-v1')\n        self.env = ScalingObservationWrapper(self.env, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5])\n        self.actions = list(map(lambda i: Action(i), range(self.env.action_space.n)))\n        self.observations = [self.env.reset()]\n        self.done = False\n\n    @property\n    def action_space_size(self) -> int:\n        \"\"\"Return the size of the action space.\"\"\"\n        return len(self.actions)\n\n    def step(self, action) -> int:\n        \"\"\"Execute one step of the game conditioned by the given action.\"\"\"\n\n        observation, reward, done, _ = self.env.step(action.index)\n        self.observations += [observation]\n        self.done = done\n        return reward\n\n    def terminal(self) -> bool:\n        \"\"\"Is the game is finished?\"\"\"\n        return self.done\n\n    def legal_actions(self) -> List[Action]:\n        \"\"\"Return the legal actions available at this instant.\"\"\"\n        return self.actions\n\n    def make_image(self, state_index: int):\n        \"\"\"Compute the state of the game.\"\"\"\n        return self.observations[state_index]\n"
  },
  {
    "path": "muzero/game/game.py",
    "content": "from abc import abstractmethod, ABC\nfrom typing import List\n\nfrom self_play.utils import Node\n\n\nclass Action(object):\n    \"\"\" Class that represent an action of a game.\"\"\"\n\n    def __init__(self, index: int):\n        self.index = index\n\n    def __hash__(self):\n        return self.index\n\n    def __eq__(self, other):\n        return self.index == other.index\n\n    def __gt__(self, other):\n        return self.index > other.index\n\n\nclass Player(object):\n    \"\"\"\n    A one player class.\n    This class is useless, it's here for legacy purpose and for potential adaptations for a two players MuZero.\n    \"\"\"\n\n    def __eq__(self, other):\n        return True\n\n\nclass ActionHistory(object):\n    \"\"\"\n    Simple history container used inside the search.\n    Only used to keep track of the actions executed.\n    \"\"\"\n\n    def __init__(self, history: List[Action], action_space_size: int):\n        self.history = list(history)\n        self.action_space_size = action_space_size\n\n    def clone(self):\n        return ActionHistory(self.history, self.action_space_size)\n\n    def add_action(self, action: Action):\n        self.history.append(action)\n\n    def last_action(self) -> Action:\n        return self.history[-1]\n\n    def action_space(self) -> List[Action]:\n        return [Action(i) for i in range(self.action_space_size)]\n\n    def to_play(self) -> Player:\n        return Player()\n\n\nclass AbstractGame(ABC):\n    \"\"\"\n    Abstract class that allows to implement a game.\n    One instance represent a single episode of interaction with the environment.\n    \"\"\"\n\n    def __init__(self, discount: float):\n        self.history = []\n        self.rewards = []\n        self.child_visits = []\n        self.root_values = []\n        self.discount = discount\n\n    def apply(self, action: Action):\n        \"\"\"Apply an action onto the environment.\"\"\"\n\n        reward = self.step(action)\n        self.rewards.append(reward)\n        self.history.append(action)\n\n    def store_search_statistics(self, root: Node):\n        \"\"\"After each MCTS run, store the statistics generated by the search.\"\"\"\n\n        sum_visits = sum(child.visit_count for child in root.children.values())\n        action_space = (Action(index) for index in range(self.action_space_size))\n        self.child_visits.append([\n            root.children[a].visit_count / sum_visits if a in root.children else 0\n            for a in action_space\n        ])\n        self.root_values.append(root.value())\n\n    def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int, to_play: Player):\n        \"\"\"Generate targets to learn from during the network training.\"\"\"\n\n        # The value target is the discounted root value of the search tree N steps\n        # into the future, plus the discounted sum of all rewards until then.\n        targets = []\n        for current_index in range(state_index, state_index + num_unroll_steps + 1):\n            bootstrap_index = current_index + td_steps\n            if bootstrap_index < len(self.root_values):\n                value = self.root_values[bootstrap_index] * self.discount ** td_steps\n            else:\n                value = 0\n\n            for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):\n                value += reward * self.discount ** i\n\n            if current_index < len(self.root_values):\n                targets.append((value, self.rewards[current_index], self.child_visits[current_index]))\n            else:\n                # States past the end of games are treated as absorbing states.\n                targets.append((0, 0, []))\n        return targets\n\n    def to_play(self) -> Player:\n        \"\"\"Return the current player.\"\"\"\n        return Player()\n\n    def action_history(self) -> ActionHistory:\n        \"\"\"Return the actions executed inside the search.\"\"\"\n        return ActionHistory(self.history, self.action_space_size)\n\n    # Methods to be implemented by the children class\n    @property\n    @abstractmethod\n    def action_space_size(self) -> int:\n        \"\"\"Return the size of the action space.\"\"\"\n        pass\n\n    @abstractmethod\n    def step(self, action) -> int:\n        \"\"\"Execute one step of the game conditioned by the given action.\"\"\"\n        pass\n\n    @abstractmethod\n    def terminal(self) -> bool:\n        \"\"\"Is the game is finished?\"\"\"\n        pass\n\n    @abstractmethod\n    def legal_actions(self) -> List[Action]:\n        \"\"\"Return the legal actions available at this instant.\"\"\"\n        pass\n\n    @abstractmethod\n    def make_image(self, state_index: int):\n        \"\"\"Compute the state of the game.\"\"\"\n        pass\n"
  },
  {
    "path": "muzero/game/gym_wrappers.py",
    "content": "import gym\nimport numpy as np\n\n\nclass ScalingObservationWrapper(gym.ObservationWrapper):\n    \"\"\"\n    Wrapper that apply a min-max scaling of observations.\n    \"\"\"\n\n    def __init__(self, env, low=None, high=None):\n        super().__init__(env)\n        assert isinstance(env.observation_space, gym.spaces.Box)\n\n        low = np.array(self.observation_space.low if low is None else low)\n        high = np.array(self.observation_space.high if high is None else high)\n\n        self.mean = (high + low) / 2\n        self.max = high - self.mean\n\n    def observation(self, observation):\n        return (observation - self.mean) / self.max\n"
  },
  {
    "path": "muzero/muzero.py",
    "content": "from config import MuZeroConfig, make_cartpole_config\nfrom networks.shared_storage import SharedStorage\nfrom self_play.self_play import run_selfplay, run_eval\nfrom training.replay_buffer import ReplayBuffer\nfrom training.training import train_network\n\n\ndef muzero(config: MuZeroConfig):\n    \"\"\"\n    MuZero training is split into two independent parts: Network training and\n    self-play data generation.\n    These two parts only communicate by transferring the latest networks checkpoint\n    from the training to the self-play, and the finished games from the self-play\n    to the training.\n    In contrast to the original MuZero algorithm this version doesn't works with\n    multiple threads, therefore the training and self-play is done alternately.\n    \"\"\"\n    storage = SharedStorage(config.new_network(), config.uniform_network(), config.new_optimizer())\n    replay_buffer = ReplayBuffer(config)\n\n    for loop in range(config.nb_training_loop):\n        print(\"Training loop\", loop)\n        score_train = run_selfplay(config, storage, replay_buffer, config.nb_episodes)\n        train_network(config, storage, replay_buffer, config.nb_epochs)\n\n        print(\"Train score:\", score_train)\n        print(\"Eval score:\", run_eval(config, storage, 50))\n        print(f\"MuZero played {config.nb_episodes * (loop + 1)} \"\n              f\"episodes and trained for {config.nb_epochs * (loop + 1)} epochs.\\n\")\n\n    return storage.latest_network()\n\n\nif __name__ == '__main__':\n    config = make_cartpole_config()\n    muzero(config)\n"
  },
  {
    "path": "muzero/networks/__init__.py",
    "content": ""
  },
  {
    "path": "muzero/networks/cartpole_network.py",
    "content": "import math\n\nimport numpy as np\nfrom tensorflow_core.python.keras import regularizers\nfrom tensorflow_core.python.keras.layers.core import Dense\nfrom tensorflow_core.python.keras.models import Sequential\n\nfrom game.game import Action\nfrom networks.network import BaseNetwork\n\n\nclass CartPoleNetwork(BaseNetwork):\n\n    def __init__(self,\n                 state_size: int,\n                 action_size: int,\n                 representation_size: int,\n                 max_value: int,\n                 hidden_neurons: int = 64,\n                 weight_decay: float = 1e-4,\n                 representation_activation: str = 'tanh'):\n        self.state_size = state_size\n        self.action_size = action_size\n        self.value_support_size = math.ceil(math.sqrt(max_value)) + 1\n\n        regularizer = regularizers.l2(weight_decay)\n        representation_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),\n                                             Dense(representation_size, activation=representation_activation,\n                                                   kernel_regularizer=regularizer)])\n        value_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),\n                                    Dense(self.value_support_size, kernel_regularizer=regularizer)])\n        policy_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),\n                                     Dense(action_size, kernel_regularizer=regularizer)])\n        dynamic_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),\n                                      Dense(representation_size, activation=representation_activation,\n                                            kernel_regularizer=regularizer)])\n        reward_network = Sequential([Dense(16, activation='relu', kernel_regularizer=regularizer),\n                                     Dense(1, kernel_regularizer=regularizer)])\n\n        super().__init__(representation_network, value_network, policy_network, dynamic_network, reward_network)\n\n    def _value_transform(self, value_support: np.array) -> float:\n        \"\"\"\n        The value is obtained by first computing the expected value from the discrete support.\n        Second, the inverse transform is then apply (the square function).\n        \"\"\"\n\n        value = self._softmax(value_support)\n        value = np.dot(value, range(self.value_support_size))\n        value = np.asscalar(value) ** 2\n        return value\n\n    def _reward_transform(self, reward: np.array) -> float:\n        return np.asscalar(reward)\n\n    def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array:\n        conditioned_hidden = np.concatenate((hidden_state, np.eye(self.action_size)[action.index]))\n        return np.expand_dims(conditioned_hidden, axis=0)\n\n    def _softmax(self, values):\n        \"\"\"Compute softmax using numerical stability tricks.\"\"\"\n        values_exp = np.exp(values - np.max(values))\n        return values_exp / np.sum(values_exp)\n"
  },
  {
    "path": "muzero/networks/network.py",
    "content": "import typing\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Callable\n\nimport numpy as np\nfrom tensorflow_core.python.keras.models import Model\n\nfrom game.game import Action\n\n\nclass NetworkOutput(typing.NamedTuple):\n    value: float\n    reward: float\n    policy_logits: Dict[Action, float]\n    hidden_state: typing.Optional[List[float]]\n\n    @staticmethod\n    def build_policy_logits(policy_logits):\n        return {Action(i): logit for i, logit in enumerate(policy_logits[0])}\n\n\nclass AbstractNetwork(ABC):\n\n    def __init__(self):\n        self.training_steps = 0\n\n    @abstractmethod\n    def initial_inference(self, image) -> NetworkOutput:\n        pass\n\n    @abstractmethod\n    def recurrent_inference(self, hidden_state, action) -> NetworkOutput:\n        pass\n\n\nclass UniformNetwork(AbstractNetwork):\n    \"\"\"policy -> uniform, value -> 0, reward -> 0\"\"\"\n\n    def __init__(self, action_size: int):\n        super().__init__()\n        self.action_size = action_size\n\n    def initial_inference(self, image) -> NetworkOutput:\n        return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None)\n\n    def recurrent_inference(self, hidden_state, action) -> NetworkOutput:\n        return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None)\n\n\nclass InitialModel(Model):\n    \"\"\"Model that combine the representation and prediction (value+policy) network.\"\"\"\n\n    def __init__(self, representation_network: Model, value_network: Model, policy_network: Model):\n        super(InitialModel, self).__init__()\n        self.representation_network = representation_network\n        self.value_network = value_network\n        self.policy_network = policy_network\n\n    def call(self, image):\n        hidden_representation = self.representation_network(image)\n        value = self.value_network(hidden_representation)\n        policy_logits = self.policy_network(hidden_representation)\n        return hidden_representation, value, policy_logits\n\n\nclass RecurrentModel(Model):\n    \"\"\"Model that combine the dynamic, reward and prediction (value+policy) network.\"\"\"\n\n    def __init__(self, dynamic_network: Model, reward_network: Model, value_network: Model, policy_network: Model):\n        super(RecurrentModel, self).__init__()\n        self.dynamic_network = dynamic_network\n        self.reward_network = reward_network\n        self.value_network = value_network\n        self.policy_network = policy_network\n\n    def call(self, conditioned_hidden):\n        hidden_representation = self.dynamic_network(conditioned_hidden)\n        reward = self.reward_network(conditioned_hidden)\n        value = self.value_network(hidden_representation)\n        policy_logits = self.policy_network(hidden_representation)\n        return hidden_representation, reward, value, policy_logits\n\n\nclass BaseNetwork(AbstractNetwork):\n    \"\"\"Base class that contains all the networks and models of MuZero.\"\"\"\n\n    def __init__(self, representation_network: Model, value_network: Model, policy_network: Model,\n                 dynamic_network: Model, reward_network: Model):\n        super().__init__()\n        # Networks blocks\n        self.representation_network = representation_network\n        self.value_network = value_network\n        self.policy_network = policy_network\n        self.dynamic_network = dynamic_network\n        self.reward_network = reward_network\n\n        # Models for inference and training\n        self.initial_model = InitialModel(self.representation_network, self.value_network, self.policy_network)\n        self.recurrent_model = RecurrentModel(self.dynamic_network, self.reward_network, self.value_network,\n                                              self.policy_network)\n\n    def initial_inference(self, image: np.array) -> NetworkOutput:\n        \"\"\"representation + prediction function\"\"\"\n\n        hidden_representation, value, policy_logits = self.initial_model.predict(np.expand_dims(image, 0))\n        output = NetworkOutput(value=self._value_transform(value),\n                               reward=0.,\n                               policy_logits=NetworkOutput.build_policy_logits(policy_logits),\n                               hidden_state=hidden_representation[0])\n        return output\n\n    def recurrent_inference(self, hidden_state: np.array, action: Action) -> NetworkOutput:\n        \"\"\"dynamics + prediction function\"\"\"\n\n        conditioned_hidden = self._conditioned_hidden_state(hidden_state, action)\n        hidden_representation, reward, value, policy_logits = self.recurrent_model.predict(conditioned_hidden)\n        output = NetworkOutput(value=self._value_transform(value),\n                               reward=self._reward_transform(reward),\n                               policy_logits=NetworkOutput.build_policy_logits(policy_logits),\n                               hidden_state=hidden_representation[0])\n        return output\n\n    @abstractmethod\n    def _value_transform(self, value: np.array) -> float:\n        pass\n\n    @abstractmethod\n    def _reward_transform(self, reward: np.array) -> float:\n        pass\n\n    @abstractmethod\n    def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array:\n        pass\n\n    def cb_get_variables(self) -> Callable:\n        \"\"\"Return a callback that return the trainable variables of the network.\"\"\"\n\n        def get_variables():\n            networks = (self.representation_network, self.value_network, self.policy_network,\n                        self.dynamic_network, self.reward_network)\n            return [variables\n                    for variables_list in map(lambda n: n.weights, networks)\n                    for variables in variables_list]\n\n        return get_variables\n"
  },
  {
    "path": "muzero/networks/shared_storage.py",
    "content": "import tensorflow_core as tf\n\nfrom networks.network import BaseNetwork, UniformNetwork, AbstractNetwork\n\n\nclass SharedStorage(object):\n    \"\"\"Save the different versions of the network.\"\"\"\n\n    def __init__(self, network: BaseNetwork, uniform_network: UniformNetwork, optimizer: tf.keras.optimizers):\n        self._networks = {}\n        self.current_network = network\n        self.uniform_network = uniform_network\n        self.optimizer = optimizer\n\n    def latest_network(self) -> AbstractNetwork:\n        if self._networks:\n            return self._networks[max(self._networks.keys())]\n        else:\n            # policy -> uniform, value -> 0, reward -> 0\n            return self.uniform_network\n\n    def save_network(self, step: int, network: BaseNetwork):\n        self._networks[step] = network\n"
  },
  {
    "path": "muzero/self_play/__init__.py",
    "content": ""
  },
  {
    "path": "muzero/self_play/mcts.py",
    "content": "\"\"\"MCTS module: where MuZero thinks inside the tree.\"\"\"\n\nimport math\nimport random\nfrom typing import List\n\nimport numpy\n\nfrom config import MuZeroConfig\nfrom game.game import Player, Action, ActionHistory\nfrom networks.network import NetworkOutput, BaseNetwork\nfrom self_play.utils import MinMaxStats, Node, softmax_sample\n\n\ndef add_exploration_noise(config: MuZeroConfig, node: Node):\n    \"\"\"\n    At the start of each search, we add dirichlet noise to the prior of the root\n    to encourage the search to explore new actions.\n    \"\"\"\n    actions = list(node.children.keys())\n    noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))\n    frac = config.root_exploration_fraction\n    for a, n in zip(actions, noise):\n        node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac\n\n\ndef run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, network: BaseNetwork):\n    \"\"\"\n    Core Monte Carlo Tree Search algorithm.\n    To decide on an action, we run N simulations, always starting at the root of\n    the search tree and traversing the tree according to the UCB formula until we\n    reach a leaf node.\n    \"\"\"\n    min_max_stats = MinMaxStats(config.known_bounds)\n\n    for _ in range(config.num_simulations):\n        history = action_history.clone()\n        node = root\n        search_path = [node]\n\n        while node.expanded():\n            action, node = select_child(config, node, min_max_stats)\n            history.add_action(action)\n            search_path.append(node)\n\n        # Inside the search tree we use the dynamics function to obtain the next\n        # hidden state given an action and the previous hidden state.\n        parent = search_path[-2]\n        network_output = network.recurrent_inference(parent.hidden_state, history.last_action())\n        expand_node(node, history.to_play(), history.action_space(), network_output)\n\n        backpropagate(search_path, network_output.value, history.to_play(), config.discount, min_max_stats)\n\n\ndef select_child(config: MuZeroConfig, node: Node, min_max_stats: MinMaxStats):\n    \"\"\"\n    Select the child with the highest UCB score.\n    \"\"\"\n    # When the parent visit count is zero, all ucb scores are zeros, therefore we return a random child\n    if node.visit_count == 0:\n        return random.sample(node.children.items(), 1)[0]\n\n    _, action, child = max(\n        (ucb_score(config, node, child, min_max_stats), action,\n         child) for action, child in node.children.items())\n    return action, child\n\n\ndef ucb_score(config: MuZeroConfig, parent: Node, child: Node,\n              min_max_stats: MinMaxStats) -> float:\n    \"\"\"\n    The score for a node is based on its value, plus an exploration bonus based on\n    the prior.\n    \"\"\"\n    pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init\n    pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)\n\n    prior_score = pb_c * child.prior\n    value_score = min_max_stats.normalize(child.value())\n    return prior_score + value_score\n\n\ndef expand_node(node: Node, to_play: Player, actions: List[Action],\n                network_output: NetworkOutput):\n    \"\"\"\n    We expand a node using the value, reward and policy prediction obtained from\n    the neural networks.\n    \"\"\"\n    node.to_play = to_play\n    node.hidden_state = network_output.hidden_state\n    node.reward = network_output.reward\n    policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}\n    policy_sum = sum(policy.values())\n    for action, p in policy.items():\n        node.children[action] = Node(p / policy_sum)\n\n\ndef backpropagate(search_path: List[Node], value: float, to_play: Player,\n                  discount: float, min_max_stats: MinMaxStats):\n    \"\"\"\n    At the end of a simulation, we propagate the evaluation all the way up the\n    tree to the root.\n    \"\"\"\n    for node in search_path[::-1]:\n        node.value_sum += value if node.to_play == to_play else -value\n        node.visit_count += 1\n        min_max_stats.update(node.value())\n\n        value = node.reward + discount * value\n\n\ndef select_action(config: MuZeroConfig, num_moves: int, node: Node, network: BaseNetwork, mode: str = 'softmax'):\n    \"\"\"\n    After running simulations inside in MCTS, we select an action based on the root's children visit counts.\n    During training we use a softmax sample for exploration.\n    During evaluation we select the most visited child.\n    \"\"\"\n    visit_counts = [child.visit_count for child in node.children.values()]\n    actions = [action for action in node.children.keys()]\n    action = None\n    if mode == 'softmax':\n        t = config.visit_softmax_temperature_fn(\n            num_moves=num_moves, training_steps=network.training_steps)\n        action = softmax_sample(visit_counts, actions, t)\n    elif mode == 'max':\n        action, _ = max(node.children.items(), key=lambda item: item[1].visit_count)\n    return action\n"
  },
  {
    "path": "muzero/self_play/self_play.py",
    "content": "\"\"\"Self-Play module: where the games are played.\"\"\"\n\nfrom config import MuZeroConfig\nfrom game.game import AbstractGame\nfrom networks.network import AbstractNetwork\nfrom networks.shared_storage import SharedStorage\nfrom self_play.mcts import run_mcts, select_action, expand_node, add_exploration_noise\nfrom self_play.utils import Node\nfrom training.replay_buffer import ReplayBuffer\n\n\ndef run_selfplay(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, train_episodes: int):\n    \"\"\"Take the latest network, produces multiple games and save them in the shared replay buffer\"\"\"\n    network = storage.latest_network()\n    returns = []\n    for _ in range(train_episodes):\n        game = play_game(config, network)\n        replay_buffer.save_game(game)\n        returns.append(sum(game.rewards))\n    return sum(returns) / train_episodes\n\n\ndef run_eval(config: MuZeroConfig, storage: SharedStorage, eval_episodes: int):\n    \"\"\"Evaluate MuZero without noise added to the prior of the root and without softmax action selection\"\"\"\n    network = storage.latest_network()\n    returns = []\n    for _ in range(eval_episodes):\n        game = play_game(config, network, train=False)\n        returns.append(sum(game.rewards))\n    return sum(returns) / eval_episodes if eval_episodes else 0\n\n\ndef play_game(config: MuZeroConfig, network: AbstractNetwork, train: bool = True) -> AbstractGame:\n    \"\"\"\n    Each game is produced by starting at the initial board position, then\n    repeatedly executing a Monte Carlo Tree Search to generate moves until the end\n    of the game is reached.\n    \"\"\"\n    game = config.new_game()\n    mode_action_select = 'softmax' if train else 'max'\n\n    while not game.terminal() and len(game.history) < config.max_moves:\n        # At the root of the search tree we use the representation function to\n        # obtain a hidden state given the current observation.\n        root = Node(0)\n        current_observation = game.make_image(-1)\n        expand_node(root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation))\n        if train:\n            add_exploration_noise(config, root)\n\n        # We then run a Monte Carlo Tree Search using only action sequences and the\n        # model learned by the networks.\n        run_mcts(config, root, game.action_history(), network)\n        action = select_action(config, len(game.history), root, network, mode=mode_action_select)\n        game.apply(action)\n        game.store_search_statistics(root)\n    return game\n"
  },
  {
    "path": "muzero/self_play/utils.py",
    "content": "\"\"\"Helpers for the MCTS\"\"\"\nfrom typing import Optional\n\nimport numpy as np\n\nMAXIMUM_FLOAT_VALUE = float('inf')\n\n\nclass MinMaxStats(object):\n    \"\"\"A class that holds the min-max values of the tree.\"\"\"\n\n    def __init__(self, known_bounds):\n        self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE\n        self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE\n\n    def update(self, value: float):\n        if value is None:\n            raise ValueError\n\n        self.maximum = max(self.maximum, value)\n        self.minimum = min(self.minimum, value)\n\n    def normalize(self, value: float) -> float:\n        # If the value is unknow, by default we set it to the minimum possible value\n        if value is None:\n            return 0.0\n\n        if self.maximum > self.minimum:\n            # We normalize only when we have set the maximum and minimum values.\n            return (value - self.minimum) / (self.maximum - self.minimum)\n        return value\n\n\nclass Node(object):\n    \"\"\"A class that represent nodes inside the MCTS tree\"\"\"\n\n    def __init__(self, prior: float):\n        self.visit_count = 0\n        self.to_play = -1\n        self.prior = prior\n        self.value_sum = 0\n        self.children = {}\n        self.hidden_state = None\n        self.reward = 0\n\n    def expanded(self) -> bool:\n        return len(self.children) > 0\n\n    def value(self) -> Optional[float]:\n        if self.visit_count == 0:\n            return None\n        return self.value_sum / self.visit_count\n\n\ndef softmax_sample(visit_counts, actions, t):\n    counts_exp = np.exp(visit_counts) * (1 / t)\n    probs = counts_exp / np.sum(counts_exp, axis=0)\n    action_idx = np.random.choice(len(actions), p=probs)\n    return actions[action_idx]\n"
  },
  {
    "path": "muzero/training/__init__.py",
    "content": ""
  },
  {
    "path": "muzero/training/replay_buffer.py",
    "content": "import random\nfrom itertools import zip_longest\nfrom typing import List\n\nfrom config import MuZeroConfig\nfrom game.game import AbstractGame\n\n\nclass ReplayBuffer(object):\n\n    def __init__(self, config: MuZeroConfig):\n        self.window_size = config.window_size\n        self.batch_size = config.batch_size\n        self.buffer = []\n\n    def save_game(self, game):\n        if len(self.buffer) > self.window_size:\n            self.buffer.pop(0)\n        self.buffer.append(game)\n\n    def sample_batch(self, num_unroll_steps: int, td_steps: int):\n        # Generate some sample of data to train on\n        games = self.sample_games()\n        game_pos = [(g, self.sample_position(g)) for g in games]\n        game_data = [(g.make_image(i), g.history[i:i + num_unroll_steps],\n                      g.make_target(i, num_unroll_steps, td_steps, g.to_play()))\n                     for (g, i) in game_pos]\n\n        # Pre-process the batch\n        image_batch, actions_time_batch, targets_batch = zip(*game_data)\n        targets_init_batch, *targets_time_batch = zip(*targets_batch)\n        actions_time_batch = list(zip_longest(*actions_time_batch, fillvalue=None))\n\n        # Building batch of valid actions and a dynamic mask for hidden representations during BPTT\n        mask_time_batch = []\n        dynamic_mask_time_batch = []\n        last_mask = [True] * len(image_batch)\n        for i, actions_batch in enumerate(actions_time_batch):\n            mask = list(map(lambda a: bool(a), actions_batch))\n            dynamic_mask = [now for last, now in zip(last_mask, mask) if last]\n            mask_time_batch.append(mask)\n            dynamic_mask_time_batch.append(dynamic_mask)\n            last_mask = mask\n            actions_time_batch[i] = [action.index for action in actions_batch if action]\n\n        batch = image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch\n        return batch\n\n    def sample_games(self) -> List[AbstractGame]:\n        # Sample game from buffer either uniformly or according to some priority.\n        return random.choices(self.buffer, k=self.batch_size)\n\n    def sample_position(self, game: AbstractGame) -> int:\n        # Sample position from game either uniformly or according to some priority.\n        return random.randint(0, len(game.history))\n"
  },
  {
    "path": "muzero/training/training.py",
    "content": "\"\"\"Training module: this is where MuZero neurons are trained.\"\"\"\n\nimport numpy as np\nimport tensorflow_core as tf\nfrom tensorflow_core.python.keras.losses import MSE\n\nfrom config import MuZeroConfig\nfrom networks.network import BaseNetwork\nfrom networks.shared_storage import SharedStorage\nfrom training.replay_buffer import ReplayBuffer\n\n\ndef train_network(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, epochs: int):\n    network = storage.current_network\n    optimizer = storage.optimizer\n\n    for _ in range(epochs):\n        batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)\n        update_weights(optimizer, network, batch)\n        storage.save_network(network.training_steps, network)\n\n\ndef update_weights(optimizer: tf.keras.optimizers, network: BaseNetwork, batch):\n    def scale_gradient(tensor, scale: float):\n        \"\"\"Trick function to scale the gradient in tensorflow\"\"\"\n        return (1. - scale) * tf.stop_gradient(tensor) + scale * tensor\n\n    def loss():\n        loss = 0\n        image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch = batch\n\n        # Initial step, from the real observation: representation + prediction networks\n        representation_batch, value_batch, policy_batch = network.initial_model(np.array(image_batch))\n\n        # Only update the element with a policy target\n        target_value_batch, _, target_policy_batch = zip(*targets_init_batch)\n        mask_policy = list(map(lambda l: bool(l), target_policy_batch))\n        target_policy_batch = list(filter(lambda l: bool(l), target_policy_batch))\n        policy_batch = tf.boolean_mask(policy_batch, mask_policy)\n\n        # Compute the loss of the first pass\n        loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))\n        loss += tf.math.reduce_mean(\n            tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch))\n\n        # Recurrent steps, from action and previous hidden state.\n        for actions_batch, targets_batch, mask, dynamic_mask in zip(actions_time_batch, targets_time_batch,\n                                                                    mask_time_batch, dynamic_mask_time_batch):\n            target_value_batch, target_reward_batch, target_policy_batch = zip(*targets_batch)\n\n            # Only execute BPTT for elements with an action\n            representation_batch = tf.boolean_mask(representation_batch, dynamic_mask)\n            target_value_batch = tf.boolean_mask(target_value_batch, mask)\n            target_reward_batch = tf.boolean_mask(target_reward_batch, mask)\n            # Creating conditioned_representation: concatenate representations with actions batch\n            actions_batch = tf.one_hot(actions_batch, network.action_size)\n\n            # Recurrent step from conditioned representation: recurrent + prediction networks\n            conditioned_representation_batch = tf.concat((representation_batch, actions_batch), axis=1)\n            representation_batch, reward_batch, value_batch, policy_batch = network.recurrent_model(\n                conditioned_representation_batch)\n\n            # Only execute BPTT for elements with a policy target\n            target_policy_batch = [policy for policy, b in zip(target_policy_batch, mask) if b]\n            mask_policy = list(map(lambda l: bool(l), target_policy_batch))\n            target_policy_batch = tf.convert_to_tensor([policy for policy in target_policy_batch if policy])\n            policy_batch = tf.boolean_mask(policy_batch, mask_policy)\n\n            # Compute the partial loss\n            l = (tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) +\n                 MSE(target_reward_batch, tf.squeeze(reward_batch)) +\n                 tf.math.reduce_mean(\n                     tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch)))\n\n            # Scale the gradient of the loss by the average number of actions unrolled\n            gradient_scale = 1. / len(actions_time_batch)\n            loss += scale_gradient(l, gradient_scale)\n\n            # Half the gradient of the representation\n            representation_batch = scale_gradient(representation_batch, 0.5)\n\n        return loss\n\n    optimizer.minimize(loss=loss, var_list=network.cb_get_variables())\n    network.training_steps += 1\n\n\ndef loss_value(target_value_batch, value_batch, value_support_size: int):\n    batch_size = len(target_value_batch)\n    targets = np.zeros((batch_size, value_support_size))\n    sqrt_value = np.sqrt(target_value_batch)\n    floor_value = np.floor(sqrt_value).astype(int)\n    rest = sqrt_value - floor_value\n    targets[range(batch_size), floor_value.astype(int)] = 1 - rest\n    targets[range(batch_size), floor_value.astype(int) + 1] = rest\n\n    return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)\n"
  }
]