SYMBOL INDEX (3614 symbols across 359 files) FILE: adversarial_robustness/jax/attacks.py function untargeted_cross_entropy (line 104) | def untargeted_cross_entropy(logits: chex.Array, function untargeted_kl_divergence (line 113) | def untargeted_kl_divergence(logits: chex.Array, function untargeted_margin (line 123) | def untargeted_margin(logits: chex.Array, class UntargetedAttack (line 134) | class UntargetedAttack: method __init__ (line 137) | def __init__(self, method __call__ (line 153) | def __call__(self, method expects_labels (line 164) | def expects_labels(self): method expects_probabilities (line 167) | def expects_probabilities(self): class StepOptimizer (line 171) | class StepOptimizer: method __init__ (line 174) | def __init__(self, method init (line 178) | def init(self, method minimize (line 184) | def minimize( class SGD (line 196) | class SGD(StepOptimizer): method __init__ (line 199) | def __init__(self, class IteratedFGSM (line 218) | class IteratedFGSM(SGD): method __init__ (line 221) | def __init__(self, class Adam (line 226) | class Adam(StepOptimizer): method __init__ (line 229) | def __init__( class PGD (line 253) | class PGD: method __init__ (line 256) | def __init__(self, method __call__ (line 270) | def __call__(self, function linf_project_fn (line 289) | def linf_project_fn(epsilon: float, bounds: Tuple[float, float]) -> Proj... function linf_initialize_fn (line 296) | def linf_initialize_fn(epsilon: float) -> InitializeFn: function gradients_fn (line 303) | def gradients_fn(loss_fn: LossFn, FILE: adversarial_robustness/jax/datasets.py function cifar10_preprocess (line 38) | def cifar10_preprocess(mode: str = 'train'): function cifar10_normalize (line 59) | def cifar10_normalize(image: chex.Array) -> chex.Array: function mnist_normalize (line 65) | def mnist_normalize(image: chex.Array) -> chex.Array: function cifar100_normalize (line 71) | def cifar100_normalize(image: chex.Array) -> chex.Array: function load_cifar10 (line 77) | def load_cifar10(batch_sizes: Sequence[int], function load_extra (line 105) | def load_extra(batch_sizes: Sequence[int], function load_dummy_data (line 132) | def load_dummy_data(batch_sizes: Sequence[int], function _random_jitter (line 151) | def _random_jitter(image: tf.Tensor, pad: int, crop: int) -> tf.Tensor: function _repeat_batch (line 158) | def _repeat_batch(batch_sizes: Sequence[int], FILE: adversarial_robustness/jax/eval.py function main (line 48) | def main(unused_argv): FILE: adversarial_robustness/jax/experiment.py function get_config (line 44) | def get_config(): class Experiment (line 153) | class Experiment(experiment.AbstractExperiment): method __init__ (line 163) | def __init__(self, mode, config, init_rng): method step (line 195) | def step(self, global_step, rng, *unused_args, **unused_kwargs): method _train_fn (line 244) | def _train_fn(self, params, avg_params, state, opt_state, global_step, method _cross_entropy_loss_fn (line 300) | def _cross_entropy_loss_fn(self, params, state, images, adv_images, la... method _trades_loss_fn (line 314) | def _trades_loss_fn(self, params, state, images, adv_images, labels, method evaluate (line 350) | def evaluate(self, global_step, rng, *unused_args, **unused_kwargs): method eval_epoch (line 359) | def eval_epoch(self, params, state, rng): method _eval_fn (line 380) | def _eval_fn(self, params, state, inputs, rng): method _initialize_training (line 409) | def _initialize_training(self, rng): method _initialize_evaluation (line 461) | def _initialize_evaluation(self): method _supervised_train_dataset (line 470) | def _supervised_train_dataset(self) -> tfds.typing.Tree[np.ndarray]: method _extra_train_dataset (line 481) | def _extra_train_dataset(self) -> tfds.typing.Tree[np.ndarray]: method _get_model (line 494) | def _get_model(self) -> Callable[..., chex.Array]: method concatenate (line 501) | def concatenate( function _dataset (line 530) | def _dataset(load_fn, function _merge_eval_scalars (line 576) | def _merge_eval_scalars(a, b): FILE: adversarial_robustness/jax/experiment_test.py function test_experiment (line 25) | def test_experiment(unused_argv): FILE: adversarial_robustness/jax/model_zoo.py class _WideResNetBlock (line 25) | class _WideResNetBlock(hk.Module): method __init__ (line 28) | def __init__(self, num_filters, stride=1, projection_shortcut=False, method __call__ (line 63) | def __call__(self, inputs, **norm_kwargs): class WideResNet (line 80) | class WideResNet(hk.Module): method __init__ (line 83) | def __init__(self, method __call__ (line 131) | def __call__(self, inputs: chex.Array, **norm_kwargs) -> chex.Array: FILE: adversarial_robustness/jax/utils.py function get_cosine_schedule (line 28) | def get_cosine_schedule( function get_step_schedule (line 44) | def get_step_schedule( function sgd_momentum (line 61) | def sgd_momentum(learning_rate_fn: optax.Schedule, function cross_entropy (line 70) | def cross_entropy(logits: chex.Array, labels: chex.Array) -> chex.Array: function kl_divergence (line 74) | def kl_divergence(q_logits: chex.Array, function accuracy (line 81) | def accuracy(logits: chex.Array, labels: chex.Array) -> chex.Array: function weight_decay (line 87) | def weight_decay(params: hk.Params, function ema_update (line 109) | def ema_update(step: chex.Array, function cutmix (line 130) | def cutmix(rng: chex.PRNGKey, function _random_box (line 170) | def _random_box(rng: chex.PRNGKey, function _compose_two_images (line 188) | def _compose_two_images(images: chex.Array, function _window_mask (line 199) | def _window_mask(destination_box: chex.Array, FILE: adversarial_robustness/pytorch/eval.py function main (line 45) | def main(unused_argv): FILE: adversarial_robustness/pytorch/model_zoo.py class _Swish (line 30) | class _Swish(torch.autograd.Function): method forward (line 34) | def forward(ctx, i): method backward (line 40) | def backward(ctx, grad_output): class Swish (line 46) | class Swish(nn.Module): method forward (line 49) | def forward(self, input_tensor): class _Block (line 53) | class _Block(nn.Module): method __init__ (line 56) | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): method forward (line 76) | def forward(self, x): class _BlockGroup (line 95) | class _BlockGroup(nn.Module): method __init__ (line 98) | def __init__(self, num_blocks, in_planes, out_planes, stride, method forward (line 110) | def forward(self, x): class WideResNet (line 114) | class WideResNet(nn.Module): method __init__ (line 117) | def __init__(self, method forward (line 149) | def forward(self, x): class _PreActBlock (line 167) | class _PreActBlock(nn.Module): method __init__ (line 170) | def __init__(self, in_planes, out_planes, stride, activation_fn=nn.ReLU): method _pad (line 188) | def _pad(self, x): method forward (line 197) | def forward(self, x): class PreActResNet (line 205) | class PreActResNet(nn.Module): method __init__ (line 208) | def __init__(self, method _make_layer (line 241) | def _make_layer(self, in_planes, out_planes, num_blocks, stride, method forward (line 252) | def forward(self, x): FILE: alphafold_casp13/asa_output.py class ASAOutputLayer (line 19) | class ASAOutputLayer(object): method __init__ (line 22) | def __init__(self, name='asa'): method compute_asa_output (line 25) | def compute_asa_output(self, activations): FILE: alphafold_casp13/config_dict.py class ConfigDict (line 19) | class ConfigDict(dict): method __init__ (line 22) | def __init__(self, *args, **kwargs): method _add (line 31) | def _add(self, key, value): method __getattr__ (line 37) | def __getattr__(self, attr): method __setattr__ (line 43) | def __setattr__(self, key, value): method __setitem__ (line 46) | def __setitem__(self, key, value): method __delattr__ (line 50) | def __delattr__(self, item): method __delitem__ (line 53) | def __delitem__(self, key): method to_json (line 57) | def to_json(self): method from_json (line 61) | def from_json(cls, json_string): FILE: alphafold_casp13/contacts.py function evaluate (line 60) | def evaluate(crop_size_x, crop_size_y, feature_normalization, checkpoint... function _run_evaluation (line 93) | def _run_evaluation( function compute_one_prediction (line 172) | def compute_one_prediction( function compute_one_patch (line 291) | def compute_one_patch(sess, experiment, output_fetches, inputs_1d, function main (line 373) | def main(argv): FILE: alphafold_casp13/contacts_dataset.py class FeatureType (line 29) | class FeatureType(enum.Enum): function shape (line 83) | def shape(feature_name, num_residues, features=None): function dim (line 107) | def dim(feature_name): function _concat_or_zeros (line 136) | def _concat_or_zeros(tensor_list, axis, tensor_shape, name): function parse_tfexample (line 143) | def parse_tfexample(raw_data, features): function create_tf_dataset (line 179) | def create_tf_dataset(tf_record_filename, features): function normalize_from_stats_file (line 201) | def normalize_from_stats_file( function convert_to_legacy_proteins_dataset_format (line 256) | def convert_to_legacy_proteins_dataset_format( FILE: alphafold_casp13/contacts_experiment.py function _int_ph (line 23) | def _int_ph(shape, name): function _float_ph (line 28) | def _float_ph(shape, name): class Contacts (line 33) | class Contacts(object): method __init__ (line 36) | def __init__( method model (line 70) | def model(self): method _get_feature_normalization (line 73) | def _get_feature_normalization(self, features): method _build_evaluation_graph (line 78) | def _build_evaluation_graph(self, tfrecord, stats_file): method get_one_example (line 198) | def get_one_example(self, sess): FILE: alphafold_casp13/contacts_network.py function call_on_tuple (line 26) | def call_on_tuple(f): class ContactsNet (line 41) | class ContactsNet(sonnet.AbstractModule): method __init__ (line 44) | def __init__(self, method quant_threshold (line 106) | def quant_threshold(self, threshold=8.0): method _build (line 118) | def _build(self, crop_size_x=0, crop_size_y=0, placeholders=None): method compute_outputs (line 150) | def compute_outputs(self, inputs_1d, residue_index, inputs_2d, crop_x,... method _concatenate_2d (line 222) | def _concatenate_2d(hidden_1d, residue_index, hidden_2d, crop_x, crop_y, method _build_2d_embedding (line 308) | def _build_2d_embedding(self, hidden_1d, residue_index, inputs_2d, method _output_from_pre_logits (line 386) | def _output_from_pre_logits(self, contact_pre_logits, features_forward, method update_crop_fetches (line 431) | def update_crop_fetches(self, fetches): function build_crops_biases (line 441) | def build_crops_biases(bias_size, raw_biases, crop_x, crop_y, back_prop): FILE: alphafold_casp13/distogram_io.py function save_rr_file (line 36) | def save_rr_file(filename, probs, domain, sequence, function save_torsions (line 50) | def save_torsions(torsions_dir, filebase, sequence, torsions_probs): function save_distance_histogram (line 58) | def save_distance_histogram( function save_distance_histogram_from_dict (line 71) | def save_distance_histogram_from_dict(filename, dh_dict): function contact_map_from_distogram (line 86) | def contact_map_from_distogram(distogram_dict): FILE: alphafold_casp13/ensemble_contact_maps.py function ensemble_distance_histograms (line 42) | def ensemble_distance_histograms(pickle_dirs, weights, output_dir): function ensemble_one_distance_histogram (line 75) | def ensemble_one_distance_histogram(pickle_files, weights): function main (line 108) | def main(argv): FILE: alphafold_casp13/parsers.py function distance_histogram_dict (line 21) | def distance_histogram_dict(f): function parse_distance_histogram_dict (line 63) | def parse_distance_histogram_dict(filepath): FILE: alphafold_casp13/paste_contact_maps.py function generate_domains (line 39) | def generate_domains(target, sequence, crop_sizes, crop_step): function get_weights (line 60) | def get_weights(path): function paste_distance_histograms (line 80) | def paste_distance_histograms( function main (line 189) | def main(argv): FILE: alphafold_casp13/secstruct.py function make_q3_matrices (line 31) | def make_q3_matrices(): class Secstruct (line 44) | class Secstruct(object): method __init__ (line 50) | def __init__(self, name='secstruct'): method make_layer_new (line 54) | def make_layer_new(self, activations): method get_q8_probs (line 63) | def get_q8_probs(self): function save_secstructs (line 67) | def save_secstructs(dump_dir_path, name, index, sequence, probs, FILE: alphafold_casp13/two_dim_convnet.py function weight_variable (line 20) | def weight_variable(shape, stddev=0.01): function bias_variable (line 27) | def bias_variable(shape): function conv2d (line 32) | def conv2d(x, w, atrou_rate=1, data_format='NHWC'): function make_conv_sep2d_layer (line 45) | def make_conv_sep2d_layer(input_node, function batch_norm_layer (line 90) | def batch_norm_layer(h_conv, layer_name, is_training=True, data_format='... function make_conv_layer (line 102) | def make_conv_layer(input_node, FILE: alphafold_casp13/two_dim_resnet.py function make_sep_res_layer (line 22) | def make_sep_res_layer( function make_two_dim_resnet (line 119) | def make_two_dim_resnet( FILE: avae/checkpointer.py class Checkpointer (line 27) | class Checkpointer: method __init__ (line 30) | def __init__(self, checkpoint_dir: str, filename: str): method save_checkpoint (line 43) | def save_checkpoint( method load_checkpoint (line 71) | def load_checkpoint( FILE: avae/data_iterators.py class Dataset (line 28) | class Dataset(enum.Enum): class MnistDataIterator (line 32) | class MnistDataIterator(object): method __init__ (line 38) | def __init__(self, subset: str, batch_size: int): method __iter__ (line 56) | def __iter__(self): method __next__ (line 59) | def __next__(self) -> types.LabelledData: class ColorMnistDataIterator (line 63) | class ColorMnistDataIterator(MnistDataIterator): method __next__ (line 73) | def __next__(self) -> types.LabelledData: FILE: avae/decoders.py class DecoderBase (line 26) | class DecoderBase(hk.Module): method __init__ (line 29) | def __init__(self, obs_var: float): method __call__ (line 39) | def __call__(self, z: jnp.ndarray) -> jnp.ndarray: method data_fidelity (line 48) | def data_fidelity( class ColorMnistMLPDecoder (line 67) | class ColorMnistMLPDecoder(DecoderBase): method __call__ (line 73) | def __call__(self, z: jnp.ndarray) -> jnp.ndarray: FILE: avae/encoders.py class EncoderBase (line 29) | class EncoderBase(hk.Module, Generic[_Params]): method __init__ (line 32) | def __init__(self, latent_dim: int): method __call__ (line 42) | def __call__(self, input_data: jnp.ndarray) -> _Params: method sample (line 53) | def sample(self, posterior: _Params, key: jnp.ndarray) -> jnp.ndarray: class ColorMnistMLPEncoder (line 66) | class ColorMnistMLPEncoder(EncoderBase[types.NormalParams]): method __call__ (line 71) | def __call__( method sample (line 88) | def sample( function _normal_params_from_logits (line 107) | def _normal_params_from_logits( FILE: avae/kl.py function kl_p_with_uniform_normal (line 21) | def kl_p_with_uniform_normal(mean: jnp.ndarray, FILE: avae/train.py function train (line 29) | def train( FILE: avae/train_main.py class Model (line 30) | class Model(enum.Enum): class EncoderArch (line 35) | class EncoderArch(enum.Enum): class DecoderArch (line 39) | class DecoderArch(enum.Enum): function main (line 80) | def main(_): FILE: avae/types.py class ELBOOutputs (line 26) | class ELBOOutputs: class LabelledData (line 33) | class LabelledData: class NormalParams (line 45) | class NormalParams: FILE: avae/vae.py class VAE (line 29) | class VAE: method __init__ (line 36) | def __init__(self, encoder: encoders.EncoderBase, method vae_elbo (line 49) | def vae_elbo( method avae_elbo (line 70) | def avae_elbo( method __call__ (line 117) | def __call__( FILE: box_arrangement/dmlab_assets.py class SkyBox (line 22) | class SkyBox(composer.Entity): method _build (line 25) | def _build(self, style): method mjcf_model (line 35) | def mjcf_model(self): method texture (line 39) | def texture(self): class WallTextures (line 43) | class WallTextures(composer.Entity): method _build (line 46) | def _build(self, style): method mjcf_model (line 56) | def mjcf_model(self): method textures (line 60) | def textures(self): class FloorTextures (line 64) | class FloorTextures(composer.Entity): method _build (line 67) | def _build(self, style): method mjcf_model (line 77) | def mjcf_model(self): method textures (line 81) | def textures(self): FILE: box_arrangement/explore.py function main (line 37) | def main(unused_argv): FILE: box_arrangement/predicate_task.py function _generate_target_permutation (line 42) | def _generate_target_permutation(num_targets, random_state): class PredicateTask (line 48) | class PredicateTask(composer.Task): method __init__ (line 51) | def __init__(self, method _create_per_walker_observables (line 155) | def _create_per_walker_observables(self, walker): method observables (line 222) | def observables(self): method name (line 226) | def name(self): method root_entity (line 230) | def root_entity(self): method _regenerate_positions (line 233) | def _regenerate_positions(self): method initialize_episode_mjcf (line 259) | def initialize_episode_mjcf(self, random_state): method _set_active_predicates (line 268) | def _set_active_predicates(self, random_state): method _choose_random_predicates (line 291) | def _choose_random_predicates(self, random_state, num_predicates): method _set_random_colors (line 305) | def _set_random_colors(self, random_state): method initialize_episode (line 342) | def initialize_episode(self, physics, random_state): method before_step (line 395) | def before_step(self, physics, actions, random_state): method after_step (line 406) | def after_step(self, physics, random_state): method get_reward (line 412) | def get_reward(self, physics): method _all_predicates_satisfied (line 425) | def _all_predicates_satisfied(self): method should_terminate_episode (line 428) | def should_terminate_episode(self, physics): method get_discount (line 432) | def get_discount(self, physics): method get_reward_spec (line 437) | def get_reward_spec(self): method get_discount_spec (line 440) | def get_discount_spec(self): FILE: box_arrangement/predicate_task_test.py class PredicateTaskTest (line 40) | class PredicateTaskTest(absltest.TestCase): method _setup_basic_gtt_task (line 42) | def _setup_basic_gtt_task(self, num_targets=1, reward_scale=1.0): method test_observables (line 70) | def test_observables(self): method test_termination_and_discount (line 78) | def test_termination_and_discount(self): method test_reward_scaling (line 114) | def test_reward_scaling(self): method test_too_few_predicates_raises_exception (line 137) | def test_too_few_predicates_raises_exception(self): method test_error_too_few_targets (line 167) | def test_error_too_few_targets(self): method test_error_if_no_predicates_found (line 200) | def test_error_if_no_predicates_found(self): FILE: box_arrangement/predicates.py class BasePredicate (line 34) | class BasePredicate(object, metaclass=abc.ABCMeta): method __init__ (line 37) | def __init__(self, walker): method reinitialize (line 41) | def reinitialize(self, random_state): method activate_predicate (line 58) | def activate_predicate(self): method objects_in_use (line 69) | def objects_in_use(self): method observation_value (line 74) | def observation_value(self): method is_active (line 79) | def is_active(self, physics): method inactive_observation_value (line 91) | def inactive_observation_value(self): class MoveWalkerToTarget (line 106) | class MoveWalkerToTarget(BasePredicate): method __init__ (line 109) | def __init__(self, walker, target, target_index=0): method reinitialize (line 122) | def reinitialize(self, random_state): method activate_predicate (line 125) | def activate_predicate(self): method objects_in_use (line 130) | def objects_in_use(self): method observation_value (line 134) | def observation_value(self): method is_active (line 140) | def is_active(self, physics): class MoveWalkerToRandomTarget (line 144) | class MoveWalkerToRandomTarget(BasePredicate): method __init__ (line 147) | def __init__(self, walker, targets=None): method reinitialize (line 159) | def reinitialize(self, random_state): method activate_predicate (line 165) | def activate_predicate(self): method objects_in_use (line 171) | def objects_in_use(self): method observation_value (line 175) | def observation_value(self): method is_active (line 181) | def is_active(self, physics): class MoveWalkerToBox (line 185) | class MoveWalkerToBox(BasePredicate): method __init__ (line 188) | def __init__(self, walker, box, box_index=0, detection_region=None): method reinitialize (line 206) | def reinitialize(self, random_state): method activate_predicate (line 211) | def activate_predicate(self): method objects_in_use (line 215) | def objects_in_use(self): method observation_value (line 219) | def observation_value(self): method is_active (line 225) | def is_active(self, physics): method _is_walker_contacting_box (line 234) | def _is_walker_contacting_box(self, physics): class MoveBoxToBox (line 246) | class MoveBoxToBox(BasePredicate): method __init__ (line 249) | def __init__(self, method reinitialize (line 278) | def reinitialize(self, random_state): method activate_predicate (line 283) | def activate_predicate(self): method objects_in_use (line 287) | def objects_in_use(self): method observation_value (line 291) | def observation_value(self): method is_active (line 297) | def is_active(self, physics): method _are_boxes_in_contact (line 307) | def _are_boxes_in_contact(self, physics): class MoveBoxToTarget (line 316) | class MoveBoxToTarget(BasePredicate): method __init__ (line 319) | def __init__(self, walker, box, target, box_index=0, target_index=0): method reinitialize (line 340) | def reinitialize(self, random_state): method _get_box_properties (line 344) | def _get_box_properties(self, random_state): method activate_predicate (line 351) | def activate_predicate(self): method objects_in_use (line 357) | def objects_in_use(self): method observation_value (line 361) | def observation_value(self): method is_active (line 367) | def is_active(self, physics): class MoveBoxToRandomTarget (line 371) | class MoveBoxToRandomTarget(BasePredicate): method __init__ (line 374) | def __init__(self, walker, box, box_index=0, targets=None): method reinitialize (line 394) | def reinitialize(self, random_state): method _get_box_properties (line 401) | def _get_box_properties(self, random_state): method activate_predicate (line 408) | def activate_predicate(self): method objects_in_use (line 414) | def objects_in_use(self): method observation_value (line 418) | def observation_value(self): method is_active (line 425) | def is_active(self, physics): FILE: box_arrangement/task_examples.py function _make_predicate_task (line 33) | def _make_predicate_task(n_boxes, n_targets, function go_to_k_targets (line 96) | def go_to_k_targets(n_targets=3, function move_box (line 110) | def move_box(n_targets=3, function move_box_or_gtt (line 124) | def move_box_or_gtt(n_targets=3, function move_box_and_gtt (line 138) | def move_box_and_gtt(n_targets=3, FILE: byol/byol_experiment.py class _ByolExperimentState (line 45) | class _ByolExperimentState(NamedTuple): class ByolExperiment (line 54) | class ByolExperiment: method __init__ (line 57) | def __init__( method _forward (line 116) | def _forward( method _optimizer (line 189) | def _optimizer(self, learning_rate: float) -> optax.GradientTransforma... method loss_fn (line 197) | def loss_fn( method _should_transpose_images (line 288) | def _should_transpose_images(self): method _update_fn (line 293) | def _update_fn( method _make_initial_state (line 353) | def _make_initial_state( method step (line 395) | def step(self, *, method save_checkpoint (line 413) | def save_checkpoint(self, step: int, rng: jnp.ndarray): method load_checkpoint (line 417) | def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]: method _initialize_train (line 424) | def _initialize_train(self): method _build_train_input (line 447) | def _build_train_input(self) -> Generator[dataset.Batch, None, None]: method _eval_batch (line 464) | def _eval_batch( method _eval_epoch (line 497) | def _eval_epoch(self, subset: Text, batch_size: int): method evaluate (line 526) | def evaluate(self, global_step, **unused_args): FILE: byol/configs/byol.py function get_config (line 26) | def get_config(num_epochs: int, batch_size: int): FILE: byol/configs/eval.py function get_config (line 22) | def get_config(checkpoint_to_evaluate: Text, batch_size: int): FILE: byol/eval_experiment.py class _EvalExperimentState (line 43) | class _EvalExperimentState(NamedTuple): class EvalExperiment (line 51) | class EvalExperiment: method __init__ (line 54) | def __init__( method _should_transpose_images (line 121) | def _should_transpose_images(self): method _backbone_fn (line 126) | def _backbone_fn( method _classif_fn (line 147) | def _classif_fn( method step (line 161) | def step(self, *, method save_checkpoint (line 176) | def save_checkpoint(self, step: int, rng: jnp.ndarray): method load_checkpoint (line 181) | def load_checkpoint(self) -> Union[Tuple[int, jnp.ndarray], None]: method _initialize_train (line 188) | def _initialize_train(self, rng): method _make_initial_state (line 254) | def _make_initial_state( method _build_train_input (line 282) | def _build_train_input(self) -> Generator[dataset.Batch, None, None]: method _optimizer (line 299) | def _optimizer(self, learning_rate: float): method _loss_fn (line 303) | def _loss_fn( method _update_func (line 334) | def _update_func( method evaluate (line 413) | def evaluate(self, global_step, **unused_args): method _eval_batch (line 422) | def _eval_batch( method _eval_epoch (line 444) | def _eval_epoch(self, subset: Text, batch_size: int): FILE: byol/main_loop.py function train_loop (line 50) | def train_loop(experiment_class: Experiment, config: Mapping[Text, Any]): function eval_loop (line 98) | def eval_loop(experiment_class: Experiment, config: Mapping[Text, Any]): function main (line 133) | def main(_): FILE: byol/main_loop_test.py class MainLoopTest (line 31) | class MainLoopTest(absltest.TestCase): method test_pretrain (line 33) | def test_pretrain(self): method test_linear_eval (line 50) | def test_linear_eval(self): method test_pipeline (line 66) | def test_pipeline(self): FILE: byol/utils/augmentations.py function postprocess (line 69) | def postprocess(inputs: JaxBatch, rng: jnp.ndarray): function _maybe_apply (line 107) | def _maybe_apply(apply_fn, inputs, rng, apply_prob): function _depthwise_conv2d (line 112) | def _depthwise_conv2d(inputs, kernel, strides, padding): function _gaussian_blur_single_image (line 133) | def _gaussian_blur_single_image(image, kernel_size, padding, sigma): function _random_gaussian_blur (line 154) | def _random_gaussian_blur(image, rng, kernel_size, padding, sigma_min, function rgb_to_hsv (line 172) | def rgb_to_hsv(r, g, b): function hsv_to_rgb (line 205) | def hsv_to_rgb(h, s, v): function adjust_brightness (line 239) | def adjust_brightness(rgb_tuple, delta): function adjust_contrast (line 243) | def adjust_contrast(image, factor): function adjust_saturation (line 250) | def adjust_saturation(h, s, v, factor): function adjust_hue (line 254) | def adjust_hue(h, s, v, delta): function _random_brightness (line 261) | def _random_brightness(rgb_tuple, rng, max_delta): function _random_contrast (line 266) | def _random_contrast(rgb_tuple, rng, max_delta): function _random_saturation (line 272) | def _random_saturation(rgb_tuple, rng, max_delta): function _random_hue (line 279) | def _random_hue(rgb_tuple, rng, max_delta): function _to_grayscale (line 285) | def _to_grayscale(image): function _color_transform_single_image (line 291) | def _color_transform_single_image(image, rng, brightness, contrast, satu... function _random_flip_single_image (line 350) | def _random_flip_single_image(image, rng): function random_flip (line 357) | def random_flip(images, rng): function color_transform (line 362) | def color_transform(images, function gaussian_blur (line 403) | def gaussian_blur(images, function _solarize_single_image (line 434) | def _solarize_single_image(image, rng, threshold, apply_prob): function solarize (line 442) | def solarize(images, rng, threshold=0.5, apply_prob=1.0): FILE: byol/utils/checkpointing.py class Checkpointer (line 29) | class Checkpointer: method __init__ (line 32) | def __init__( method maybe_save_checkpoint (line 52) | def maybe_save_checkpoint( method maybe_load_checkpoint (line 82) | def maybe_load_checkpoint( function load_checkpoint (line 97) | def load_checkpoint(checkpoint_path): FILE: byol/utils/dataset.py class Split (line 29) | class Split(enum.Enum): method from_string (line 37) | def from_string(cls, name: Text) -> 'Split': method num_examples (line 47) | def num_examples(self): class PreprocessMode (line 56) | class PreprocessMode(enum.Enum): function normalize_images (line 63) | def normalize_images(images: jnp.ndarray) -> jnp.ndarray: function load (line 72) | def load(split: Split, function _to_tfds_split (line 158) | def _to_tfds_split(split: Split) -> tfds.Split: function _shard (line 170) | def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int... function _preprocess_image (line 184) | def _preprocess_image( function _decode_and_random_crop (line 206) | def _decode_and_random_crop(image_bytes: tf.Tensor) -> tf.Tensor: function transpose_images (line 236) | def transpose_images(batch: Batch): function _decode_and_center_crop (line 247) | def _decode_and_center_crop( FILE: byol/utils/helpers.py function topk_accuracy (line 23) | def topk_accuracy( function softmax_cross_entropy (line 44) | def softmax_cross_entropy( function l2_normalize (line 74) | def l2_normalize( function l2_weight_regularizer (line 85) | def l2_weight_regularizer(params): function regression_loss (line 109) | def regression_loss(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: function bcast_local_devices (line 115) | def bcast_local_devices(value): function get_first (line 127) | def get_first(xs): FILE: byol/utils/networks.py class MLP (line 25) | class MLP(hk.Module): method __init__ (line 28) | def __init__( method __call__ (line 40) | def __call__(self, inputs: jnp.ndarray, is_training: bool) -> jnp.ndar... function check_length (line 48) | def check_length(length, value, name): class ResNetTorso (line 53) | class ResNetTorso(hk.Module): method __init__ (line 56) | def __init__( method __call__ (line 132) | def __call__(self, inputs, is_training, test_local_stats=False): class TinyResNet (line 154) | class TinyResNet(ResNetTorso): method __init__ (line 157) | def __init__(self, class ResNet18 (line 184) | class ResNet18(ResNetTorso): method __init__ (line 187) | def __init__(self, class ResNet34 (line 214) | class ResNet34(ResNetTorso): method __init__ (line 217) | def __init__(self, class ResNet50 (line 244) | class ResNet50(ResNetTorso): method __init__ (line 247) | def __init__(self, class ResNet101 (line 273) | class ResNet101(ResNetTorso): method __init__ (line 276) | def __init__(self, class ResNet152 (line 302) | class ResNet152(ResNetTorso): method __init__ (line 305) | def __init__(self, class ResNet200 (line 331) | class ResNet200(ResNetTorso): method __init__ (line 334) | def __init__(self, FILE: byol/utils/optimizers.py function exclude_bias_and_norm (line 30) | def exclude_bias_and_norm(path: Tuple[Any], val: jnp.ndarray) -> jnp.nda... function _partial_update (line 38) | def _partial_update(updates: optax.Updates, class ScaleByLarsState (line 57) | class ScaleByLarsState(NamedTuple): function scale_by_lars (line 61) | def scale_by_lars( class AddWeightDecayState (line 108) | class AddWeightDecayState(NamedTuple): function add_weight_decay (line 112) | def add_weight_decay( function lars (line 144) | def lars( FILE: byol/utils/schedules.py function target_ema (line 20) | def target_ema(global_step: jnp.ndarray, function learning_schedule (line 27) | def learning_schedule(global_step: jnp.ndarray, function _cosine_decay (line 46) | def _cosine_decay(global_step: jnp.ndarray, FILE: catch_carry/arm_opener.py class _ArmPropContactRemover (line 28) | class _ArmPropContactRemover(object): method __init__ (line 31) | def __init__(self, physics, arm_root, prop, gap): method _contact_pair_is_relevant (line 53) | def _contact_pair_is_relevant(self, contact): method _forward_and_find_next_contact (line 59) | def _forward_and_find_next_contact(self, physics): method _remove_contact_ik_iteration (line 69) | def _remove_contact_ik_iteration(self, physics, contact): method _override_margins_and_gaps (line 139) | def _override_margins_and_gaps(self, physics): method remove_contacts (line 151) | def remove_contacts(self, physics): function open_arms_for_prop (line 166) | def open_arms_for_prop(physics, left_arm_root, right_arm_root, prop, gap): FILE: catch_carry/ball_toss.py class BallToss (line 39) | class BallToss(composer.Task): method __init__ (line 42) | def __init__(self, walker, method root_entity (line 145) | def root_entity(self): method task_observables (line 149) | def task_observables(self): method name (line 153) | def name(self): method initialize_episode_mjcf (line 156) | def initialize_episode_mjcf(self, random_state): method initialize_episode (line 176) | def initialize_episode(self, physics, random_state): method after_step (line 230) | def after_step(self, physics, random_state): method get_reward (line 275) | def get_reward(self, physics): method get_discount (line 278) | def get_discount(self, physics): method should_terminate_episode (line 281) | def should_terminate_episode(self, physics): method _evaluate_contacts (line 284) | def _evaluate_contacts(self, physics): FILE: catch_carry/explore.py function main (line 32) | def main(unused_argv): FILE: catch_carry/mocap_data.py class Flag (line 36) | class Flag(enum.IntEnum): function _get_clip_info (line 51) | def _get_clip_info(loader, clip_number, flags): function _get_all_clip_infos_if_necessary (line 60) | def _get_all_clip_infos_if_necessary(): function _assert_partitions_all_clips (line 106) | def _assert_partitions_all_clips(*args): function all_clips (line 140) | def all_clips(): function floor_level (line 145) | def floor_level(): function medium_pedestal (line 150) | def medium_pedestal(): function high_pedestal (line 155) | def high_pedestal(): function light_prop (line 160) | def light_prop(): function heavy_prop (line 165) | def heavy_prop(): function small_box (line 170) | def small_box(): function large_box (line 175) | def large_box(): function small_ball (line 180) | def small_ball(): function large_ball (line 185) | def large_ball(): FILE: catch_carry/props.py class Pedestal (line 21) | class Pedestal(composer.Entity): method _build (line 24) | def _build(self, size=(.2, .3, .05), rgba=(0, .5, 0, 1), name='pedesta... method mjcf_model (line 30) | def mjcf_model(self): method geom (line 34) | def geom(self): method after_compile (line 37) | def after_compile(self, physics, unused_random_state): method body_geom_ids (line 44) | def body_geom_ids(self): class Bucket (line 48) | class Bucket(composer.Entity): method _build (line 51) | def _build(self, size=(.2, .3, .05), rgba=(0, .5, 0, 1), name='pedesta... method mjcf_model (line 70) | def mjcf_model(self): method geom (line 74) | def geom(self): method after_compile (line 77) | def after_compile(self, physics, unused_random_state): method body_geom_ids (line 84) | def body_geom_ids(self): FILE: catch_carry/task_examples.py function build_vision_warehouse (line 26) | def build_vision_warehouse(random_state=None): function build_vision_toss (line 54) | def build_vision_toss(random_state=None): FILE: catch_carry/trajectories.py class ClipSegment (line 31) | class ClipSegment(enum.Enum): function _get_rotated_bounding_box (line 57) | def _get_rotated_bounding_box(size, quaternion): function _get_prop_z_extent (line 77) | def _get_prop_z_extent(prop_proto, quaternion): class WarehouseTrajectory (line 100) | class WarehouseTrajectory(trajectory.Trajectory): method infer_pedestal_positions (line 103) | def infer_pedestal_positions(self, num_averaged_steps=30, method get_props_z_extent (line 128) | def get_props_z_extent(self, physics): class SinglePropCarrySegmentedTrajectory (line 137) | class SinglePropCarrySegmentedTrajectory(WarehouseTrajectory): method __init__ (line 147) | def __init__(self, method _generate_segments (line 161) | def _generate_segments(self): method segment_interval (line 218) | def segment_interval(self, segment): method get_random_timestep_in_segment (line 222) | def get_random_timestep_in_segment(self, segment, random_step): FILE: catch_carry/warehouse.py function _is_same_state (line 76) | def _is_same_state(state_1, state_2): function _singleton_or_none (line 85) | def _singleton_or_none(iterable): function _generate_pedestal_colors (line 93) | def _generate_pedestal_colors(num_pedestals): function _rotate_vector_by_quaternion (line 109) | def _rotate_vector_by_quaternion(vec, quat): class WarehousePhase (line 116) | class WarehousePhase(enum.Enum): function _find_random_free_pedestal_id (line 124) | def _find_random_free_pedestal_id(target_state, random_state): function _find_random_occupied_pedestal_id (line 130) | def _find_random_occupied_pedestal_id(target_state, random_state): function one_hot (line 136) | def one_hot(values, num_unique): class SinglePropFourPhases (line 140) | class SinglePropFourPhases(object): method __init__ (line 143) | def __init__(self, fixed_initialization_phase=None): method initialize_episode (line 147) | def initialize_episode(self, target_state, random_state): method on_success (line 192) | def on_success(self, target_state, random_state): method phase (line 232) | def phase(self): method prop_id (line 236) | def prop_id(self): method pedestal_id (line 240) | def pedestal_id(self): class PhasedBoxCarry (line 244) | class PhasedBoxCarry(composer.Task): method __init__ (line 247) | def __init__( method root_entity (line 403) | def root_entity(self): method task_observables (line 407) | def task_observables(self): method name (line 411) | def name(self): method initialize_episode_mjcf (line 414) | def initialize_episode_mjcf(self, random_state): method _settle_props (line 448) | def _settle_props(self, physics): method initialize_episode (line 463) | def initialize_episode(self, physics, random_state): method _move_arms_if_necessary (line 574) | def _move_arms_if_necessary(self, physics): method after_step (line 584) | def after_step(self, physics, random_state): method _on_transition (line 612) | def _on_transition(self, physics): method get_reward (line 625) | def get_reward(self, physics): method get_discount (line 628) | def get_discount(self, physics): method should_terminate_episode (line 631) | def should_terminate_episode(self, physics): method _update_current_state (line 634) | def _update_current_state(self, physics): method _evaluate_target_state (line 680) | def _evaluate_target_state(self): FILE: counterfactual_fairness/adult.py function _read_data (line 37) | def _read_data( function _combine_category_coding (line 48) | def _combine_category_coding(df_1, df_2): function read_all_data (line 68) | def read_all_data(root_dir, remove_missing=True): FILE: counterfactual_fairness/adult_pscf.py function build_input (line 54) | def build_input(train_data: pd.DataFrame, batch_size: int, class CausalNetOutput (line 64) | class CausalNetOutput(NamedTuple): function build_causal_graph (line 73) | def build_causal_graph(train_data: pd.DataFrame, column_names: List[str], function build_forward_fn (line 184) | def build_forward_fn(train_data: pd.DataFrame, column_names: List[str], function _loss_fn (line 367) | def _loss_fn( function _evaluate (line 398) | def _evaluate( function _loss_klqp (line 413) | def _loss_klqp(outputs: CausalNetOutput, beta: float) -> jnp.ndarray: class Updater (line 437) | class Updater: method __init__ (line 443) | def __init__(self, net_init, loss_fn, eval_fn, method init (line 453) | def init(self, init_rng, data): method update (line 466) | def update(self, state: Mapping[str, Any], data: jnp.ndarray): method evaluate (line 492) | def evaluate(self, state: Mapping[str, Any], inputs: jnp.ndarray, function main (line 502) | def main(_): FILE: counterfactual_fairness/adult_pscf_config.py function get_config (line 20) | def get_config(): FILE: counterfactual_fairness/causal_network.py class Node (line 27) | class Node: method __init__ (line 42) | def __init__(self, distribution_module, parents=(), hidden=False): method __repr__ (line 77) | def __repr__(self): method dim (line 81) | def dim(self): method name (line 86) | def name(self): method hidden (line 90) | def hidden(self): method observed_value (line 94) | def observed_value(self): method find_ancestor (line 97) | def find_ancestor(self, name): method parents (line 107) | def parents(self): method distribution_module (line 111) | def distribution_module(self): method distribution (line 115) | def distribution(self): method make_distribution (line 119) | def make_distribution(self, node_to_replacement=None): method populate (line 150) | def populate(self, data, node_to_replacement=None): class DistributionModule (line 193) | class DistributionModule(hk.Module): method __init__ (line 200) | def __init__(self, column, index, dim): method dim (line 216) | def dim(self): method column (line 221) | def column(self): method index (line 225) | def index(self): method prepare_data (line 228) | def prepare_data(self, data): method _package_args (line 243) | def _package_args(self, args): class Gaussian (line 258) | class Gaussian(DistributionModule): method __init__ (line 261) | def __init__(self, column, index, dim=1, hidden_shape=(), method __call__ (line 272) | def __call__(self, *args): method prepare_data (line 293) | def prepare_data(self, data): class GaussianMixture (line 305) | class GaussianMixture(DistributionModule): method __init__ (line 308) | def __init__(self, column, num_components, dim=1): method __call__ (line 321) | def __call__(self, *args): class MLPMultinomial (line 353) | class MLPMultinomial(DistributionModule): method __init__ (line 356) | def __init__(self, column, index, dim, hidden_shape=(), method from_frame (line 375) | def from_frame(cls, data, column, hidden_shape=()): method __call__ (line 385) | def __call__(self, *args): method prepare_data (line 393) | def prepare_data(self, data): function populate (line 402) | def populate(nodes, dataframe, node_to_replacement=None): FILE: counterfactual_fairness/utils.py function get_dataset (line 29) | def get_dataset(dataset: pd.DataFrame, function multinomial_mode (line 45) | def multinomial_mode( function multinomial_class (line 74) | def multinomial_class( function multinomial_mode_ndarray (line 94) | def multinomial_mode_ndarray(probs: jax.Array) -> jax.Array: function multinomial_accuracy (line 112) | def multinomial_accuracy(distribution_or_probs: tfd.Distribution, function softmax_ndarray (line 130) | def softmax_ndarray(logits: jax.Array) -> jax.Array: function get_samples (line 139) | def get_samples(distribution, num_samples, seed=None): function mmd_loss (line 169) | def mmd_loss(distribution: tfd.Distribution, function mmd_loss_exact (line 238) | def mmd_loss_exact(distribution_a, distribution_b, num_samples, gamma=1.): function scalar_log_prob (line 281) | def scalar_log_prob(distribution, val): FILE: counterfactual_fairness/variational.py class Variational (line 27) | class Variational(hk.Module): method __init__ (line 33) | def __init__(self, method __call__ (line 61) | def __call__(self, *args) -> tfd.Distribution: FILE: cs_gan/cs.py class CS (line 27) | class CS(object): method __init__ (line 30) | def __init__(self, metric_net, generator, method connect (line 53) | def connect(self, data, generator_inputs): method _get_rip_loss (line 101) | def _get_rip_loss(self, img1, img2): method _get_measurement_error (line 120) | def _get_measurement_error(self, target_img, sample_img): method gen_loss_fn (line 128) | def gen_loss_fn(self, data, samples): method _build_optimization_components (line 132) | def _build_optimization_components( function _get_and_check_variables (line 146) | def _get_and_check_variables(module): FILE: cs_gan/file_utils.py class FileExporter (line 22) | class FileExporter(object): method __init__ (line 25) | def __init__(self, path, grid_height=None, zoom=1): method _reshape (line 42) | def _reshape(self, data): method save (line 76) | def save(self, data, name): FILE: cs_gan/gan.py class GAN (line 24) | class GAN(object): method __init__ (line 41) | def __init__(self, discriminator, generator, method connect (line 68) | def connect(self, data, generator_inputs): method gen_loss_fn (line 124) | def gen_loss_fn(self, data, samples): method _build_optimization_components (line 133) | def _build_optimization_components( method get_variables (line 153) | def get_variables(self): function _get_and_check_variables (line 159) | def _get_and_check_variables(module): FILE: cs_gan/image_metrics.py function get_image_metrics_for_samples (line 22) | def get_image_metrics_for_samples( FILE: cs_gan/main.py function main (line 79) | def main(argv): FILE: cs_gan/main_cs.py function main (line 70) | def main(argv): FILE: cs_gan/main_ode.py function _copy_vars (line 82) | def _copy_vars(v_list): function _restore_vars (line 90) | def _restore_vars(v_list, t_list): function _scale_vars (line 98) | def _scale_vars(s, v_list): function _acc_grads (line 103) | def _acc_grads(g_sum, g_w, g): function _compute_reg_grads (line 108) | def _compute_reg_grads(gen_grads, disc_vars): function run_model (line 115) | def run_model(prior, images, model, disc_reg_weight): function update_model (line 151) | def update_model(model, disc_grads, gen_grads, disc_opt, gen_opt, function main (line 172) | def main(argv): FILE: cs_gan/nets.py function _sn_custom_getter (line 24) | def _sn_custom_getter(): class ConvGenNet (line 33) | class ConvGenNet(snt.AbstractModule): method __init__ (line 36) | def __init__(self, name='conv_gen'): method _build (line 39) | def _build(self, inputs, is_training): class ConvMetricNet (line 60) | class ConvMetricNet(snt.AbstractModule): method __init__ (line 63) | def __init__(self, num_outputs=2, use_sn=True, name='sn_metric'): method _build (line 68) | def _build(self, inputs): class MLPGeneratorNet (line 90) | class MLPGeneratorNet(snt.AbstractModule): method __init__ (line 93) | def __init__(self, name='mlp_generator'): method _build (line 96) | def _build(self, inputs, is_training=True): class MLPMetricNet (line 104) | class MLPMetricNet(snt.AbstractModule): method __init__ (line 107) | def __init__(self, num_outputs=2, name='mlp_metric'): method _build (line 111) | def _build(self, inputs): FILE: cs_gan/tests/gan_test.py class DummyGenerator (line 22) | class DummyGenerator(snt.AbstractModule): method __init__ (line 24) | def __init__(self): method _build (line 27) | def _build(self, inputs, is_training): class GanTest (line 31) | class GanTest(tf.test.TestCase): method testConnect (line 33) | def testConnect(self): FILE: cs_gan/utils.py class ModelOutputs (line 29) | class ModelOutputs( class OptimizationComponent (line 47) | class OptimizationComponent( function cross_entropy_loss (line 63) | def cross_entropy_loss(logits, expected): function optimise_and_sample (line 88) | def optimise_and_sample(init_z, module, data, is_training): function get_optimisation_cost (line 117) | def get_optimisation_cost(initial_z, optimised_z): function _project_z (line 123) | def _project_z(z, project_method='clip'): class DataProcessor (line 134) | class DataProcessor(object): method preprocess (line 136) | def preprocess(self, x): method postprocess (line 139) | def postprocess(self, x): function _get_np_data (line 143) | def _get_np_data(data_processor, dataset, split='train'): function make_output_dir (line 165) | def make_output_dir(output_dir): function get_ckpt_dir (line 171) | def get_ckpt_dir(output_dir): function get_real_data_for_eval (line 178) | def get_real_data_for_eval(num_eval_samples, dataset, split='valid'): function get_summaries (line 184) | def get_summaries(ops): function get_train_dataset (line 196) | def get_train_dataset(data_processor, dataset, batch_size): function get_generator (line 213) | def get_generator(dataset): function get_metric_net (line 220) | def get_metric_net(dataset, num_outputs=2, use_sn=True): function make_prior (line 227) | def make_prior(num_latents): FILE: curl/layers.py class ResidualStack (line 25) | class ResidualStack(snt.AbstractModule): method __init__ (line 28) | def __init__(self, method _build (line 47) | def _build(self, h): class SharedConvModule (line 73) | class SharedConvModule(snt.AbstractModule): method __init__ (line 76) | def __init__(self, method _build (line 91) | def _build(self, x, is_training=True): FILE: curl/model.py class SharedEncoder (line 33) | class SharedEncoder(snt.AbstractModule): method __init__ (line 36) | def __init__(self, encoder_type, n_enc, enc_strides, name='shared_enco... method _build (line 63) | def _build(self, x, is_training): function cluster_encoder_fn (line 74) | def cluster_encoder_fn(hiddens, n_y_active, n_y, is_training=True): function latent_encoder_fn (line 107) | def latent_encoder_fn(hiddens, y, n_y, n_z, is_training=True): function data_decoder_fn (line 141) | def data_decoder_fn(z, function latent_decoder_fn (line 242) | def latent_decoder_fn(y, n_z, is_training=True): class Curl (line 272) | class Curl(object): method __init__ (line 275) | def __init__(self, method sample (line 298) | def sample(self, sample_shape=(), y=None, mean=False): method reconstruct (line 332) | def reconstruct(self, x, use_mode=True, use_mean=False): method log_prob (line 356) | def log_prob(self, x): method log_prob_elbo (line 361) | def log_prob_elbo(self, x): method log_prob_elbo_components (line 366) | def log_prob_elbo_components(self, x, y=None, reduce_op=tf.reduce_sum): method _kl_and_qy (line 477) | def _kl_and_qy(self, hiddens): method _kl_and_z (line 520) | def _kl_and_z(self, hiddens, y): method infer_latent (line 561) | def infer_latent(self, hiddens, y=None, use_mean_y=False): method generate_latent (line 594) | def generate_latent(self, y): method get_shared_rep (line 607) | def get_shared_rep(self, x, is_training): method infer_cluster (line 622) | def infer_cluster(self, hiddens): method predict (line 636) | def predict(self, z, y): method compute_prior (line 654) | def compute_prior(self): class UpsampleModule (line 665) | class UpsampleModule(snt.AbstractModule): method __init__ (line 687) | def __init__(self, method _build (line 720) | def _build(self, z, is_training=True, test_local_stats=True, use_bn=Fa... FILE: curl/train_main.py function main (line 43) | def main(unused_argv): FILE: curl/train_sup.py function main (line 28) | def main(unused_argv): FILE: curl/train_unsup.py function main (line 28) | def main(unused_argv): FILE: curl/training.py function compute_purity (line 48) | def compute_purity(confusion): function process_dataset (line 52) | def process_dataset(iterator, function get_data_sources (line 102) | def get_data_sources(dataset, dataset_kwargs, batch_size, test_batch_size, function setup_training_and_eval_graphs (line 227) | def setup_training_and_eval_graphs(x, label, y, n_y, curl_model, function get_generated_data (line 285) | def get_generated_data(sess, gen_op, y_input, gen_buffer_size, function setup_dynamic_ops (line 330) | def setup_dynamic_ops(n_y): function copy_component_params (line 431) | def copy_component_params(ind_from, ind_to, sess, ind_from_ph, ind_to_ph, function run_training (line 479) | def run_training( FILE: curl/unit_test.py class TrainingTest (line 27) | class TrainingTest(absltest.TestCase): method testRunTraining (line 29) | def testRunTraining(self): FILE: curl/utils.py function generate_gaussian (line 27) | def generate_gaussian(logits, sigma_nonlin, sigma_param): function construct_prior_probs (line 47) | def construct_prior_probs(batch_size, n_y, n_y_active): function maybe_center_crop (line 69) | def maybe_center_crop(layer, target_hw): FILE: density_functional_approximation_dm21/cc/dm21_aot_compiled_example.cc function run_dm21_compiled_functional (line 22) | void run_dm21_compiled_functional() { FILE: density_functional_approximation_dm21/cc/run_dm21_aot_compiled_example.cc function main (line 21) | int main(int argc, char** argv) { FILE: density_functional_approximation_dm21/density_functional_approximation_dm21/compute_hfx_density.py function _evaluate_nu_slow (line 81) | def _evaluate_nu_slow(mol: mole.Mole, function _evaluate_nu (line 97) | def _evaluate_nu(mol: mole.Mole, function _nu_chunk (line 115) | def _nu_chunk(mol: mole.Mole, function _compute_exx_block (line 151) | def _compute_exx_block(nu: np.ndarray, function _compute_jk_block (line 170) | def _compute_jk_block(nu: np.ndarray, fxx: np.ndarray, dm: np.ndarray, class HFDensityResult (line 185) | class HFDensityResult: function get_hf_density (line 214) | def get_hf_density( FILE: density_functional_approximation_dm21/density_functional_approximation_dm21/compute_hfx_density_test.py class ComputeHfxDensityTest (line 28) | class ComputeHfxDensityTest(parameterized.TestCase): method setUp (line 30) | def setUp(self): method test_closed_shell (line 41) | def test_closed_shell(self, omega): method test_hf_density_on_open_shell (line 77) | def test_hf_density_on_open_shell(self, omega): function _nu_test_systems (line 118) | def _nu_test_systems(): class NuTest (line 183) | class NuTest(parameterized.TestCase): method setUp (line 185) | def setUp(self): method test_nu_integrals (line 191) | def test_nu_integrals(self, atom, charge, spin, basis, num_grids, hermi): method test_range_separated_nu (line 205) | def test_range_separated_nu(self): FILE: density_functional_approximation_dm21/density_functional_approximation_dm21/export_saved_model.py function export (line 38) | def export( function main (line 55) | def main(argv: Sequence[str]) -> None: FILE: density_functional_approximation_dm21/density_functional_approximation_dm21/neural_numint.py class Functional (line 38) | class Functional(enum.Enum): class FunctionalInputs (line 64) | class FunctionalInputs: class _GridState (line 109) | class _GridState: class _SystemState (line 131) | class _SystemState: function _get_number_of_density_matrices (line 147) | def _get_number_of_density_matrices(dms): class NeuralNumInt (line 155) | class NeuralNumInt(numint.NumInt): method __init__ (line 170) | def __init__(self, method _build_graph (line 207) | def _build_graph(self, batch_dim: Optional[int] = None): method export_functional_and_derivatives (line 326) | def export_functional_and_derivatives( method rsh_coeff (line 360) | def rsh_coeff(self, *args): method hybrid_coeff (line 364) | def hybrid_coeff(self, *args, **kwargs): method _xc_type (line 368) | def _xc_type(self, *args, **kwargs): method nr_rks (line 371) | def nr_rks(self, method nr_uks (line 435) | def nr_uks(self, method block_loop (line 509) | def block_loop( method construct_functional_inputs (line 564) | def construct_functional_inputs( method eval_xc (line 646) | def eval_xc( FILE: density_functional_approximation_dm21/density_functional_approximation_dm21/neural_numint_test.py class NeuralNumintTest (line 29) | class NeuralNumintTest(tf.test.TestCase, parameterized.TestCase): method setUp (line 31) | def setUp(self): method test_rks (line 56) | def test_rks(self, functional, expected_energy): method test_uks (line 88) | def test_uks(self, functional, expected_energy): method test_exported_model (line 103) | def test_exported_model(self): FILE: enformer/attention_module.py class TransformerBlock (line 66) | class TransformerBlock(snt.Module): method __init__ (line 69) | def __init__( method __call__ (line 87) | def __call__(self, inputs: tf.Tensor, is_training: bool) -> tf.Tensor: class MultiheadAttention (line 104) | class MultiheadAttention(snt.Module): method __init__ (line 107) | def __init__(self, method _multihead_output (line 207) | def _multihead_output(self, linear, inputs): method __call__ (line 218) | def __call__(self, function relative_shift (line 283) | def relative_shift(x): function get_positional_feature_function (line 297) | def get_positional_feature_function(name): function positional_features_all (line 312) | def positional_features_all(positions: tf.Tensor, function _prepend_dims (line 372) | def _prepend_dims(x, num_dims): function positional_features_exponential (line 376) | def positional_features_exponential(positions: tf.Tensor, function positional_features_central_mask (line 409) | def positional_features_central_mask(positions: tf.Tensor, function gamma_pdf (line 426) | def gamma_pdf(x, concentration, rate): function positional_features_gamma (line 434) | def positional_features_gamma(positions: tf.Tensor, function positional_features_cosine (line 463) | def positional_features_cosine(positions: tf.Tensor, function positional_features_linear_masks (line 479) | def positional_features_linear_masks(positions: tf.Tensor, function positional_features_sin_cos (line 496) | def positional_features_sin_cos(positions: tf.Tensor, FILE: enformer/enformer.py class Enformer (line 42) | class Enformer(snt.Module): method __init__ (line 45) | def __init__(self, method trunk (line 164) | def trunk(self): method heads (line 168) | def heads(self): method __call__ (line 171) | def __call__(self, inputs: tf.Tensor, method predict_on_batch (line 181) | def predict_on_batch(self, x): class TargetLengthCrop1D (line 186) | class TargetLengthCrop1D(snt.Module): method __init__ (line 189) | def __init__(self, method __call__ (line 195) | def __call__(self, inputs): class Sequential (line 207) | class Sequential(snt.Module): method __init__ (line 210) | def __init__(self, method __call__ (line 223) | def __call__(self, inputs: tf.Tensor, is_training: bool, **kwargs): function pooling_module (line 233) | def pooling_module(kind, pool_size): class SoftmaxPooling1D (line 244) | class SoftmaxPooling1D(snt.Module): method __init__ (line 247) | def __init__(self, method _initialize (line 270) | def _initialize(self, num_features): method __call__ (line 276) | def __call__(self, inputs): class Residual (line 287) | class Residual(snt.Module): method __init__ (line 290) | def __init__(self, module: snt.Module, name='residual'): method __call__ (line 294) | def __call__(self, inputs: tf.Tensor, is_training: bool, *args, function gelu (line 299) | def gelu(x: tf.Tensor) -> tf.Tensor: function one_hot_encode (line 313) | def one_hot_encode(sequence: str, function exponential_linspace_int (line 328) | def exponential_linspace_int(start, end, num, divisible_by=1): function accepts_is_training (line 337) | def accepts_is_training(module): FILE: enformer/enformer_test.py class TestEnformer (line 28) | class TestEnformer(unittest.TestCase): method test_enformer (line 30) | def test_enformer(self): function _get_random_input (line 38) | def _get_random_input(): FILE: fusion_tcv/agent.py class AbstractAgent (line 24) | class AbstractAgent(abc.ABC): method reset (line 27) | def reset(self): method step (line 31) | def step(self, timestep: dm_env.TimeStep) -> np.ndarray: class ZeroAgent (line 35) | class ZeroAgent(AbstractAgent): method step (line 38) | def step(self, timestep: dm_env.TimeStep) -> np.ndarray: FILE: fusion_tcv/combiners.py class AbstractCombiner (line 27) | class AbstractCombiner(targets.AbstractTarget): method __call__ (line 31) | def __call__(self, values: List[float], # pytype: disable=signature-m... method outputs (line 36) | def outputs(self) -> int: method _clean_values_weights (line 41) | def _clean_values_weights( class Mean (line 60) | class Mean(AbstractCombiner): method __call__ (line 66) | def __call__(self, values: List[float], function _multiply (line 74) | def _multiply(values, weights, mean): class Multiply (line 96) | class Multiply(AbstractCombiner): method __call__ (line 108) | def __call__(self, values: List[float], class GeometricMean (line 116) | class GeometricMean(AbstractCombiner): method __call__ (line 125) | def __call__(self, values: List[float], class Min (line 133) | class Min(AbstractCombiner): method __call__ (line 136) | def __call__(self, values: List[float], class Max (line 144) | class Max(AbstractCombiner): method __call__ (line 147) | def __call__(self, values: List[float], class LNorm (line 156) | class LNorm(AbstractCombiner): method __call__ (line 178) | def __call__(self, values: List[float], class SmoothMax (line 190) | class SmoothMax(AbstractCombiner): method __call__ (line 208) | def __call__(self, values: List[float], FILE: fusion_tcv/combiners_test.py class CombinersTest (line 25) | class CombinersTest(absltest.TestCase): method assertNan (line 27) | def assertNan(self, value): method test_errors (line 31) | def test_errors(self): method test_mean (line 40) | def test_mean(self): method test_geometric_mean (line 49) | def test_geometric_mean(self): method test_multiply (line 63) | def test_multiply(self): method test_min (line 76) | def test_min(self): method test_max (line 86) | def test_max(self): method test_lnorm (line 96) | def test_lnorm(self): method test_smoothmax (line 122) | def test_smoothmax(self): FILE: fusion_tcv/environment.py class Environment (line 38) | class Environment(auto_reset_environment.AutoResetEnvironment): method __init__ (line 47) | def __init__( method observation_spec (line 94) | def observation_spec(self): method action_spec (line 98) | def action_spec(self) -> specs.BoundedArray: method _reset (line 102) | def _reset(self) -> dm_env.TimeStep: method _simulator_voltages_from_voltages (line 114) | def _simulator_voltages_from_voltages(self, voltages): method _step (line 128) | def _step(self, action: np.ndarray) -> dm_env.TimeStep: method _extract_observation (line 151) | def _extract_observation( FILE: fusion_tcv/experiments.py function fundamental_capability (line 22) | def fundamental_capability() -> environment.Environment: function elongation (line 31) | def elongation() -> environment.Environment: function iter (line 40) | def iter() -> environment.Environment: # pylint: disable=redefined-builtin function negative_triangularity (line 49) | def negative_triangularity() -> environment.Environment: function snowflake (line 58) | def snowflake() -> environment.Environment: function droplet (line 67) | def droplet() -> environment.Environment: FILE: fusion_tcv/experiments_test.py class FundamentalCapabilityTest (line 25) | class FundamentalCapabilityTest(test_utils.EnvironmentTestMixin, method make_object_under_test (line 28) | def make_object_under_test(self): class ElongationTest (line 32) | class ElongationTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 34) | def make_object_under_test(self): class IterTest (line 38) | class IterTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 40) | def make_object_under_test(self): class NegativeTriangularityTest (line 44) | class NegativeTriangularityTest(test_utils.EnvironmentTestMixin, method make_object_under_test (line 47) | def make_object_under_test(self): class SnowflakeTest (line 51) | class SnowflakeTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 53) | def make_object_under_test(self): class DropletTest (line 57) | class DropletTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 59) | def make_object_under_test(self): class ExperimentsTest (line 63) | class ExperimentsTest(parameterized.TestCase): method test_env (line 73) | def test_env(self, env_fn): FILE: fusion_tcv/fge_octave.py class ShotCondition (line 28) | class ShotCondition: class FGESimulatorOctave (line 34) | class FGESimulatorOctave: method __init__ (line 40) | def __init__( method reset (line 63) | def reset(self, variation: param_variation.Settings) -> fge_state.FGES... method step (line 70) | def step(self, voltages: np.ndarray) -> fge_state.FGEState: FILE: fusion_tcv/fge_state.py class StopSignalException (line 25) | class StopSignalException(Exception): # pylint: disable=g-bad-exception... class InvalidSolutionError (line 30) | class InvalidSolutionError(RuntimeError): class UnhandledOctaveError (line 35) | class UnhandledOctaveError(Exception): class FGEState (line 40) | class FGEState: method __init__ (line 47) | def __init__(self, num_plasmas): method num_plasmas (line 51) | def num_plasmas(self) -> int: method rzip_d (line 55) | def rzip_d(self) -> Tuple[List[float], List[float], List[float]]: method get_coil_currents_by_type (line 62) | def get_coil_currents_by_type(self, coil_type) -> np.ndarray: method get_lcfs_points (line 66) | def get_lcfs_points(self, domain: int) -> shape.ShapePoints: method get_observation_vector (line 70) | def get_observation_vector(self) -> np.ndarray: method elongation (line 74) | def elongation(self) -> List[float]: method triangularity (line 78) | def triangularity(self) -> List[float]: method radius (line 82) | def radius(self) -> List[float]: method limit_point_d (line 86) | def limit_point_d(self) -> List[shape.Point]: method is_diverted_d (line 90) | def is_diverted_d(self) -> List[bool]: method x_points (line 94) | def x_points(self) -> shape.ShapePoints: method flux (line 98) | def flux(self) -> np.ndarray: method magnetic_axis_flux_strength (line 103) | def magnetic_axis_flux_strength(self) -> float: method lcfs_flux_strength (line 108) | def lcfs_flux_strength(self) -> float: method r_coordinates (line 113) | def r_coordinates(self) -> np.ndarray: method z_coordinates (line 119) | def z_coordinates(self): FILE: fusion_tcv/named_array.py function lengths_to_ranges (line 21) | def lengths_to_ranges( class NamedRanges (line 32) | class NamedRanges: method __init__ (line 35) | def __init__(self, counts: Mapping[str, int]): method __getitem__ (line 39) | def __getitem__(self, name) -> List[int]: method __contains__ (line 42) | def __contains__(self, name) -> bool: method set_range (line 45) | def set_range(self, name: str, value: List[int]): method range (line 49) | def range(self, name: str) -> List[int]: method index (line 52) | def index(self, name: str) -> int: method count (line 58) | def count(self, name: str) -> int: method names (line 61) | def names(self) -> Iterable[str]: method ranges (line 64) | def ranges(self) -> Iterable[Tuple[str, List[int]]]: method counts (line 67) | def counts(self) -> Mapping[str, int]: method size (line 71) | def size(self) -> int: method named_array (line 74) | def named_array(self, array: np.ndarray) -> "NamedArray": method new_named_array (line 77) | def new_named_array(self) -> "NamedArray": method new_random_named_array (line 80) | def new_random_named_array(self) -> "NamedArray": class NamedArray (line 84) | class NamedArray: method __init__ (line 87) | def __init__(self, array: np.ndarray, names: NamedRanges): method __getitem__ (line 93) | def __getitem__( method __setitem__ (line 109) | def __setitem__( method array (line 119) | def array(self) -> np.ndarray: method names (line 123) | def names(self) -> NamedRanges: method to_dict (line 126) | def to_dict(self) -> Mapping[str, np.ndarray]: FILE: fusion_tcv/named_array_test.py class NamedRangesTest (line 22) | class NamedRangesTest(absltest.TestCase): method test_lengths_to_ranges (line 24) | def test_lengths_to_ranges(self): method test_named_ranges (line 28) | def test_named_ranges(self): class NamedArrayTest (line 52) | class NamedArrayTest(absltest.TestCase): method test_name_array (line 54) | def test_name_array(self): FILE: fusion_tcv/noise.py class Noise (line 22) | class Noise: method __init__ (line 25) | def __init__(self, method use_zero_noise (line 56) | def use_zero_noise(cls): method use_default_noise (line 69) | def use_default_noise(cls, scale=1): method add_action_noise (line 102) | def add_action_noise(self, action): method add_measurement_noise (line 108) | def add_measurement_noise(self, measurement_vec): FILE: fusion_tcv/param_variation.py class Settings (line 33) | class Settings: method _psu_voltage_offset_string (line 55) | def _psu_voltage_offset_string(self) -> str: class ParamGenerator (line 74) | class ParamGenerator: method __init__ (line 81) | def __init__(self, method generate (line 107) | def generate(self) -> Settings: function loguniform_rv (line 121) | def loguniform_rv(lower, upper): FILE: fusion_tcv/ref_gen.py class AbstractReferenceGenerator (line 27) | class AbstractReferenceGenerator(abc.ABC): method reset (line 31) | def reset(self) -> named_array.NamedArray: method step (line 35) | def step(self) -> named_array.NamedArray: class LinearTransition (line 40) | class LinearTransition: class LinearTransitionReferenceGenerator (line 46) | class LinearTransitionReferenceGenerator(AbstractReferenceGenerator): method __init__ (line 49) | def __init__(self, start_offset: int = 0): method _next_transition (line 55) | def _next_transition(self) -> LinearTransition: method reset (line 58) | def reset(self) -> named_array.NamedArray: method _reset_counters (line 65) | def _reset_counters(self): method step (line 70) | def step(self) -> named_array.NamedArray: class FixedReferenceGenerator (line 98) | class FixedReferenceGenerator(LinearTransitionReferenceGenerator): method __init__ (line 101) | def __init__(self, transitions: List[LinearTransition], method reset (line 107) | def reset(self) -> named_array.NamedArray: method _next_transition (line 111) | def _next_transition(self) -> LinearTransition: class TimedTransition (line 122) | class TimedTransition: class ParametrizedShapeTimedTarget (line 128) | class ParametrizedShapeTimedTarget: class PresetShapePointsReferenceGenerator (line 134) | class PresetShapePointsReferenceGenerator(FixedReferenceGenerator): method __init__ (line 137) | def __init__( class ShapeFromShot (line 151) | class ShapeFromShot(PresetShapePointsReferenceGenerator): method __init__ (line 154) | def __init__( class RZIpTarget (line 188) | class RZIpTarget: function make_symmetric_multidomain_rzip_reference (line 194) | def make_symmetric_multidomain_rzip_reference( FILE: fusion_tcv/references.py function fundamental_capability (line 25) | def fundamental_capability() -> ref_gen.AbstractReferenceGenerator: function elongation (line 341) | def elongation() -> ref_gen.AbstractReferenceGenerator: function negative_triangularity (line 366) | def negative_triangularity() -> ref_gen.AbstractReferenceGenerator: function snowflake (line 386) | def snowflake() -> ref_gen.AbstractReferenceGenerator: function iter (line 700) | def iter() -> ref_gen.AbstractReferenceGenerator: # pylint: disable=red... function droplet (line 1565) | def droplet() -> ref_gen.AbstractReferenceGenerator: FILE: fusion_tcv/references_main.py function print_ref (line 36) | def print_ref(step: int, ref: named_array.NamedArray): function main (line 42) | def main(argv: Sequence[str]) -> None: FILE: fusion_tcv/rewards.py class AbstractMeasure (line 32) | class AbstractMeasure(abc.ABC): method __call__ (line 35) | def __call__(self, targets: List[targets_lib.Target]) -> List[float]: class AbsDist (line 39) | class AbsDist(AbstractMeasure): method __call__ (line 43) | def __call__(targets: List[targets_lib.Target]) -> List[float]: class MeasureDetails (line 48) | class MeasureDetails: class RewardDetails (line 55) | class RewardDetails: class AbstractReward (line 62) | class AbstractReward(abc.ABC): method reward (line 66) | def reward( method terminal_reward (line 75) | def terminal_reward(self) -> float: class Component (line 84) | class Component: class Reward (line 92) | class Reward(AbstractReward): method __init__ (line 111) | def __init__(self, method terminal_reward (line 136) | def terminal_reward(self) -> float: method reward (line 139) | def reward( FILE: fusion_tcv/run_loop.py function run_loop (line 22) | def run_loop(env: environment.Environment, agent, FILE: fusion_tcv/shape.py class Point (line 29) | class Point(NamedTuple): method to_polar (line 34) | def to_polar(self) -> "PolarPoint": method __neg__ (line 38) | def __neg__(self): method __add__ (line 41) | def __add__(self, pt_or_val: Union["Point", float]): method __sub__ (line 47) | def __sub__(self, pt_or_val: Union["Point", float]): method __mul__ (line 53) | def __mul__(self, pt_or_val: Union["Point", float]): method __truediv__ (line 59) | def __truediv__(self, pt_or_val: Union["Point", float]): function dist (line 68) | def dist(p1: Union[Point, np.ndarray], p2: Union[Point, np.ndarray]) -> ... function to_shape_points (line 75) | def to_shape_points(array: np.ndarray) -> ShapePoints: function center_point (line 79) | def center_point(points: ShapePoints) -> Point: class ShapeSide (line 83) | class ShapeSide(enum.Enum): class PolarPoint (line 89) | class PolarPoint(NamedTuple): method to_point (line 93) | def to_point(self) -> Point: function evenly_spaced_angles (line 97) | def evenly_spaced_angles(num: int): function angle_aligned_dists (line 101) | def angle_aligned_dists(points: np.ndarray, angles: np.ndarray) -> np.nd... function angle_aligned_points (line 109) | def angle_aligned_points(points: np.ndarray, num_points: int, function dist_angle_to_surface (line 118) | def dist_angle_to_surface(points: np.ndarray, angle: float) -> float: function dist_angle_to_segment (line 127) | def dist_angle_to_segment(p1, p2, angle: float) -> Optional[float]: function dist_point_to_surface (line 149) | def dist_point_to_surface(points: np.ndarray, point: np.ndarray) -> float: function dist_point_to_segment (line 155) | def dist_point_to_segment(v: np.ndarray, w: np.ndarray, p: np.ndarray) -... function sort_by_angle (line 170) | def sort_by_angle(points: ShapePoints) -> ShapePoints: function spline_interpolate_points (line 175) | def spline_interpolate_points( class ParametrizedShape (line 211) | class ParametrizedShape: method uniform_random_shape (line 222) | def uniform_random_shape( method gen_points (line 242) | def gen_points(self, num_points: int) -> Tuple[ShapePoints, Point]: function trim_zero_points (line 264) | def trim_zero_points(points: ShapePoints) -> Optional[ShapePoints]: class Diverted (line 269) | class Diverted(enum.Enum): method from_refs (line 276) | def from_refs(cls, references: named_array.NamedArray) -> "Diverted": class Shape (line 289) | class Shape: method from_references (line 300) | def from_references(cls, references: named_array.NamedArray) -> "Shape": method gen_references (line 327) | def gen_references(self) -> named_array.NamedArray: method canonical (line 369) | def canonical(self) -> "Shape": function points_from_references (line 408) | def points_from_references( class ReferenceTimeSlice (line 418) | class ReferenceTimeSlice: method __post_init__ (line 423) | def __post_init__(self): function canonicalize_reference_series (line 428) | def canonicalize_reference_series( FILE: fusion_tcv/targets.py class TargetError (line 30) | class TargetError(Exception): class Target (line 35) | class Target: method invalid (line 40) | def invalid(cls): class AbstractTarget (line 45) | class AbstractTarget(abc.ABC): method name (line 49) | def name(self) -> str: method outputs (line 54) | def outputs(self) -> int: method __call__ (line 58) | def __call__( class AbstractSingleValuePerDomainTarget (line 67) | class AbstractSingleValuePerDomainTarget(AbstractTarget): method __post_init__ (line 72) | def __post_init__(self): method outputs (line 80) | def outputs(self) -> int: method name (line 84) | def name(self) -> str: class R (line 89) | class R(AbstractSingleValuePerDomainTarget): method __call__ (line 92) | def __call__(self, class Z (line 105) | class Z(AbstractSingleValuePerDomainTarget): method __call__ (line 108) | def __call__(self, class Ip (line 121) | class Ip(AbstractSingleValuePerDomainTarget): method __call__ (line 124) | def __call__(self, class OHCurrentsClose (line 136) | class OHCurrentsClose(AbstractTarget): method outputs (line 140) | def outputs(self) -> int: method __call__ (line 143) | def __call__(self, class EFCurrents (line 152) | class EFCurrents(AbstractTarget): method outputs (line 156) | def outputs(self) -> int: method __call__ (line 159) | def __call__(self, class VoltageOOB (line 169) | class VoltageOOB(AbstractTarget): method outputs (line 174) | def outputs(self) -> int: method __call__ (line 177) | def __call__(self, class ShapeElongation (line 190) | class ShapeElongation(AbstractSingleValuePerDomainTarget): method __call__ (line 193) | def __call__( class ShapeTriangularity (line 207) | class ShapeTriangularity(AbstractSingleValuePerDomainTarget): method __call__ (line 210) | def __call__( class ShapeRadius (line 224) | class ShapeRadius(AbstractSingleValuePerDomainTarget): method __call__ (line 227) | def __call__( class AbstractPointsTarget (line 241) | class AbstractPointsTarget(AbstractTarget): method __post_init__ (line 247) | def __post_init__(self): method outputs (line 263) | def outputs(self) -> int: method _target_points (line 266) | def _target_points( function splined_lcfs_points (line 275) | def splined_lcfs_points( class ShapeLCFSDistance (line 297) | class ShapeLCFSDistance(AbstractPointsTarget): method __call__ (line 305) | def __call__( function flux_at_points (line 321) | def flux_at_points(state: fge_state.FGEState, points: np.ndarray) -> np.... class ShapeNormalizedLCFSFlux (line 336) | class ShapeNormalizedLCFSFlux(AbstractPointsTarget): method __call__ (line 345) | def __call__( class LegsNormalizedFlux (line 361) | class LegsNormalizedFlux(ShapeNormalizedLCFSFlux): class AbstractXPointTarget (line 367) | class AbstractXPointTarget(AbstractPointsTarget): class XPointFluxGradient (line 373) | class XPointFluxGradient(AbstractXPointTarget): method __call__ (line 376) | def __call__( function _dist (line 400) | def _dist(p1: shape.Point, p2: shape.Point): function _min_dist (line 404) | def _min_dist(pt: shape.Point, points: shape.ShapePoints, class XPointDistance (line 416) | class XPointDistance(AbstractXPointTarget): method __call__ (line 432) | def __call__( class XPointFar (line 451) | class XPointFar(AbstractXPointTarget): method __call__ (line 472) | def __call__( class XPointNormalizedFlux (line 502) | class XPointNormalizedFlux(AbstractXPointTarget): method __call__ (line 510) | def __call__( class XPointCount (line 542) | class XPointCount(AbstractTarget): method outputs (line 547) | def outputs(self) -> int: method __call__ (line 550) | def __call__( class Diverted (line 565) | class Diverted(AbstractTarget): method outputs (line 570) | def outputs(self) -> int: method __call__ (line 573) | def __call__( class LimitPoint (line 592) | class LimitPoint(AbstractPointsTarget): method __call__ (line 599) | def __call__( FILE: fusion_tcv/tcv_common.py function observation_spec (line 183) | def observation_spec(): function measurements_to_dict (line 204) | def measurements_to_dict(measurements): function dict_to_measurement (line 224) | def dict_to_measurement(measurement_dict): function action_spec (line 251) | def action_spec(): function get_coil_spec (line 255) | def get_coil_spec(coil_names: Sequence[Text], FILE: fusion_tcv/terminations.py class Abstract (line 25) | class Abstract(abc.ABC): method terminate (line 29) | def terminate(self, state: fge_state.FGEState) -> Optional[str]: class CoilCurrentSaturation (line 33) | class CoilCurrentSaturation(Abstract): method terminate (line 36) | def terminate(self, state: fge_state.FGEState) -> Optional[str]: class OHTooDifferent (line 48) | class OHTooDifferent(Abstract): method __init__ (line 51) | def __init__(self, max_diff: float): method terminate (line 54) | def terminate(self, state: fge_state.FGEState) -> Optional[str]: class IPTooLow (line 65) | class IPTooLow(Abstract): method __init__ (line 68) | def __init__(self, singlet_threshold: float, droplet_threshold: float): method terminate (line 72) | def terminate(self, state: fge_state.FGEState) -> Optional[str]: class AnyTermination (line 84) | class AnyTermination(Abstract): method __init__ (line 87) | def __init__(self, terminators: List[Abstract]): method terminate (line 90) | def terminate(self, state: fge_state.FGEState) -> Optional[str]: FILE: fusion_tcv/trajectory.py class Trajectory (line 23) | class Trajectory: method stack (line 31) | def stack(cls, series: List["Trajectory"]) -> "Trajectory": FILE: fusion_tcv/transforms.py class AbstractTransform (line 31) | class AbstractTransform(abc.ABC): method __call__ (line 34) | def __call__(self, errors: List[float]) -> List[float]: method outputs (line 38) | def outputs(self) -> Optional[int]: function clip (line 42) | def clip(value: float, low: float, high: float) -> float: function scale (line 50) | def scale(v: float, a: float, b: float, c: float, d: float) -> float: function logistic (line 56) | def logistic(v: float) -> float: class Equal (line 63) | class Equal(AbstractTransform): method __call__ (line 67) | def __call__(self, errors: List[float]) -> List[float]: class Abs (line 79) | class Abs(AbstractTransform): method __call__ (line 83) | def __call__(errors: List[float]) -> List[float]: class Neg (line 87) | class Neg(AbstractTransform): method __call__ (line 91) | def __call__(errors: List[float]) -> List[float]: class Pow (line 96) | class Pow(AbstractTransform): method __call__ (line 100) | def __call__(self, errors: List[float]) -> List[float]: class Log (line 105) | class Log(AbstractTransform): method __call__ (line 109) | def __call__(self, errors: List[float]) -> List[float]: class ClippedLinear (line 114) | class ClippedLinear(AbstractTransform): method __call__ (line 119) | def __call__(self, errors: List[float]) -> List[float]: class SoftPlus (line 125) | class SoftPlus(AbstractTransform): method __call__ (line 141) | def __call__(self, errors: List[float]) -> List[float]: class NegExp (line 147) | class NegExp(AbstractTransform): method __call__ (line 163) | def __call__(self, errors: List[float]) -> List[float]: class Sigmoid (line 169) | class Sigmoid(AbstractTransform): method __call__ (line 180) | def __call__(self, errors: List[float]) -> List[float]: FILE: fusion_tcv/transforms_test.py class TransformsTest (line 25) | class TransformsTest(absltest.TestCase): method assertNan (line 27) | def assertNan(self, value: float): method test_clip (line 30) | def test_clip(self): method test_scale (line 36) | def test_scale(self): method test_logistic (line 54) | def test_logistic(self): method test_exp_scaled (line 63) | def test_exp_scaled(self): method test_neg (line 81) | def test_neg(self): method test_abs (line 86) | def test_abs(self): method test_pow (line 91) | def test_pow(self): method test_log (line 96) | def test_log(self): method test_clipped_linear (line 101) | def test_clipped_linear(self): method test_softplus (line 120) | def test_softplus(self): method test_sigmoid (line 134) | def test_sigmoid(self): method test_equal (line 150) | def test_equal(self): FILE: galaxy_mergers/antennae_helpers.py function norm_antennae_images (line 27) | def norm_antennae_images(images, scale=1000): function renorm_antennae (line 31) | def renorm_antennae(images): function get_antennae_images (line 37) | def get_antennae_images(antennae_fits_dir): function preprocess_antennae_images (line 74) | def preprocess_antennae_images(antennae_images): FILE: galaxy_mergers/config.py function get_config (line 21) | def get_config(filter_time_intervals=None): FILE: galaxy_mergers/evaluator.py class GalaxyMergeClassifierEvaluator (line 30) | class GalaxyMergeClassifierEvaluator(): method __init__ (line 33) | def __init__(self, strategy, optimizer_config, total_train_batch_size, method build_eval_input (line 62) | def build_eval_input(self, additional_lambdas=None): method run_test_model_ensemble (line 130) | def run_test_model_ensemble(self, images, physical_features, augmentat... method checkpoint_items (line 170) | def checkpoint_items(self): function run_model_on_dataset (line 174) | def run_model_on_dataset(evaluator, dataset, config, n_batches=16): function get_config_dataset_evaluator (line 230) | def get_config_dataset_evaluator(filter_time_intervals, FILE: galaxy_mergers/helpers.py function restore_checkpoint (line 28) | def restore_checkpoint(checkpoint_dir, experiment): function sum_average_transformed_mu_and_sigma (line 37) | def sum_average_transformed_mu_and_sigma(mu, log_sigma_sq): function aggregate_regression_ensemble (line 74) | def aggregate_regression_ensemble(logits_or_times, ensemble_size, function aggregate_classification_ensemble (line 92) | def aggregate_classification_ensemble(logits_or_times, ensemble_size, function unpack_evaluator_output (line 107) | def unpack_evaluator_output(data, return_seq_info=False, return_redshift... function process_data_into_myrs (line 128) | def process_data_into_myrs(redshifts, *data_lists): function print_rmse_and_class_accuracy (line 147) | def print_rmse_and_class_accuracy(mus, regression_targets, redshifts): function print_stats (line 162) | def print_stats(vec, do_print=True): function get_image_from_fits (line 169) | def get_image_from_fits(base_dir, seq='475_31271', time='497', axis=2): function stack_desired_galaxy_images (line 188) | def stack_desired_galaxy_images(base_dir, seq, n_time_slices): function draw_galaxy_image (line 205) | def draw_galaxy_image(image, target_size=None, color_map='viridis'): function collect_merger_sequence (line 216) | def collect_merger_sequence(ds, seq=b'370_11071', n_examples_to_sift=5000): function take_samples (line 227) | def take_samples(sample_idxs, *data_lists): FILE: galaxy_mergers/interpretability_helpers.py function rotate_by_right_angle_multiple (line 22) | def rotate_by_right_angle_multiple(image, rot=90): function compute_gradient (line 39) | def compute_gradient(images, evaluator, is_training=False): function compute_grads_for_rotations (line 48) | def compute_grads_for_rotations(images, evaluator, is_training=False): function compute_grads_for_rotations_and_flips (line 60) | def compute_grads_for_rotations_and_flips(images, evaluator): FILE: galaxy_mergers/losses.py function normalize_regression_loss (line 36) | def normalize_regression_loss(regression_loss, predictions): function equal32 (line 46) | def equal32(x, y): function mse_loss (line 50) | def mse_loss(predicted, targets): function get_std_factor_from_confidence_percent (line 54) | def get_std_factor_from_confidence_percent(percent): function get_all_metric_names (line 60) | def get_all_metric_names(task_type, model_uncertainty, loss_config, # p... function compute_loss_and_metrics (line 80) | def compute_loss_and_metrics(mu, log_sigma_sq, FILE: galaxy_mergers/main.py function main (line 35) | def main(_) -> None: FILE: galaxy_mergers/model.py class ResNet (line 24) | class ResNet(snt.Module): method __init__ (line 27) | def __init__(self, method __call__ (line 137) | def __call__(self, inputs, features, is_training): class LinearBNReLU (line 178) | class LinearBNReLU(snt.Module): method __init__ (line 181) | def __init__(self, output_size=64, method __call__ (line 198) | def __call__(self, x, is_training): FILE: galaxy_mergers/preprocessing.py function _make_padding_sizes (line 41) | def _make_padding_sizes(pad_size, random_centering): function resize_and_pad (line 51) | def resize_and_pad(image, target_size, random_centering): function resize_and_extract (line 69) | def resize_and_extract(image, target_size, random_centering): function resize_and_center (line 92) | def resize_and_center(image, target_size, random_centering): function random_rotation_and_flip (line 99) | def random_rotation_and_flip(image): function get_all_rotations_and_flips (line 104) | def get_all_rotations_and_flips(images): function random_rescaling (line 115) | def random_rescaling(image, random_centering): function get_all_rescalings (line 126) | def get_all_rescalings(images, image_width, random_centering): function move_repeats_to_batch (line 140) | def move_repeats_to_batch(image, n_repeats): function get_classification_label (line 147) | def get_classification_label(dataset_row, class_boundaries): function get_regression_label (line 157) | def get_regression_label(dataset_row, task_type): function get_normalized_time_target (line 171) | def get_normalized_time_target(dataset_row): function apply_time_filter (line 175) | def apply_time_filter(dataset_row, time_interval): function normalize_physical_feature (line 182) | def normalize_physical_feature(name, dataset_row): function prepare_dataset (line 188) | def prepare_dataset(ds, target_size, crop_type, n_repeats, augmentations, FILE: gated_linear_networks/base.py function _l2_normalize (line 37) | def _l2_normalize(x: Array, axis: int) -> Array: function _wrapped_fn_argnames (line 41) | def _wrapped_fn_argnames(fun): function _vmap (line 46) | def _vmap(fun, in_axes=0, out_axes=0, parameters=None): class NormalizedRandomNormal (line 70) | class NormalizedRandomNormal(hk.initializers.RandomNormal): method __init__ (line 73) | def __init__(self, method __call__ (line 80) | def __call__(self, shape: Shape, dtype: DType) -> Array: class ShapeScaledConstant (line 88) | class ShapeScaledConstant(hk.initializers.Initializer): method __call__ (line 91) | def __call__(self, shape: Shape, dtype: DType) -> jnp.ndarray: class LocalUpdateModule (line 96) | class LocalUpdateModule(hk.Module): method __init__ (line 99) | def __init__(self, name: Optional[str] = None): method inference (line 106) | def inference(self, *args, **kwargs): method update (line 110) | def update(self, *args, **kwargs): method output_sizes (line 115) | def output_sizes(self) -> Shape: class GatedLinearNetwork (line 119) | class GatedLinearNetwork(LocalUpdateModule): method __init__ (line 122) | def __init__(self, method _add_bias (line 152) | def _add_bias(self, inputs): method inference (line 155) | def inference(self, inputs: Array, side_info: Array, *args, method update (line 167) | def update(self, inputs, side_info, target, learning_rate, *args, **kw... method output_sizes (line 191) | def output_sizes(self): method _compute_context (line 195) | def _compute_context( class _GatedLinearLayer (line 210) | class _GatedLinearLayer(LocalUpdateModule): method __init__ (line 213) | def __init__(self, method _get_weights (line 235) | def _get_weights(self, input_size): method _get_hyperplanes (line 246) | def _get_hyperplanes(self, side_info_size): method inference (line 264) | def inference(self, inputs: Array, side_info: Array, *args, method update (line 281) | def update(self, inputs: Array, side_info: Array, target: Array, method output_sizes (line 305) | def output_sizes(self): class Mutator (line 309) | class Mutator(LocalUpdateModule): method __init__ (line 312) | def __init__( method output_sizes (line 322) | def output_sizes(self): class LastNeuronAggregator (line 326) | class LastNeuronAggregator(Mutator): method __init__ (line 329) | def __init__( method inference (line 340) | def inference(self, *args, **kwargs) -> Array: method update (line 344) | def update(self, *args, **kwargs) -> Tuple[Array, Array, Array]: FILE: gated_linear_networks/bernoulli.py class GatedLinearNetwork (line 35) | class GatedLinearNetwork(base.GatedLinearNetwork): method __init__ (line 38) | def __init__(self, method _add_bias (line 52) | def _add_bias(self, inputs): method _inference_fn (line 56) | def _inference_fn( method _update_fn (line 74) | def _update_fn( class LastNeuronAggregator (line 104) | class LastNeuronAggregator(base.LastNeuronAggregator): FILE: gated_linear_networks/bernoulli_test.py function _get_dataset (line 28) | def _get_dataset(input_size, batch_size=None): class GatedLinearNetworkTest (line 42) | class GatedLinearNetworkTest(parameterized.TestCase): method setUp (line 45) | def setUp(self): method test_shapes (line 92) | def test_shapes(self, batch_size): method test_update (line 125) | def test_update(self, batch_size): method test_batch_consistency (line 168) | def test_batch_consistency(self): FILE: gated_linear_networks/examples/bernoulli_mnist.py function main (line 76) | def main(unused_argv): FILE: gated_linear_networks/examples/utils.py function _moments (line 30) | def _moments(image): function _deskew (line 44) | def _deskew(image): function _deskew_dataset (line 54) | def _deskew_dataset(dataset): function load_deskewed_mnist (line 65) | def load_deskewed_mnist(*a, **k): class MeanStdEstimator (line 73) | class MeanStdEstimator(hk.Module): method __call__ (line 76) | def __call__(self, sample: jax.Array) -> Tuple[Array, Array]: FILE: gated_linear_networks/examples/utils_test.py class MeanStdEstimator (line 25) | class MeanStdEstimator(absltest.TestCase): method test_statistics (line 27) | def test_statistics(self): FILE: gated_linear_networks/gaussian.py function _unpack_inputs (line 36) | def _unpack_inputs(inputs: Array) -> Tuple[Array, Array]: function _pack_inputs (line 43) | def _pack_inputs(mu: Array, sigma_sq: Array) -> Array: class GatedLinearNetwork (line 50) | class GatedLinearNetwork(base.GatedLinearNetwork): method __init__ (line 53) | def __init__( method _add_bias (line 75) | def _add_bias(self, inputs): method _inference_fn (line 83) | def _inference_fn( method _project_weights (line 108) | def _project_weights(inputs: Array, # [input_size] method _update_fn (line 135) | def _update_fn( class ConstantInputSigma (line 168) | class ConstantInputSigma(base.Mutator): method __init__ (line 171) | def __init__( method inference (line 180) | def inference(self, inputs, *args, **kwargs): method update (line 187) | def update(self, inputs, *args, **kwargs): class LastNeuronAggregator (line 194) | class LastNeuronAggregator(base.LastNeuronAggregator): FILE: gated_linear_networks/gaussian_test.py function _get_dataset (line 28) | def _get_dataset(input_size, batch_size=None): class UtilsTest (line 42) | class UtilsTest(absltest.TestCase): method test_packing_identity (line 44) | def test_packing_identity(self): class GatedLinearNetworkTest (line 55) | class GatedLinearNetworkTest(parameterized.TestCase): method setUp (line 58) | def setUp(self): method test_shapes (line 111) | def test_shapes(self, batch_size): method test_update (line 145) | def test_update(self, batch_size): method test_batch_consistency (line 187) | def test_batch_consistency(self): FILE: geomancer/data_writer.py function get_normal (line 46) | def get_normal(x): function render (line 52) | def render(quat, light, mesh='bunny', meshdir='data'): function get_tangent (line 100) | def get_tangent(quat, light, mesh='bunny', meshdir='data', function main (line 144) | def main(_): FILE: geomancer/geomancer.py function sym_op (line 29) | def sym_op(x, zero_trace=False): function vec_to_sym (line 62) | def vec_to_sym(x, n, zero_trace=False): function ffdiag (line 74) | def ffdiag(data, lr=1.0, tol=1e-10, verbose=False, eig_init=False): function avg_angle_between_subspaces (line 115) | def avg_angle_between_subspaces(xs, ys): function make_nearest_neighbors_graph (line 132) | def make_nearest_neighbors_graph(data, k, n=1000): function make_tangents (line 164) | def make_tangents(data, neighbor_graph, k): function make_connection (line 175) | def make_connection(tangents, neighbor_graph): function make_laplacian (line 190) | def make_laplacian(connection, neighbor_graph, sym=True, zero_trace=True): function cluster_subspaces (line 221) | def cluster_subspaces(omega): function fit (line 242) | def fit(data, k, gamma=None, nnbrs=None, neig=10, shard_size=1000): function eval_aligned (line 282) | def eval_aligned(tangents, true_tangents): function eval_unaligned (line 292) | def eval_unaligned(data, tangents, true_data, true_tangents, k=10, n=1000): FILE: geomancer/geomancer_test.py class GeomancerTest (line 25) | class GeomancerTest(parameterized.TestCase): method test_sym_op (line 30) | def test_sym_op(self, zero_trace): method test_ffdiag (line 47) | def test_ffdiag(self): method test_make_nearest_neighbor_graph (line 65) | def test_make_nearest_neighbor_graph(self): FILE: geomancer/train.py function make_so_tangent (line 41) | def make_so_tangent(q): function make_sphere_tangent (line 60) | def make_sphere_tangent(x): function make_true_tangents (line 65) | def make_true_tangents(spec, data): function make_product_manifold (line 97) | def make_product_manifold(specification, npts): function main (line 143) | def main(_): FILE: glassy_dynamics/apply_binary.py function main (line 44) | def main(argv): FILE: glassy_dynamics/graph_model.py function make_graph_from_static_structure (line 32) | def make_graph_from_static_structure( function apply_random_rotation (line 79) | def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple: class GraphBasedModel (line 103) | class GraphBasedModel(snt.AbstractModule): method __init__ (line 115) | def __init__(self, method _build (line 166) | def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor: FILE: glassy_dynamics/graph_model_test.py class GraphModelTest (line 27) | class GraphModelTest(tf.test.TestCase, parameterized.TestCase): method setUp (line 29) | def setUp(self): method _get_graphs_tuple (line 82) | def _get_graphs_tuple(self): method test_make_graph_from_static_structure (line 93) | def test_make_graph_from_static_structure(self): method _is_equal_up_to_rotation (line 108) | def _is_equal_up_to_rotation(self, x, y): method test_apply_random_rotation (line 115) | def test_apply_random_rotation(self): method test_GraphModel (line 131) | def test_GraphModel(self, n_recurrences, mlp_sizes): FILE: glassy_dynamics/train.py class ParticleType (line 40) | class ParticleType(enum.IntEnum): function get_targets (line 51) | def get_targets( function load_data (line 68) | def load_data( function get_loss_ops (line 106) | def get_loss_ops( function get_minimize_op (line 130) | def get_minimize_op( function _log_stats_and_return_mean_correlation (line 150) | def _log_stats_and_return_mean_correlation( function train_model (line 170) | def train_model(train_file_pattern: Text, function apply_model (line 324) | def apply_model(checkpoint_path: Text, FILE: glassy_dynamics/train_binary.py function main (line 49) | def main(argv): FILE: glassy_dynamics/train_test.py class TrainTest (line 25) | class TrainTest(tf.test.TestCase): method test_get_targets (line 27) | def test_get_targets(self): method test_load_data (line 38) | def test_load_data(self): class TensorflowTrainTest (line 62) | class TensorflowTrainTest(tf.test.TestCase): method test_get_loss_op (line 64) | def test_get_loss_op(self): method test_get_minimize_op (line 75) | def test_get_minimize_op(self): method test_train_model (line 87) | def test_train_model(self): method test_apply_model (line 106) | def test_apply_model(self): FILE: glassy_dynamics/train_using_jax.py class ParticleType (line 35) | class ParticleType(enum.IntEnum): function make_graph_from_static_structure (line 46) | def make_graph_from_static_structure(positions, types, box, edge_thresho... function get_targets (line 88) | def get_targets(initial_positions, trajectory_target_positions): function load_data (line 103) | def load_data(file_pattern, time_index, max_files_to_load=None): function apply_random_rotation (line 137) | def apply_random_rotation(graph): function network_definition (line 162) | def network_definition(graph): function train_model (line 204) | def train_model(train_file_pattern, FILE: hierarchical_probabilistic_unet/geco_utils.py class MovingAverage (line 29) | class MovingAverage(snt.AbstractModule): method __init__ (line 37) | def __init__(self, decay, local=True, differentiable=False, method _build (line 44) | def _build(self, inputs): class LagrangeMultiplier (line 50) | class LagrangeMultiplier(snt.AbstractModule): method __init__ (line 53) | def __init__(self, method _build (line 66) | def _build(self, ma_constraint): function _sample_gumbel (line 85) | def _sample_gumbel(shape, eps=1e-20): function _topk_mask (line 92) | def _topk_mask(score, k): function ce_loss (line 100) | def ce_loss(logits, labels, mask=None, top_k_percentage=None, FILE: hierarchical_probabilistic_unet/model.py class _HierarchicalCore (line 28) | class _HierarchicalCore(snt.AbstractModule): method __init__ (line 36) | def __init__(self, latent_dims, channels_per_block, method _build (line 82) | def _build(self, inputs, mean=False, z_q=None): class _StitchingDecoder (line 185) | class _StitchingDecoder(snt.AbstractModule): method __init__ (line 192) | def __init__(self, latent_dims, channels_per_block, num_classes, method _build (line 238) | def _build(self, encoder_features, decoder_features): class HierarchicalProbUNet (line 273) | class HierarchicalProbUNet(snt.AbstractModule): method __init__ (line 276) | def __init__(self, method _build (line 394) | def _build(self, seg, img): method sample (line 424) | def sample(self, img, mean=False, z_q=None): method reconstruct (line 448) | def reconstruct(self, seg, img, mean=False): method rec_loss (line 469) | def rec_loss(self, seg, img, mask=None, top_k_percentage=None, method kl (line 490) | def kl(self, seg, img): method loss (line 517) | def loss(self, seg, img, mask): FILE: hierarchical_probabilistic_unet/model_test.py function _get_placeholders (line 37) | def _get_placeholders(): class HierarchicalProbUNetTest (line 44) | class HierarchicalProbUNetTest(tf.test.TestCase): method test_shape_of_sample (line 46) | def test_shape_of_sample(self): method test_shape_of_reconstruction (line 55) | def test_shape_of_reconstruction(self): method test_shapes_in_prior (line 64) | def test_shapes_in_prior(self): method test_shape_of_kl (line 100) | def test_shape_of_kl(self): FILE: hierarchical_probabilistic_unet/unet_utils.py function res_block (line 25) | def res_block(input_features, n_channels, n_down_channels=None, function resize_up (line 74) | def resize_up(input_features, scale=2): function resize_down (line 93) | def resize_down(input_features, scale=2): FILE: hierarchical_transformer_memory/hierarchical_attention/htm_attention.py class HierarchicalMemory (line 31) | class HierarchicalMemory(NamedTuple): function sinusoid_position_encoding (line 43) | def sinusoid_position_encoding( class HierarchicalMemoryAttention (line 69) | class HierarchicalMemoryAttention(hk.Module): method __init__ (line 72) | def __init__(self, method num_heads (line 100) | def num_heads(self): method _singlehead_linear (line 104) | def _singlehead_linear(self, method __call__ (line 116) | def __call__( function hk_vmap (line 224) | def hk_vmap(*args, **kwargs): FILE: hierarchical_transformer_memory/hierarchical_attention/htm_attention_test.py function _build_queries_and_memory (line 26) | def _build_queries_and_memory(query_length, num_memories, mem_chunk_size, class HierarchicalAttentionTest (line 44) | class HierarchicalAttentionTest(parameterized.TestCase): method test_output_shapes (line 61) | def test_output_shapes(self, query_length, num_memories, mem_chunk_size, method test_masking (line 80) | def test_masking(self): FILE: hierarchical_transformer_memory/pycolab_ballet/ballet_environment.py function _generate_template (line 75) | def _generate_template(object_name): function get_scrolling_cropper (line 170) | def get_scrolling_cropper(rows=9, cols=9, crop_pad_char=" "): class BalletEnvironment (line 177) | class BalletEnvironment(dm_env.Environment): method __init__ (line 180) | def __init__(self, num_dancers, dance_delay, max_steps, rng=None): method _game_factory (line 210) | def _game_factory(self): method _render_observation (line 242) | def _render_observation(self, observation): method reset (line 261) | def reset(self): method step (line 281) | def step(self, action): method observation_spec (line 309) | def observation_spec(self): method action_spec (line 323) | def action_spec(self): method _is_game_over (line 329) | def _is_game_over(self): method _clear_state (line 334) | def _clear_state(self): function simple_builder (line 342) | def simple_builder(level_name): function main (line 365) | def main(argv): FILE: hierarchical_transformer_memory/pycolab_ballet/ballet_environment_core.py class DIRECTIONS (line 51) | class DIRECTIONS(enum.IntEnum): class DancerSprite (line 143) | class DancerSprite(prefab_sprites.MazeWalker): method __init__ (line 146) | def __init__(self, corner, position, character, motion, color, shape, method update (line 157) | def update(self, actions, board, layers, backdrop, things, the_plot): class PlayerSprite (line 190) | class PlayerSprite(prefab_sprites.MazeWalker): method __init__ (line 196) | def __init__(self, corner, position, character): method update (line 200) | def update(self, actions, board, layers, backdrop, things, the_plot): function make_game (line 233) | def make_game(dancers_and_properties, dance_delay=16): function main (line 292) | def main(argv): FILE: hierarchical_transformer_memory/pycolab_ballet/ballet_environment_test.py class BalletEnvironmentTest (line 25) | class BalletEnvironmentTest(parameterized.TestCase): method test_full_wrapper (line 27) | def test_full_wrapper(self): method test_simple_builder (line 63) | def test_simple_builder(self, level_name): FILE: iodine/configurations.py function clevr6 (line 19) | def clevr6(): function multi_dsprites (line 133) | def multi_dsprites(): function tetrominoes (line 258) | def tetrominoes(): FILE: iodine/main.py function default_config (line 42) | def default_config(): function build (line 71) | def build(identifier, _config): function get_train_step (line 76) | def get_train_step(model, dataset, optimizer): function get_checkpoint_dir (line 101) | def get_checkpoint_dir(continue_run, checkpoint_dir, _run, _log): function get_session (line 120) | def get_session(chkp_dir, loss, stop_after_steps, save_summaries_steps, function load_checkpoint (line 140) | def load_checkpoint(use_placeholder=False, session=None): function main (line 173) | def main(save_summaries_steps): FILE: iodine/modules/data.py class IODINEDataset (line 30) | class IODINEDataset(snt.AbstractModule): method __init__ (line 36) | def __init__( method _build (line 60) | def _build(self, subset="train"): method filter_by_num_objects (line 85) | def filter_by_num_objects(self, d): method preprocess (line 101) | def preprocess(self, data): method preprocess_factors (line 148) | def preprocess_factors(self, data, sg): method get_placeholders (line 154) | def get_placeholders(self, batch_size=None): class CLEVR (line 177) | class CLEVR(IODINEDataset): method __init__ (line 188) | def __init__( method preprocess_factors (line 204) | def preprocess_factors(self, data, sg): class MultiDSprites (line 215) | class MultiDSprites(IODINEDataset): method __init__ (line 227) | def __init__( class Tetrominoes (line 241) | class Tetrominoes(IODINEDataset): method __init__ (line 250) | def __init__(self, path, image_dim=(35, 35), name="tetrominoes", **kwa... method preprocess_factors (line 254) | def preprocess_factors(self, data, sg): FILE: iodine/modules/decoder.py class ComponentDecoder (line 21) | class ComponentDecoder(snt.AbstractModule): method __init__ (line 23) | def __init__(self, pixel_decoder, name="component_decoder"): method set_output_shapes (line 28) | def set_output_shapes(self, pixel, mask): method _build (line 33) | def _build(self, z): FILE: iodine/modules/distributions.py class DistributionModule (line 33) | class DistributionModule(snt.AbstractModule): method __init__ (line 36) | def __init__(self, name="distribution"): method set_output_shape (line 40) | def set_output_shape(self, shape): method output_shape (line 44) | def output_shape(self): method input_shapes (line 48) | def input_shapes(self): method get_default_prior (line 51) | def get_default_prior(self, batch_dim=(1,)): class BernoulliOutput (line 56) | class BernoulliOutput(DistributionModule): method __init__ (line 58) | def __init__(self, name="bernoulli_output"): method input_shapes (line 62) | def input_shapes(self): method _build (line 65) | def _build(self, params): class LocScaleDistribution (line 71) | class LocScaleDistribution(DistributionModule): method __init__ (line 99) | def __init__( method input_shapes (line 121) | def input_shapes(self): method _build (line 128) | def _build(self, params): class MaskedMixture (line 163) | class MaskedMixture(DistributionModule): method __init__ (line 165) | def __init__( method set_output_shape (line 190) | def set_output_shape(self, shape): method _build (line 194) | def _build(self, pixel, mask): method input_shapes (line 213) | def input_shapes(self): method get_default_prior (line 218) | def get_default_prior(self, batch_dim=(1,)): FILE: iodine/modules/factor_eval.py class FactorRegressor (line 27) | class FactorRegressor(snt.AbstractModule): method __init__ (line 30) | def __init__(self, mapping=None, name="repres_content"): method _build (line 44) | def _build(self, z, latent, visibility, pred_mask, true_mask): method predict (line 132) | def predict(self, z): method get_error_func (line 147) | def get_error_func(factor): method get_metric (line 157) | def get_metric(factor): method one_hot (line 166) | def one_hot(f, nr_categories): method angle_to_vector (line 170) | def angle_to_vector(theta): method get_preprocessing (line 174) | def get_preprocessing(factor): function sse (line 186) | def sse(true, pred): function accuracy (line 191) | def accuracy(labels, logits, assignment, mean_var_tot, num_vis): function r2 (line 199) | def r2(labels, pred, assignment, mean_var_tot, num_vis): FILE: iodine/modules/iodine.py class IODINE (line 56) | class IODINE(snt.AbstractModule): method __init__ (line 101) | def __init__( method _build (line 154) | def _build(self, data): method encode (line 225) | def encode(self, images): method decode (line 266) | def decode(self, z): method eval (line 276) | def eval(self, data): method get_sample_images (line 409) | def get_sample_images(self, nr_samples=16): method get_overview_images (line 417) | def get_overview_images(self, data, nr_images=4, mask_components=False): method _get_initial_z (line 448) | def _get_initial_z(self): method _parse_iter_loss_weights (line 466) | def _parse_iter_loss_weights(self, iter_loss_weight): method _propagate_shape_info (line 479) | def _propagate_shape_info(self, image_shape): method _get_image_for_iter (line 489) | def _get_image_for_iter(self, images, t): method _get_mask_posterior (line 497) | def _get_mask_posterior(out_dist, img): method _get_inputs_for (line 502) | def _get_inputs_for(self, out_params, out_dist, img, z_dist, zp, loss): method _apply_preprocessing (line 575) | def _apply_preprocessing(self, name, val): method _get_coord_channels (line 594) | def _get_coord_channels(self): method _raw_kl (line 615) | def _raw_kl(self, z_dist): method _reconstruction_error (line 618) | def _reconstruction_error(self, x_dist, img): method _get_monitored_scalars (line 622) | def _get_monitored_scalars(self, out_dist, data): FILE: iodine/modules/networks.py class CNN (line 25) | class CNN(snt.AbstractModule): method __init__ (line 33) | def __init__(self, cnn_opt, mlp_opt, mode="flatten", name="cnn"): method set_output_shapes (line 53) | def set_output_shapes(self, shape): method _build (line 59) | def _build(self, image): class MLP (line 90) | class MLP(snt.AbstractModule): method __init__ (line 93) | def __init__(self, name="mlp", **mlp_opt): method set_output_shapes (line 100) | def set_output_shapes(self, shape): method _build (line 105) | def _build(self, data): class DeConv (line 118) | class DeConv(snt.AbstractModule): method __init__ (line 125) | def __init__(self, mlp_opt, cnn_opt, name="deconv"): method set_output_shapes (line 144) | def set_output_shapes(self, shape): method _build (line 148) | def _build(self, z): class BroadcastConv (line 162) | class BroadcastConv(snt.AbstractModule): method __init__ (line 171) | def __init__( method set_output_shapes (line 208) | def set_output_shapes(self, shape): method _build (line 212) | def _build(self, z): method append_coordinate_channels (line 242) | def append_coordinate_channels(self, output): class LSTM (line 266) | class LSTM(snt.RNNCore): method __init__ (line 273) | def __init__(self, hidden_sizes, name="lstm"): method initial_state (line 279) | def initial_state(self, batch_size, **kwargs): method _build (line 284) | def _build(self, data, prev_states): FILE: iodine/modules/plotting.py function clean_ax (line 27) | def clean_ax(ax, color=None, lw=4.0): function optional_ax (line 36) | def optional_ax(fn): function optional_clean_ax (line 48) | def optional_clean_ax(fn): function show_img (line 65) | def show_img(img, mask=None, ax=None, norm=False): function show_mask (line 76) | def show_mask(m, ax): function show_mat (line 83) | def show_mat(m, ax, vmin=None, vmax=None, cmap="viridis"): function show_coords (line 89) | def show_coords(m, ax): function example_plot (line 97) | def example_plot(rinfo, function iterations_plot (line 131) | def iterations_plot(rinfo, b=0, mask_components=False, size=2): function inputs_plot (line 173) | def inputs_plot(rinfo, b=0, t=0, size=2): FILE: iodine/modules/refinement.py class RefinementCore (line 22) | class RefinementCore(snt.RNNCore): method __init__ (line 36) | def __init__(self, method initial_state (line 47) | def initial_state(self, batch_size, **unused_kwargs): method _build (line 50) | def _build(self, inputs, prev_state): method prepare_spatial_inputs (line 68) | def prepare_spatial_inputs(self, inputs): method prepare_flat_inputs (line 80) | def prepare_flat_inputs(self, hidden, inputs): class ResHead (line 90) | class ResHead(snt.AbstractModule): method __init__ (line 93) | def __init__(self, name="residual_head"): method _build (line 96) | def _build(self, zp_old, inputs): class PredictorCorrectorHead (line 110) | class PredictorCorrectorHead(snt.AbstractModule): method __init__ (line 121) | def __init__( method _build (line 135) | def _build(self, zp_old, inputs): FILE: iodine/modules/utils.py function get_act_func (line 43) | def get_act_func(name_or_func): function get_distribution (line 66) | def get_distribution(name_or_dist): function get_mask_plot_colors (line 76) | def get_mask_plot_colors(nr_colors): function color_transform (line 84) | def color_transform(masks): function construct_diagnostic_image (line 91) | def construct_diagnostic_image( function construct_reconstr_image (line 156) | def construct_reconstr_image(images, recons, border_width=2, function construct_iterations_image (line 186) | def construct_iterations_image( function get_overview_image (line 234) | def get_overview_image(image, output_dist, mask_components=False): class OnlineMeanVarEstimator (line 249) | class OnlineMeanVarEstimator(snt.AbstractModule): method __init__ (line 252) | def __init__(self, axis=None, ddof=0.0, name="online_mean_var"): method _build (line 257) | def _build(self, x, weights=None): function print_shapes (line 306) | def print_shapes(name, value, indent=""): function _pad_images (line 332) | def _pad_images(images, image_border_value=0.5, border_width=2): function images_to_grid (line 362) | def images_to_grid( function flatten_all_but_last (line 448) | def flatten_all_but_last(tensor, n_dims=1): function ensure_3d (line 460) | def ensure_3d(tensor): function build (line 474) | def build(plan, identifier): function _resolve_constructor (line 494) | def _resolve_constructor(plan_subsection): FILE: kfac_ferminet_alpha/curvature_blocks.py class CurvatureBlock (line 29) | class CurvatureBlock(utils.Stateful, abc.ABC): method __init__ (line 32) | def __init__(self, layer_tag_eq: tgm.jax_core.JaxprEqn): method layer_tag_primitive (line 37) | def layer_tag_primitive(self) -> tgm.tags.LayerTag: method outputs_shapes (line 42) | def outputs_shapes(self) -> Sequence[Sequence[int]]: method inputs_shapes (line 48) | def inputs_shapes(self) -> Sequence[Sequence[int]]: method params_shapes (line 54) | def params_shapes(self) -> Sequence[Sequence[int]]: method init (line 60) | def init(self, rng: jnp.ndarray) -> MutableMapping[str, Any]: method update_curvature_matrix_estimate (line 75) | def update_curvature_matrix_estimate( method update_curvature_inverse_estimate (line 86) | def update_curvature_inverse_estimate( method multiply_matpower (line 94) | def multiply_matpower( class NaiveDiagonal (line 107) | class NaiveDiagonal(CurvatureBlock): method init (line 111) | def init(self, rng: jnp.ndarray) -> Dict[str, Any]: method update_curvature_matrix_estimate (line 118) | def update_curvature_matrix_estimate( method update_curvature_inverse_estimate (line 131) | def update_curvature_inverse_estimate( method multiply_matpower (line 138) | def multiply_matpower( class TwoKroneckerFactored (line 154) | class TwoKroneckerFactored(CurvatureBlock, abc.ABC): method has_bias (line 163) | def has_bias(self) -> bool: method input_size (line 167) | def input_size(self) -> int: method output_size (line 171) | def output_size(self) -> int: method compute_extra_scale (line 174) | def compute_extra_scale(self) -> Optional[Union[int, float, jnp.ndarra... method init (line 177) | def init(self, rng: jnp.ndarray) -> Dict[str, Any]: method update_curvature_inverse_estimate (line 191) | def update_curvature_inverse_estimate( method multiply_matpower (line 211) | def multiply_matpower( class DenseTwoKroneckerFactored (line 246) | class DenseTwoKroneckerFactored(TwoKroneckerFactored): method input_size (line 249) | def input_size(self) -> int: method output_size (line 255) | def output_size(self) -> int: method update_curvature_matrix_estimate (line 258) | def update_curvature_matrix_estimate( class ScaleAndShiftDiagonal (line 280) | class ScaleAndShiftDiagonal(CurvatureBlock): method has_scale (line 286) | def has_scale(self) -> bool: method has_shift (line 290) | def has_shift(self) -> bool: method init (line 293) | def init(self, rng: jnp.ndarray) -> Dict[str, Any]: method update_curvature_matrix_estimate (line 321) | def update_curvature_matrix_estimate( method update_curvature_inverse_estimate (line 352) | def update_curvature_inverse_estimate( method multiply_matpower (line 359) | def multiply_matpower( class ScaleAndShiftFull (line 383) | class ScaleAndShiftFull(CurvatureBlock): method _has_scale (line 389) | def _has_scale(self) -> bool: method _has_shift (line 393) | def _has_shift(self) -> bool: method init (line 396) | def init(self, rng: jnp.ndarray) -> Dict[str, Any]: method update_curvature_matrix_estimate (line 404) | def update_curvature_matrix_estimate( method update_curvature_inverse_estimate (line 439) | def update_curvature_inverse_estimate( method multiply_matpower (line 448) | def multiply_matpower( function copy_default_tag_to_block (line 484) | def copy_default_tag_to_block() -> MutableMapping[str, CurvatureBlockCtor]: function get_default_tag_to_block (line 488) | def get_default_tag_to_block(tag_name: str) -> CurvatureBlockCtor: function set_default_tag_to_block (line 492) | def set_default_tag_to_block( FILE: kfac_ferminet_alpha/distributions.py class MultivariateNormalDiag (line 21) | class MultivariateNormalDiag: method __init__ (line 24) | def __init__( method loc (line 39) | def loc(self) -> jnp.ndarray: method scale_diag (line 44) | def scale_diag(self) -> jnp.ndarray: method _num_dims (line 48) | def _num_dims(self) -> int: method _standardize (line 52) | def _standardize(self, value: jnp.ndarray) -> jnp.ndarray: method log_prob (line 55) | def log_prob(self, value: jnp.ndarray) -> jnp.ndarray: method mean (line 61) | def mean(self) -> jnp.ndarray: method sample (line 65) | def sample(self, seed: jnp.ndarray) -> jnp.ndarray: FILE: kfac_ferminet_alpha/estimator.py class CurvatureEstimator (line 37) | class CurvatureEstimator(utils.Stateful): method __init__ (line 42) | def __init__(self, method diagonal_weight (line 114) | def diagonal_weight(self) -> jnp.ndarray: method vectors_to_blocks (line 117) | def vectors_to_blocks( method blocks_to_vectors (line 142) | def blocks_to_vectors(self, per_block_vectors: Sequence[BlockVector]) ... method init (line 158) | def init( method mat_type (line 172) | def mat_type(self) -> str: method vec_block_apply (line 175) | def vec_block_apply( method multiply_inverse (line 190) | def multiply_inverse(self, parameter_structured_vector: Any) -> Any: method multiply (line 202) | def multiply(self, parameter_structured_vector: Any) -> Any: method multiply_matpower (line 214) | def multiply_matpower( method update_curvature_matrix_estimate (line 236) | def update_curvature_matrix_estimate( method update_curvature_estimate_inverse (line 324) | def update_curvature_estimate_inverse( FILE: kfac_ferminet_alpha/example.py function glorot_uniform (line 47) | def glorot_uniform(shape, key): function fully_connected_layer (line 54) | def fully_connected_layer(params, x): function model_init (line 59) | def model_init(rng_key, batch, encoder_sizes=(1000, 500, 250, 30)): function model_loss (line 74) | def model_loss(params, inputs, l2_reg): function random_data (line 90) | def random_data(multi_device, batch_shape, rng): function main (line 100) | def main(argv): FILE: kfac_ferminet_alpha/layers_and_loss_tags.py class LossTag (line 30) | class LossTag(jax_core.Primitive): method __init__ (line 34) | def __init__(self, cls, num_inputs: int, num_targets: int = 1): method num_inputs (line 49) | def num_inputs(self) -> int: method num_targets (line 53) | def num_targets(self) -> int: method loss (line 56) | def loss(self, *args, weight: float = 1.0, **kwargs): method loss_evaluate (line 59) | def loss_evaluate(self, *args, weight: float = 1.0, **kwargs): method get_outputs (line 62) | def get_outputs(self, *args, weight: float, return_loss: bool, **kwargs): method impl (line 79) | def impl(self, *args, weight: float, return_loss: bool, **kwargs): method abstract_eval (line 82) | def abstract_eval(self, *args, weight: float, return_loss: bool, **kwa... method xla_translation (line 91) | def xla_translation( method jvp (line 105) | def jvp( method batching (line 128) | def batching(self, batched_args, batched_dims, **kwargs): class LayerTag (line 132) | class LayerTag(jax_core.Primitive): method __init__ (line 135) | def __init__(self, name: str, num_inputs: int, num_outputs: int): method num_outputs (line 153) | def num_outputs(self) -> int: method num_inputs (line 157) | def num_inputs(self) -> int: method split_all_inputs (line 160) | def split_all_inputs( method get_outputs (line 170) | def get_outputs(self, *operands: _T, **kwargs) -> _T: method xla_translation (line 174) | def xla_translation(self, c, *operands: _T, **kwargs) -> _T: method transpose (line 178) | def transpose(cotangent, *operands, **kwargs): method impl (line 181) | def impl(self, *operands, **kwargs): method abstract_eval (line 184) | def abstract_eval(self, *abstract_operands, **kwargs): method batching (line 192) | def batching(self, batched_operands, batched_dims, **kwargs): function register_generic (line 208) | def register_generic(parameter: _T) -> _T: function register_dense (line 223) | def register_dense(y, x, w, b=None): function dense_func (line 229) | def dense_func(x, params): function dense_tagging (line 240) | def dense_tagging(jaxpr, inverse_map, values_map): function register_conv2d (line 259) | def register_conv2d(y, x, w, b=None, **kwargs): function conv2d_func (line 265) | def conv2d_func(x, params): function conv2d_tagging (line 281) | def conv2d_tagging(jaxpr, inverse_map, values_map): function register_scale_and_shift (line 305) | def register_scale_and_shift(y, args, has_scale: bool, has_shift: bool): function scale_and_shift_func (line 312) | def scale_and_shift_func(x, params, has_scale: bool, has_shift: bool): function scale_and_shift_tagging (line 325) | def scale_and_shift_tagging( function batch_norm_func (line 340) | def batch_norm_func( function batch_norm_tagging_func (line 351) | def batch_norm_tagging_func( FILE: kfac_ferminet_alpha/loss_functions.py class LossFunction (line 31) | class LossFunction(abc.ABC): method __init__ (line 40) | def __init__(self, weight: FloatArray): method weight (line 44) | def weight(self) -> FloatArray: method targets (line 49) | def targets(self) -> Optional[jnp.ndarray]: method inputs (line 59) | def inputs(self) -> Sequence[jnp.ndarray]: method copy_with_different_inputs (line 64) | def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): method evaluate (line 67) | def evaluate( method _evaluate (line 88) | def _evaluate(self, targets: jnp.ndarray) -> jnp.ndarray: method grad_of_evaluate (line 99) | def grad_of_evaluate( method multiply_ggn (line 120) | def multiply_ggn(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_ggn_unweighted (line 137) | def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_ggn_factor (line 141) | def multiply_ggn_factor(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_ggn_factor_unweighted (line 164) | def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.n... method multiply_ggn_factor_transpose (line 168) | def multiply_ggn_factor_transpose(self, vector: jnp.ndarray) -> jnp.nd... method multiply_ggn_factor_transpose_unweighted (line 192) | def multiply_ggn_factor_transpose_unweighted( method multiply_ggn_factor_replicated_one_hot (line 199) | def multiply_ggn_factor_replicated_one_hot(self, index: Index) -> jnp.... method multiply_ggn_factor_replicated_one_hot_unweighted (line 228) | def multiply_ggn_factor_replicated_one_hot_unweighted( method ggn_factor_inner_shape (line 236) | def ggn_factor_inner_shape(self) -> Sequence[int]: class NegativeLogProbLoss (line 241) | class NegativeLogProbLoss(LossFunction): method inputs (line 245) | def inputs(self): method params (line 250) | def params(self): method multiply_fisher (line 254) | def multiply_fisher(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_fisher_unweighted (line 269) | def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_fisher_factor (line 272) | def multiply_fisher_factor(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_fisher_factor_unweighted (line 297) | def multiply_fisher_factor_unweighted( method multiply_fisher_factor_transpose (line 303) | def multiply_fisher_factor_transpose( method multiply_fisher_factor_transpose_unweighted (line 332) | def multiply_fisher_factor_transpose_unweighted( method multiply_fisher_factor_replicated_one_hot (line 338) | def multiply_fisher_factor_replicated_one_hot( method multiply_fisher_factor_replicated_one_hot_unweighted (line 372) | def multiply_fisher_factor_replicated_one_hot_unweighted( method fisher_factor_inner_shape (line 380) | def fisher_factor_inner_shape(self) -> Sequence[int]: method sample (line 385) | def sample(self, rng_key: jnp.ndarray) -> jnp.ndarray: method grad_of_evaluate_on_sample (line 389) | def grad_of_evaluate_on_sample( class NaturalParamsNegativeLogProbLoss (line 407) | class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss, abc.ABC): method multiply_ggn_unweighted (line 418) | def multiply_ggn_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_ggn_factor_unweighted (line 421) | def multiply_ggn_factor_unweighted(self, vector: jnp.ndarray) -> jnp.n... method multiply_ggn_factor_transpose_unweighted (line 424) | def multiply_ggn_factor_transpose_unweighted( method multiply_ggn_factor_replicated_one_hot_unweighted (line 430) | def multiply_ggn_factor_replicated_one_hot_unweighted( method ggn_factor_inner_shape (line 437) | def ggn_factor_inner_shape(self) -> Sequence[int]: class DistributionNegativeLogProbLoss (line 441) | class DistributionNegativeLogProbLoss(NegativeLogProbLoss): method dist (line 446) | def dist(self): method _evaluate (line 450) | def _evaluate(self, targets: jnp.ndarray): method sample (line 453) | def sample(self, rng_key: jnp.ndarray): method fisher_factor_inner_shape (line 457) | def fisher_factor_inner_shape(self) -> Sequence[int]: class NormalMeanNegativeLogProbLoss (line 461) | class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, method __init__ (line 475) | def __init__( method targets (line 490) | def targets(self) -> Optional[jnp.ndarray]: method dist (line 494) | def dist(self): method params (line 499) | def params(self): method copy_with_different_inputs (line 502) | def copy_with_different_inputs(self, inputs: Sequence[jnp.ndarray]): method multiply_fisher_unweighted (line 511) | def multiply_fisher_unweighted(self, vector: jnp.ndarray) -> jnp.ndarray: method multiply_fisher_factor_unweighted (line 514) | def multiply_fisher_factor_unweighted( method multiply_fisher_factor_transpose_unweighted (line 520) | def multiply_fisher_factor_transpose_unweighted( method multiply_fisher_factor_replicated_one_hot_unweighted (line 526) | def multiply_fisher_factor_replicated_one_hot_unweighted( function insert_slice_in_zeros (line 537) | def insert_slice_in_zeros( function register_normal_predictive_distribution (line 590) | def register_normal_predictive_distribution( function register_squared_error_loss (line 629) | def register_squared_error_loss( FILE: kfac_ferminet_alpha/optimizer.py class Optimizer (line 35) | class Optimizer(utils.Stateful): method __init__ (line 41) | def __init__( method finalize (line 215) | def finalize( method _init (line 288) | def _init(self, rng: jnp.ndarray) -> State: method verify_args_and_get_step_counter (line 296) | def verify_args_and_get_step_counter( method _burnin (line 330) | def _burnin( method _step (line 365) | def _step( method init (line 489) | def init( method step (line 501) | def step( method propose_directions (line 569) | def propose_directions( method velocities_and_delta (line 599) | def velocities_and_delta( FILE: kfac_ferminet_alpha/tag_graph_matcher.py function match_nodes (line 38) | def match_nodes(g1, g2, mapping, node1, node2): function generate_candidates (line 62) | def generate_candidates(g1, g2, mapping, node1, node2): function find_mappings (line 74) | def find_mappings(pattern, graph, mapping, terminals): function match_pattern (line 106) | def match_pattern(pattern, graph): function read_env (line 144) | def read_env(env, var): function write_env (line 151) | def write_env(env, var, val): function abstract_single_value (line 155) | def abstract_single_value(value): function abstract_args (line 163) | def abstract_args(args): function _extract_call_jaxpr (line 167) | def _extract_call_jaxpr(primitive, params): function evaluate_eqn (line 175) | def evaluate_eqn(eqn, in_values, write_func): function clean_jaxpr_eqns (line 195) | def clean_jaxpr_eqns(jaxpr, preserve_tags=True): function broadcast_merger (line 220) | def broadcast_merger(f): class JaxGraph (line 267) | class JaxGraph(NamedTuple): function default_compare (line 281) | def default_compare(node1, node2): function reshape_compare (line 293) | def reshape_compare(node1, node2): function broadcast_in_dim_compare (line 302) | def broadcast_in_dim_compare(node1, node2): function conv_compare (line 308) | def conv_compare(node1, node2): function kfac_node_match (line 339) | def kfac_node_match(node1, node2): function var_to_str (line 359) | def var_to_str(var): function extract_param_vars_flat (line 376) | def extract_param_vars_flat(jaxpr, in_tree, params_index): function fill_jaxpr_to_graph (line 385) | def fill_jaxpr_to_graph(graph, jaxpr, in_vars=None, out_vars=None): function create_digraph (line 422) | def create_digraph(jaxpr, params): function function_to_jax_graph (line 436) | def function_to_jax_graph(func, args, params_index, tagging_func=None): function print_nice_jaxpr (line 457) | def print_nice_jaxpr(jaxpr): function auto_register_tags (line 462) | def auto_register_tags(func, function register_function (line 617) | def register_function(name, func, tagging_func, example_args, params_index, function get_graph_patterns (line 658) | def get_graph_patterns(): FILE: kfac_ferminet_alpha/tests/common.py function fully_connected_layer (line 23) | def fully_connected_layer(params, x): function init_autoencoder (line 28) | def init_autoencoder(key, data_shape): function autoencoder (line 44) | def autoencoder(all_params, x_in): FILE: kfac_ferminet_alpha/tests/graph_matcher_test.py function tagged_autoencoder (line 28) | def tagged_autoencoder(all_params, x_in): class TestGraphMatcher (line 48) | class TestGraphMatcher(unittest.TestCase): method _test_jaxpr (line 51) | def _test_jaxpr(self, init_func, model_func, tagged_model, data_shape): method test_autoencoder (line 78) | def test_autoencoder(self): FILE: kfac_ferminet_alpha/tests/tracer_test.py function autoencoder_aux (line 30) | def autoencoder_aux(all_aux, all_params, x_in): class TestTracer (line 49) | class TestTracer(unittest.TestCase): method generate_data (line 53) | def generate_data(init_func, func, data_shape, rng_key): method assertStructureAllClose (line 70) | def assertStructureAllClose(self, x, y, rtol=1E-5, atol=1E-5, **kwargs): method test_tacer_jvp (line 79) | def test_tacer_jvp(self): method test_tracer_vjp (line 104) | def test_tracer_vjp(self): method test_tracer_hvp (line 130) | def test_tracer_hvp(self): method test_trace_estimator (line 156) | def test_trace_estimator(self): FILE: kfac_ferminet_alpha/tracer.py function extract_tags (line 32) | def extract_tags( function construct_compute_losses_inputs (line 47) | def construct_compute_losses_inputs( function _unbox_loss_tag (line 89) | def _unbox_loss_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LossTag: function _unbox_layer_tag (line 94) | def _unbox_layer_tag(jaxpr_eqn: core.JaxprEqn) -> tags.LayerTag: function trace_losses_matrix_vector_vjp (line 99) | def trace_losses_matrix_vector_vjp(tagged_func: _Function, function trace_losses_matrix_vector_jvp (line 138) | def trace_losses_matrix_vector_jvp( function trace_losses_matrix_vector_hvp (line 160) | def trace_losses_matrix_vector_hvp(tagged_func, params_index=0): function trace_estimator_vjp (line 191) | def trace_estimator_vjp(tagged_func: _Function) -> _Function: FILE: kfac_ferminet_alpha/utils.py function wrap_if_pmap (line 29) | def wrap_if_pmap(p_func): function get_first (line 47) | def get_first(obj: T) -> T: function get_mean (line 51) | def get_mean(obj: T) -> T: function get_sum (line 55) | def get_sum(obj: T) -> T: function replicate_all_local_devices (line 62) | def replicate_all_local_devices(obj: T) -> T: function make_different_rng_key_on_all_devices (line 68) | def make_different_rng_key_on_all_devices(rng: jnp.ndarray) -> jnp.ndarray: function scalar_mul (line 77) | def scalar_mul(obj: T, scalar: Union[float, jnp.ndarray]) -> T: function scalar_div (line 81) | def scalar_div(obj: T, scalar: Union[float, jnp.ndarray]) -> T: function make_func_args (line 85) | def make_func_args(params, func_state, rng, batch, has_state: bool, function extract_func_outputs (line 102) | def extract_func_outputs( function inner_product (line 120) | def inner_product(obj1: T, obj2: T) -> jnp.ndarray: function psd_inv_cholesky (line 127) | def psd_inv_cholesky(matrix: jnp.ndarray, damping: jnp.ndarray) -> jnp.n... function solve_maybe_small (line 134) | def solve_maybe_small(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: function pi_adjusted_inverse (line 152) | def pi_adjusted_inverse( function convert_value_and_grad_to_value_func (line 218) | def convert_value_and_grad_to_value_func( function check_structure_shapes_and_dtype (line 234) | def check_structure_shapes_and_dtype(obj1: T, obj2: T) -> None: function check_first_dim_is_batch_size (line 242) | def check_first_dim_is_batch_size(batch_size: int, *args: jnp.ndarray) -... function py_tree_registered_dataclass (line 250) | def py_tree_registered_dataclass(cls, *args, **kwargs): class WeightedMovingAverage (line 262) | class WeightedMovingAverage: method __init__ (line 265) | def __init__(self, weight: jnp.ndarray, array: jnp.ndarray): method zero (line 270) | def zero(shape: Sequence[int]) -> "WeightedMovingAverage": method weight (line 274) | def weight(self) -> jnp.ndarray: method value (line 278) | def value(self) -> jnp.ndarray: method raw_value (line 282) | def raw_value(self) -> jnp.ndarray: method update (line 285) | def update(self, value: jnp.ndarray, old_weight_multiplier: float, method sync (line 290) | def sync(self, pmap_axis_name: str) -> None: method __str__ (line 293) | def __str__(self) -> str: method __repr__ (line 297) | def __repr__(self) -> str: class Stateful (line 308) | class Stateful: method __init__ (line 311) | def __init__(self, stateful_fields_names: Optional[Sequence[str]] = ()): method _add_stateful_fields_names (line 314) | def _add_stateful_fields_names(self, value: Sequence[str]) -> None: method get_state (line 317) | def get_state(self) -> Mapping[str, Any]: method set_state (line 324) | def set_state(self, value): method clear_state (line 332) | def clear_state(self) -> None: method pop_state (line 338) | def pop_state(self) -> Mapping[str, Any]: method _get_state_from_instance (line 345) | def _get_state_from_instance(obj): method _set_state_to_instance (line 362) | def _set_state_to_instance(obj, value): method _clear_state_from_instance (line 392) | def _clear_state_from_instance(obj): method infer_class_state (line 409) | def infer_class_state(class_type): function compute_sq_norm_relative_abs_diff (line 443) | def compute_sq_norm_relative_abs_diff(obj, pmap_axis_name): function product (line 451) | def product(iterable_object): FILE: learning_to_simulate/connectivity_utils.py function _compute_connectivity (line 25) | def _compute_connectivity(positions, radius, add_self_edges): function _compute_connectivity_for_batch (line 54) | def _compute_connectivity_for_batch( function compute_connectivity_for_batch_pyfunc (line 106) | def compute_connectivity_for_batch_pyfunc( FILE: learning_to_simulate/graph_network.py function build_mlp (line 43) | def build_mlp( class EncodeProcessDecode (line 50) | class EncodeProcessDecode(snt.AbstractModule): method __init__ (line 53) | def __init__( method _build (line 89) | def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor: method _networks_builder (line 101) | def _networks_builder(self): method _encode (line 135) | def _encode( method _process (line 150) | def _process( method _process_step (line 166) | def _process_step( method _decode (line 180) | def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor: FILE: learning_to_simulate/learned_simulator.py class LearnedSimulator (line 37) | class LearnedSimulator(snt.AbstractModule): method __init__ (line 40) | def __init__( method _build (line 84) | def _build(self, position_sequence, n_particles_per_example, method _encoder_preprocessor (line 114) | def _encoder_preprocessor( method _decoder_postprocessor (line 188) | def _decoder_postprocessor(self, normalized_acceleration, position_seq... method get_predicted_and_target_normalized_accelerations (line 206) | def get_predicted_and_target_normalized_accelerations( method _inverse_decoder_postprocessor (line 249) | def _inverse_decoder_postprocessor(self, next_position, position_seque... function time_diff (line 263) | def time_diff(input_sequence): FILE: learning_to_simulate/model_demo.py function sample_random_position_sequence (line 59) | def sample_random_position_sequence(): function main (line 68) | def main(): FILE: learning_to_simulate/noise_utils.py function get_random_walk_noise_for_position_sequence (line 23) | def get_random_walk_noise_for_position_sequence( FILE: learning_to_simulate/reading_utils.py function convert_to_tensor (line 48) | def convert_to_tensor(x, encoded_dtype): function parse_serialized_simulation_example (line 59) | def parse_serialized_simulation_example(example_proto, metadata): function split_trajectory (line 109) | def split_trajectory(context, features, window_length=7): FILE: learning_to_simulate/render_rollout.py function main (line 53) | def main(unused_argv): FILE: learning_to_simulate/train.py function get_kinematic_mask (line 76) | def get_kinematic_mask(particle_types): function prepare_inputs (line 81) | def prepare_inputs(tensor_dict): function prepare_rollout_inputs (line 130) | def prepare_rollout_inputs(context, features): function batch_concat (line 148) | def batch_concat(dataset, batch_size): function get_input_fn (line 173) | def get_input_fn(data_path, batch_size, mode, split): function rollout (line 219) | def rollout(simulator, features, num_steps): function _combine_std (line 272) | def _combine_std(std_x, std_y): function _get_simulator (line 276) | def _get_simulator(model_kwargs, metadata, acc_noise_std, vel_noise_std): function get_one_step_estimator_fn (line 304) | def get_one_step_estimator_fn(data_path, function get_rollout_estimator_fn (line 385) | def get_rollout_estimator_fn(data_path, function _read_metadata (line 427) | def _read_metadata(data_path): function main (line 432) | def main(_): FILE: meshgraphnets/cfd_eval.py function _rollout (line 23) | def _rollout(model, initial_state, num_steps): function evaluate (line 46) | def evaluate(model, inputs): FILE: meshgraphnets/cfd_model.py class Model (line 26) | class Model(snt.AbstractModule): method __init__ (line 29) | def __init__(self, learned_model, name='Model'): method _build_graph (line 40) | def _build_graph(self, inputs, is_training): method _build (line 63) | def _build(self, inputs): method loss (line 69) | def loss(self, inputs): method _update (line 88) | def _update(self, inputs, per_node_network_output): FILE: meshgraphnets/cloth_eval.py function _rollout (line 23) | def _rollout(model, initial_state, num_steps): function evaluate (line 45) | def evaluate(model, inputs): FILE: meshgraphnets/cloth_model.py class Model (line 26) | class Model(snt.AbstractModule): method __init__ (line 29) | def __init__(self, learned_model, name='Model'): method _build_graph (line 40) | def _build_graph(self, inputs, is_training): method _build (line 68) | def _build(self, inputs): method loss (line 74) | def loss(self, inputs): method _update (line 92) | def _update(self, inputs, per_node_network_output): FILE: meshgraphnets/common.py class NodeType (line 22) | class NodeType(enum.IntEnum): function triangles_to_edges (line 33) | def triangles_to_edges(faces): FILE: meshgraphnets/core_model.py class GraphNetBlock (line 28) | class GraphNetBlock(snt.AbstractModule): method __init__ (line 31) | def __init__(self, model_fn, name='GraphNetBlock'): method _update_edge_features (line 35) | def _update_edge_features(self, node_features, edge_set): method _update_node_features (line 43) | def _update_node_features(self, node_features, edge_sets): method _build (line 54) | def _build(self, graph): class EncodeProcessDecode (line 75) | class EncodeProcessDecode(snt.AbstractModule): method __init__ (line 78) | def __init__(self, method _make_mlp (line 90) | def _make_mlp(self, output_size, layer_norm=True): method _encoder (line 98) | def _encoder(self, graph): method _decoder (line 108) | def _decoder(self, graph): method _build (line 114) | def _build(self, graph): FILE: meshgraphnets/dataset.py function _parse (line 27) | def _parse(proto, meta): function load_dataset (line 48) | def load_dataset(path, split): function add_targets (line 58) | def add_targets(ds, fields, add_history): function split_and_preprocess (line 72) | def split_and_preprocess(ds, noise_field, noise_scale, noise_gamma): function batch_dataset (line 91) | def batch_dataset(ds, batch_size): FILE: meshgraphnets/normalization.py class Normalizer (line 22) | class Normalizer(snt.AbstractModule): method __init__ (line 25) | def __init__(self, size, max_accumulations=10**6, std_epsilon=1e-8, method _build (line 38) | def _build(self, batched_data, accumulate=True): method inverse (line 50) | def inverse(self, normalized_batch_data): method _accumulate (line 54) | def _accumulate(self, batched_data): method _mean (line 65) | def _mean(self): method _std_with_epsilon (line 69) | def _std_with_epsilon(self): FILE: meshgraphnets/plot_cfd.py function main (line 30) | def main(unused_argv): FILE: meshgraphnets/plot_cloth.py function main (line 30) | def main(unused_argv): FILE: meshgraphnets/run_model.py function learner (line 54) | def learner(model, params): function evaluator (line 88) | def evaluator(model, params): function main (line 113) | def main(argv): FILE: mmv/config.py function get_model_config (line 19) | def get_model_config(ckpt_path): FILE: mmv/eval_ucf101.py function get_sampling_offset (line 68) | def get_sampling_offset(sequence: tf.Tensor, function sample_or_pad_sequence_indices (line 104) | def sample_or_pad_sequence_indices(sequence: tf.Tensor, function random_sample_sequence (line 158) | def random_sample_sequence(sequence: tf.Tensor, function sample_linspace_sequence (line 175) | def sample_linspace_sequence(sequence: tf.Tensor, function resize_smallest (line 221) | def resize_smallest(frames: tf.Tensor, min_resize: int) -> tf.Tensor: function process_samples (line 252) | def process_samples(features_dict, num_frames=32, stride=1, is_training=... function space_to_depth_batch (line 285) | def space_to_depth_batch(features_dict): function reshape_windows (line 295) | def reshape_windows(features_dict, num_frames): function compute_accuracy_metrics (line 302) | def compute_accuracy_metrics(pred, gt, prefix=''): function forward_fn (line 313) | def forward_fn(images: jnp.ndarray, function main (line 336) | def main(argv): FILE: mmv/models/mm_embeddings.py function _setkey_if_not_exists (line 55) | def _setkey_if_not_exists(d, key, value): class AudioTextVideoEmbedding (line 60) | class AudioTextVideoEmbedding(hk.Module): method __init__ (line 63) | def __init__( method _get_pair_embedding_heads (line 153) | def _get_pair_embedding_heads(self, method _activate_interaction (line 183) | def _activate_interaction(self, inputs, activation_fn, is_training, method __call__ (line 199) | def __call__(self, class EmbeddingModule (line 353) | class EmbeddingModule(hk.Module): method __init__ (line 356) | def __init__(self, method __call__ (line 378) | def __call__(self, input_feature, is_training): class VisualModule (line 407) | class VisualModule(hk.Module): method __init__ (line 410) | def __init__(self, method __call__ (line 440) | def __call__(self, images, is_training): class AudioModule (line 446) | class AudioModule(hk.Module): method __init__ (line 449) | def __init__(self, method __call__ (line 478) | def __call__(self, class TextModule (line 492) | class TextModule(hk.Module): method __init__ (line 495) | def __init__(self, method __call__ (line 511) | def __call__(self, word_ids, is_training): FILE: mmv/models/normalization.py class _BatchNorm (line 26) | class _BatchNorm(hk.BatchNorm): method __init__ (line 29) | def __init__(self, method __call__ (line 50) | def __call__(self, class _CrossReplicaBatchNorm (line 57) | class _CrossReplicaBatchNorm(hk.BatchNorm): method __init__ (line 60) | def __init__(self, method __call__ (line 82) | def __call__(self, class _LayerNorm (line 89) | class _LayerNorm(hk.LayerNorm): method __init__ (line 92) | def __init__(self, method __call__ (line 102) | def __call__(self, # pytype: disable=signature-mismatch # overriding... function get_normalize_fn (line 116) | def get_normalize_fn( FILE: mmv/models/resnet.py class BottleneckBlock (line 31) | class BottleneckBlock(hk.Module): method __init__ (line 35) | def __init__(self, method __call__ (line 80) | def __call__(self, class BasicBlock (line 99) | class BasicBlock(hk.Module): method __init__ (line 103) | def __init__(self, method __call__ (line 140) | def __call__(self, class ResNetUnit (line 159) | class ResNetUnit(hk.Module): method __init__ (line 163) | def __init__(self, method __call__ (line 179) | def __call__(self, class ResNetV2 (line 208) | class ResNetV2(hk.Module): method __init__ (line 223) | def __init__(self, method __call__ (line 294) | def __call__(self, inputs, is_training, final_endpoint='output'): FILE: mmv/models/s3d.py class _MaxPool (line 28) | class _MaxPool(hk.MaxPool): method __call__ (line 31) | def __call__(self, function self_gating (line 38) | def self_gating(inputs: types.TensorLike) -> jnp.ndarray: class SUnit3D (line 66) | class SUnit3D(hk.Module): method __init__ (line 69) | def __init__( method __call__ (line 142) | def __call__( class InceptionBlockV13D (line 167) | class InceptionBlockV13D(hk.Module): method __init__ (line 177) | def __init__(self, method __call__ (line 216) | def __call__( class S3D (line 298) | class S3D(hk.Module): method __init__ (line 329) | def __init__(self, method __call__ (line 377) | def __call__(self, FILE: mmv/models/s3d_test.py class _CallableS3D (line 29) | class _CallableS3D: method __init__ (line 32) | def __init__(self, *args, **kwargs): method init (line 41) | def init(self, inputs, **kwargs): method __call__ (line 45) | def __call__(self, inputs, **kwargs): class S3DTest (line 53) | class S3DTest(parameterized.TestCase): method test_endpoint_expected_output_dimensions (line 75) | def test_endpoint_expected_output_dimensions(self, endpoint, expected_... method test_space_to_depth (line 81) | def test_space_to_depth(self): FILE: mmv/models/tsm_resnet.py class TSMResNetBlock (line 34) | class TSMResNetBlock(hk.Module): method __init__ (line 42) | def __init__(self, method __call__ (line 75) | def __call__(self, class TSMResNetUnit (line 154) | class TSMResNetUnit(hk.Module): method __init__ (line 157) | def __init__(self, method __call__ (line 189) | def __call__(self, class TSMResNetV2 (line 217) | class TSMResNetV2(hk.Module): method __init__ (line 231) | def __init__(self, method __call__ (line 279) | def __call__( FILE: mmv/models/tsm_resnet_test.py class TSMResNetTest (line 28) | class TSMResNetTest(parameterized.TestCase): method test_output_dimension (line 39) | def test_output_dimension(self, final_endpoint, expected_shape): method test_tpu_mode (line 51) | def test_tpu_mode(self): FILE: mmv/models/tsm_utils.py function prepare_inputs (line 26) | def prepare_inputs( function prepare_outputs (line 42) | def prepare_outputs(outputs: types.TensorLike, function apply_temporal_shift (line 59) | def apply_temporal_shift( function temporal_shift_gpu (line 75) | def temporal_shift_gpu( function temporal_shift_tpu (line 109) | def temporal_shift_tpu( FILE: mmv/models/tsm_utils_test.py class TsmUtilsTest (line 27) | class TsmUtilsTest(parameterized.TestCase): method test_prepare_inputs (line 33) | def test_prepare_inputs(self, input_shape, expected_mode, expected_shape, method test_prepare_outputs (line 42) | def test_prepare_outputs(self): method test_apply_tsm (line 51) | def test_apply_tsm(self): FILE: mmv/utils/checkpoint.py function load_checkpoint (line 22) | def load_checkpoint(checkpoint_path): FILE: mmv/utils/ucf101_dataset.py class ModUcf101 (line 47) | class ModUcf101(tfds.video.Ucf101): method _info (line 51) | def _info(self): FILE: nfnets/agc_optax.py function compute_norm (line 21) | def compute_norm(x, axis, keepdims): function unitwise_norm (line 26) | def unitwise_norm(x): function my_clip (line 42) | def my_clip(g_norm, max_norm, grad): function adaptive_grad_clip (line 51) | def adaptive_grad_clip(clip, eps=1e-3) -> optax.GradientTransformation: FILE: nfnets/autoaugment.py function policy_v0 (line 36) | def policy_v0(): function policy_vtest (line 71) | def policy_vtest(): function blend (line 82) | def blend(image1, image2, factor): function cutout (line 125) | def cutout(image, pad_size, replace=0): function solarize (line 176) | def solarize(image, threshold=128): function solarize_add (line 183) | def solarize_add(image, addition=0, threshold=128): function color (line 193) | def color(image, factor): function contrast (line 199) | def contrast(image, factor): function brightness (line 216) | def brightness(image, factor): function posterize (line 222) | def posterize(image, bits): function rotate (line 228) | def rotate(image, degrees, replace): function translate_x (line 253) | def translate_x(image, pixels, replace): function translate_y (line 259) | def translate_y(image, pixels, replace): function shear_x (line 265) | def shear_x(image, level, replace): function shear_y (line 276) | def shear_y(image, level, replace): function autocontrast (line 287) | def autocontrast(image): function sharpness (line 326) | def sharpness(image, factor): function equalize (line 358) | def equalize(image): function invert (line 398) | def invert(image): function wrap (line 404) | def wrap(image): function unwrap (line 412) | def unwrap(image, replace): function _randomly_negate_tensor (line 470) | def _randomly_negate_tensor(tensor): function _rotate_level_to_arg (line 477) | def _rotate_level_to_arg(level): function _shrink_level_to_arg (line 483) | def _shrink_level_to_arg(level): function _enhance_level_to_arg (line 492) | def _enhance_level_to_arg(level): function _shear_level_to_arg (line 496) | def _shear_level_to_arg(level): function _translate_level_to_arg (line 503) | def _translate_level_to_arg(level, translate_const): function level_to_arg (line 510) | def level_to_arg(hparams): function _parse_policy_info (line 535) | def _parse_policy_info(name, prob, level, replace_value, augmentation_hp... function _apply_func_with_prob (line 558) | def _apply_func_with_prob(func, image, args, prob): function select_and_apply_random_policy (line 579) | def select_and_apply_random_policy(policies, image): function build_and_apply_nas_policy (line 592) | def build_and_apply_nas_policy(policies, image, function distort_image_with_autoaugment (line 641) | def distort_image_with_autoaugment(image, augmentation_name): function distort_image_with_randaugment (line 672) | def distort_image_with_randaugment(image, num_layers, magnitude): FILE: nfnets/base.py class WSConv2D (line 121) | class WSConv2D(hk.Conv2D): method standardize_weight (line 125) | def standardize_weight(self, weight, eps=1e-4): method __call__ (line 138) | def __call__(self, inputs: jnp.ndarray, eps: float = 1e-4) -> jnp.ndar... function signal_metrics (line 157) | def signal_metrics(x, i): function count_conv_flops (line 167) | def count_conv_flops(in_ch, conv, h, w): class SqueezeExcite (line 177) | class SqueezeExcite(hk.Module): method __init__ (line 180) | def __init__(self, in_ch, out_ch, se_ratio=0.5, method __call__ (line 195) | def __call__(self, x): class StochDepth (line 202) | class StochDepth(hk.Module): method __init__ (line 205) | def __init__(self, drop_rate, scale_by_keep=False, name=None): method __call__ (line 210) | def __call__(self, x, is_training) -> jnp.ndarray: FILE: nfnets/dataset.py class Split (line 39) | class Split(enum.Enum): method from_string (line 47) | def from_string(cls, name: Text) -> 'Split': method num_examples (line 53) | def num_examples(self): function load (line 58) | def load( function cutmix_padding (line 226) | def cutmix_padding(h, w): function my_cutmix (line 266) | def my_cutmix(batch): function my_mixup (line 279) | def my_mixup(batch): function mixup_or_cutmix (line 292) | def mixup_or_cutmix(batch): function my_mixup_cutmix (line 301) | def my_mixup_cutmix(batch): function _to_tfds_split (line 324) | def _to_tfds_split(split: Split) -> tfds.Split: function _shard (line 333) | def _shard(split: Split, shard_index: int, num_shards: int) -> Tuple[int... function _preprocess_image (line 347) | def _preprocess_image( function _augment_image (line 377) | def _augment_image( function _normalize_image (line 407) | def _normalize_image(image: tf.Tensor) -> tf.Tensor: function _distorted_bounding_box_crop (line 414) | def _distorted_bounding_box_crop( function _decode_and_random_crop (line 442) | def _decode_and_random_crop(image_bytes: tf.Tensor, function _decode_and_center_crop (line 464) | def _decode_and_center_crop( function get_shape (line 488) | def get_shape(image_bytes): function crop (line 497) | def crop(image_bytes, crop_window): function _decode_and_resize_then_crop (line 508) | def _decode_and_resize_then_crop( FILE: nfnets/experiment.py function get_config (line 41) | def get_config(): class Experiment (line 116) | class Experiment(experiment.AbstractExperiment): method __init__ (line 126) | def __init__(self, mode, config, init_rng): method _initialize_train (line 167) | def _initialize_train(self): method _make_opt (line 182) | def _make_opt(self): method _forward_fn (line 211) | def _forward_fn(self, inputs, is_training): method _one_hot (line 225) | def _one_hot(self, value): method _loss_fn (line 230) | def _loss_fn(self, params, state, inputs, rng): method _train_fn (line 252) | def _train_fn(self, params, states, opt_states, inputs, rng, global_step, method step (line 291) | def step(self, global_step, rng, *unused_args, **unused_kwargs): method _build_train_input (line 310) | def _build_train_input(self): method evaluate (line 336) | def evaluate(self, global_step, **unused_args): method _eval_epoch (line 344) | def _eval_epoch(self, params, state): method _eval_fn (line 361) | def _eval_fn(self, params, state, inputs): method _build_eval_input (line 371) | def _build_eval_input(self): FILE: nfnets/experiment_nf_regnets.py function get_config (line 21) | def get_config(): FILE: nfnets/experiment_nfnets.py function get_config (line 30) | def get_config(): class Experiment (line 102) | class Experiment(experiment.Experiment): method _make_opt (line 105) | def _make_opt(self): FILE: nfnets/fixup_resnet.py class FixUp_ResNet (line 31) | class FixUp_ResNet(hk.Module): method __init__ (line 42) | def __init__(self, num_classes, variant='ResNet50', width=4, method __call__ (line 88) | def __call__(self, x, is_training=True, return_metrics=False): method count_flops (line 115) | def count_flops(self, h, w): class ResBlock (line 129) | class ResBlock(hk.Module): method __init__ (line 132) | def __init__(self, in_ch, out_ch, num_blocks, bottleneck_ratio=0.25, method __call__ (line 172) | def __call__(self, x, is_training): method count_flops (line 200) | def count_flops(self, h, w): FILE: nfnets/nf_regnet.py class NF_RegNet (line 24) | class NF_RegNet(hk.Module): method __init__ (line 29) | def __init__(self, num_classes, variant='B0', method __call__ (line 93) | def __call__(self, x, is_training=True, return_metrics=False): method count_flops (line 116) | def count_flops(self, h, w): class NFBlock (line 133) | class NFBlock(hk.Module): method __init__ (line 136) | def __init__(self, in_ch, out_ch, expansion=2.25, se_ratio=0.5, method __call__ (line 177) | def __call__(self, x, is_training): method count_flops (line 201) | def count_flops(self, h, w): FILE: nfnets/nf_resnet.py class NF_ResNet (line 24) | class NF_ResNet(hk.Module): method __init__ (line 35) | def __init__(self, num_classes, variant='ResNet50', width=4, method __call__ (line 94) | def __call__(self, x, is_training=True, return_metrics=False): method count_flops (line 118) | def count_flops(self, h, w): class NFResBlock (line 132) | class NFResBlock(hk.Module): method __init__ (line 135) | def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25, method __call__ (line 177) | def __call__(self, x, is_training): method count_flops (line 197) | def count_flops(self, h, w): FILE: nfnets/nfnet.py class NFNet (line 30) | class NFNet(hk.Module): method __init__ (line 40) | def __init__(self, num_classes, variant='F0', method __call__ (line 131) | def __call__(self, x, is_training=True, return_metrics=False): method count_flops (line 154) | def count_flops(self, h, w): class NFBlock (line 176) | class NFBlock(hk.Module): method __init__ (line 179) | def __init__(self, in_ch, out_ch, expansion=0.5, se_ratio=0.5, method __call__ (line 227) | def __call__(self, x, is_training): method count_flops (line 253) | def count_flops(self, h, w): FILE: nfnets/optim.py class Optimizer (line 25) | class Optimizer(object): method __init__ (line 28) | def __init__(self, params, defaults): method add_hyperparam_group (line 49) | def add_hyperparam_group(self, group, suffix, defaults): method create_param_groups (line 62) | def create_param_groups(self, params, defaults): method create_buffers (line 74) | def create_buffers(self, name, params): method get_opt_params (line 78) | def get_opt_params(self, param_name, itr): method get_hyper (line 91) | def get_hyper(self, param_name, hyper_name): method plugin (line 96) | def plugin(self, states): method states (line 99) | def states(self): method broadcast (line 102) | def broadcast(self): method gather (line 108) | def gather(self): method __setattr__ (line 117) | def __setattr__(self, name, value): method __getattr__ (line 126) | def __getattr__(self, name): method step (line 135) | def step(self, params, grads, states, itr=None): function _is_non_empty_two_level_mapping (line 154) | def _is_non_empty_two_level_mapping(obj): class Schedule (line 164) | class Schedule(object): class CosineDecay (line 168) | class CosineDecay(Schedule): method __init__ (line 171) | def __init__(self, min_val, max_val, num_steps): method __call__ (line 176) | def __call__(self, itr): class WarmupCosineDecay (line 181) | class WarmupCosineDecay(Schedule): method __init__ (line 184) | def __init__(self, start_val, min_val, max_val, num_steps, warmup_steps): method __call__ (line 191) | def __call__(self, itr): class WarmupExpDecay (line 203) | class WarmupExpDecay(Schedule): method __init__ (line 206) | def __init__(self, start_val, max_val, warmup_steps, method __call__ (line 214) | def __call__(self, itr): class SGD (line 226) | class SGD(Optimizer): method __init__ (line 242) | def __init__(self, params, lr, weight_decay=None, method create_buffers (line 249) | def create_buffers(self, name, param): method update_param (line 256) | def update_param(self, param, grad, state, opt_params): class Adam (line 275) | class Adam(Optimizer): method __init__ (line 297) | def __init__(self, params, lr, beta1=0.9, beta2=0.999, method create_buffers (line 305) | def create_buffers(self, name, param): method update_param (line 314) | def update_param(self, param, grad, state, opt_params): class RMSProp (line 346) | class RMSProp(Optimizer): method __init__ (line 368) | def __init__(self, params, lr, decay, momentum, weight_decay=None, eps... method create_buffers (line 374) | def create_buffers(self, name, param): method update_param (line 382) | def update_param(self, param, grad, state, opt_params): class Fromage (line 405) | class Fromage(Optimizer): method __init__ (line 419) | def __init__(self, params, lr, weight_decay=None, eps=1e-5): method create_buffers (line 423) | def create_buffers(self, name, param): # pylint: disable=unused-argument method update_param (line 427) | def update_param(self, param, grad, state, opt_params): function compute_norm (line 440) | def compute_norm(x, axis, keepdims): function unitwise_norm (line 446) | def unitwise_norm(x): class SGD_AGC (line 462) | class SGD_AGC(Optimizer): # pylint:disable=invalid-name method __init__ (line 472) | def __init__(self, params, lr, weight_decay=None, method create_buffers (line 481) | def create_buffers(self, name, param): method update_param (line 484) | def update_param(self, param, grad, state, opt_params): class Hybrid (line 502) | class Hybrid(Optimizer): method __init__ (line 515) | def __init__(self, param_groups): method create_buffers (line 521) | def create_buffers(self, name, param): method update_param (line 524) | def update_param(self, param, grad, state, opt_params): FILE: nfnets/resnet.py class ResNet (line 23) | class ResNet(hk.Module): method __init__ (line 34) | def __init__(self, width, num_classes, method __call__ (line 83) | def __call__(self, x, is_training, test_local_stats=False, class ResBlockV2 (line 114) | class ResBlockV2(hk.Module): method __init__ (line 117) | def __init__(self, out_ch, stride=1, use_projection=False, method __call__ (line 150) | def __call__(self, x, is_training, test_local_stats): class ResBlockV1 (line 167) | class ResBlockV1(ResBlockV2): method __call__ (line 170) | def __call__(self, x, is_training, test_local_stats): FILE: nfnets/skipinit_resnet.py class SkipInit_ResNet (line 31) | class SkipInit_ResNet(hk.Module): method __init__ (line 42) | def __init__(self, num_classes, variant='ResNet50', width=4, method __call__ (line 87) | def __call__(self, x, is_training=True, return_metrics=False): method count_flops (line 111) | def count_flops(self, h, w): class NFResBlock (line 125) | class NFResBlock(hk.Module): method __init__ (line 128) | def __init__(self, in_ch, out_ch, bottleneck_ratio=0.25, method __call__ (line 160) | def __call__(self, x, is_training): method count_flops (line 177) | def count_flops(self, h, w): FILE: nfnets/test.py function test_experiment (line 22) | def test_experiment(): function test_nfnet_experiment (line 42) | def test_nfnet_experiment(): FILE: nfnets/utils.py function reduce_fn (line 22) | def reduce_fn(x, mode): function softmax_cross_entropy (line 34) | def softmax_cross_entropy(logits, labels, reduction='sum'): function topk_correct (line 53) | def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)): function any_in (line 68) | def any_in(prediction, target): function tf1_ema (line 73) | def tf1_ema(ema_value, current_value, decay, step): function ema (line 79) | def ema(ema_value, current_value, decay, step): function _replicate (line 89) | def _replicate(x, devices=None): function broadcast (line 97) | def broadcast(obj): function split_tree (line 105) | def split_tree(tuple_tree, base_tree, n): function load_haiku_file (line 111) | def load_haiku_file(filename): function flatten_haiku_tree (line 118) | def flatten_haiku_tree(haiku_dict): FILE: object_attention_for_reasoning/model.py function append_ids (line 41) | def append_ids(tensor, id_vector, axis): class ClevrerTransformerModel (line 51) | class ClevrerTransformerModel(object): method __init__ (line 54) | def __init__(self, use_relative_positions, shuffle_objects, method _apply_transformers (line 79) | def _apply_transformers(self, lang_embedding, vision_embedding): method apply_model_descriptive (line 104) | def apply_model_descriptive(self, inputs): method apply_model_mc (line 140) | def apply_model_mc(self, inputs): FILE: object_attention_for_reasoning/run_model.py function load_monet_latents (line 36) | def load_monet_latents(base_dir, scene_index): function _split_string (line 42) | def _split_string(s): function _pad (line 47) | def _pad(array, length): function encode_sentence (line 52) | def encode_sentence(token_map, sentence, pad_length): function encode_choices (line 60) | def encode_choices(token_map, choices): function main (line 68) | def main(unused_argv): FILE: object_attention_for_reasoning/transformer.py function rel_shift (line 57) | def rel_shift(position_logits): function _layer_norm (line 105) | def _layer_norm(inputs): function _concat_and_slice (line 112) | def _concat_and_slice(prev_memory, new_memory): function simple_attention (line 119) | def simple_attention(queries, keys, values): class ResidualDropoutWrapper (line 125) | class ResidualDropoutWrapper(base.AbstractModule): method __init__ (line 131) | def __init__(self, method _build (line 141) | def _build(self, inputs, *args, **kwargs): function future_mask (line 162) | def future_mask(chunk_size, dtype): function _memory_size (line 172) | def _memory_size(state): function create_mask (line 180) | def create_mask(inputs, state, equal_window): function default_mlp (line 216) | def default_mlp(hidden_sizes, activate_final=False, init_std=2., **kwargs): function get_position_encodings (line 228) | def get_position_encodings(sequence_length, class MultiheadAttention (line 254) | class MultiheadAttention(base.AbstractModule): method __init__ (line 257) | def __init__(self, method multihead_linear (line 306) | def multihead_linear(self, inputs, name): method _build (line 318) | def _build(self, class TransformerTower (line 463) | class TransformerTower(base.AbstractModule): method __init__ (line 470) | def __init__(self, method get_sublayers (line 537) | def get_sublayers(self, is_training): method _build (line 559) | def _build(self, method attention_module (line 665) | def attention_module(self, i): FILE: ogb_lsc/mag/batching_utils.py function dynamically_batch (line 26) | def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple], function _batch_np (line 103) | def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: function _get_graph_size (line 123) | def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]: function _is_over_batch_size (line 130) | def _is_over_batch_size( FILE: ogb_lsc/mag/config.py function get_config (line 21) | def get_config(debug: bool = False) -> config_dict.ConfigDict: FILE: ogb_lsc/mag/csr_builder.py function _read_edge_data (line 54) | def _read_edge_data(path): function _build_coo (line 63) | def _build_coo(edges_data, use_boolean=False): function _get_output_paths (line 74) | def _get_output_paths(directory, content_names, use_boolean): function _write_csr (line 86) | def _write_csr(path, csr): function main (line 92) | def main(argv): FILE: ogb_lsc/mag/data_utils.py function get_raw_directory (line 96) | def get_raw_directory(data_root): function get_preprocessed_directory (line 100) | def get_preprocessed_directory(data_root): function _log_path_decorator (line 104) | def _log_path_decorator(fn): function load_csr (line 114) | def load_csr(path, debug=False): function load_npy (line 122) | def load_npy(path): function get_arrays (line 127) | def get_arrays(data_root="/data/", function add_nodes_year (line 215) | def add_nodes_year(graph, paper_year): function add_nodes_label (line 224) | def add_nodes_label(graph, paper_label): function add_nodes_embedding_from_array (line 233) | def add_nodes_embedding_from_array(graph, array): function get_graph_subsampling_dataset (line 246) | def get_graph_subsampling_dataset( function paper_features_to_author_features (line 292) | def paper_features_to_author_features( function author_features_to_institution_features (line 308) | def author_features_to_institution_features( function generate_fused_paper_adjacency_matrix (line 324) | def generate_fused_paper_adjacency_matrix(neighbor_indices, neighbor_dis... function generate_k_fold_splits (line 387) | def generate_k_fold_splits( function get_train_and_valid_idx_for_split (line 410) | def get_train_and_valid_idx_for_split( function generate_fused_node_labels (line 421) | def generate_fused_node_labels(neighbor_indices, neighbor_distances, function _pad_to_shape (line 446) | def _pad_to_shape( function _fix_adjacency_shapes (line 476) | def _fix_adjacency_shapes( FILE: ogb_lsc/mag/datasets.py class Batch (line 44) | class Batch(NamedTuple): function build_dataset_iterator (line 54) | def build_dataset_iterator( function _get_bitstring_year_representation (line 244) | def _get_bitstring_year_representation(year: np.ndarray): function _np_one_hot (line 252) | def _np_one_hot(targets: np.ndarray, nb_classes: int): function _get_one_hot_year_representation (line 258) | def _get_one_hot_year_representation( function _add_one_hot_features_to_batch (line 278) | def _add_one_hot_features_to_batch(batch: Batch) -> Batch: function _add_embeddings_to_batch (line 293) | def _add_embeddings_to_batch(batch: Batch, embeddings: np.ndarray) -> Ba... FILE: ogb_lsc/mag/download_mag.py class DataCorruptionError (line 56) | class DataCorruptionError(Exception): function _get_gcs_root (line 60) | def _get_gcs_root(): function _get_gcs_bucket (line 64) | def _get_gcs_bucket(): function _write_blob_to_destination (line 69) | def _write_blob_to_destination(blob, task_root, ignore_existing=True): function main (line 90) | def main(unused_argv): FILE: ogb_lsc/mag/ensemble_predictions.py function _np_one_hot (line 45) | def _np_one_hot(targets: np.ndarray, nb_classes: int): function ensemble_predictions (line 51) | def ensemble_predictions( function load_predictions (line 116) | def load_predictions(predictions_path, split): function generate_ensembled_predictions (line 153) | def generate_ensembled_predictions( function evaluate_validation (line 181) | def evaluate_validation(valid_predictions): function save_test_submission_file (line 193) | def save_test_submission_file(test_predictions, output_dir): function main (line 200) | def main(argv): FILE: ogb_lsc/mag/experiment.py class Experiment (line 138) | class Experiment(experiment.AbstractExperiment): method __init__ (line 149) | def __init__( method _train_init (line 181) | def _train_init(self): method _eval_init (line 195) | def _eval_init(self): method step (line 213) | def step( method _build_numpy_dataset_iterator (line 244) | def _build_numpy_dataset_iterator(self, split: str, is_training: bool): method _initialize_experiment_state (line 256) | def _initialize_experiment_state( method _get_learning_rate (line 286) | def _get_learning_rate(self, global_step: jnp.ndarray) -> jnp.ndarray: method _optimizer (line 292) | def _optimizer( method _forward_fn (line 302) | def _forward_fn( method _bgrl_loss (line 314) | def _bgrl_loss( method _loss (line 414) | def _loss( method _update_func (line 434) | def _update_func( method evaluate (line 506) | def evaluate(self, global_step, rng, **unused_kwargs): method _evaluate_with_ensemble (line 526) | def _evaluate_with_ensemble( method _maybe_save_predictions (line 546) | def _maybe_save_predictions(self, predictions, global_step): method _evaluate_params (line 559) | def _evaluate_params( method _log_results (line 607) | def _log_results(self, prefix, results): function _restore_state_to_in_memory_checkpointer (line 614) | def _restore_state_to_in_memory_checkpointer(restore_path): function _get_step_date_label (line 641) | def _get_step_date_label(global_step): function _save_state_from_in_memory_checkpointer (line 647) | def _save_state_from_in_memory_checkpointer( function _setup_signals (line 673) | def _setup_signals(save_model_fn): function main (line 700) | def main(argv, experiment_class: experiment.AbstractExperiment): FILE: ogb_lsc/mag/generate_validation_splits.py function main (line 35) | def main(argv): FILE: ogb_lsc/mag/losses.py class Predictions (line 29) | class Predictions(NamedTuple): function node_classification_loss (line 36) | def node_classification_loss( function get_predictions_labels_and_logits (line 68) | def get_predictions_labels_and_logits( function topk_correct (line 81) | def topk_correct( function ensemble_predictions_by_probability_average (line 94) | def ensemble_predictions_by_probability_average( function get_accuracy_dict (line 108) | def get_accuracy_dict(predictions: Predictions) -> Dict[str, float]: function bgrl_loss (line 123) | def bgrl_loss( function get_corrupted_view (line 150) | def get_corrupted_view( function _assert_consistent_predictions (line 184) | def _assert_consistent_predictions(predictions_list: Sequence[Prediction... function _l2_normalize (line 193) | def _l2_normalize( FILE: ogb_lsc/mag/models.py class ModelOutput (line 35) | class ModelOutput(NamedTuple): function build_update_fn (line 42) | def build_update_fn( function build_gn (line 82) | def build_gn( function _get_activation_fn (line 127) | def _get_activation_fn(name: str) -> Callable[[jnp.ndarray], jnp.ndarray]: class NodePropertyEncodeProcessDecode (line 137) | class NodePropertyEncodeProcessDecode(hk.Module): method __init__ (line 140) | def __init__( method _dropout_graph (line 168) | def _dropout_graph(self, graph: jraph.GraphsTuple) -> jraph.GraphsTuple: method _encode (line 176) | def _encode( method _process (line 201) | def _process( method _node_mlp (line 224) | def _node_mlp( method __call__ (line 241) | def __call__( FILE: ogb_lsc/mag/neighbor_builder.py function _read_paper_pca_features (line 38) | def _read_paper_pca_features(): function _read_adjacency_indices (line 45) | def _read_adjacency_indices(): function build_annoy_index (line 55) | def build_annoy_index(features): function _get_annoy_index_path (line 70) | def _get_annoy_index_path(): function save_annoy_index (line 74) | def save_annoy_index(annoy_index): function read_annoy_index (line 81) | def read_annoy_index(features): function compute_neighbor_indices_and_distances (line 89) | def compute_neighbor_indices_and_distances(features): function _write_neighbors (line 108) | def _write_neighbors(neighbor_indices, neighbor_distances): function _write_fused_edges (line 122) | def _write_fused_edges(fused_paper_adjacency_matrix): function _write_fused_nodes (line 135) | def _write_fused_nodes(fused_node_labels): function main (line 143) | def main(unused_argv): FILE: ogb_lsc/mag/pca_builder.py function _sample_vectors (line 40) | def _sample_vectors(vectors, num_samples, seed=0): function _pca (line 47) | def _pca(feat): function _read_raw_paper_features (line 54) | def _read_raw_paper_features(): function _get_principal_components (line 65) | def _get_principal_components(features, function _project_features_onto_principal_components (line 78) | def _project_features_onto_principal_components(features, function _read_adjacency_indices (line 103) | def _read_adjacency_indices(): function _compute_author_pca_features (line 113) | def _compute_author_pca_features(paper_pca_features, index_arrays): function _compute_institution_pca_features (line 118) | def _compute_institution_pca_features(author_pca_features, index_arrays): function _write_array (line 123) | def _write_array(path, array): function main (line 129) | def main(unused_argv): FILE: ogb_lsc/mag/schedules.py function apply_ema_decay (line 20) | def apply_ema_decay( function ema_decay_schedule (line 29) | def ema_decay_schedule( function _cosine_decay (line 42) | def _cosine_decay( function learning_schedule (line 54) | def learning_schedule( FILE: ogb_lsc/mag/split_and_save_indices.py function main (line 37) | def main(argv) -> None: FILE: ogb_lsc/mag/sub_sampler.py function get_or_sample_row (line 23) | def get_or_sample_row(node_id: int, function get_neighbours (line 50) | def get_neighbours(node_id: int, function get_senders (line 76) | def get_senders(neighbour_type: int, function make_edge_type_feature (line 90) | def make_edge_type_feature(node_type: int, neighbour_type: int): function subsample_graph (line 97) | def subsample_graph(paper_id: int, FILE: ogb_lsc/pcq/batching_utils.py function dynamically_batch (line 26) | def dynamically_batch(graphs_tuple_iterator: Iterator[jraph.GraphsTuple], function _batch_np (line 103) | def _batch_np(graphs: Sequence[jraph.GraphsTuple]) -> jraph.GraphsTuple: function _get_graph_size (line 123) | def _get_graph_size(graph: jraph.GraphsTuple) -> Tuple[int, int, int]: function _is_over_batch_size (line 130) | def _is_over_batch_size( FILE: ogb_lsc/pcq/config.py function get_config (line 21) | def get_config(debug: bool = False) -> config_dict.ConfigDict: FILE: ogb_lsc/pcq/conformer_utils.py function generate_conformers (line 27) | def generate_conformers( function atom_to_feature_vector (line 83) | def atom_to_feature_vector( function compute_conformer (line 102) | def compute_conformer(smile: str, max_iter: int = -1) -> np.ndarray: function get_random_rotation_matrix (line 145) | def get_random_rotation_matrix(include_mirror_symmetry: bool) -> tf.Tensor: function rotate (line 155) | def rotate(vectors: tf.Tensor, rotation_matrix: tf.Tensor) -> tf.Tensor: function _embed_conformers (line 160) | def _embed_conformers( function _minimize_by_mmff (line 223) | def _minimize_by_mmff( function _minimize_by_uff (line 251) | def _minimize_by_uff( function _get_symmetry_rotation_matrix (line 274) | def _get_symmetry_rotation_matrix(sign: tf.Tensor) -> tf.Tensor: function _quaternion_to_rotation_matrix (line 289) | def _quaternion_to_rotation_matrix(quaternion: tf.Tensor) -> tf.Tensor: function _get_random_rotation_3d (line 312) | def _get_random_rotation_3d() -> tf.Tensor: function _get_random_mirror_symmetry (line 320) | def _get_random_mirror_symmetry() -> tf.Tensor: FILE: ogb_lsc/pcq/dataset_utils.py function build_dataset_iterator (line 40) | def build_dataset_iterator( function _get_conformer_filter (line 161) | def _get_conformer_filter(with_nans: bool): function _numpy_to_tensor_spec (line 189) | def _numpy_to_tensor_spec(arr: np.ndarray) -> tf.TensorSpec: function _sample_uniform_categorical (line 199) | def _sample_uniform_categorical(num: int, size: int) -> tf.Tensor: function _downcast_ints (line 204) | def _downcast_ints(x): function _one_hot_atoms (line 210) | def _one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor: function _sample_one_hot_atoms (line 218) | def _sample_one_hot_atoms(atoms: tf.Tensor) -> tf.Tensor: function _one_hot_bonds (line 229) | def _one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor: function _sample_one_hot_bonds (line 237) | def _sample_one_hot_bonds(bonds: tf.Tensor) -> tf.Tensor: function _maybe_one_hot_atoms_with_noise (line 248) | def _maybe_one_hot_atoms_with_noise( function _load_smiles (line 283) | def _load_smiles( function _convert_ogb_graph_to_graphs_tuple (line 306) | def _convert_ogb_graph_to_graphs_tuple(ogb_graph): function _load_conformers (line 326) | def _load_conformers(indices: List[int], function _add_conformer_features (line 341) | def _add_conformer_features( function _get_pcq_graph_generator (line 384) | def _get_pcq_graph_generator(indices, smiles, labels, conformers): FILE: ogb_lsc/pcq/datasets.py function load_splits (line 31) | def load_splits() -> Dict[str, List[int]]: function load_kth_fold_indices (line 37) | def load_kth_fold_indices(data_root: str, k_fold_split_id: int) -> List[... function load_all_except_kth_fold_indices (line 43) | def load_all_except_kth_fold_indices(data_root: str, k_fold_split_id: int, function load_smile_strings (line 55) | def load_smile_strings( function load_cached_conformers (line 71) | def load_cached_conformers(cached_fname: str) -> Dict[str, np.ndarray]: function _get_pcq_dataset (line 77) | def _get_pcq_dataset(only_smiles: bool): function _load_pickle (line 81) | def _load_pickle(fname: str): FILE: ogb_lsc/pcq/download_pcq.py class DataCorruptionError (line 45) | class DataCorruptionError(Exception): function _get_gcs_root (line 49) | def _get_gcs_root(): function _get_gcs_bucket (line 53) | def _get_gcs_bucket(): function _write_blob_to_destination (line 58) | def _write_blob_to_destination(blob, task_root, ignore_existing=True): function main (line 79) | def main(unused_argv): FILE: ogb_lsc/pcq/ensemble_predictions.py class _Predictions (line 57) | class _Predictions(NamedTuple): function _load_dill (line 62) | def _load_dill(fname) -> bytes: function _sort_by_indices (line 67) | def _sort_by_indices(predictions: _Predictions) -> _Predictions: function load_predictions (line 74) | def load_predictions(path: str, split: str) -> _Predictions: function mean_mae_distance (line 84) | def mean_mae_distance(x, y): function _load_valid_labels (line 88) | def _load_valid_labels() -> np.ndarray: function evaluate_valid_predictions (line 93) | def evaluate_valid_predictions(ensembled_predictions: _Predictions): function clip_predictions (line 104) | def clip_predictions(predictions: _Predictions) -> _Predictions: function _generate_test_prediction_file (line 109) | def _generate_test_prediction_file(test_predictions: np.ndarray, function merge_complementary_results (line 133) | def merge_complementary_results(split: str, results_a: _Predictions, function ensemble_valid_predictions (line 153) | def ensemble_valid_predictions( function ensemble_test_predictions (line 177) | def ensemble_test_predictions( function create_submission_from_predictions (line 189) | def create_submission_from_predictions( function merge_predictions (line 201) | def merge_predictions(split: str) -> List[_Predictions]: function main (line 223) | def main(_): FILE: ogb_lsc/pcq/experiment.py function _get_step_date_label (line 51) | def _get_step_date_label(global_step: int): class _Predictions (line 57) | class _Predictions(NamedTuple): function tf1_ema (line 62) | def tf1_ema(ema_value, current_value, decay, step): function _sort_predictions_by_indices (line 68) | def _sort_predictions_by_indices(predictions: _Predictions): class Experiment (line 75) | class Experiment(experiment.AbstractExperiment): method __init__ (line 86) | def __init__(self, mode, init_rng, config): method step (line 117) | def step(self, global_step: jnp.ndarray, rng: jnp.ndarray, **unused_ar... method _construct_loss_config (line 136) | def _construct_loss_config(self): method _train_init (line 146) | def _train_init(self): method _loss (line 169) | def _loss( method _maybe_save_predictions (line 178) | def _maybe_save_predictions( method _build_numpy_dataset_iterator (line 195) | def _build_numpy_dataset_iterator(self, split: str, is_training: bool): method _update_parameters (line 208) | def _update_parameters( method evaluate (line 244) | def evaluate(self, global_step: jnp.ndarray, rng: jnp.ndarray, method _sum_regression_scalars (line 267) | def _sum_regression_scalars(self, preds: jnp.ndarray, method _get_prediction (line 280) | def _get_prediction( method _get_predictions (line 292) | def _get_predictions( method _eval_init (line 332) | def _eval_init(self): method _forward (line 336) | def _forward(self, **graph: Mapping[str, chex.ArrayTree]) -> chex.Arra... function _restore_state_to_in_memory_checkpointer (line 344) | def _restore_state_to_in_memory_checkpointer(restore_path): function _save_state_from_in_memory_checkpointer (line 371) | def _save_state_from_in_memory_checkpointer( function _setup_signals (line 397) | def _setup_signals(save_model_fn): function main (line 424) | def main(argv, experiment_class: experiment.AbstractExperiment): FILE: ogb_lsc/pcq/generate_conformer_features.py function generate_conformer_features (line 43) | def generate_conformer_features(smiles: List[str]) -> List[np.ndarray]: function main (line 52) | def main(_): FILE: ogb_lsc/pcq/generate_validation_splits.py function main (line 37) | def main(argv): FILE: ogb_lsc/pcq/model.py class RegressionLossConfig (line 39) | class RegressionLossConfig: function _sigmoid_cross_entropy (line 48) | def _sigmoid_cross_entropy( function _softmax_cross_entropy (line 57) | def _softmax_cross_entropy( function _regression_loss (line 65) | def _regression_loss( function _build_mlp (line 80) | def _build_mlp( function _compute_relative_displacement_and_distance (line 99) | def _compute_relative_displacement_and_distance( function _broadcast_global_to_nodes (line 123) | def _broadcast_global_to_nodes( function _broadcast_global_to_edges (line 134) | def _broadcast_global_to_edges( class GraphPropertyEncodeProcessDecode (line 145) | class GraphPropertyEncodeProcessDecode(hk.Module): method __init__ (line 148) | def __init__( method __call__ (line 206) | def __call__(self, graph: jraph.GraphsTuple) -> chex.ArrayTree: method get_loss (line 215) | def get_loss( method _prepare_features (line 268) | def _prepare_features(self, graph: jraph.GraphsTuple) -> jraph.GraphsT... method _encoder (line 308) | def _encoder( method _processor (line 335) | def _processor( method _decoder (line 412) | def _decoder( method _forward (line 441) | def _forward(self, graph: jraph.GraphsTuple, is_training: bool): method _get_node_auxiliary_loss (line 451) | def _get_node_auxiliary_loss( method _get_edge_auxiliary_loss (line 462) | def _get_edge_auxiliary_loss( method _get_loss (line 473) | def _get_loss(self, pred, targets, is_regression): function get_utilization_scalars (line 482) | def get_utilization_scalars( function sum_with_mask (line 495) | def sum_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: function _mean_with_mask (line 499) | def _mean_with_mask(array: jnp.ndarray, mask: jnp.ndarray) -> jnp.ndarray: function _mask_out_padding_graph (line 504) | def _mask_out_padding_graph( function _aggregate_nodes_to_globals (line 519) | def _aggregate_nodes_to_globals(graph, node_features): FILE: option_keyboard/auto_reset_environment.py class Base (line 30) | class Base(dm_env.Environment): method __init__ (line 37) | def __init__(self): method _reset (line 41) | def _reset(self): method _step (line 45) | def _step(self, action): method reset (line 48) | def reset(self): method step (line 52) | def step(self, action): FILE: option_keyboard/configs.py function get_task_config (line 19) | def get_task_config(): function get_pretrain_config (line 31) | def get_pretrain_config(): function get_fig4_task_config (line 43) | def get_fig4_task_config(): function get_fig5_task_config (line 55) | def get_fig5_task_config(default_w): FILE: option_keyboard/dqn_agent.py class Agent (line 23) | class Agent(): method __init__ (line 26) | def __init__( method _extract_observation (line 96) | def _extract_observation(self, obs): method step (line 99) | def step(self, timestep, is_training=False): method update (line 109) | def update(self, step_tm1, action, step_t): class ValueNet (line 127) | class ValueNet(snt.AbstractModule): method __init__ (line 130) | def __init__(self, method _build (line 145) | def _build(self, observation): method num_actions (line 153) | def num_actions(self): function _batched_index (line 157) | def _batched_index(values, indices): FILE: option_keyboard/environment_wrappers.py class EnvironmentWithLogging (line 32) | class EnvironmentWithLogging(dm_env.Environment): method __init__ (line 35) | def __init__(self, env): method reset (line 39) | def reset(self): method step (line 43) | def step(self, action): method episode_return (line 55) | def episode_return(self): method action_spec (line 58) | def action_spec(self): method observation_spec (line 61) | def observation_spec(self): method __getattr__ (line 64) | def __getattr__(self, name): class EnvironmentWithKeyboard (line 68) | class EnvironmentWithKeyboard(dm_env.Environment): method __init__ (line 71) | def __init__(self, method _compute_reward (line 103) | def _compute_reward(self, option, obs): method reset (line 106) | def reset(self): method step (line 109) | def step(self, option): method _should_terminate (line 145) | def _should_terminate(self, option, obs): method action_spec (line 155) | def action_spec(self): method _extract_observation (line 159) | def _extract_observation(self, obs): method observation_spec (line 162) | def observation_spec(self): method __getattr__ (line 165) | def __getattr__(self, name): class EnvironmentWithKeyboardDirect (line 169) | class EnvironmentWithKeyboardDirect(dm_env.Environment): method __init__ (line 178) | def __init__(self, method _compute_reward (line 205) | def _compute_reward(self, option, obs): method reset (line 209) | def reset(self): method step (line 212) | def step(self, option): method _should_terminate (line 248) | def _should_terminate(self, option, obs): method action_spec (line 258) | def action_spec(self): method _extract_observation (line 265) | def _extract_observation(self, obs): method observation_spec (line 268) | def observation_spec(self): method __getattr__ (line 271) | def __getattr__(self, name): function _discretize_actions (line 275) | def _discretize_actions(num_actions_per_dim, class EnvironmentWithLearnedPhi (line 310) | class EnvironmentWithLearnedPhi(dm_env.Environment): method __init__ (line 313) | def __init__(self, env, model_path): method reset (line 337) | def reset(self): method step (line 341) | def step(self, action): method action_spec (line 356) | def action_spec(self): method observation (line 359) | def observation(self): method observation_spec (line 364) | def observation_spec(self): method __getattr__ (line 374) | def __getattr__(self, name): FILE: option_keyboard/experiment.py function _ema (line 24) | def _ema(base, val, decay=0.995): function run (line 28) | def run(env, agent, num_episodes, report_every=200, num_eval_reps=1): function run_episode (line 74) | def run_episode(environment, agent, is_training=False): function write_returns_to_file (line 93) | def write_returns_to_file(path, returns): FILE: option_keyboard/gpe_gpi_experiments/eval_keyboard_fig5.py function evaluate_keyboard (line 83) | def evaluate_keyboard(keyboard_path, weights_to_sweep): function main (line 124) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/regressed_agent.py class Agent (line 22) | class Agent(): method __init__ (line 25) | def __init__( method step (line 68) | def step(self, timestep, is_training=False): method update (line 78) | def update(self, step_tm1, action, step_t): method get_logs (line 93) | def get_logs(self): FILE: option_keyboard/gpe_gpi_experiments/run_dqn_fig4b.py function main (line 39) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_dqn_fig5.py function main (line 40) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4b.py function main (line 56) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_regressed_w_fig4c.py function main (line 56) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_regressed_w_with_phi_fig4c.py function main (line 66) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_true_w_fig4.py function main (line 56) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/run_true_w_fig6.py function main (line 57) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/train_keyboard.py function main (line 35) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/train_keyboard_with_phi.py function main (line 36) | def main(argv): FILE: option_keyboard/gpe_gpi_experiments/train_phi_model.py function collect_experience (line 51) | def collect_experience(env, num_episodes, verbose=False): class PhiModel (line 96) | class PhiModel(snt.AbstractModule): method __init__ (line 99) | def __init__(self, method _build (line 111) | def _build(self, observation, actions): function create_ph (line 128) | def create_ph(tensor): function main (line 132) | def main(argv): FILE: option_keyboard/keyboard_agent.py class Agent (line 27) | class Agent(): method __init__ (line 30) | def __init__( method keyboard (line 134) | def keyboard(self): method _extract_observation (line 137) | def _extract_observation(self, obs): method step (line 140) | def step(self, timestep, is_training=False): method update (line 152) | def update(self, step_tm1, action, step_t): method export (line 169) | def export(self, path): class OptionValueNet (line 176) | class OptionValueNet(snt.AbstractModule): method __init__ (line 179) | def __init__(self, method _build (line 200) | def _build(self, observation): method gpi (line 212) | def gpi(self, observation, cumulant_weights): method num_cumulants (line 222) | def num_cumulants(self): method num_policies (line 226) | def num_policies(self): method num_actions (line 230) | def num_actions(self): function _batched_index (line 234) | def _batched_index(values, indices): FILE: option_keyboard/keyboard_utils.py function create_and_train_keyboard (line 27) | def create_and_train_keyboard(num_episodes, function create_and_train_keyboard_with_phi (line 59) | def create_and_train_keyboard_with_phi(num_episodes, FILE: option_keyboard/run_dqn.py function main (line 36) | def main(argv): FILE: option_keyboard/run_dqn_test.py class RunDQNTest (line 28) | class RunDQNTest(absltest.TestCase): method test_run (line 30) | def test_run(self): FILE: option_keyboard/run_ok.py function main (line 44) | def main(argv): FILE: option_keyboard/run_ok_test.py class RunDQNTest (line 28) | class RunDQNTest(absltest.TestCase): method test_run (line 30) | def test_run(self): FILE: option_keyboard/scavenger.py class Action (line 31) | class Action(enum.IntEnum): function _one_hot (line 39) | def _one_hot(indices, depth): function _random_pos (line 43) | def _random_pos(arena_size): class Scavenger (line 47) | class Scavenger(auto_reset_environment.Base): method __init__ (line 50) | def __init__(self, method state (line 84) | def state(self): method set_state (line 93) | def set_state(self, state): method player_pos (line 102) | def player_pos(self): method _reset (line 105) | def _reset(self): method _step (line 139) | def _step(self, action): method observation (line 183) | def observation(self, force_non_egocentric=False): method observation_spec (line 217) | def observation_spec(self): method action_spec (line 245) | def action_spec(self): class SequentialCollectionRewarder (line 249) | class SequentialCollectionRewarder(object): method get_reward (line 252) | def get_reward(self, state, consumed): class BalancedCollectionRewarder (line 268) | class BalancedCollectionRewarder(object): method get_reward (line 271) | def get_reward(self, state, consumed): FILE: option_keyboard/smart_module.py function _getcallargs (line 31) | def _getcallargs(signature, *args, **kwargs): function _to_placeholder (line 39) | def _to_placeholder(arg): class SmartModuleExport (line 47) | class SmartModuleExport(object): method __init__ (line 50) | def __init__(self, object_factory): method _create_captured_method (line 57) | def _create_captured_method(self, method_name): method __getattr__ (line 79) | def __getattr__(self, name): method __call__ (line 94) | def __call__(self, *args, **kwargs): method export (line 97) | def export(self, path, session, overwrite=False): class SmartModuleImport (line 171) | class SmartModuleImport(object): method __init__ (line 174) | def __init__(self, module): method _create_wrapped_method (line 183) | def _create_wrapped_method(self, method): method __getattr__ (line 216) | def __getattr__(self, name): method __call__ (line 225) | def __call__(self, *args, **kwargs): FILE: perceiver/bytes_tokenizer.py class BytesTokenizer (line 20) | class BytesTokenizer: method __init__ (line 23) | def __init__(self): method to_string (line 26) | def to_string(self, inputs: np.ndarray) -> str: method to_int (line 32) | def to_int(self, inputs: Union[str, bytes]) -> np.ndarray: method vocab_size (line 40) | def vocab_size(self) -> int: method pad_token (line 44) | def pad_token(self) -> int: method bos_token (line 48) | def bos_token(self) -> int: method eos_token (line 52) | def eos_token(self) -> int: method mask_token (line 56) | def mask_token(self) -> int: method cls_token (line 60) | def cls_token(self) -> int: method sep_token (line 64) | def sep_token(self) -> int: FILE: perceiver/io_processors.py function reverse_space_to_depth (line 36) | def reverse_space_to_depth( function space_to_depth (line 55) | def space_to_depth( function extract_patches (line 74) | def extract_patches(images: jnp.ndarray, function patches_for_flow (line 124) | def patches_for_flow(inputs: jnp.ndarray) -> jnp.ndarray: class Conv2DDownsample (line 145) | class Conv2DDownsample(hk.Module): method __init__ (line 148) | def __init__( method __call__ (line 193) | def __call__(self, inputs: jnp.ndarray, *, class Conv2DUpsample (line 209) | class Conv2DUpsample(hk.Module): method __init__ (line 212) | def __init__( method __call__ (line 241) | def __call__(self, inputs: jnp.ndarray, *, class Conv3DUpsample (line 252) | class Conv3DUpsample(hk.Module): method __init__ (line 255) | def __init__(self, method __call__ (line 267) | def __call__(self, x: jnp.ndarray, *, is_training: bool) -> jnp.ndarray: class ImagePreprocessor (line 291) | class ImagePreprocessor(hk.Module): method __init__ (line 294) | def __init__( method _build_network_inputs (line 354) | def _build_network_inputs( method __call__ (line 388) | def __call__( class ImagePostprocessor (line 441) | class ImagePostprocessor(hk.Module): method __init__ (line 444) | def __init__( method __call__ (line 487) | def __call__( class OneHotPreprocessor (line 513) | class OneHotPreprocessor(hk.Module): method __init__ (line 516) | def __init__(self, name: Optional[str] = None): method __call__ (line 519) | def __call__(self, inputs: jnp.ndarray, *, class AudioPreprocessor (line 531) | class AudioPreprocessor(hk.Module): method __init__ (line 534) | def __init__( method _build_network_inputs (line 565) | def _build_network_inputs( method __call__ (line 588) | def __call__(self, inputs: jnp.ndarray, *, class AudioPostprocessor (line 600) | class AudioPostprocessor(hk.Module): method __init__ (line 603) | def __init__( method __call__ (line 617) | def __call__(self, inputs: jnp.ndarray, *, class IdentityPostprocessor (line 625) | class IdentityPostprocessor(hk.Module): method __init__ (line 628) | def __init__(self, name: Optional[str] = None): method __call__ (line 631) | def __call__(self, inputs: jnp.ndarray, *, function restructure (line 638) | def restructure(modality_sizes: ModalitySizeT, class MultimodalPreprocessor (line 659) | class MultimodalPreprocessor(hk.Module): method __init__ (line 666) | def __init__( method __call__ (line 688) | def __call__(self, inputs: jnp.ndarray, *, class MultimodalPostprocessor (line 736) | class MultimodalPostprocessor(hk.Module): method __init__ (line 739) | def __init__( method __call__ (line 757) | def __call__( class ClassificationPostprocessor (line 773) | class ClassificationPostprocessor(hk.Module): method __init__ (line 776) | def __init__( method __call__ (line 783) | def __call__(self, inputs: jnp.ndarray, *, class ProjectionPostprocessor (line 791) | class ProjectionPostprocessor(hk.Module): method __init__ (line 794) | def __init__( method __call__ (line 801) | def __call__(self, inputs: jnp.ndarray, *, class EmbeddingDecoder (line 809) | class EmbeddingDecoder(hk.Module): method __init__ (line 812) | def __init__(self, embedding_matrix: jnp.ndarray, name='embedding_deco... method __call__ (line 823) | def __call__(self, embeddings: jnp.ndarray) -> jnp.ndarray: FILE: perceiver/io_processors_test.py function _create_test_image (line 23) | def _create_test_image(shape): function test_space_to_depth_image (line 28) | def test_space_to_depth_image(): function test_space_to_depth_video (line 35) | def test_space_to_depth_video(): function test_reverse_space_to_depth_image (line 43) | def test_reverse_space_to_depth_image(): function test_reverse_space_to_depth_video (line 50) | def test_reverse_space_to_depth_video(): function test_extract_patches (line 58) | def test_extract_patches(): FILE: perceiver/perceiver.py function attend (line 33) | def attend(q, k, v, dropout_prob=0.0, attention_mask=None): function conv_1d (line 81) | def conv_1d( function layer_norm (line 94) | def layer_norm(x, name=None): function make_cross_attention_mask (line 99) | def make_cross_attention_mask(query_mask, kv_mask): class Attention (line 112) | class Attention(hk.Module): method __init__ (line 115) | def __init__(self, method __call__ (line 137) | def __call__(self, inputs_q, inputs_kv, attention_mask=None): class MLP (line 180) | class MLP(hk.Module): method __init__ (line 183) | def __init__(self, method __call__ (line 193) | def __call__(self, x, *, is_training): class SelfAttention (line 206) | class SelfAttention(hk.Module): method __init__ (line 209) | def __init__(self, method __call__ (line 229) | def __call__(self, class CrossAttention (line 257) | class CrossAttention(hk.Module): method __init__ (line 260) | def __init__(self, method __call__ (line 284) | def __call__(self, class Perceiver (line 340) | class Perceiver(hk.Module): method __init__ (line 343) | def __init__( method __call__ (line 358) | def __call__(self, inputs, *, is_training, subsampled_output_points=None, class PerceiverEncoder (line 392) | class PerceiverEncoder(hk.Module): method __init__ (line 395) | def __init__( method latents (line 457) | def latents(self, inputs): method __call__ (line 461) | def __call__(self, inputs, z, *, is_training, input_mask=None): class AbstractPerceiverDecoder (line 475) | class AbstractPerceiverDecoder(hk.Module, metaclass=abc.ABCMeta): method decoder_query (line 479) | def decoder_query(self, inputs, modality_sizes=None, inputs_without_po... method output_shape (line 484) | def output_shape(self, inputs): method __call__ (line 488) | def __call__(self, query, z, *, is_training, query_mask=None): class ProjectionDecoder (line 492) | class ProjectionDecoder(AbstractPerceiverDecoder): method __init__ (line 495) | def __init__( method decoder_query (line 506) | def decoder_query(self, inputs, modality_sizes=None, inputs_without_po... method output_shape (line 510) | def output_shape(self, inputs): method __call__ (line 513) | def __call__(self, query, z, *, is_training, query_mask=None): class BasicDecoder (line 521) | class BasicDecoder(AbstractPerceiverDecoder): method __init__ (line 524) | def __init__(self, method output_shape (line 566) | def output_shape(self, inputs): method decoder_query (line 570) | def decoder_query(self, inputs, modality_sizes=None, method __call__ (line 597) | def __call__(self, query, z, *, is_training, class ClassificationDecoder (line 627) | class ClassificationDecoder(AbstractPerceiverDecoder): method __init__ (line 633) | def __init__(self, method decoder_query (line 645) | def decoder_query(self, inputs, modality_sizes=None, method output_shape (line 651) | def output_shape(self, inputs): method __call__ (line 654) | def __call__(self, query, z, *, is_training, query_mask=None): class MultimodalDecoder (line 660) | class MultimodalDecoder(AbstractPerceiverDecoder): method __init__ (line 669) | def __init__(self, modalities, num_outputs, output_num_channels, method decoder_query (line 685) | def decoder_query(self, inputs, modality_sizes, inputs_without_pos=Non... method output_shape (line 723) | def output_shape(self, inputs): method __call__ (line 731) | def __call__(self, query, z, *, is_training, query_mask=None): class BasicVideoAutoencodingDecoder (line 736) | class BasicVideoAutoencodingDecoder(AbstractPerceiverDecoder): method __init__ (line 742) | def __init__(self, method decoder_query (line 759) | def decoder_query(self, inputs, modality_sizes=None, method output_shape (line 766) | def output_shape(self, inputs): method __call__ (line 770) | def __call__(self, query, z, *, is_training, query_mask=None): class FlowDecoder (line 777) | class FlowDecoder(AbstractPerceiverDecoder): method __init__ (line 780) | def __init__(self, method output_shape (line 795) | def output_shape(self, inputs): method decoder_query (line 801) | def decoder_query( method __call__ (line 809) | def __call__(self, query, z, *, is_training, query_mask=None): FILE: perceiver/position_encoding.py function generate_fourier_features (line 25) | def generate_fourier_features( function build_linear_positions (line 77) | def build_linear_positions(index_dims, output_range=(-1.0, 1.0)): class AbstractPositionEncoding (line 99) | class AbstractPositionEncoding(hk.Module, metaclass=abc.ABCMeta): method __call__ (line 103) | def __call__(self, batch_size, pos): class TrainablePositionEncoding (line 107) | class TrainablePositionEncoding(AbstractPositionEncoding): method __init__ (line 110) | def __init__(self, index_dim, num_channels=128, init_scale=0.02, name=... method __call__ (line 116) | def __call__(self, batch_size, pos=None): function _check_or_build_spatial_positions (line 128) | def _check_or_build_spatial_positions(pos, index_dims, batch_size): class FourierPositionEncoding (line 153) | class FourierPositionEncoding(AbstractPositionEncoding): method __init__ (line 156) | def __init__(self, index_dims, num_bands, concat_pos=True, method __call__ (line 166) | def __call__(self, batch_size, pos=None): class PositionEncodingProjector (line 177) | class PositionEncodingProjector(AbstractPositionEncoding): method __init__ (line 180) | def __init__(self, output_size, base_position_encoding, name=None): method __call__ (line 185) | def __call__(self, batch_size, pos=None): function build_position_encoding (line 191) | def build_position_encoding( FILE: perceiver/train/autoaugment.py function policy_v0 (line 38) | def policy_v0(): function policy_vtest (line 73) | def policy_vtest(): function blend (line 84) | def blend(image1, image2, factor): function cutout (line 127) | def cutout(image, pad_size, replace=0): function solarize (line 178) | def solarize(image, threshold=128): function solarize_add (line 185) | def solarize_add(image, addition=0, threshold=128): function color (line 195) | def color(image, factor): function contrast (line 201) | def contrast(image, factor): function brightness (line 218) | def brightness(image, factor): function posterize (line 224) | def posterize(image, bits): function rotate (line 230) | def rotate(image, degrees, replace): function translate_x (line 255) | def translate_x(image, pixels, replace): function translate_y (line 261) | def translate_y(image, pixels, replace): function shear_x (line 267) | def shear_x(image, level, replace): function shear_y (line 278) | def shear_y(image, level, replace): function autocontrast (line 289) | def autocontrast(image): function sharpness (line 328) | def sharpness(image, factor): function equalize (line 360) | def equalize(image): function invert (line 400) | def invert(image): function wrap (line 406) | def wrap(image): function unwrap (line 414) | def unwrap(image, replace): function _randomly_negate_tensor (line 472) | def _randomly_negate_tensor(tensor): function _rotate_level_to_arg (line 479) | def _rotate_level_to_arg(level): function _shrink_level_to_arg (line 485) | def _shrink_level_to_arg(level): function _enhance_level_to_arg (line 494) | def _enhance_level_to_arg(level): function _shear_level_to_arg (line 498) | def _shear_level_to_arg(level): function _translate_level_to_arg (line 505) | def _translate_level_to_arg(level, translate_const): function level_to_arg (line 512) | def level_to_arg(hparams): function _parse_policy_info (line 537) | def _parse_policy_info(name, prob, level, replace_value, augmentation_hp... function _apply_func_with_prob (line 560) | def _apply_func_with_prob(func, image, args, prob): function select_and_apply_random_policy (line 581) | def select_and_apply_random_policy(policies, image): function build_and_apply_nas_policy (line 594) | def build_and_apply_nas_policy(policies, image, function distort_image_with_autoaugment (line 643) | def distort_image_with_autoaugment(image, augmentation_name): function distort_image_with_randaugment (line 674) | def distort_image_with_randaugment(image, num_layers, magnitude): FILE: perceiver/train/dataset.py class Split (line 41) | class Split(enum.Enum): method from_string (line 49) | def from_string(cls, name: Text) -> 'Split': method num_examples (line 55) | def num_examples(self): function load (line 60) | def load( function cutmix_padding (line 155) | def cutmix_padding(h, w): function my_cutmix (line 194) | def my_cutmix(batch): function my_mixup (line 207) | def my_mixup(batch): function my_mixup_cutmix (line 220) | def my_mixup_cutmix(batch): function _to_tfds_split (line 243) | def _to_tfds_split(split: Split) -> tfds.Split: function _shard (line 256) | def _shard( function _preprocess_image (line 271) | def _preprocess_image( function _normalize_image (line 307) | def _normalize_image(image: tf.Tensor) -> tf.Tensor: function _distorted_bounding_box_crop (line 314) | def _distorted_bounding_box_crop( function _decode_whole_image (line 350) | def _decode_whole_image(image_bytes: tf.Tensor) -> Tuple[tf.Tensor, tf.T... function _decode_and_random_crop (line 356) | def _decode_and_random_crop( function _center_crop (line 382) | def _center_crop(image, crop_dim): function _decode_and_center_crop (line 392) | def _decode_and_center_crop( FILE: perceiver/train/experiment.py function get_training_steps (line 62) | def get_training_steps(batch_size, n_epochs): function get_config (line 66) | def get_config(): class Experiment (line 225) | class Experiment(experiment.AbstractExperiment): method __init__ (line 237) | def __init__(self, mode, init_rng, config): method _forward_fn (line 265) | def _forward_fn( method step (line 294) | def step(self, global_step: int, rng: jnp.ndarray, # pytype: disable=... method _initialize_train (line 311) | def _initialize_train(self): method _load_data (line 342) | def _load_data(self, split, is_training, batch_dims): method _build_train_input (line 353) | def _build_train_input(self) -> Generator[dataset.Batch, None, None]: method _one_hot (line 371) | def _one_hot(self, value): method _loss_fn (line 376) | def _loss_fn( method _update_func (line 422) | def _update_func( method evaluate (line 468) | def evaluate(self, global_step, rng, **unused_args): method _eval_batch (line 476) | def _eval_batch( method _build_eval_input (line 505) | def _build_eval_input(self) -> Generator[dataset.Batch, None, None]: method _eval_epoch (line 513) | def _eval_epoch(self, rng): FILE: perceiver/train/utils.py function any_in (line 38) | def any_in(prediction, target): function topk_correct (line 43) | def topk_correct(logits, labels, mask=None, prefix='', topk=(1, 5)): function softmax_cross_entropy (line 57) | def softmax_cross_entropy(logits, labels): function _get_batch_scaled_lr (line 70) | def _get_batch_scaled_lr(total_batch_size, lr, scale_by_batch=True): function get_learning_rate_schedule (line 80) | def get_learning_rate_schedule( function _weight_decay_exclude (line 139) | def _weight_decay_exclude( class AddWeightDecayState (line 167) | class AddWeightDecayState(NamedTuple): function add_weight_decay (line 171) | def add_weight_decay( function make_optimizer (line 202) | def make_optimizer(optimizer_config, lr_schedule): FILE: physics_inspired_models/eval_metric.py function quit_function (line 30) | def quit_function(fn_name): function exit_after (line 36) | def exit_after(s): function do_grid_search (line 55) | def do_grid_search(data_x_exp, data_y, clf, parameters, cv): function symplectic_matrix (line 63) | def symplectic_matrix(dim): function create_latent_mask (line 73) | def create_latent_mask(z0, dist_std_threshold=0.5): function standardize_data (line 118) | def standardize_data(data): function find_best_polynomial (line 125) | def find_best_polynomial(data_x, data_y, max_poly_order, rsq_threshold, function eval_monomial_grad (line 204) | def eval_monomial_grad(feature, x, w, grad_acc): function compute_jacobian_manual (line 229) | def compute_jacobian_manual(x, polynomial_features, weight_matrix, toler... function calculate_jacobian_prod (line 244) | def calculate_jacobian_prod(jacobian, noise_eps=1e-6): function normalise_jacobian_prods (line 260) | def normalise_jacobian_prods(jacobian_preds): function calculate_symetric_score (line 272) | def calculate_symetric_score( FILE: physics_inspired_models/integrators.py function solve_ivp_dt (line 86) | def solve_ivp_dt( function solve_ivp_dt_two_directions (line 211) | def solve_ivp_dt_two_directions( function solve_ivp_t_eval (line 260) | def solve_ivp_t_eval( class RungaKutta (line 374) | class RungaKutta(GeneralIntegrator): method __init__ (line 377) | def __init__( method __call__ (line 394) | def __call__( class GeneralEuler (line 419) | class GeneralEuler(RungaKutta): method __init__ (line 422) | def __init__(self): class RungaKutta2 (line 431) | class RungaKutta2(RungaKutta): method __init__ (line 434) | def __init__(self): class RungaKutta4 (line 443) | class RungaKutta4(RungaKutta): method __init__ (line 446) | def __init__(self): class RungaKutta38 (line 457) | class RungaKutta38(RungaKutta): method __init__ (line 460) | def __init__(self): function solve_hamiltonian_ivp_dt (line 500) | def solve_hamiltonian_ivp_dt( function solve_hamiltonian_ivp_t_eval (line 598) | def solve_hamiltonian_ivp_t_eval( class CompositionSymplectic (line 692) | class CompositionSymplectic(SymplecticIntegrator): method __init__ (line 707) | def __init__( method __call__ (line 725) | def __call__( class SymplecticEuler (line 756) | class SymplecticEuler(CompositionSymplectic): method __init__ (line 767) | def __init__(self, position_first=True): class SymmetricCompositionSymplectic (line 782) | class SymmetricCompositionSymplectic(CompositionSymplectic): method __init__ (line 793) | def __init__( function symmetrize_coefficients (line 818) | def symmetrize_coefficients( class LeapFrog (line 832) | class LeapFrog(SymmetricCompositionSymplectic): method __init__ (line 845) | def __init__(self, position_first=False): class Ruth4 (line 862) | class Ruth4(SymmetricCompositionSymplectic): method __init__ (line 865) | def __init__(self): class Symmetric4 (line 884) | class Symmetric4(SymmetricCompositionSymplectic): method __init__ (line 887) | def __init__(self): class Symmetric6 (line 904) | class Symmetric6(SymmetricCompositionSymplectic): method __init__ (line 907) | def __init__(self): function coefficients_based_on_composing_second_order (line 926) | def coefficients_based_on_composing_second_order( class SymmetricSo4 (line 941) | class SymmetricSo4(SymmetricCompositionSymplectic): method __init__ (line 944) | def __init__(self, position_first: bool = False): class SymmetricSo6 (line 958) | class SymmetricSo6(SymmetricCompositionSymplectic): method __init__ (line 961) | def __init__(self, position_first: bool = False): class SymmetricSo8 (line 976) | class SymmetricSo8(SymmetricCompositionSymplectic): method __init__ (line 979) | def __init__(self, position_first: bool = False): function get_integrator (line 1028) | def get_integrator( FILE: physics_inspired_models/jaxline_configs.py function get_config (line 24) | def get_config(arg_string): function sym_metric_hgn_plus_plus_sweep (line 280) | def sym_metric_hgn_plus_plus_sweep(): function sym_metric_hgn_sweep (line 303) | def sym_metric_hgn_sweep(): function benchmark_hgn_overlap_sweep (line 310) | def benchmark_hgn_overlap_sweep(): function benchmark_lgn_sweep (line 327) | def benchmark_lgn_sweep(): function benchmark_ode_sweep (line 346) | def benchmark_ode_sweep(): function benchmark_rgn_sweep (line 365) | def benchmark_rgn_sweep(): function benchmark_ar_sweep (line 382) | def benchmark_ar_sweep(): FILE: physics_inspired_models/jaxline_train.py class HGNExperiment (line 37) | class HGNExperiment(experiment.AbstractExperiment): method __init__ (line 48) | def __init__(self, mode, init_rng, config): method _process_stats (line 85) | def _process_stats(self, stats, axis_name=None): method step (line 107) | def step(self, global_step, rng, **unused_args): method _initialize_train (line 131) | def _initialize_train(self): method _build_train_input (line 155) | def _build_train_input(self): method _jax_train_step_fn (line 171) | def _jax_train_step_fn(self, params, state, opt_state, rng_key, batch,... method _jax_burnin_fn (line 188) | def _jax_burnin_fn(self, params, state, rng_key, batch): method evaluate (line 203) | def evaluate(self, global_step, rng, writer): method _eval_epoch (line 222) | def _eval_epoch(self, step, rng): method _eval_epoch_metric (line 233) | def _eval_epoch_metric(self, step, rng): method _eval_epoch_vpt (line 270) | def _eval_epoch_vpt(self, step, rng): method _reconstruct_and_align (line 281) | def _reconstruct_and_align(self, rng_key, full_trajectory, prefix, suf... method _initialize_eval (line 341) | def _initialize_eval(self): method _initialize_eval_metric (line 368) | def _initialize_eval_metric(self): method _initialize_eval_vpt (line 394) | def _initialize_eval_vpt(self): method _jax_eval_step_fn (line 424) | def _jax_eval_step_fn(self, params, state, rng_key, batch, step): method _eval_batch_vpt (line 435) | def _eval_batch_vpt(self, params, state, rng_key, batch): method _eval_batch_metric (line 477) | def _eval_batch_metric(self, params, rng, batch, eval_seq_len=200): method _get_gt_and_model_phase_space_for_eval (line 530) | def _get_gt_and_model_phase_space_for_eval(self, params, rng, batch, FILE: physics_inspired_models/metrics.py function calculate_small_latents (line 32) | def calculate_small_latents(dist, threshold=0.5): function compute_scale (line 43) | def compute_scale( function compute_data_domain_stats (line 56) | def compute_data_domain_stats( function compute_vae_stats (line 84) | def compute_vae_stats( function training_statistics (line 102) | def training_statistics( function evaluation_only_statistics (line 130) | def evaluation_only_statistics( function geco_objective (line 210) | def geco_objective( function elbo_objective (line 237) | def elbo_objective(neg_log_p_x, kl, final_beta, beta_delay, step): FILE: physics_inspired_models/models/autoregressive.py class TeacherForcingAutoregressiveModel (line 31) | class TeacherForcingAutoregressiveModel(base.SequenceModel): method __init__ (line 34) | def __init__( method process_inputs_for_encoder (line 89) | def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: method process_latents_for_dynamics (line 92) | def process_latents_for_dynamics(self, z: jnp.ndarray) -> jnp.ndarray: method process_latents_for_decoder (line 95) | def process_latents_for_decoder(self, z: jnp.ndarray) -> jnp.ndarray: method inferred_index (line 99) | def inferred_index(self) -> int: method train_sequence_length (line 103) | def train_sequence_length(self) -> int: method train_data_split (line 106) | def train_data_split( method unroll_without_inputs (line 118) | def unroll_without_inputs( method unroll_latent_dynamics (line 147) | def unroll_latent_dynamics( method _models_core (line 188) | def _models_core( method training_objectives (line 223) | def training_objectives( # pytype: disable=signature-mismatch # jax-... method reconstruct (line 279) | def reconstruct( method gt_state_and_latents (line 303) | def gt_state_and_latents( # pytype: disable=signature-mismatch # jax... method _init_non_model_params_and_state (line 333) | def _init_non_model_params_and_state( method _init_latent_system (line 339) | def _init_latent_system( # pytype: disable=signature-mismatch # jax-... FILE: physics_inspired_models/models/base.py class SequenceModel (line 33) | class SequenceModel(abc.ABC, Generic[T]): method __init__ (line 36) | def __init__( method train_sequence_length (line 141) | def train_sequence_length(self) -> int: method train_data_split (line 146) | def train_data_split( method decode_latents (line 153) | def decode_latents( method apply_latent_transform (line 173) | def apply_latent_transform( method process_inputs_for_encoder (line 186) | def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: method process_latents_for_dynamics (line 190) | def process_latents_for_dynamics(self, z: jnp.ndarray) -> T: method process_latents_for_decoder (line 194) | def process_latents_for_decoder(self, z: T) -> jnp.ndarray: method unroll_latent_dynamics (line 198) | def unroll_latent_dynamics( method reconstruct (line 213) | def reconstruct( method training_objectives (line 224) | def training_objectives( method inferred_index (line 239) | def inferred_index(self): method inferred_right_offset (line 249) | def inferred_right_offset(self): method gt_state_and_latents (line 253) | def gt_state_and_latents( method _init_non_model_params_and_state (line 267) | def _init_non_model_params_and_state( method _init_latent_system (line 275) | def _init_latent_system( method _init (line 284) | def _init( method init (line 346) | def init( FILE: physics_inspired_models/models/common.py function construct_model (line 29) | def construct_model( FILE: physics_inspired_models/models/deterministic_vae.py class DeterministicLatentsGenerativeModel (line 34) | class DeterministicLatentsGenerativeModel(base.SequenceModel[_ArrayOrPha... method __init__ (line 37) | def __init__( method process_inputs_for_encoder (line 148) | def process_inputs_for_encoder(self, x: jnp.ndarray) -> jnp.ndarray: method process_latents_for_dynamics (line 151) | def process_latents_for_dynamics(self, z: jnp.ndarray) -> _ArrayOrPhase: method process_latents_for_decoder (line 156) | def process_latents_for_decoder(self, z: _ArrayOrPhase) -> jnp.ndarray: method inferred_index (line 162) | def inferred_index(self) -> int: method targets_index_offset (line 172) | def targets_index_offset(self) -> int: method targets_length (line 183) | def targets_length(self) -> int: method train_sequence_length (line 189) | def train_sequence_length(self) -> int: method train_data_split (line 209) | def train_data_split( method prior (line 256) | def prior(self) -> distrax.Distribution: method sample_latent_from_prior (line 274) | def sample_latent_from_prior( method sample_trajectories_from_prior (line 286) | def sample_trajectories_from_prior( method verify_unroll_args (line 312) | def verify_unroll_args( method unroll_latent_dynamics (line 330) | def unroll_latent_dynamics( # pytype: disable=signature-mismatch # j... method _models_core (line 354) | def _models_core( method training_objectives (line 396) | def training_objectives( # pytype: disable=signature-mismatch # jax-... method reconstruct (line 498) | def reconstruct( method gt_state_and_latents (line 535) | def gt_state_and_latents( # pytype: disable=signature-mismatch # jax... method _init_non_model_params_and_state (line 589) | def _init_non_model_params_and_state( method _init_latent_system (line 602) | def _init_latent_system( FILE: physics_inspired_models/models/dynamics.py class PhysicsSimulationNetwork (line 33) | class PhysicsSimulationNetwork(hk.Module): method __init__ (line 36) | def __init__( method sum_per_dim_energy (line 239) | def sum_per_dim_energy(self, energy: jnp.ndarray) -> jnp.ndarray: method feature_matrix_vector (line 244) | def feature_matrix_vector(self, m, v): method mass_matrix_mul (line 249) | def mass_matrix_mul( method mass_matrix_inv_mul (line 318) | def mass_matrix_inv_mul( method momentum_from_velocity (line 387) | def momentum_from_velocity( method velocity_from_momentum (line 400) | def velocity_from_momentum( method kinetic_energy_velocity (line 413) | def kinetic_energy_velocity( method kinetic_energy_momentum (line 444) | def kinetic_energy_momentum( method potential_energy_velocity (line 475) | def potential_energy_velocity( method potential_energy_momentum (line 493) | def potential_energy_momentum( method hamiltonian (line 511) | def hamiltonian( method lagrangian (line 523) | def lagrangian( method energy_from_momentum (line 535) | def energy_from_momentum( method energy_from_velocity (line 543) | def energy_from_velocity( method velocity_and_acceleration (line 554) | def velocity_and_acceleration( method simulate (line 580) | def simulate( method __call__ (line 673) | def __call__(self, *args, **kwargs): class OdeNetwork (line 677) | class OdeNetwork(hk.Module): method __init__ (line 680) | def __init__( method simulate (line 707) | def simulate( method __call__ (line 755) | def __call__(self, *args, **kwargs): class DiscreteDynamicsNetwork (line 759) | class DiscreteDynamicsNetwork(hk.Module): method __init__ (line 762) | def __init__( method simulate (line 779) | def simulate( method __call__ (line 838) | def __call__(self, *args, **kwargs): FILE: physics_inspired_models/models/networks.py class DenseNet (line 28) | class DenseNet(hk.Module): method __init__ (line 31) | def __init__( method __call__ (line 52) | def __call__(self, inputs: jnp.ndarray, is_training: bool): class Conv2DNet (line 61) | class Conv2DNet(hk.Module): method __init__ (line 64) | def __init__( method __call__ (line 119) | def __call__(self, inputs: jnp.ndarray, is_training: bool): class SpatialConvEncoder (line 132) | class SpatialConvEncoder(hk.Module): method __init__ (line 135) | def __init__( method spatial_aggregation (line 201) | def spatial_aggregation(self, x: jnp.ndarray) -> jnp.ndarray: method make_distribution (line 214) | def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distri... method __call__ (line 229) | def __call__( class SpatialConvDecoder (line 252) | class SpatialConvDecoder(hk.Module): method __init__ (line 255) | def __init__( method spatial_de_aggregation (line 316) | def spatial_de_aggregation(self, x: jnp.ndarray) -> jnp.ndarray: method add_constant_channels (line 347) | def add_constant_channels(self, inputs: jnp.ndarray) -> jnp.ndarray: method make_distribution (line 365) | def make_distribution(self, net_output: jnp.ndarray) -> distrax.Distri... method __call__ (line 376) | def __call__( function make_flexible_net (line 401) | def make_flexible_net( function make_flexible_recurrent_net (line 452) | def make_flexible_recurrent_net( FILE: physics_inspired_models/utils.py function filter_only_scalar_stats (line 37) | def filter_only_scalar_stats(stats): function to_numpy (line 41) | def to_numpy(obj): function geco_lagrange_product (line 46) | def geco_lagrange_product(lagrange_multiplier, constraint_ema, constrain... function bcast_if (line 76) | def bcast_if(x, t, n): function stack_time_into_channels (line 80) | def stack_time_into_channels( function stack_device_dim_into_batch (line 90) | def stack_device_dim_into_batch(obj): function nearest_neighbour_upsampling (line 94) | def nearest_neighbour_upsampling(x, scale, data_format="NHWC"): function get_activation (line 111) | def get_activation(arg: Union[_Activation, str]) -> _Activation: function merge_first_dims (line 128) | def merge_first_dims(x: jnp.ndarray, num_dims_to_merge: int = 2) -> jnp.... function extract_image (line 132) | def extract_image( function extract_gt_state (line 147) | def extract_gt_state(inputs: Any) -> jnp.ndarray: function reshape_latents_conv_to_flat (line 156) | def reshape_latents_conv_to_flat(conv_latents, axis_n_to_keep=1): function triu_matrix_from_v (line 165) | def triu_matrix_from_v(x, ndim): function flatten_dict (line 175) | def flatten_dict(d, parent_key: str = "", sep: str = "_") -> Dict[str, A... function convert_to_pytype (line 186) | def convert_to_pytype(target, reference): function func_if_not_scalar (line 193) | def func_if_not_scalar(func): class MultiBatchAccumulator (line 206) | class MultiBatchAccumulator(object): method __init__ (line 209) | def __init__(self): method add (line 215) | def add(self, averaged_values, num_samples): method value (line 231) | def value(self): method max (line 234) | def max(self): method min (line 237) | def min(self): method sum (line 240) | def sum(self): function inner_product (line 251) | def inner_product(x: Any, y: Any) -> jnp.ndarray: function wrap_if_pmap (line 263) | def wrap_if_pmap(p_func): FILE: physics_planning_games/board_games/__init__.py function get_environments_by_tag (line 30) | def get_environments_by_tag(tag): function load (line 42) | def load(environment_name, FILE: physics_planning_games/board_games/_internal/arenas.py class Standard (line 28) | class Standard(composer.Arena): method _build (line 31) | def _build(self, name=None): method _build_observables (line 112) | def _build_observables(self): method front_camera (line 116) | def front_camera(self): method front_camera_2 (line 120) | def front_camera_2(self): method top_down_camera (line 124) | def top_down_camera(self): method attach_offset (line 127) | def attach_offset(self, entity, offset, attach_site=None): class ArenaObservables (line 143) | class ArenaObservables(composer.Observables): method front_camera (line 147) | def front_camera(self): method front_camera_2 (line 151) | def front_camera_2(self): method top_down_camera (line 155) | def top_down_camera(self): method top_down_camera_invisible_robot (line 159) | def top_down_camera_invisible_robot(self): FILE: physics_planning_games/board_games/_internal/boards.py function _make_checkerboard (line 34) | def _make_checkerboard(rows, function _make_goboard (line 73) | def _make_goboard(boardsize, class CheckerBoard (line 155) | class CheckerBoard(composer.Entity): method __init__ (line 158) | def __init__(self, *args, **kwargs): method _build (line 162) | def _build(self, rows=3, columns=3, square_halfwidth=0.05): method mjcf_model (line 177) | def mjcf_model(self): method before_substep (line 180) | def before_substep(self, physics, random_state): method validate_finger_touch (line 187) | def validate_finger_touch(self, physics, row, col, hand): method get_contact_pos (line 203) | def get_contact_pos(self, physics, row, col): method get_contact_indices (line 226) | def get_contact_indices(self, physics): method sample_pos_inside_touch_sensor (line 235) | def sample_pos_inside_touch_sensor(self, physics, random_state, row, c... class GoBoard (line 241) | class GoBoard(CheckerBoard): method _build (line 244) | def _build(self, boardsize=7, square_halfwidth=0.05): # pytype: disab... method get_contact_indices (line 263) | def get_contact_indices(self, physics): method validate_finger_touch (line 277) | def validate_finger_touch(self, physics, row, col, hand): method sample_pos_inside_touch_sensor (line 296) | def sample_pos_inside_touch_sensor(self, physics, random_state, row, c... FILE: physics_planning_games/board_games/_internal/observations.py class ObservableSpec (line 22) | class ObservableSpec(collections.namedtuple( class CameraObservableSpec (line 30) | class CameraObservableSpec(collections.namedtuple( class ObservationSettings (line 36) | class ObservationSettings(collections.namedtuple( class ObservableNames (line 43) | class ObservableNames(collections.namedtuple( method __new__ (line 49) | def __new__(cls, proprio=(), ftt=(), prop_pose=(), board_state=(), cam... function make_options (line 104) | def make_options(obs_settings, obs_names): FILE: physics_planning_games/board_games/_internal/pieces.py class Markers (line 36) | class Markers(composer.Entity): method _build (line 39) | def _build(self, method _build_observables (line 78) | def _build_observables(self): method mjcf_model (line 82) | def mjcf_model(self): method markers (line 87) | def markers(self): method initialize_episode (line 96) | def initialize_episode(self, physics, random_state): method _reset (line 101) | def _reset(self, physics): method make_all_invisible (line 111) | def make_all_invisible(self, physics): method make_visible_by_bpos (line 117) | def make_visible_by_bpos(self, physics, player_id, all_bpos): method mark (line 124) | def mark(self, physics, player_id, pos, bpos=None): class MarkersObservables (line 155) | class MarkersObservables(composer.Observables): method position (line 159) | def position(self): FILE: physics_planning_games/board_games/_internal/pieces_test.py class MarkersTest (line 25) | class MarkersTest(absltest.TestCase): method test_position_observable (line 27) | def test_position_observable(self): method test_invalid_player_id (line 45) | def test_invalid_player_id(self): method test_too_many_moves (line 53) | def test_too_many_moves(self): FILE: physics_planning_games/board_games/_internal/registry.py function done_importing_tasks (line 35) | def done_importing_tasks(): FILE: physics_planning_games/board_games/board_games_test.py class GoTest (line 24) | class GoTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 26) | def make_object_under_test(self): class TicTacToeTest (line 30) | class TicTacToeTest(test_utils.EnvironmentTestMixin, absltest.TestCase): method make_object_under_test (line 32) | def make_object_under_test(self): FILE: physics_planning_games/board_games/go.py class Go (line 35) | class Go(jaco_arm_board_game.JacoArmBoardGame): method __init__ (line 38) | def __init__(self, board_size, observation_settings, opponent=None, method name (line 78) | def name(self): method control_timestep (line 82) | def control_timestep(self): method after_substep (line 85) | def after_substep(self, physics, random_state): function go_7x7 (line 152) | def go_7x7(): FILE: physics_planning_games/board_games/go_logic.py function _go_marker_to_int (line 41) | def _go_marker_to_int(go_marker, board_size): function _int_to_go_marker (line 57) | def _int_to_go_marker(move_int, board_size): function _go_marker_to_str (line 77) | def _go_marker_to_str(go_marker): function _str_to_go_marker (line 85) | def _str_to_go_marker(move_str): function _get_gnugo_ref_config (line 110) | def _get_gnugo_ref_config(level=1, binary_path=None): class Stone (line 140) | class Stone(enum.Enum): method __lt__ (line 145) | def __lt__(self, other): function gtp_to_sgf_point (line 150) | def gtp_to_sgf_point(gtp_point, board_size): class Gtp (line 163) | class Gtp(object): method __init__ (line 168) | def __init__(self, checkpoint_file=None): method set_board_size (line 185) | def set_board_size(self, size): method set_komi (line 190) | def set_komi(self, komi): method set_free_handicap (line 194) | def set_free_handicap(self, vertices): method place_free_handicap (line 198) | def place_free_handicap(self, n): method make_move (line 202) | def make_move(self, move, record=True): method set_byo_yomi_time (line 207) | def set_byo_yomi_time(self, t): method num_moves (line 210) | def num_moves(self): method clear_board (line 213) | def clear_board(self): method generate_move (line 218) | def generate_move(self, color): method board (line 226) | def board(self): method quit (line 232) | def quit(self): method final_status (line 235) | def final_status(self, status): method fixed_handicap (line 239) | def fixed_handicap(self, handicap): method undo (line 243) | def undo(self, num_moves): method _record_move (line 249) | def _record_move(self, move, stderr=None): method to_sgf (line 257) | def to_sgf(self): method _format_sgf_move (line 266) | def _format_sgf_move(self, move): method _sgf_escape (line 272) | def _sgf_escape(self, text): method gtp_command (line 276) | def gtp_command(self, command, log=True): class GtpError (line 291) | class GtpError(Exception): method __init__ (line 293) | def __init__(self, response): method __str__ (line 297) | def __str__(self): class GoEngine (line 301) | class GoEngine(Gtp): method __init__ (line 309) | def __init__(self, command='', checkpoint_file=None, extra_flags=None): method gtp_command (line 319) | def gtp_command(self, command, log=True): class GoGameLogic (line 340) | class GoGameLogic(logic_base.OpenSpielBasedLogic): method __init__ (line 343) | def __init__(self, board_size, gnugo_level=1, komi=5.5): method board_size (line 356) | def board_size(self): method get_gtp_player (line 359) | def get_gtp_player(self): method reset (line 362) | def reset(self): method show_board (line 376) | def show_board(self): method get_gtp_reward (line 379) | def get_gtp_reward(self): method get_board_state (line 382) | def get_board_state(self): method set_state_from_history (line 397) | def set_state_from_history(self, move_history): method get_move_history (line 407) | def get_move_history(self): method apply (line 411) | def apply(self, player, action): function gen_move (line 449) | def gen_move(game_logic, player): function gen_random_move (line 460) | def gen_random_move(game_logic, random_state): class GoGTPOpponent (line 471) | class GoGTPOpponent(logic_base.Opponent): method __init__ (line 474) | def __init__(self, board_size, mixture_p=0.0): method reset (line 484) | def reset(self): method policy (line 487) | def policy(self, game_logic, player, random_state): class GoRandomOpponent (line 504) | class GoRandomOpponent(logic_base.Opponent): method __init__ (line 507) | def __init__(self, board_size): method reset (line 510) | def reset(self): method policy (line 514) | def policy(self, game_logic, player, random_state): FILE: physics_planning_games/board_games/go_logic_test.py class GoGameLogicTest (line 23) | class GoGameLogicTest(parameterized.TestCase): method setUp (line 25) | def setUp(self): method test_valid_move_sequence (line 31) | def test_valid_move_sequence(self): method test_pass (line 39) | def test_pass(self): method test_invalid_move_sequence (line 47) | def test_invalid_move_sequence(self): method test_random_opponent_vs_gnugo (line 71) | def test_random_opponent_vs_gnugo(self): method test_go_marker_to_int (line 120) | def test_go_marker_to_int(self, row, col): method test_go_marker_to_str (line 133) | def test_go_marker_to_str(self, row, col): FILE: physics_planning_games/board_games/jaco_arm_board_game.py function _uniform_downward_rotation (line 42) | def _uniform_downward_rotation(): class JacoArmBoardGame (line 49) | class JacoArmBoardGame(composer.Task): method __init__ (line 52) | def __init__(self, observation_settings, opponent, game_logic, board, method root_entity (line 103) | def root_entity(self): method arm (line 107) | def arm(self): method hand (line 111) | def hand(self): method task_observables (line 115) | def task_observables(self): method get_reward (line 118) | def get_reward(self, physics): method should_terminate_episode (line 122) | def should_terminate_episode(self, physics): method initialize_episode (line 125) | def initialize_episode(self, physics, random_state): method before_step (line 130) | def before_step(self, physics, action, random_state): method after_substep (line 134) | def after_substep(self, physics, random_state): FILE: physics_planning_games/board_games/logic_base.py class GameLogic (line 24) | class GameLogic(ABC): method __init__ (line 29) | def __init__(self): method reset (line 33) | def reset(self): method is_game_over (line 37) | def is_game_over(self): method get_reward (line 41) | def get_reward(self): method get_board_state (line 45) | def get_board_state(self): method apply (line 49) | def apply(self, player, action): class OpenSpielBasedLogic (line 61) | class OpenSpielBasedLogic(GameLogic): method is_game_over (line 66) | def is_game_over(self): method get_reward (line 71) | def get_reward(self): method open_spiel_state (line 91) | def open_spiel_state(self): class Opponent (line 96) | class Opponent(ABC): method __init__ (line 100) | def __init__(self): method reset (line 104) | def reset(self): method policy (line 108) | def policy(self, game_logic, random_state): FILE: physics_planning_games/board_games/tic_tac_toe.py class TicTacToe (line 27) | class TicTacToe(jaco_arm_board_game.JacoArmBoardGame): method __init__ (line 30) | def __init__(self, observation_settings, opponent=None, method control_timestep (line 57) | def control_timestep(self): method after_substep (line 60) | def after_substep(self, physics, random_state): function tic_tac_toe_markers_features (line 94) | def tic_tac_toe_markers_features(**unused_kwargs): function tic_tac_toe_mixture_opponent_markers_features (line 99) | def tic_tac_toe_mixture_opponent_markers_features(mixture_p=0.25): function tic_tac_toe_optimal_opponent_markers_features (line 108) | def tic_tac_toe_optimal_opponent_markers_features(**unused_kwargs): FILE: physics_planning_games/board_games/tic_tac_toe_logic.py class TicTacToeGameLogic (line 33) | class TicTacToeGameLogic(logic_base.OpenSpielBasedLogic): method __init__ (line 36) | def __init__(self): method reset (line 39) | def reset(self): method get_board_state (line 56) | def get_board_state(self): method apply (line 71) | def apply(self, player, action): class TicTacToeRandomOpponent (line 94) | class TicTacToeRandomOpponent(logic_base.Opponent): method __init__ (line 97) | def __init__(self): method reset (line 100) | def reset(self): method policy (line 104) | def policy(self, game_logic, random_state): class TicTacToeMixtureOpponent (line 124) | class TicTacToeMixtureOpponent(logic_base.Opponent): method __init__ (line 130) | def __init__(self, mixture_p): method reset (line 143) | def reset(self): method policy (line 146) | def policy(self, game_logic, random_state): class TicTacToeOptimalOpponent (line 153) | class TicTacToeOptimalOpponent(logic_base.Opponent): method __init__ (line 159) | def __init__(self): method reset (line 162) | def reset(self): method policy (line 165) | def policy(self, game_logic, random_state): function numpy_array_to_open_spiel_state (line 170) | def numpy_array_to_open_spiel_state(board_state): function open_spiel_move_to_single_marker_action (line 198) | def open_spiel_move_to_single_marker_action(action): function tic_tac_toe_random_move (line 203) | def tic_tac_toe_random_move(state, random_state): function tic_tac_toe_minimax (line 227) | def tic_tac_toe_minimax(state, random_state): FILE: physics_planning_games/board_games/tic_tac_toe_logic_test.py class TicTacToeGameLogicTest (line 23) | class TicTacToeGameLogicTest(parameterized.TestCase): method setUp (line 25) | def setUp(self): method test_valid_move_sequence (line 31) | def test_valid_move_sequence(self): method test_invalid_move_sequence (line 51) | def test_invalid_move_sequence(self): method test_reward_and_termination (line 102) | def test_reward_and_termination(self, move_sequence, winner_id): method test_random_opponent_vs_optimal (line 117) | def test_random_opponent_vs_optimal(self): method test_minimax_policy (line 179) | def test_minimax_policy(self, move_sequence, optimal_move): FILE: physics_planning_games/explore.py function main (line 43) | def main(argv): FILE: physics_planning_games/mujoban/boxoban.py function boxoban_level_generator (line 31) | def boxoban_level_generator(levels_set="unfiltered", data_split="valid"): class Boxoban (line 38) | class Boxoban(object): method __init__ (line 41) | def __init__(self, method get_data (line 65) | def get_data(self): FILE: physics_planning_games/mujoban/mujoban.py function _round_positions (line 65) | def _round_positions(boxes, walker, last_round_walker): class Mujoban (line 75) | class Mujoban(composer.Task): method __init__ (line 85) | def __init__(self, method name (line 197) | def name(self): method root_entity (line 201) | def root_entity(self): method _regenerate_positions (line 204) | def _regenerate_positions(self): method initialize_episode_mjcf (line 218) | def initialize_episode_mjcf(self, random_state): method initialize_episode (line 291) | def initialize_episode(self, physics, random_state): method before_step (line 320) | def before_step(self, physics, actions, random_state): method _get_object_positions_in_grid (line 329) | def _get_object_positions_in_grid(self, physics): method _update_entity_pixel_layers (line 337) | def _update_entity_pixel_layers(self, physics): method after_step (line 416) | def after_step(self, physics, random_state): method get_reward (line 428) | def get_reward(self, physics): method get_discount (line 439) | def get_discount(self, physics): method should_terminate_episode (line 442) | def should_terminate_episode(self, physics): method get_reward_spec (line 446) | def get_reward_spec(self): method task_observables (line 450) | def task_observables(self): FILE: physics_planning_games/mujoban/mujoban_level.py function single_level_generator (line 55) | def single_level_generator(level=_DEFAULT_LEVEL): function _ascii_to_text_grid_level (line 60) | def _ascii_to_text_grid_level(ascii_level): class MujobanLevel (line 91) | class MujobanLevel(labmaze.BaseMaze): method __init__ (line 94) | def __init__(self, ascii_level_generator=single_level_generator): method regenerate (line 106) | def regenerate(self): method num_boxes (line 119) | def num_boxes(self): method num_targets (line 123) | def num_targets(self): method entity_layer (line 127) | def entity_layer(self): method variations_layer (line 131) | def variations_layer(self): method height (line 135) | def height(self): method width (line 139) | def width(self): FILE: physics_planning_games/mujoban/mujoban_level_test.py class MujobanLevelTest (line 45) | class MujobanLevelTest(absltest.TestCase): method test_ascii_to_text_grid_level (line 47) | def test_ascii_to_text_grid_level(self): FILE: physics_planning_games/mujoban/mujoban_pad.py function _get_activator_box (line 24) | def _get_activator_box(pad_xpos, pad_size, boxes, tolerance=0.0): class MujobanPad (line 42) | class MujobanPad(composer.Entity): method _build (line 45) | def _build(self, rgba=None, pressed_rgba=None, method rgba (line 62) | def rgba(self): method pressed_rgba (line 66) | def pressed_rgba(self): method register_box (line 69) | def register_box(self, box_entity): method site (line 73) | def site(self): method boxes (line 77) | def boxes(self): method activator (line 81) | def activator(self): method mjcf_model (line 85) | def mjcf_model(self): method initialize_episode_mjcf (line 88) | def initialize_episode_mjcf(self, unused_random_state): method initialize_episode (line 91) | def initialize_episode(self, physics, unused_random_state): method _update_activation (line 94) | def _update_activation(self, physics): method before_step (line 113) | def before_step(self, physics, unused_random_state): method after_substep (line 116) | def after_substep(self, physics, unused_random_state): method activated (line 120) | def activated(self): method reset (line 124) | def reset(self, physics): FILE: physics_planning_games/mujoban/mujoban_test.py class MujobanTest (line 33) | class MujobanTest(absltest.TestCase): method test (line 35) | def test(self): FILE: physics_planning_games/mujoban/props.py class Box (line 25) | class Box(props.Primitive): method _build (line 28) | def _build(self, half_lengths=None, mass=None, name='box'): class BoxWithSites (line 36) | class BoxWithSites(Box): method _build (line 39) | def _build(self, half_lengths=None, mass=None, name='box'): method corner_sites (line 61) | def corner_sites(self): FILE: polygen/data_utils.py function random_shift (line 28) | def random_shift(vertices, shift_factor=0.25): function make_vertex_model_dataset (line 44) | def make_vertex_model_dataset(ds, apply_random_shift=False): function make_face_model_dataset (line 69) | def make_face_model_dataset( function read_obj_file (line 107) | def read_obj_file(obj_file): function read_obj (line 144) | def read_obj(obj_path): function write_obj (line 151) | def write_obj(vertices, faces, file_path, transpose=True, scale=1.): function quantize_verts (line 172) | def quantize_verts(verts, n_bits=8): function dequantize_verts (line 182) | def dequantize_verts(verts, n_bits=8, add_noise=False): function face_to_cycles (line 194) | def face_to_cycles(face): function flatten_faces (line 203) | def flatten_faces(faces): function unflatten_faces (line 213) | def unflatten_faces(flat_faces): function center_vertices (line 229) | def center_vertices(vertices): function normalize_vertices_scale (line 237) | def normalize_vertices_scale(vertices): function quantize_process_mesh (line 246) | def quantize_process_mesh(vertices, faces, tris=None, quantization_bits=8): function process_mesh (line 302) | def process_mesh(vertices, faces, quantization_bits=8): function load_process_mesh (line 330) | def load_process_mesh(mesh_obj_path, quantization_bits=8): function plot_meshes (line 337) | def plot_meshes(mesh_list, FILE: polygen/model_test.py function _get_vertex_model_batch (line 40) | def _get_vertex_model_batch(): function _get_face_model_batch (line 49) | def _get_face_model_batch(): class VertexModelTest (line 61) | class VertexModelTest(tf.test.TestCase): method setUp (line 63) | def setUp(self): method test_model_runs (line 74) | def test_model_runs(self): method test_sample_outputs_range (line 86) | def test_sample_outputs_range(self): class FaceModelTest (line 101) | class FaceModelTest(tf.test.TestCase): method setUp (line 103) | def setUp(self): method test_model_runs (line 115) | def test_model_runs(self): method test_sample_outputs_range (line 133) | def test_sample_outputs_range(self): FILE: polygen/modules.py function dequantize_verts (line 27) | def dequantize_verts(verts, n_bits, add_noise=False): function quantize_verts (line 39) | def quantize_verts(verts, n_bits): function top_k_logits (line 49) | def top_k_logits(logits, k): function top_p_logits (line 61) | def top_p_logits(logits, p): function multihead_self_attention_memory_efficient (line 87) | def multihead_self_attention_memory_efficient(x, class TransformerEncoder (line 225) | class TransformerEncoder(snt.AbstractModule): method __init__ (line 237) | def __init__(self, method _build (line 272) | def _build(self, inputs, is_training=False): class TransformerDecoder (line 349) | class TransformerDecoder(snt.AbstractModule): method __init__ (line 361) | def __init__(self, method _build (line 397) | def _build(self, method create_init_cache (line 526) | def create_init_cache(self, batch_size): function conv_residual_block (line 549) | def conv_residual_block(inputs, class ResNet (line 620) | class ResNet(snt.AbstractModule): method __init__ (line 623) | def __init__(self, method _build (line 648) | def _build(self, inputs, is_training=False): class VertexModel (line 704) | class VertexModel(snt.AbstractModule): method __init__ (line 716) | def __init__(self, method _embed_class_label (line 749) | def _embed_class_label(self, labels): method _prepare_context (line 760) | def _prepare_context(self, context, is_training=False): method _embed_inputs (line 769) | def _embed_inputs(self, vertices, global_context_embedding=None): method _project_to_logits (line 822) | def _project_to_logits(self, inputs): method _create_dist (line 832) | def _create_dist(self, method _build (line 862) | def _build(self, batch, is_training=False): method sample (line 883) | def sample(self, class ImageToVertexModel (line 1026) | class ImageToVertexModel(VertexModel): method __init__ (line 1039) | def __init__(self, method _prepare_context (line 1069) | def _prepare_context(self, context, is_training=False): class VoxelToVertexModel (line 1094) | class VoxelToVertexModel(VertexModel): method __init__ (line 1107) | def __init__(self, method _prepare_context (line 1137) | def _prepare_context(self, context, is_training=False): class FaceModel (line 1170) | class FaceModel(snt.AbstractModule): method __init__ (line 1186) | def __init__(self, method _embed_class_label (line 1226) | def _embed_class_label(self, labels): method _prepare_context (line 1237) | def _prepare_context(self, context, is_training=False): method _embed_vertices (line 1257) | def _embed_vertices(self, vertices, vertices_mask, is_training=False): method _embed_inputs (line 1289) | def _embed_inputs(self, faces_long, vertex_embeddings, method _project_to_pointers (line 1320) | def _project_to_pointers(self, inputs): method _create_dist (line 1331) | def _create_dist(self, method _build (line 1372) | def _build(self, batch, is_training=False): method sample (line 1399) | def sample(self, FILE: rapid_task_solving/memory_planning_game.py class MemoryPlanningGame (line 24) | class MemoryPlanningGame(dm_env.Environment): method __init__ (line 37) | def __init__(self, method _one_hot (line 67) | def _one_hot(self, node): method step (line 72) | def step(self, action): method _observation (line 102) | def _observation(self): method observation_spec (line 108) | def observation_spec(self): method action_spec (line 116) | def action_spec(self): method take_random_action (line 119) | def take_random_action(self): method reset (line 122) | def reset(self): method _respawn (line 134) | def _respawn(self): method _set_new_goal (line 138) | def _set_new_goal(self): method position (line 148) | def position(self): method goal (line 152) | def goal(self): method previous_action (line 156) | def previous_action(self): method episode_reward (line 160) | def episode_reward(self): method draw_maze (line 163) | def draw_maze(self, ax=None): FILE: rapid_task_solving/one_shot_streetlearn.py function deg_to_rad (line 23) | def deg_to_rad(x): function rad_to_deg (line 28) | def rad_to_deg(x): class OneShotStreetLearn (line 33) | class OneShotStreetLearn(dm_env.Environment): method __init__ (line 44) | def __init__(self, dataset_path, max_episode_steps, num_junctions=8, method reset (line 57) | def reset(self): method _current_edge (line 71) | def _current_edge(self): method _set_new_goal (line 74) | def _set_new_goal(self): method _one_hot (line 81) | def _one_hot(self, edge): method _observation (line 86) | def _observation(self): method observation_spec (line 92) | def observation_spec(self): method action_spec (line 100) | def action_spec(self): method step (line 103) | def step(self, action): method randomize_observations (line 132) | def randomize_observations(self, subgraph): method _calculate_bearing (line 138) | def _calculate_bearing(self, node, neighbor): method _neighbors_bearings (line 150) | def _neighbors_bearings(self, subgraph, node): method _sort_neighbors (line 158) | def _sort_neighbors(self, node, neighbour): method _move (line 168) | def _move(self, action): method _all_next_junctions (line 185) | def _all_next_junctions(self, subgraph, node): method _get_next_junction (line 191) | def _get_next_junction(self, subgraph, initial_node, next_node): method get_random_subgraph (line 200) | def get_random_subgraph(self): method draw_subgraph (line 224) | def draw_subgraph(self, ax=None): FILE: rl_unplugged/atari.py function _decode_frames (line 108) | def _decode_frames(pngs: tf.Tensor): function _make_reverb_sample (line 124) | def _make_reverb_sample(o_t: tf.Tensor, function _tf_example_to_reverb_sample (line 155) | def _tf_example_to_reverb_sample(tf_example: tf.train.Example function dataset (line 188) | def dataset(path: str, class AtariDopamineWrapper (line 207) | class AtariDopamineWrapper(dm_env.Environment): method __init__ (line 210) | def __init__(self, env, max_episode_steps=108000): method reset (line 216) | def reset(self): method step (line 222) | def step(self, action): method observation_spec (line 239) | def observation_spec(self): method action_spec (line 243) | def action_spec(self): function environment (line 247) | def environment(game: str) -> dm_env.Environment: FILE: rl_unplugged/atari_example.py function main (line 36) | def main(_): FILE: rl_unplugged/dm_control_suite.py function _build_rodent_escape_env (line 51) | def _build_rodent_escape_env(): function _build_rodent_maze_env (line 72) | def _build_rodent_maze_env(): function _build_rodent_corridor_gaps (line 115) | def _build_rodent_corridor_gaps(): function _build_rodent_two_touch_env (line 148) | def _build_rodent_two_touch_env(): function _build_humanoid_walls_env (line 179) | def _build_humanoid_walls_env(): function _build_humanoid_corridor_env (line 213) | def _build_humanoid_corridor_env(): function _build_humanoid_corridor_gaps (line 236) | def _build_humanoid_corridor_gaps(): class MujocoActionNormalizer (line 264) | class MujocoActionNormalizer(wrappers.EnvironmentWrapper): method __init__ (line 272) | def __init__(self, environment, rescale='clip'): method step (line 276) | def step(self, action): class NormilizeActionSpecWrapper (line 287) | class NormilizeActionSpecWrapper(wrappers.EnvironmentWrapper): method __init__ (line 290) | def __init__(self, environment): method _from_normal_actions (line 308) | def _from_normal_actions(self, actions): method step (line 313) | def step(self, action): method action_spec (line 317) | def action_spec(self): class FilterObservationsWrapper (line 321) | class FilterObservationsWrapper(wrappers.EnvironmentWrapper): method __init__ (line 324) | def __init__(self, environment, observations_to_keep): method _filter_observation (line 331) | def _filter_observation(self, timestep): method step (line 336) | def step(self, action): method reset (line 339) | def reset(self): method observation_spec (line 342) | def observation_spec(self): class ControlSuite (line 346) | class ControlSuite: method __init__ (line 349) | def __init__(self, task_name='humanoid_run'): method shapes (line 492) | def shapes(self): method data_path (line 496) | def data_path(self): method uint8_features (line 500) | def uint8_features(self): method environment (line 504) | def environment(self): class CmuThirdParty (line 519) | class CmuThirdParty: method __init__ (line 522) | def __init__(self, task_name='humanoid_walls'): method get_pixel_keys (line 570) | def get_pixel_keys(): method uint8_features (line 574) | def uint8_features(self): method shapes (line 578) | def shapes(self): method data_path (line 582) | def data_path(self): method environment (line 586) | def environment(self): class Rodent (line 608) | class Rodent: method __init__ (line 611) | def __init__(self, task_name='rodent_gaps'): method get_pixel_keys (line 654) | def get_pixel_keys(): method shapes (line 658) | def shapes(self): method uint8_features (line 662) | def uint8_features(self): method data_path (line 666) | def data_path(self): method environment (line 670) | def environment(self): function _parse_seq_tf_example (line 693) | def _parse_seq_tf_example(example, uint8_features, shapes): function _build_sequence_example (line 729) | def _build_sequence_example(sequences): function _build_sarsa_example (line 748) | def _build_sarsa_example(sequences): function _padded_batch (line 767) | def _padded_batch(example_ds, batch_size, shapes, drop_remainder=False): function dataset (line 785) | def dataset(root_path: str, FILE: rl_unplugged/dm_control_suite_example.py function main (line 42) | def main(_): FILE: rl_unplugged/networks.py function instance_norm_and_elu (line 24) | def instance_norm_and_elu(x): class ControlNetwork (line 32) | class ControlNetwork(snt.Module): method __init__ (line 36) | def __init__(self, method __call__ (line 55) | def __call__(self, inputs, action: tf.Tensor = None, task=None): FILE: rl_unplugged/rwrl.py function _decombine_key (line 57) | def _decombine_key(k: str, delimiter: str = DELIMITER) -> Sequence[str]: function tf_example_to_feature_description (line 61) | def tf_example_to_feature_description(example, function tree_deflatten_with_delimiter (line 80) | def tree_deflatten_with_delimiter( function get_slice_of_nested (line 104) | def get_slice_of_nested(nested: Dict[str, Any], start: int, function repeat_last_and_append_to_nested (line 109) | def repeat_last_and_append_to_nested(nested: Dict[str, Any]) -> Dict[str... function tf_example_to_reverb_sample (line 114) | def tf_example_to_reverb_sample(example, function dataset (line 135) | def dataset(path: str, function environment (line 179) | def environment( FILE: rl_unplugged/rwrl_example.py function main (line 36) | def main(_): FILE: scratchgan/discriminator_nets.py class LSTMEmbedDiscNet (line 24) | class LSTMEmbedDiscNet(snt.AbstractModule): method __init__ (line 27) | def __init__(self, method _build (line 49) | def _build(self, sequence, sequence_length, is_training=True): FILE: scratchgan/eval_metrics.py function fid (line 24) | def fid(generated_sentences, real_sentences): FILE: scratchgan/experiment.py function main (line 73) | def main(_): function train (line 101) | def train(config): function evaluate_pair (line 302) | def evaluate_pair(config, batch_size, checkpoint_path, data_dir, dataset, FILE: scratchgan/generators.py class LSTMGen (line 26) | class LSTMGen(snt.AbstractModule): method __init__ (line 32) | def __init__(self, method _build (line 60) | def _build(self, is_training=True, temperature=1.0): FILE: scratchgan/losses.py function sequential_cross_entropy_loss (line 23) | def sequential_cross_entropy_loss(logits, expected): function reinforce_loss (line 45) | def reinforce_loss(disc_logits, gen_logprobs, gamma, decay): FILE: scratchgan/reader.py function tokenize (line 44) | def tokenize(sentence): function _build_vocab (line 49) | def _build_vocab(json_data): function string_sequence_to_sequence (line 76) | def string_sequence_to_sequence(string_sequence, word_to_id): function _integerize (line 86) | def _integerize(json_data, word_to_id, dataset): function get_raw_data (line 102) | def get_raw_data(data_path, dataset, truncate_vocab=20000): function iterator (line 147) | def iterator(raw_data, batch_size, random=False): FILE: scratchgan/utils.py function _get_embedding_initializer (line 31) | def _get_embedding_initializer(vocab_file, embedding_source, vocab_size): function append_position_signal (line 61) | def append_position_signal(embeddings, position_dim=8): function get_position_signal (line 74) | def get_position_signal(sequence_length, position_dim=8): function get_mask_by_length (line 107) | def get_mask_by_length(lengths, max_length): function get_mask_past_symbol (line 127) | def get_mask_past_symbol(reference, symbol, optimize_for_tpu=False): function get_first_occurrence_indices (line 144) | def get_first_occurrence_indices(reference, symbol, optimize_for_tpu=Fal... function sequence_to_sentence (line 204) | def sequence_to_sentence(sequence, id_to_word): function batch_sequences_to_sentences (line 215) | def batch_sequences_to_sentences(sequences, id_to_word): function write_eval_results (line 219) | def write_eval_results(checkpoint_dir, all_gen_sentences, checkpoint_name, function maybe_pick_models_to_evaluate (line 240) | def maybe_pick_models_to_evaluate(checkpoint_dir): function get_embedding_path (line 265) | def get_embedding_path(data_dir, dataset): function make_partially_trainable_embeddings (line 270) | def make_partially_trainable_embeddings(vocab_file, embedding_source, FILE: side_effects_penalties/agent.py class EpsilonGreedyPolicy (line 27) | class EpsilonGreedyPolicy(object): method __init__ (line 30) | def __init__(self, value_function, actions): method get_action (line 46) | def get_action(self, epsilon, state): class QLearning (line 68) | class QLearning(object): method __init__ (line 71) | def __init__(self, actions, alpha=0.1, epsilon=0.1, q_initialisation=0.0, method begin_episode (line 97) | def begin_episode(self): method _timestep_to_state (line 102) | def _timestep_to_state(self, timestep): method step (line 105) | def step(self, timestep): method _calculate_reward (line 120) | def _calculate_reward(self, timestep, unused_state): method _update (line 125) | def _update(self, timestep, state): method end_episode (line 144) | def end_episode(self, timestep): method value_function (line 150) | def value_function(self): FILE: side_effects_penalties/agent_with_penalties.py class QLearningSE (line 26) | class QLearningSE(agent.QLearning): method __init__ (line 29) | def __init__( method begin_episode (line 111) | def begin_episode(self): method _calculate_reward (line 116) | def _calculate_reward(self, timestep, state): FILE: side_effects_penalties/file_loading.py function filename (line 25) | def filename(env_name, noops, dev_measure, dev_fun, baseline, beta, function load_files (line 40) | def load_files(baseline, dev_measure, dev_fun, value_discount, beta, env... FILE: side_effects_penalties/results_summary.py function beta_choice (line 63) | def beta_choice(baseline, dev_measure, dev_fun, value_discount, env_name, function penalty_label (line 85) | def penalty_label(dev_measure, dev_fun, value_discount): function make_summary_data_frame (line 99) | def make_summary_data_frame( function main (line 169) | def main(unused_argv): FILE: side_effects_penalties/run_experiment.py function run_experiment (line 74) | def run_experiment( function _smooth (line 122) | def _smooth(values, window=100): function add_smoothed_data (line 126) | def add_smoothed_data(df, groupby='seed', window=100): function main (line 134) | def main(unused_argv): FILE: side_effects_penalties/side_effects_penalty_test.py class SideEffectsTestCase (line 33) | class SideEffectsTestCase(parameterized.TestCase): method _timestep_to_state (line 35) | def _timestep_to_state(self, timestep): method _env_to_action_range (line 38) | def _env_to_action_range(self, env): class BaselineTestCase (line 44) | class BaselineTestCase(SideEffectsTestCase): method _create_baseline (line 46) | def _create_baseline(self, env_name): method _test_trajectory (line 54) | def _test_trajectory(self, actions, key): class StartBaselineTest (line 74) | class StartBaselineTest(BaselineTestCase): method testInit (line 77) | def testInit(self, env_name): method testTenNoops (line 82) | def testTenNoops(self, env_name): class InactionBaselineTest (line 87) | class InactionBaselineTest(BaselineTestCase): method testStaticEnvOneAction (line 94) | def testStaticEnvOneAction(self, action): method testStaticEnvRandomActions (line 98) | def testStaticEnvRandomActions(self): method testInactionPolicy (line 106) | def testInactionPolicy(self, env_name): class StepwiseBaselineTest (line 113) | class StepwiseBaselineTest(BaselineTestCase): method testStaticEnvRandomActions (line 115) | def testStaticEnvRandomActions(self): method testInactionPolicy (line 123) | def testInactionPolicy(self, env_name): method testInactionRollout (line 130) | def testInactionRollout(self, env_name): method testStaticRollouts (line 147) | def testStaticRollouts(self): method testConveyorRollouts (line 170) | def testConveyorRollouts(self, which_rollout, env_name): class NoDeviationTest (line 192) | class NoDeviationTest(SideEffectsTestCase): method _random_initial_transition (line 194) | def _random_initial_transition(self): method testNoDeviation (line 204) | def testNoDeviation(self): method testNoDeviationUpdate (line 209) | def testNoDeviationUpdate(self): class UnreachabilityTest (line 216) | class UnreachabilityTest(SideEffectsTestCase): method testUnreachabilityCycle (line 219) | def testUnreachabilityCycle(self, gamma): FILE: sketchy/dataset_example.py function main (line 28) | def main(argv): FILE: sketchy/metadata_schema.py class Episode (line 49) | class Episode(Base): class Tag (line 82) | class Tag(Base): class RewardSequence (line 95) | class RewardSequence(Base): class ArchiveFile (line 121) | class ArchiveFile(Base): FILE: sketchy/reward_example.py function main (line 30) | def main(argv): FILE: sketchy/sketchy.py function load_frames (line 20) | def load_frames(filenames, num_parallel_reads=1, num_map_threads=None): function _parse_example (line 77) | def _parse_example(example): function _decode_images (line 81) | def _decode_images(record): FILE: synthetic_returns/synthetic_returns.py class EpisodicMemory (line 27) | class EpisodicMemory(hk.RNNCore): method __init__ (line 30) | def __init__(self, memory_size, capacity, name="episodic_memory"): method __call__ (line 43) | def __call__(self, inputs, prev_state): method initial_state (line 68) | def initial_state(self, batch_size): class SyntheticReturnsCoreWrapper (line 87) | class SyntheticReturnsCoreWrapper(hk.RNNCore): method __init__ (line 90) | def __init__(self, core, memory_size, capacity, hidden_layers, alpha, ... method initial_state (line 133) | def initial_state(self, batch_size): method __call__ (line 139) | def __call__(self, inputs, prev_state): FILE: tandem_dqn/gym_atari.py function _game_id (line 36) | def _game_id(game, sticky_actions): function _register_atari_environments (line 40) | def _register_atari_environments(): class GymAtari (line 72) | class GymAtari(dm_env.Environment): method __init__ (line 75) | def __init__(self, game, sticky_actions, seed): method reset (line 80) | def reset(self) -> dm_env.TimeStep: method step (line 88) | def step(self, action: np.int32) -> dm_env.TimeStep: method observation_spec (line 119) | def observation_spec(self) -> Tuple[specs.Array, specs.Array]: method action_spec (line 124) | def action_spec(self) -> specs.DiscreteArray: method close (line 129) | def close(self): class RandomNoopsEnvironmentWrapper (line 133) | class RandomNoopsEnvironmentWrapper(dm_env.Environment): method __init__ (line 136) | def __init__(self, method reset (line 151) | def reset(self): method step (line 169) | def step(self, action): method _apply_random_noops (line 188) | def _apply_random_noops(self, initial_timestep): method observation_spec (line 204) | def observation_spec(self): method action_spec (line 207) | def action_spec(self): method reward_spec (line 210) | def reward_spec(self): method discount_spec (line 213) | def discount_spec(self): method close (line 216) | def close(self): FILE: tandem_dqn/losses.py function _mc_learning (line 34) | def _mc_learning( function _qr_loss (line 48) | def _qr_loss(q_tm1, q_t, q_target_t, transitions, rng_key): function _sarsa_loss (line 68) | def _sarsa_loss(q_tm1, q_t, transitions, rng_key): function _mc_loss (line 86) | def _mc_loss(q_tm1, transitions, rng_key): function _double_q_loss (line 95) | def _double_q_loss(q_tm1, q_t, q_target_t, transitions, rng_key): function _q_regression_loss (line 113) | def _q_regression_loss(q_tm1, q_tm1_target): function make_loss_fn (line 120) | def make_loss_fn(loss_type: str, active: bool) -> Callable[..., Any]: FILE: tandem_dqn/networks.py class QNetworkOutputs (line 30) | class QNetworkOutputs(typing.NamedTuple): class QRNetworkOutputs (line 34) | class QRNetworkOutputs(typing.NamedTuple): function _dqn_default_initializer (line 42) | def _dqn_default_initializer( function make_quantiles (line 65) | def make_quantiles(): function conv (line 70) | def conv( function linear (line 95) | def linear(num_outputs: int, with_bias=True, name=None) -> NetworkFn: function linear_with_shared_bias (line 112) | def linear_with_shared_bias(num_outputs: int, name=None) -> NetworkFn: function dqn_torso (line 128) | def dqn_torso() -> NetworkFn: function dqn_value_head (line 154) | def dqn_value_head(num_actions: int, shared_bias: bool = False) -> Netwo... function qr_atari_network (line 171) | def qr_atari_network(num_actions: int, quantiles: jnp.ndarray) -> Networ... function double_dqn_atari_network (line 192) | def double_dqn_atari_network(num_actions: int) -> NetworkFn: function make_network (line 206) | def make_network(network_type: str, num_actions: int) -> Network: FILE: tandem_dqn/processors.py function reset (line 43) | def reset(processor: Processor[[Any], Any]) -> None: function trailing_zero_pad (line 52) | def trailing_zero_pad( function none_to_zero_pad (line 66) | def none_to_zero_pad(values: List[Optional[NamedTuple]]) -> List[NamedTu... function named_tuple_sequence_stack (line 79) | def named_tuple_sequence_stack(values: Sequence[NamedTuple]) -> NamedTuple: class Deque (line 88) | class Deque: method __init__ (line 91) | def __init__(self, max_length: int, initial_values=None): method reset (line 95) | def reset(self) -> None: method __call__ (line 99) | def __call__(self, value: Any) -> collections.deque: class FixedPaddedBuffer (line 104) | class FixedPaddedBuffer: method __init__ (line 124) | def __init__(self, length: int, initial_index: int): method reset (line 131) | def reset(self) -> None: method __call__ (line 135) | def __call__(self, value: Any) -> Sequence[Any]: class ConditionallySubsample (line 145) | class ConditionallySubsample: method __init__ (line 148) | def __init__(self, condition: Processor[[Any], bool]): method reset (line 151) | def reset(self) -> None: method __call__ (line 154) | def __call__(self, value: Any) -> Optional[Any]: class TimestepBufferCondition (line 158) | class TimestepBufferCondition: method __init__ (line 169) | def __init__(self, period: int): method reset (line 174) | def reset(self): method __call__ (line 178) | def __call__(self, timesteps: Iterable[dm_env.TimeStep]) -> bool: class ApplyToNamedTupleField (line 215) | class ApplyToNamedTupleField: method __init__ (line 218) | def __init__(self, field: Text, *processors: Processor[[Any], Any]): method reset (line 222) | def reset(self) -> None: method __call__ (line 226) | def __call__(self, value: NamedTuple) -> NamedTuple: class Maybe (line 234) | class Maybe: method __init__ (line 237) | def __init__(self, processor: Processor[[Any], Any]): method reset (line 240) | def reset(self) -> None: method __call__ (line 243) | def __call__(self, value: Optional[Any]) -> Optional[Any]: class Sequential (line 250) | class Sequential: method __init__ (line 253) | def __init__(self, *processors: Processor[[Any], Any]): method reset (line 256) | def reset(self) -> None: method __call__ (line 260) | def __call__(self, value: Any) -> Any: class ZeroDiscountOnLifeLoss (line 266) | class ZeroDiscountOnLifeLoss: method __init__ (line 273) | def __init__(self): method reset (line 276) | def reset(self) -> None: method __call__ (line 279) | def __call__(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: function reduce_step_type (line 288) | def reduce_step_type(step_types: Sequence[StepType], function aggregate_rewards (line 313) | def aggregate_rewards(rewards: Sequence[Optional[float]], function aggregate_discounts (line 328) | def aggregate_discounts(discounts: Sequence[Optional[float]], function rgb2y (line 350) | def rgb2y(array: np.ndarray) -> np.ndarray: function resize (line 358) | def resize(shape: Tuple[int, ...]) -> Processor[[np.ndarray], np.ndarray]: function select_rgb_observation (line 372) | def select_rgb_observation(timestep: dm_env.TimeStep) -> dm_env.TimeStep: function apply_additional_discount (line 377) | def apply_additional_discount( function clip_reward (line 383) | def clip_reward(bound: float) -> Processor[[Optional[float]], Optional[f... function show (line 392) | def show(prefix: Text) -> Processor[[Any], Any]: function atari (line 402) | def atari( class AtariEnvironmentWrapper (line 491) | class AtariEnvironmentWrapper(dm_env.Environment): method __init__ (line 500) | def __init__( method reset (line 541) | def reset(self) -> dm_env.TimeStep: method step (line 550) | def step(self, action: int) -> dm_env.TimeStep: method action_spec (line 567) | def action_spec(self) -> specs.DiscreteArray: method observation_spec (line 570) | def observation_spec(self) -> specs.Array: class AtariSimpleActionEnvironmentWrapper (line 577) | class AtariSimpleActionEnvironmentWrapper(dm_env.Environment): method __init__ (line 583) | def __init__(self, environment: dm_env.Environment): method reset (line 590) | def reset(self) -> dm_env.TimeStep: method step (line 593) | def step(self, action: int) -> dm_env.TimeStep: method action_spec (line 596) | def action_spec(self) -> specs.DiscreteArray: method observation_spec (line 603) | def observation_spec(self) -> specs.Array: