Repository: johan-gras/MuZero
Branch: master
Commit: 4f53f0c3e6b8
Files: 20
Total size: 44.4 KB
Directory structure:
gitextract_mhh1v2wx/
├── .gitignore
├── README.rst
└── muzero/
├── __init__.py
├── config.py
├── game/
│ ├── __init__.py
│ ├── cartpole.py
│ ├── game.py
│ └── gym_wrappers.py
├── muzero.py
├── networks/
│ ├── __init__.py
│ ├── cartpole_network.py
│ ├── network.py
│ └── shared_storage.py
├── self_play/
│ ├── __init__.py
│ ├── mcts.py
│ ├── self_play.py
│ └── utils.py
└── training/
├── __init__.py
├── replay_buffer.py
└── training.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Pythond
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# Pycharm
.idea/
================================================
FILE: README.rst
================================================
.. |copy| unicode:: 0xA9
.. |---| unicode:: U+02014
======
MuZero
======
This repository is a Python implementation of the MuZero algorithm.
It is based upon the `pre-print paper`__ and the `pseudocode`__ describing the Muzero framework.
Neural computations are implemented with Tensorflow.
You can easily train your own MuZero, more specifically for one player and non-image based environments (such as `CartPole`__).
If you wish to train Muzero on other kinds of environments, this codebase can be used with slight modifications.
__ https://arxiv.org/abs/1911.08265
__ https://arxiv.org/src/1911.08265v1/anc/pseudocode.py
__ https://gym.openai.com/envs/CartPole-v1/
**DISCLAIMER**: this code is early research code. What this means is:
- Silent bugs may exist.
- It may not work reliably on other environments or with other hyper-parameters.
- The code quality and documentation are quite lacking, and much of the code might still feel "in-progress".
- The training and testing pipeline is not very advanced.
Dependencies
============
We run this code using:
- Conda **4.7.12**
- Python **3.7**
- Tensorflow **2.0.0**
- Numpy **1.17.3**
Training your MuZero
====================
This code must be run from the main function in ``muzero.py`` (don't forget to first configure your conda environment).
Training a Cartpole-v1 bot
--------------------------
To train a model, please follow these steps:
1) Create or modify an existing configuration of Muzero in ``config.py``.
2) Call the right configuration inside the main of ``muzero.py``.
3) Run the main function: ``python muzero.py``.
Training on an other environment
--------------------------------
To train on a different environment than Cartpole-v1, please follow these additional steps:
1) Create a class that extends ``AbstractGame``, this class should implement the behavior of your environment.
For instance, the ``CartPole`` class extends ``AbstractGame`` and works as a wrapper upon `gym CartPole-v1`__.
You can use the ``CartPole`` class as a template for any gym environment.
__ https://gym.openai.com/envs/CartPole-v1/
2) **This step is optional** (only if you want to use a different kind of network architecture or value/reward transform).
Create a class that extends ``BaseNetwork``, this class should implement the different networks (representation, value, policy, reward and dynamic) and value/reward transforms.
For instance, the ``CartPoleNetwork`` class extends ``BaseNetwork`` and implements fully connected networks.
3) **This step is optional** (only if you use a different value/reward transform).
You should implement the corresponding inverse value/reward transform by modifying the ``loss_value`` and ``loss_reward`` function inside ``training.py``.
Differences from the paper
==========================
This implementation differ from the original paper in the following manners:
- We use fully connected layers instead of convolutional ones. This is due to the nature of our environment (Cartpole-v1) which as no spatial correlation in the observation vector.
- We don't scale the hidden state between 0 and 1 using min-max normalization. Instead we use a tanh function that maps any values in a range between -1 and 1.
- We do use a slightly simple invertible transform for the value prediction by removing the linear term.
- During training, samples are drawn from a uniform distribution instead of using prioritized replay.
- We also scale the loss of each head by 1/K (with K the number of unrolled steps). But, instead we consider that K is always constant (even if it is not always true).
================================================
FILE: muzero/__init__.py
================================================
================================================
FILE: muzero/config.py
================================================
import collections
from typing import Optional, Dict
import tensorflow_core as tf
from game.cartpole import CartPole
from game.game import AbstractGame
from networks.cartpole_network import CartPoleNetwork
from networks.network import BaseNetwork, UniformNetwork
KnownBounds = collections.namedtuple('KnownBounds', ['min', 'max'])
class MuZeroConfig(object):
def __init__(self,
game,
nb_training_loop: int,
nb_episodes: int,
nb_epochs: int,
network_args: Dict,
network,
action_space_size: int,
max_moves: int,
discount: float,
dirichlet_alpha: float,
num_simulations: int,
batch_size: int,
td_steps: int,
visit_softmax_temperature_fn,
lr: float,
known_bounds: Optional[KnownBounds] = None):
### Environment
self.game = game
### Self-Play
self.action_space_size = action_space_size
# self.num_actors = num_actors
self.visit_softmax_temperature_fn = visit_softmax_temperature_fn
self.max_moves = max_moves
self.num_simulations = num_simulations
self.discount = discount
# Root prior exploration noise.
self.root_dirichlet_alpha = dirichlet_alpha
self.root_exploration_fraction = 0.25
# UCB formula
self.pb_c_base = 19652
self.pb_c_init = 1.25
# If we already have some information about which values occur in the
# environment, we can use them to initialize the rescaling.
# This is not strictly necessary, but establishes identical behaviour to
# AlphaZero in board games.
self.known_bounds = known_bounds
### Training
self.nb_training_loop = nb_training_loop
self.nb_episodes = nb_episodes # Nb of episodes per training loop
self.nb_epochs = nb_epochs # Nb of epochs per training loop
# self.training_steps = int(1000e3)
# self.checkpoint_interval = int(1e3)
self.window_size = int(1e6)
self.batch_size = batch_size
self.num_unroll_steps = 5
self.td_steps = td_steps
self.weight_decay = 1e-4
self.momentum = 0.9
self.network_args = network_args
self.network = network
self.lr = lr
# Exponential learning rate schedule
# self.lr_init = lr_init
# self.lr_decay_rate = 0.1
# self.lr_decay_steps = lr_decay_steps
def new_game(self) -> AbstractGame:
return self.game(self.discount)
def new_network(self) -> BaseNetwork:
return self.network(**self.network_args)
def uniform_network(self) -> UniformNetwork:
return UniformNetwork(self.action_space_size)
def new_optimizer(self) -> tf.keras.optimizers:
return tf.keras.optimizers.SGD(learning_rate=self.lr, momentum=self.momentum)
def make_cartpole_config() -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
return 1.0
return MuZeroConfig(
game=CartPole,
nb_training_loop=50,
nb_episodes=20,
nb_epochs=20,
network_args={'action_size': 2,
'state_size': 4,
'representation_size': 4,
'max_value': 500},
network=CartPoleNetwork,
action_space_size=2,
max_moves=1000,
discount=0.99,
dirichlet_alpha=0.25,
num_simulations=11, # Odd number perform better in eval mode
batch_size=512,
td_steps=10,
visit_softmax_temperature_fn=visit_softmax_temperature,
lr=0.05)
"""
Legacy configs from the DeepMind's pseudocode.
def make_board_game_config(action_space_size: int, max_moves: int,
dirichlet_alpha: float,
lr_init: float) -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
if num_moves < 30:
return 1.0
else:
return 0.0 # Play according to the max.
return MuZeroConfig(
action_space_size=action_space_size,
max_moves=max_moves,
discount=1.0,
dirichlet_alpha=dirichlet_alpha,
num_simulations=800,
batch_size=2048,
td_steps=max_moves, # Always use Monte Carlo return.
num_actors=3000,
lr_init=lr_init,
lr_decay_steps=400e3,
visit_softmax_temperature_fn=visit_softmax_temperature,
known_bounds=KnownBounds(-1, 1))
def make_go_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=362, max_moves=722, dirichlet_alpha=0.03, lr_init=0.01)
def make_chess_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=4672, max_moves=512, dirichlet_alpha=0.3, lr_init=0.1)
def make_shogi_config() -> MuZeroConfig:
return make_board_game_config(
action_space_size=11259, max_moves=512, dirichlet_alpha=0.15, lr_init=0.1)
def make_atari_config() -> MuZeroConfig:
def visit_softmax_temperature(num_moves, training_steps):
if training_steps < 500e3:
return 1.0
elif training_steps < 750e3:
return 0.5
else:
return 0.25
return MuZeroConfig(
action_space_size=18,
max_moves=27000, # Half an hour at action repeat 4.
discount=0.997,
dirichlet_alpha=0.25,
num_simulations=50,
batch_size=1024,
td_steps=10,
num_actors=350,
lr_init=0.05,
lr_decay_steps=350e3,
visit_softmax_temperature_fn=visit_softmax_temperature)
"""
================================================
FILE: muzero/game/__init__.py
================================================
================================================
FILE: muzero/game/cartpole.py
================================================
from typing import List
import gym
from game.game import Action, AbstractGame
from game.gym_wrappers import ScalingObservationWrapper
class CartPole(AbstractGame):
"""The Gym CartPole environment"""
def __init__(self, discount: float):
super().__init__(discount)
self.env = gym.make('CartPole-v1')
self.env = ScalingObservationWrapper(self.env, low=[-2.4, -2.0, -0.42, -3.5], high=[2.4, 2.0, 0.42, 3.5])
self.actions = list(map(lambda i: Action(i), range(self.env.action_space.n)))
self.observations = [self.env.reset()]
self.done = False
@property
def action_space_size(self) -> int:
"""Return the size of the action space."""
return len(self.actions)
def step(self, action) -> int:
"""Execute one step of the game conditioned by the given action."""
observation, reward, done, _ = self.env.step(action.index)
self.observations += [observation]
self.done = done
return reward
def terminal(self) -> bool:
"""Is the game is finished?"""
return self.done
def legal_actions(self) -> List[Action]:
"""Return the legal actions available at this instant."""
return self.actions
def make_image(self, state_index: int):
"""Compute the state of the game."""
return self.observations[state_index]
================================================
FILE: muzero/game/game.py
================================================
from abc import abstractmethod, ABC
from typing import List
from self_play.utils import Node
class Action(object):
""" Class that represent an action of a game."""
def __init__(self, index: int):
self.index = index
def __hash__(self):
return self.index
def __eq__(self, other):
return self.index == other.index
def __gt__(self, other):
return self.index > other.index
class Player(object):
"""
A one player class.
This class is useless, it's here for legacy purpose and for potential adaptations for a two players MuZero.
"""
def __eq__(self, other):
return True
class ActionHistory(object):
"""
Simple history container used inside the search.
Only used to keep track of the actions executed.
"""
def __init__(self, history: List[Action], action_space_size: int):
self.history = list(history)
self.action_space_size = action_space_size
def clone(self):
return ActionHistory(self.history, self.action_space_size)
def add_action(self, action: Action):
self.history.append(action)
def last_action(self) -> Action:
return self.history[-1]
def action_space(self) -> List[Action]:
return [Action(i) for i in range(self.action_space_size)]
def to_play(self) -> Player:
return Player()
class AbstractGame(ABC):
"""
Abstract class that allows to implement a game.
One instance represent a single episode of interaction with the environment.
"""
def __init__(self, discount: float):
self.history = []
self.rewards = []
self.child_visits = []
self.root_values = []
self.discount = discount
def apply(self, action: Action):
"""Apply an action onto the environment."""
reward = self.step(action)
self.rewards.append(reward)
self.history.append(action)
def store_search_statistics(self, root: Node):
"""After each MCTS run, store the statistics generated by the search."""
sum_visits = sum(child.visit_count for child in root.children.values())
action_space = (Action(index) for index in range(self.action_space_size))
self.child_visits.append([
root.children[a].visit_count / sum_visits if a in root.children else 0
for a in action_space
])
self.root_values.append(root.value())
def make_target(self, state_index: int, num_unroll_steps: int, td_steps: int, to_play: Player):
"""Generate targets to learn from during the network training."""
# The value target is the discounted root value of the search tree N steps
# into the future, plus the discounted sum of all rewards until then.
targets = []
for current_index in range(state_index, state_index + num_unroll_steps + 1):
bootstrap_index = current_index + td_steps
if bootstrap_index < len(self.root_values):
value = self.root_values[bootstrap_index] * self.discount ** td_steps
else:
value = 0
for i, reward in enumerate(self.rewards[current_index:bootstrap_index]):
value += reward * self.discount ** i
if current_index < len(self.root_values):
targets.append((value, self.rewards[current_index], self.child_visits[current_index]))
else:
# States past the end of games are treated as absorbing states.
targets.append((0, 0, []))
return targets
def to_play(self) -> Player:
"""Return the current player."""
return Player()
def action_history(self) -> ActionHistory:
"""Return the actions executed inside the search."""
return ActionHistory(self.history, self.action_space_size)
# Methods to be implemented by the children class
@property
@abstractmethod
def action_space_size(self) -> int:
"""Return the size of the action space."""
pass
@abstractmethod
def step(self, action) -> int:
"""Execute one step of the game conditioned by the given action."""
pass
@abstractmethod
def terminal(self) -> bool:
"""Is the game is finished?"""
pass
@abstractmethod
def legal_actions(self) -> List[Action]:
"""Return the legal actions available at this instant."""
pass
@abstractmethod
def make_image(self, state_index: int):
"""Compute the state of the game."""
pass
================================================
FILE: muzero/game/gym_wrappers.py
================================================
import gym
import numpy as np
class ScalingObservationWrapper(gym.ObservationWrapper):
"""
Wrapper that apply a min-max scaling of observations.
"""
def __init__(self, env, low=None, high=None):
super().__init__(env)
assert isinstance(env.observation_space, gym.spaces.Box)
low = np.array(self.observation_space.low if low is None else low)
high = np.array(self.observation_space.high if high is None else high)
self.mean = (high + low) / 2
self.max = high - self.mean
def observation(self, observation):
return (observation - self.mean) / self.max
================================================
FILE: muzero/muzero.py
================================================
from config import MuZeroConfig, make_cartpole_config
from networks.shared_storage import SharedStorage
from self_play.self_play import run_selfplay, run_eval
from training.replay_buffer import ReplayBuffer
from training.training import train_network
def muzero(config: MuZeroConfig):
"""
MuZero training is split into two independent parts: Network training and
self-play data generation.
These two parts only communicate by transferring the latest networks checkpoint
from the training to the self-play, and the finished games from the self-play
to the training.
In contrast to the original MuZero algorithm this version doesn't works with
multiple threads, therefore the training and self-play is done alternately.
"""
storage = SharedStorage(config.new_network(), config.uniform_network(), config.new_optimizer())
replay_buffer = ReplayBuffer(config)
for loop in range(config.nb_training_loop):
print("Training loop", loop)
score_train = run_selfplay(config, storage, replay_buffer, config.nb_episodes)
train_network(config, storage, replay_buffer, config.nb_epochs)
print("Train score:", score_train)
print("Eval score:", run_eval(config, storage, 50))
print(f"MuZero played {config.nb_episodes * (loop + 1)} "
f"episodes and trained for {config.nb_epochs * (loop + 1)} epochs.\n")
return storage.latest_network()
if __name__ == '__main__':
config = make_cartpole_config()
muzero(config)
================================================
FILE: muzero/networks/__init__.py
================================================
================================================
FILE: muzero/networks/cartpole_network.py
================================================
import math
import numpy as np
from tensorflow_core.python.keras import regularizers
from tensorflow_core.python.keras.layers.core import Dense
from tensorflow_core.python.keras.models import Sequential
from game.game import Action
from networks.network import BaseNetwork
class CartPoleNetwork(BaseNetwork):
def __init__(self,
state_size: int,
action_size: int,
representation_size: int,
max_value: int,
hidden_neurons: int = 64,
weight_decay: float = 1e-4,
representation_activation: str = 'tanh'):
self.state_size = state_size
self.action_size = action_size
self.value_support_size = math.ceil(math.sqrt(max_value)) + 1
regularizer = regularizers.l2(weight_decay)
representation_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),
Dense(representation_size, activation=representation_activation,
kernel_regularizer=regularizer)])
value_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),
Dense(self.value_support_size, kernel_regularizer=regularizer)])
policy_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),
Dense(action_size, kernel_regularizer=regularizer)])
dynamic_network = Sequential([Dense(hidden_neurons, activation='relu', kernel_regularizer=regularizer),
Dense(representation_size, activation=representation_activation,
kernel_regularizer=regularizer)])
reward_network = Sequential([Dense(16, activation='relu', kernel_regularizer=regularizer),
Dense(1, kernel_regularizer=regularizer)])
super().__init__(representation_network, value_network, policy_network, dynamic_network, reward_network)
def _value_transform(self, value_support: np.array) -> float:
"""
The value is obtained by first computing the expected value from the discrete support.
Second, the inverse transform is then apply (the square function).
"""
value = self._softmax(value_support)
value = np.dot(value, range(self.value_support_size))
value = np.asscalar(value) ** 2
return value
def _reward_transform(self, reward: np.array) -> float:
return np.asscalar(reward)
def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array:
conditioned_hidden = np.concatenate((hidden_state, np.eye(self.action_size)[action.index]))
return np.expand_dims(conditioned_hidden, axis=0)
def _softmax(self, values):
"""Compute softmax using numerical stability tricks."""
values_exp = np.exp(values - np.max(values))
return values_exp / np.sum(values_exp)
================================================
FILE: muzero/networks/network.py
================================================
import typing
from abc import ABC, abstractmethod
from typing import Dict, List, Callable
import numpy as np
from tensorflow_core.python.keras.models import Model
from game.game import Action
class NetworkOutput(typing.NamedTuple):
value: float
reward: float
policy_logits: Dict[Action, float]
hidden_state: typing.Optional[List[float]]
@staticmethod
def build_policy_logits(policy_logits):
return {Action(i): logit for i, logit in enumerate(policy_logits[0])}
class AbstractNetwork(ABC):
def __init__(self):
self.training_steps = 0
@abstractmethod
def initial_inference(self, image) -> NetworkOutput:
pass
@abstractmethod
def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
pass
class UniformNetwork(AbstractNetwork):
"""policy -> uniform, value -> 0, reward -> 0"""
def __init__(self, action_size: int):
super().__init__()
self.action_size = action_size
def initial_inference(self, image) -> NetworkOutput:
return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None)
def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
return NetworkOutput(0, 0, {Action(i): 1 / self.action_size for i in range(self.action_size)}, None)
class InitialModel(Model):
"""Model that combine the representation and prediction (value+policy) network."""
def __init__(self, representation_network: Model, value_network: Model, policy_network: Model):
super(InitialModel, self).__init__()
self.representation_network = representation_network
self.value_network = value_network
self.policy_network = policy_network
def call(self, image):
hidden_representation = self.representation_network(image)
value = self.value_network(hidden_representation)
policy_logits = self.policy_network(hidden_representation)
return hidden_representation, value, policy_logits
class RecurrentModel(Model):
"""Model that combine the dynamic, reward and prediction (value+policy) network."""
def __init__(self, dynamic_network: Model, reward_network: Model, value_network: Model, policy_network: Model):
super(RecurrentModel, self).__init__()
self.dynamic_network = dynamic_network
self.reward_network = reward_network
self.value_network = value_network
self.policy_network = policy_network
def call(self, conditioned_hidden):
hidden_representation = self.dynamic_network(conditioned_hidden)
reward = self.reward_network(conditioned_hidden)
value = self.value_network(hidden_representation)
policy_logits = self.policy_network(hidden_representation)
return hidden_representation, reward, value, policy_logits
class BaseNetwork(AbstractNetwork):
"""Base class that contains all the networks and models of MuZero."""
def __init__(self, representation_network: Model, value_network: Model, policy_network: Model,
dynamic_network: Model, reward_network: Model):
super().__init__()
# Networks blocks
self.representation_network = representation_network
self.value_network = value_network
self.policy_network = policy_network
self.dynamic_network = dynamic_network
self.reward_network = reward_network
# Models for inference and training
self.initial_model = InitialModel(self.representation_network, self.value_network, self.policy_network)
self.recurrent_model = RecurrentModel(self.dynamic_network, self.reward_network, self.value_network,
self.policy_network)
def initial_inference(self, image: np.array) -> NetworkOutput:
"""representation + prediction function"""
hidden_representation, value, policy_logits = self.initial_model.predict(np.expand_dims(image, 0))
output = NetworkOutput(value=self._value_transform(value),
reward=0.,
policy_logits=NetworkOutput.build_policy_logits(policy_logits),
hidden_state=hidden_representation[0])
return output
def recurrent_inference(self, hidden_state: np.array, action: Action) -> NetworkOutput:
"""dynamics + prediction function"""
conditioned_hidden = self._conditioned_hidden_state(hidden_state, action)
hidden_representation, reward, value, policy_logits = self.recurrent_model.predict(conditioned_hidden)
output = NetworkOutput(value=self._value_transform(value),
reward=self._reward_transform(reward),
policy_logits=NetworkOutput.build_policy_logits(policy_logits),
hidden_state=hidden_representation[0])
return output
@abstractmethod
def _value_transform(self, value: np.array) -> float:
pass
@abstractmethod
def _reward_transform(self, reward: np.array) -> float:
pass
@abstractmethod
def _conditioned_hidden_state(self, hidden_state: np.array, action: Action) -> np.array:
pass
def cb_get_variables(self) -> Callable:
"""Return a callback that return the trainable variables of the network."""
def get_variables():
networks = (self.representation_network, self.value_network, self.policy_network,
self.dynamic_network, self.reward_network)
return [variables
for variables_list in map(lambda n: n.weights, networks)
for variables in variables_list]
return get_variables
================================================
FILE: muzero/networks/shared_storage.py
================================================
import tensorflow_core as tf
from networks.network import BaseNetwork, UniformNetwork, AbstractNetwork
class SharedStorage(object):
"""Save the different versions of the network."""
def __init__(self, network: BaseNetwork, uniform_network: UniformNetwork, optimizer: tf.keras.optimizers):
self._networks = {}
self.current_network = network
self.uniform_network = uniform_network
self.optimizer = optimizer
def latest_network(self) -> AbstractNetwork:
if self._networks:
return self._networks[max(self._networks.keys())]
else:
# policy -> uniform, value -> 0, reward -> 0
return self.uniform_network
def save_network(self, step: int, network: BaseNetwork):
self._networks[step] = network
================================================
FILE: muzero/self_play/__init__.py
================================================
================================================
FILE: muzero/self_play/mcts.py
================================================
"""MCTS module: where MuZero thinks inside the tree."""
import math
import random
from typing import List
import numpy
from config import MuZeroConfig
from game.game import Player, Action, ActionHistory
from networks.network import NetworkOutput, BaseNetwork
from self_play.utils import MinMaxStats, Node, softmax_sample
def add_exploration_noise(config: MuZeroConfig, node: Node):
"""
At the start of each search, we add dirichlet noise to the prior of the root
to encourage the search to explore new actions.
"""
actions = list(node.children.keys())
noise = numpy.random.dirichlet([config.root_dirichlet_alpha] * len(actions))
frac = config.root_exploration_fraction
for a, n in zip(actions, noise):
node.children[a].prior = node.children[a].prior * (1 - frac) + n * frac
def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHistory, network: BaseNetwork):
"""
Core Monte Carlo Tree Search algorithm.
To decide on an action, we run N simulations, always starting at the root of
the search tree and traversing the tree according to the UCB formula until we
reach a leaf node.
"""
min_max_stats = MinMaxStats(config.known_bounds)
for _ in range(config.num_simulations):
history = action_history.clone()
node = root
search_path = [node]
while node.expanded():
action, node = select_child(config, node, min_max_stats)
history.add_action(action)
search_path.append(node)
# Inside the search tree we use the dynamics function to obtain the next
# hidden state given an action and the previous hidden state.
parent = search_path[-2]
network_output = network.recurrent_inference(parent.hidden_state, history.last_action())
expand_node(node, history.to_play(), history.action_space(), network_output)
backpropagate(search_path, network_output.value, history.to_play(), config.discount, min_max_stats)
def select_child(config: MuZeroConfig, node: Node, min_max_stats: MinMaxStats):
"""
Select the child with the highest UCB score.
"""
# When the parent visit count is zero, all ucb scores are zeros, therefore we return a random child
if node.visit_count == 0:
return random.sample(node.children.items(), 1)[0]
_, action, child = max(
(ucb_score(config, node, child, min_max_stats), action,
child) for action, child in node.children.items())
return action, child
def ucb_score(config: MuZeroConfig, parent: Node, child: Node,
min_max_stats: MinMaxStats) -> float:
"""
The score for a node is based on its value, plus an exploration bonus based on
the prior.
"""
pb_c = math.log((parent.visit_count + config.pb_c_base + 1) / config.pb_c_base) + config.pb_c_init
pb_c *= math.sqrt(parent.visit_count) / (child.visit_count + 1)
prior_score = pb_c * child.prior
value_score = min_max_stats.normalize(child.value())
return prior_score + value_score
def expand_node(node: Node, to_play: Player, actions: List[Action],
network_output: NetworkOutput):
"""
We expand a node using the value, reward and policy prediction obtained from
the neural networks.
"""
node.to_play = to_play
node.hidden_state = network_output.hidden_state
node.reward = network_output.reward
policy = {a: math.exp(network_output.policy_logits[a]) for a in actions}
policy_sum = sum(policy.values())
for action, p in policy.items():
node.children[action] = Node(p / policy_sum)
def backpropagate(search_path: List[Node], value: float, to_play: Player,
discount: float, min_max_stats: MinMaxStats):
"""
At the end of a simulation, we propagate the evaluation all the way up the
tree to the root.
"""
for node in search_path[::-1]:
node.value_sum += value if node.to_play == to_play else -value
node.visit_count += 1
min_max_stats.update(node.value())
value = node.reward + discount * value
def select_action(config: MuZeroConfig, num_moves: int, node: Node, network: BaseNetwork, mode: str = 'softmax'):
"""
After running simulations inside in MCTS, we select an action based on the root's children visit counts.
During training we use a softmax sample for exploration.
During evaluation we select the most visited child.
"""
visit_counts = [child.visit_count for child in node.children.values()]
actions = [action for action in node.children.keys()]
action = None
if mode == 'softmax':
t = config.visit_softmax_temperature_fn(
num_moves=num_moves, training_steps=network.training_steps)
action = softmax_sample(visit_counts, actions, t)
elif mode == 'max':
action, _ = max(node.children.items(), key=lambda item: item[1].visit_count)
return action
================================================
FILE: muzero/self_play/self_play.py
================================================
"""Self-Play module: where the games are played."""
from config import MuZeroConfig
from game.game import AbstractGame
from networks.network import AbstractNetwork
from networks.shared_storage import SharedStorage
from self_play.mcts import run_mcts, select_action, expand_node, add_exploration_noise
from self_play.utils import Node
from training.replay_buffer import ReplayBuffer
def run_selfplay(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, train_episodes: int):
"""Take the latest network, produces multiple games and save them in the shared replay buffer"""
network = storage.latest_network()
returns = []
for _ in range(train_episodes):
game = play_game(config, network)
replay_buffer.save_game(game)
returns.append(sum(game.rewards))
return sum(returns) / train_episodes
def run_eval(config: MuZeroConfig, storage: SharedStorage, eval_episodes: int):
"""Evaluate MuZero without noise added to the prior of the root and without softmax action selection"""
network = storage.latest_network()
returns = []
for _ in range(eval_episodes):
game = play_game(config, network, train=False)
returns.append(sum(game.rewards))
return sum(returns) / eval_episodes if eval_episodes else 0
def play_game(config: MuZeroConfig, network: AbstractNetwork, train: bool = True) -> AbstractGame:
"""
Each game is produced by starting at the initial board position, then
repeatedly executing a Monte Carlo Tree Search to generate moves until the end
of the game is reached.
"""
game = config.new_game()
mode_action_select = 'softmax' if train else 'max'
while not game.terminal() and len(game.history) < config.max_moves:
# At the root of the search tree we use the representation function to
# obtain a hidden state given the current observation.
root = Node(0)
current_observation = game.make_image(-1)
expand_node(root, game.to_play(), game.legal_actions(), network.initial_inference(current_observation))
if train:
add_exploration_noise(config, root)
# We then run a Monte Carlo Tree Search using only action sequences and the
# model learned by the networks.
run_mcts(config, root, game.action_history(), network)
action = select_action(config, len(game.history), root, network, mode=mode_action_select)
game.apply(action)
game.store_search_statistics(root)
return game
================================================
FILE: muzero/self_play/utils.py
================================================
"""Helpers for the MCTS"""
from typing import Optional
import numpy as np
MAXIMUM_FLOAT_VALUE = float('inf')
class MinMaxStats(object):
"""A class that holds the min-max values of the tree."""
def __init__(self, known_bounds):
self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE
self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE
def update(self, value: float):
if value is None:
raise ValueError
self.maximum = max(self.maximum, value)
self.minimum = min(self.minimum, value)
def normalize(self, value: float) -> float:
# If the value is unknow, by default we set it to the minimum possible value
if value is None:
return 0.0
if self.maximum > self.minimum:
# We normalize only when we have set the maximum and minimum values.
return (value - self.minimum) / (self.maximum - self.minimum)
return value
class Node(object):
"""A class that represent nodes inside the MCTS tree"""
def __init__(self, prior: float):
self.visit_count = 0
self.to_play = -1
self.prior = prior
self.value_sum = 0
self.children = {}
self.hidden_state = None
self.reward = 0
def expanded(self) -> bool:
return len(self.children) > 0
def value(self) -> Optional[float]:
if self.visit_count == 0:
return None
return self.value_sum / self.visit_count
def softmax_sample(visit_counts, actions, t):
counts_exp = np.exp(visit_counts) * (1 / t)
probs = counts_exp / np.sum(counts_exp, axis=0)
action_idx = np.random.choice(len(actions), p=probs)
return actions[action_idx]
================================================
FILE: muzero/training/__init__.py
================================================
================================================
FILE: muzero/training/replay_buffer.py
================================================
import random
from itertools import zip_longest
from typing import List
from config import MuZeroConfig
from game.game import AbstractGame
class ReplayBuffer(object):
def __init__(self, config: MuZeroConfig):
self.window_size = config.window_size
self.batch_size = config.batch_size
self.buffer = []
def save_game(self, game):
if len(self.buffer) > self.window_size:
self.buffer.pop(0)
self.buffer.append(game)
def sample_batch(self, num_unroll_steps: int, td_steps: int):
# Generate some sample of data to train on
games = self.sample_games()
game_pos = [(g, self.sample_position(g)) for g in games]
game_data = [(g.make_image(i), g.history[i:i + num_unroll_steps],
g.make_target(i, num_unroll_steps, td_steps, g.to_play()))
for (g, i) in game_pos]
# Pre-process the batch
image_batch, actions_time_batch, targets_batch = zip(*game_data)
targets_init_batch, *targets_time_batch = zip(*targets_batch)
actions_time_batch = list(zip_longest(*actions_time_batch, fillvalue=None))
# Building batch of valid actions and a dynamic mask for hidden representations during BPTT
mask_time_batch = []
dynamic_mask_time_batch = []
last_mask = [True] * len(image_batch)
for i, actions_batch in enumerate(actions_time_batch):
mask = list(map(lambda a: bool(a), actions_batch))
dynamic_mask = [now for last, now in zip(last_mask, mask) if last]
mask_time_batch.append(mask)
dynamic_mask_time_batch.append(dynamic_mask)
last_mask = mask
actions_time_batch[i] = [action.index for action in actions_batch if action]
batch = image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch
return batch
def sample_games(self) -> List[AbstractGame]:
# Sample game from buffer either uniformly or according to some priority.
return random.choices(self.buffer, k=self.batch_size)
def sample_position(self, game: AbstractGame) -> int:
# Sample position from game either uniformly or according to some priority.
return random.randint(0, len(game.history))
================================================
FILE: muzero/training/training.py
================================================
"""Training module: this is where MuZero neurons are trained."""
import numpy as np
import tensorflow_core as tf
from tensorflow_core.python.keras.losses import MSE
from config import MuZeroConfig
from networks.network import BaseNetwork
from networks.shared_storage import SharedStorage
from training.replay_buffer import ReplayBuffer
def train_network(config: MuZeroConfig, storage: SharedStorage, replay_buffer: ReplayBuffer, epochs: int):
network = storage.current_network
optimizer = storage.optimizer
for _ in range(epochs):
batch = replay_buffer.sample_batch(config.num_unroll_steps, config.td_steps)
update_weights(optimizer, network, batch)
storage.save_network(network.training_steps, network)
def update_weights(optimizer: tf.keras.optimizers, network: BaseNetwork, batch):
def scale_gradient(tensor, scale: float):
"""Trick function to scale the gradient in tensorflow"""
return (1. - scale) * tf.stop_gradient(tensor) + scale * tensor
def loss():
loss = 0
image_batch, targets_init_batch, targets_time_batch, actions_time_batch, mask_time_batch, dynamic_mask_time_batch = batch
# Initial step, from the real observation: representation + prediction networks
representation_batch, value_batch, policy_batch = network.initial_model(np.array(image_batch))
# Only update the element with a policy target
target_value_batch, _, target_policy_batch = zip(*targets_init_batch)
mask_policy = list(map(lambda l: bool(l), target_policy_batch))
target_policy_batch = list(filter(lambda l: bool(l), target_policy_batch))
policy_batch = tf.boolean_mask(policy_batch, mask_policy)
# Compute the loss of the first pass
loss += tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size))
loss += tf.math.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch))
# Recurrent steps, from action and previous hidden state.
for actions_batch, targets_batch, mask, dynamic_mask in zip(actions_time_batch, targets_time_batch,
mask_time_batch, dynamic_mask_time_batch):
target_value_batch, target_reward_batch, target_policy_batch = zip(*targets_batch)
# Only execute BPTT for elements with an action
representation_batch = tf.boolean_mask(representation_batch, dynamic_mask)
target_value_batch = tf.boolean_mask(target_value_batch, mask)
target_reward_batch = tf.boolean_mask(target_reward_batch, mask)
# Creating conditioned_representation: concatenate representations with actions batch
actions_batch = tf.one_hot(actions_batch, network.action_size)
# Recurrent step from conditioned representation: recurrent + prediction networks
conditioned_representation_batch = tf.concat((representation_batch, actions_batch), axis=1)
representation_batch, reward_batch, value_batch, policy_batch = network.recurrent_model(
conditioned_representation_batch)
# Only execute BPTT for elements with a policy target
target_policy_batch = [policy for policy, b in zip(target_policy_batch, mask) if b]
mask_policy = list(map(lambda l: bool(l), target_policy_batch))
target_policy_batch = tf.convert_to_tensor([policy for policy in target_policy_batch if policy])
policy_batch = tf.boolean_mask(policy_batch, mask_policy)
# Compute the partial loss
l = (tf.math.reduce_mean(loss_value(target_value_batch, value_batch, network.value_support_size)) +
MSE(target_reward_batch, tf.squeeze(reward_batch)) +
tf.math.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=policy_batch, labels=target_policy_batch)))
# Scale the gradient of the loss by the average number of actions unrolled
gradient_scale = 1. / len(actions_time_batch)
loss += scale_gradient(l, gradient_scale)
# Half the gradient of the representation
representation_batch = scale_gradient(representation_batch, 0.5)
return loss
optimizer.minimize(loss=loss, var_list=network.cb_get_variables())
network.training_steps += 1
def loss_value(target_value_batch, value_batch, value_support_size: int):
batch_size = len(target_value_batch)
targets = np.zeros((batch_size, value_support_size))
sqrt_value = np.sqrt(target_value_batch)
floor_value = np.floor(sqrt_value).astype(int)
rest = sqrt_value - floor_value
targets[range(batch_size), floor_value.astype(int)] = 1 - rest
targets[range(batch_size), floor_value.astype(int) + 1] = rest
return tf.nn.softmax_cross_entropy_with_logits(logits=value_batch, labels=targets)
gitextract_mhh1v2wx/
├── .gitignore
├── README.rst
└── muzero/
├── __init__.py
├── config.py
├── game/
│ ├── __init__.py
│ ├── cartpole.py
│ ├── game.py
│ └── gym_wrappers.py
├── muzero.py
├── networks/
│ ├── __init__.py
│ ├── cartpole_network.py
│ ├── network.py
│ └── shared_storage.py
├── self_play/
│ ├── __init__.py
│ ├── mcts.py
│ ├── self_play.py
│ └── utils.py
└── training/
├── __init__.py
├── replay_buffer.py
└── training.py
SYMBOL INDEX (106 symbols across 13 files)
FILE: muzero/config.py
class MuZeroConfig (line 14) | class MuZeroConfig(object):
method __init__ (line 16) | def __init__(self,
method new_game (line 82) | def new_game(self) -> AbstractGame:
method new_network (line 85) | def new_network(self) -> BaseNetwork:
method uniform_network (line 88) | def uniform_network(self) -> UniformNetwork:
method new_optimizer (line 91) | def new_optimizer(self) -> tf.keras.optimizers:
function make_cartpole_config (line 95) | def make_cartpole_config() -> MuZeroConfig:
FILE: muzero/game/cartpole.py
class CartPole (line 9) | class CartPole(AbstractGame):
method __init__ (line 12) | def __init__(self, discount: float):
method action_space_size (line 21) | def action_space_size(self) -> int:
method step (line 25) | def step(self, action) -> int:
method terminal (line 33) | def terminal(self) -> bool:
method legal_actions (line 37) | def legal_actions(self) -> List[Action]:
method make_image (line 41) | def make_image(self, state_index: int):
FILE: muzero/game/game.py
class Action (line 7) | class Action(object):
method __init__ (line 10) | def __init__(self, index: int):
method __hash__ (line 13) | def __hash__(self):
method __eq__ (line 16) | def __eq__(self, other):
method __gt__ (line 19) | def __gt__(self, other):
class Player (line 23) | class Player(object):
method __eq__ (line 29) | def __eq__(self, other):
class ActionHistory (line 33) | class ActionHistory(object):
method __init__ (line 39) | def __init__(self, history: List[Action], action_space_size: int):
method clone (line 43) | def clone(self):
method add_action (line 46) | def add_action(self, action: Action):
method last_action (line 49) | def last_action(self) -> Action:
method action_space (line 52) | def action_space(self) -> List[Action]:
method to_play (line 55) | def to_play(self) -> Player:
class AbstractGame (line 59) | class AbstractGame(ABC):
method __init__ (line 65) | def __init__(self, discount: float):
method apply (line 72) | def apply(self, action: Action):
method store_search_statistics (line 79) | def store_search_statistics(self, root: Node):
method make_target (line 90) | def make_target(self, state_index: int, num_unroll_steps: int, td_step...
method to_play (line 113) | def to_play(self) -> Player:
method action_history (line 117) | def action_history(self) -> ActionHistory:
method action_space_size (line 124) | def action_space_size(self) -> int:
method step (line 129) | def step(self, action) -> int:
method terminal (line 134) | def terminal(self) -> bool:
method legal_actions (line 139) | def legal_actions(self) -> List[Action]:
method make_image (line 144) | def make_image(self, state_index: int):
FILE: muzero/game/gym_wrappers.py
class ScalingObservationWrapper (line 5) | class ScalingObservationWrapper(gym.ObservationWrapper):
method __init__ (line 10) | def __init__(self, env, low=None, high=None):
method observation (line 20) | def observation(self, observation):
FILE: muzero/muzero.py
function muzero (line 8) | def muzero(config: MuZeroConfig):
FILE: muzero/networks/cartpole_network.py
class CartPoleNetwork (line 12) | class CartPoleNetwork(BaseNetwork):
method __init__ (line 14) | def __init__(self,
method _value_transform (line 42) | def _value_transform(self, value_support: np.array) -> float:
method _reward_transform (line 53) | def _reward_transform(self, reward: np.array) -> float:
method _conditioned_hidden_state (line 56) | def _conditioned_hidden_state(self, hidden_state: np.array, action: Ac...
method _softmax (line 60) | def _softmax(self, values):
FILE: muzero/networks/network.py
class NetworkOutput (line 11) | class NetworkOutput(typing.NamedTuple):
method build_policy_logits (line 18) | def build_policy_logits(policy_logits):
class AbstractNetwork (line 22) | class AbstractNetwork(ABC):
method __init__ (line 24) | def __init__(self):
method initial_inference (line 28) | def initial_inference(self, image) -> NetworkOutput:
method recurrent_inference (line 32) | def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
class UniformNetwork (line 36) | class UniformNetwork(AbstractNetwork):
method __init__ (line 39) | def __init__(self, action_size: int):
method initial_inference (line 43) | def initial_inference(self, image) -> NetworkOutput:
method recurrent_inference (line 46) | def recurrent_inference(self, hidden_state, action) -> NetworkOutput:
class InitialModel (line 50) | class InitialModel(Model):
method __init__ (line 53) | def __init__(self, representation_network: Model, value_network: Model...
method call (line 59) | def call(self, image):
class RecurrentModel (line 66) | class RecurrentModel(Model):
method __init__ (line 69) | def __init__(self, dynamic_network: Model, reward_network: Model, valu...
method call (line 76) | def call(self, conditioned_hidden):
class BaseNetwork (line 84) | class BaseNetwork(AbstractNetwork):
method __init__ (line 87) | def __init__(self, representation_network: Model, value_network: Model...
method initial_inference (line 102) | def initial_inference(self, image: np.array) -> NetworkOutput:
method recurrent_inference (line 112) | def recurrent_inference(self, hidden_state: np.array, action: Action) ...
method _value_transform (line 124) | def _value_transform(self, value: np.array) -> float:
method _reward_transform (line 128) | def _reward_transform(self, reward: np.array) -> float:
method _conditioned_hidden_state (line 132) | def _conditioned_hidden_state(self, hidden_state: np.array, action: Ac...
method cb_get_variables (line 135) | def cb_get_variables(self) -> Callable:
FILE: muzero/networks/shared_storage.py
class SharedStorage (line 6) | class SharedStorage(object):
method __init__ (line 9) | def __init__(self, network: BaseNetwork, uniform_network: UniformNetwo...
method latest_network (line 15) | def latest_network(self) -> AbstractNetwork:
method save_network (line 22) | def save_network(self, step: int, network: BaseNetwork):
FILE: muzero/self_play/mcts.py
function add_exploration_noise (line 15) | def add_exploration_noise(config: MuZeroConfig, node: Node):
function run_mcts (line 27) | def run_mcts(config: MuZeroConfig, root: Node, action_history: ActionHis...
function select_child (line 55) | def select_child(config: MuZeroConfig, node: Node, min_max_stats: MinMax...
function ucb_score (line 69) | def ucb_score(config: MuZeroConfig, parent: Node, child: Node,
function expand_node (line 83) | def expand_node(node: Node, to_play: Player, actions: List[Action],
function backpropagate (line 98) | def backpropagate(search_path: List[Node], value: float, to_play: Player,
function select_action (line 112) | def select_action(config: MuZeroConfig, num_moves: int, node: Node, netw...
FILE: muzero/self_play/self_play.py
function run_selfplay (line 12) | def run_selfplay(config: MuZeroConfig, storage: SharedStorage, replay_bu...
function run_eval (line 23) | def run_eval(config: MuZeroConfig, storage: SharedStorage, eval_episodes...
function play_game (line 33) | def play_game(config: MuZeroConfig, network: AbstractNetwork, train: boo...
FILE: muzero/self_play/utils.py
class MinMaxStats (line 9) | class MinMaxStats(object):
method __init__ (line 12) | def __init__(self, known_bounds):
method update (line 16) | def update(self, value: float):
method normalize (line 23) | def normalize(self, value: float) -> float:
class Node (line 34) | class Node(object):
method __init__ (line 37) | def __init__(self, prior: float):
method expanded (line 46) | def expanded(self) -> bool:
method value (line 49) | def value(self) -> Optional[float]:
function softmax_sample (line 55) | def softmax_sample(visit_counts, actions, t):
FILE: muzero/training/replay_buffer.py
class ReplayBuffer (line 9) | class ReplayBuffer(object):
method __init__ (line 11) | def __init__(self, config: MuZeroConfig):
method save_game (line 16) | def save_game(self, game):
method sample_batch (line 21) | def sample_batch(self, num_unroll_steps: int, td_steps: int):
method sample_games (line 49) | def sample_games(self) -> List[AbstractGame]:
method sample_position (line 53) | def sample_position(self, game: AbstractGame) -> int:
FILE: muzero/training/training.py
function train_network (line 13) | def train_network(config: MuZeroConfig, storage: SharedStorage, replay_b...
function update_weights (line 23) | def update_weights(optimizer: tf.keras.optimizers, network: BaseNetwork,...
function loss_value (line 88) | def loss_value(target_value_batch, value_batch, value_support_size: int):
Condensed preview — 20 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (48K chars).
[
{
"path": ".gitignore",
"chars": 1732,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "README.rst",
"chars": 3606,
"preview": ".. |copy| unicode:: 0xA9\n.. |---| unicode:: U+02014\n\n======\nMuZero\n======\n\nThis repository is a Python implementation of"
},
{
"path": "muzero/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "muzero/config.py",
"chars": 5786,
"preview": "import collections\nfrom typing import Optional, Dict\n\nimport tensorflow_core as tf\n\nfrom game.cartpole import CartPole\nf"
},
{
"path": "muzero/game/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "muzero/game/cartpole.py",
"chars": 1382,
"preview": "from typing import List\n\nimport gym\n\nfrom game.game import Action, AbstractGame\nfrom game.gym_wrappers import ScalingObs"
},
{
"path": "muzero/game/game.py",
"chars": 4571,
"preview": "from abc import abstractmethod, ABC\nfrom typing import List\n\nfrom self_play.utils import Node\n\n\nclass Action(object):\n "
},
{
"path": "muzero/game/gym_wrappers.py",
"chars": 631,
"preview": "import gym\nimport numpy as np\n\n\nclass ScalingObservationWrapper(gym.ObservationWrapper):\n \"\"\"\n Wrapper that apply "
},
{
"path": "muzero/muzero.py",
"chars": 1522,
"preview": "from config import MuZeroConfig, make_cartpole_config\nfrom networks.shared_storage import SharedStorage\nfrom self_play.s"
},
{
"path": "muzero/networks/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "muzero/networks/cartpole_network.py",
"chars": 3115,
"preview": "import math\n\nimport numpy as np\nfrom tensorflow_core.python.keras import regularizers\nfrom tensorflow_core.python.keras."
},
{
"path": "muzero/networks/network.py",
"chars": 5760,
"preview": "import typing\nfrom abc import ABC, abstractmethod\nfrom typing import Dict, List, Callable\n\nimport numpy as np\nfrom tenso"
},
{
"path": "muzero/networks/shared_storage.py",
"chars": 801,
"preview": "import tensorflow_core as tf\n\nfrom networks.network import BaseNetwork, UniformNetwork, AbstractNetwork\n\n\nclass SharedSt"
},
{
"path": "muzero/self_play/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "muzero/self_play/mcts.py",
"chars": 4949,
"preview": "\"\"\"MCTS module: where MuZero thinks inside the tree.\"\"\"\n\nimport math\nimport random\nfrom typing import List\n\nimport numpy"
},
{
"path": "muzero/self_play/self_play.py",
"chars": 2522,
"preview": "\"\"\"Self-Play module: where the games are played.\"\"\"\n\nfrom config import MuZeroConfig\nfrom game.game import AbstractGame\n"
},
{
"path": "muzero/self_play/utils.py",
"chars": 1761,
"preview": "\"\"\"Helpers for the MCTS\"\"\"\nfrom typing import Optional\n\nimport numpy as np\n\nMAXIMUM_FLOAT_VALUE = float('inf')\n\n\nclass M"
},
{
"path": "muzero/training/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "muzero/training/replay_buffer.py",
"chars": 2331,
"preview": "import random\nfrom itertools import zip_longest\nfrom typing import List\n\nfrom config import MuZeroConfig\nfrom game.game "
},
{
"path": "muzero/training/training.py",
"chars": 4995,
"preview": "\"\"\"Training module: this is where MuZero neurons are trained.\"\"\"\n\nimport numpy as np\nimport tensorflow_core as tf\nfrom t"
}
]
About this extraction
This page contains the full source code of the johan-gras/MuZero GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 20 files (44.4 KB), approximately 10.4k tokens, and a symbol index with 106 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.