SYMBOL INDEX (698 symbols across 81 files) FILE: rigl/cifar_resnet/data_helper.py function pad_input (line 29) | def pad_input(x, crop_dim=4): function preprocess_train (line 47) | def preprocess_train(x, width, height): function input_fn (line 64) | def input_fn(params): FILE: rigl/cifar_resnet/data_helper_test.py class DataHelperTest (line 39) | class DataHelperTest(tf.test.TestCase, parameterized.TestCase): method get_next (line 41) | def get_next(self): method testInputPipeline (line 55) | def testInputPipeline(self): method testTrainingStep (line 78) | def testTrainingStep(self, training_method): FILE: rigl/cifar_resnet/resnet_model.py class WideResNetModel (line 33) | class WideResNetModel(object): method __init__ (line 36) | def __init__(self, method build (line 70) | def build(self, inputs, depth, width, num_classes, name=None): method _batch_norm (line 127) | def _batch_norm(self, net, name=None): method _dense (line 150) | def _dense(self, net, num_units, name=None, sparsity_technique='baseli... method _conv (line 158) | def _conv(self, method _residual_block (line 183) | def _residual_block(self, net, name, output_size, subsample, blocks): FILE: rigl/cifar_resnet/resnet_train_eval.py function create_eval_metrics (line 141) | def create_eval_metrics(labels, logits): function train_fn (line 171) | def train_fn(training_method, global_step, total_loss, train_dir, accuracy, function build_model (line 299) | def build_model(mode, function wide_resnet_w_pruning (line 370) | def wide_resnet_w_pruning(features, labels, mode, params): function main (line 474) | def main(argv): FILE: rigl/experimental/jax/datasets/cifar10.py class CIFAR10Dataset (line 27) | class CIFAR10Dataset(dataset_base.ImageDataset): method __init__ (line 39) | def __init__(self, method preprocess (line 66) | def preprocess( FILE: rigl/experimental/jax/datasets/cifar10_test.py class CIFAR10DatasetTest (line 23) | class CIFAR10DatasetTest(absltest.TestCase): method setUp (line 26) | def setUp(self): method test_create_dataset (line 38) | def test_create_dataset(self): method test_train_image_dims_content (line 42) | def test_train_image_dims_content(self): method test_test_image_dims_content (line 68) | def test_test_image_dims_content(self): method test_train_data_length (line 94) | def test_train_data_length(self): method test_test_data_length (line 102) | def test_test_data_length(self): method test_dataset_nonevenly_divisible_batch_size (line 110) | def test_dataset_nonevenly_divisible_batch_size(self): FILE: rigl/experimental/jax/datasets/dataset_base.py class Dataset (line 29) | class Dataset(metaclass=abc.ABCMeta): method __init__ (line 47) | def __init__(self, method _dataset_dir (line 95) | def _dataset_dir(self): method get_train (line 99) | def get_train(self): method get_train_len (line 103) | def get_train_len(self): method get_test (line 107) | def get_test(self): method get_test_len (line 111) | def get_test_len(self): method preprocess (line 115) | def preprocess( method augment (line 130) | def augment( class ImageDataset (line 147) | class ImageDataset(Dataset): method preprocess (line 152) | def preprocess( FILE: rigl/experimental/jax/datasets/dataset_base_test.py class DummyDataset (line 22) | class DummyDataset(dataset_base.ImageDataset): method __init__ (line 30) | def __init__(self, class DummyDatasetTest (line 50) | class DummyDatasetTest(absltest.TestCase): method setUp (line 53) | def setUp(self): method test_create_dataset (line 64) | def test_create_dataset(self): method test_train_image_dims_content (line 68) | def test_train_image_dims_content(self): method test_test_image_dims_content (line 86) | def test_test_image_dims_content(self): method test_train_data_length (line 104) | def test_train_data_length(self): method test_test_data_length (line 112) | def test_test_data_length(self): FILE: rigl/experimental/jax/datasets/dataset_factory.py function create_dataset (line 38) | def create_dataset(name, *args, **kwargs): FILE: rigl/experimental/jax/datasets/dataset_factory_test.py class DatasetCommonTest (line 24) | class DatasetCommonTest(parameterized.TestCase): method setUp (line 26) | def setUp(self): method _create_dataset (line 32) | def _create_dataset(self, dataset_name): method test_dataset_supported (line 40) | def test_dataset_supported(self): method test_dataset_train_iterators (line 48) | def test_dataset_train_iterators(self, dataset_name): method test_dataset_test_iterators (line 76) | def test_dataset_test_iterators(self, dataset_name): method test_dataset_unsupported (line 103) | def test_dataset_unsupported(self): FILE: rigl/experimental/jax/datasets/mnist.py class MNISTDataset (line 27) | class MNISTDataset(dataset_base.ImageDataset): method __init__ (line 35) | def __init__(self, method preprocess (line 54) | def preprocess( FILE: rigl/experimental/jax/datasets/mnist_test.py class MNISTDatasetTest (line 23) | class MNISTDatasetTest(absltest.TestCase): method setUp (line 26) | def setUp(self): method test_create_dataset (line 38) | def test_create_dataset(self): method test_train_image_dims_content (line 42) | def test_train_image_dims_content(self): method test_test_image_dims_content (line 67) | def test_test_image_dims_content(self): method test_train_data_length (line 93) | def test_train_data_length(self): method test_test_data_length (line 101) | def test_test_data_length(self): FILE: rigl/experimental/jax/fixed_param.py function main (line 195) | def main(argv: List[str]): FILE: rigl/experimental/jax/fixed_param_test.py class FixedParamTest (line 27) | class FixedParamTest(absltest.TestCase): method test_run (line 29) | def test_run(self): FILE: rigl/experimental/jax/models/cifar10_cnn.py class CIFAR10CNN (line 32) | class CIFAR10CNN(flax.deprecated.nn.Module): method apply (line 35) | def apply(self, FILE: rigl/experimental/jax/models/cifar10_cnn_test.py class CIFAR10CNNTest (line 25) | class CIFAR10CNNTest(absltest.TestCase): method setUp (line 28) | def setUp(self): method test_output_shapes (line 36) | def test_output_shapes(self): method test_invalid_spatial_dimensions (line 48) | def test_invalid_spatial_dimensions(self): method test_invalid_masks_depth (line 56) | def test_invalid_masks_depth(self): FILE: rigl/experimental/jax/models/mnist_cnn.py class MNISTCNN (line 32) | class MNISTCNN(flax.deprecated.nn.Module): method apply (line 35) | def apply(self, FILE: rigl/experimental/jax/models/mnist_cnn_test.py class MNISTCNNTest (line 25) | class MNISTCNNTest(absltest.TestCase): method setUp (line 28) | def setUp(self): method test_output_shapes (line 36) | def test_output_shapes(self): method test_invalid_depth (line 48) | def test_invalid_depth(self): FILE: rigl/experimental/jax/models/mnist_fc.py function feature_dim_for_param (line 32) | def feature_dim_for_param(input_len, class MNISTFC (line 81) | class MNISTFC(flax.deprecated.nn.Module): method apply (line 84) | def apply(self, FILE: rigl/experimental/jax/models/mnist_fc_test.py class MNISTFCTest (line 31) | class MNISTFCTest(parameterized.TestCase): method setUp (line 34) | def setUp(self): method test_output_shapes (line 44) | def test_output_shapes(self): method test_invalid_masks_depth (line 56) | def test_invalid_masks_depth(self): method _create_model (line 73) | def _create_model(self, features): method test_feature_dim_for_param_depth (line 83) | def test_feature_dim_for_param_depth(self, depth): FILE: rigl/experimental/jax/models/model_factory.py function create_model (line 37) | def create_model( function update_model (line 66) | def update_model(model, FILE: rigl/experimental/jax/models/model_factory_test.py class ModelCommonTest (line 26) | class ModelCommonTest(parameterized.TestCase): method setUp (line 29) | def setUp(self): method _create_model (line 35) | def _create_model(self, model_name): method test_model_supported (line 42) | def test_model_supported(self, model_name): method test_model_unsupported (line 52) | def test_model_unsupported(self): FILE: rigl/experimental/jax/prune.py function main (line 166) | def main(argv: List[str]): FILE: rigl/experimental/jax/prune_test.py class PruneTest (line 26) | class PruneTest(absltest.TestCase): method test_prune_fixed_schedule (line 28) | def test_prune_fixed_schedule(self): method test_prune_global_pruning_schedule (line 45) | def test_prune_global_pruning_schedule(self): method test_prune_local_pruning_schedule (line 62) | def test_prune_local_pruning_schedule(self): FILE: rigl/experimental/jax/pruning/init.py function sparse_init (line 25) | def sparse_init( FILE: rigl/experimental/jax/pruning/init_test.py class MaskedDense (line 28) | class MaskedDense(flax.deprecated.nn.Module): method apply (line 33) | def apply(self, class MaskedDenseSparseInit (line 47) | class MaskedDenseSparseInit(flax.deprecated.nn.Module): method apply (line 52) | def apply(self, class MaskedCNN (line 70) | class MaskedCNN(flax.deprecated.nn.Module): method apply (line 75) | def apply(self, class MaskedCNNSparseInit (line 89) | class MaskedCNNSparseInit(flax.deprecated.nn.Module): method apply (line 94) | def apply(self, class InitTest (line 112) | class InitTest(absltest.TestCase): method setUp (line 114) | def setUp(self): method test_init_kaiming_sparse_normal_output (line 121) | def test_init_kaiming_sparse_normal_output(self): method test_dense_no_mask (line 140) | def test_dense_no_mask(self): method test_dense_sparse_init_kaiming (line 157) | def test_dense_sparse_init_kaiming(self): method test_cnn_sparse_init_kaiming (line 193) | def test_cnn_sparse_init_kaiming(self): FILE: rigl/experimental/jax/pruning/mask_factory.py function create_mask (line 47) | def create_mask(mask_type, base_model, FILE: rigl/experimental/jax/pruning/mask_factory_test.py class MaskedDense (line 29) | class MaskedDense(flax.deprecated.nn.Module): method apply (line 34) | def apply(self, class MaskFactoryTest (line 46) | class MaskFactoryTest(parameterized.TestCase): method setUp (line 48) | def setUp(self): method _create_mask (line 60) | def _create_mask(self, mask_type): method test_mask_supported (line 66) | def test_mask_supported(self, mask_type): method test_mask_unsupported (line 73) | def test_mask_unsupported(self): FILE: rigl/experimental/jax/pruning/masked.py class MaskedModule (line 55) | class MaskedModule(flax.deprecated.nn.Module): method apply (line 66) | def apply(self, function masked (line 115) | def masked(module, mask): function generate_model_masks (line 120) | def generate_model_masks( function _filter_param (line 158) | def _filter_param(param_names, function mask_map (line 182) | def mask_map(model, function iterate_mask (line 234) | def iterate_mask( function shuffled_mask (line 257) | def shuffled_mask(model, rng, function random_mask (line 292) | def random_mask(model, function simple_mask (line 326) | def simple_mask(model, function symmetric_mask (line 348) | def symmetric_mask(model, class _PerNeuronShuffle (line 379) | class _PerNeuronShuffle: method __init__ (line 382) | def __init__(self, init_rng, sparsity): method __call__ (line 393) | def __call__(self, param_name, param): function shuffled_neuron_mask (line 418) | def shuffled_neuron_mask(model, function _fill_diagonal_wrap (line 452) | def _fill_diagonal_wrap(shape, function _random_neuron_mask (line 511) | def _random_neuron_mask(neuron_length, class _PerNeuronNoInputAblationShuffle (line 535) | class _PerNeuronNoInputAblationShuffle: method __init__ (line 538) | def __init__(self, init_rng, sparsity): method _get_rng (line 549) | def _get_rng(self): method __call__ (line 554) | def __call__(self, param_name, param): function shuffled_neuron_no_input_ablation_mask (line 597) | def shuffled_neuron_no_input_ablation_mask(model, function propagate_masks (line 637) | def propagate_masks( function mask_layer_sparsity (line 710) | def mask_layer_sparsity(mask_layer): function mask_sparsity (line 733) | def mask_sparsity( FILE: rigl/experimental/jax/pruning/masked_test.py class Dense (line 29) | class Dense(flax.deprecated.nn.Module): method apply (line 34) | def apply(self, inputs): class MaskedDense (line 39) | class MaskedDense(flax.deprecated.nn.Module): method apply (line 44) | def apply(self, class DenseTwoLayer (line 56) | class DenseTwoLayer(flax.deprecated.nn.Module): method apply (line 61) | def apply(self, inputs): class MaskedTwoLayerDense (line 67) | class MaskedTwoLayerDense(flax.deprecated.nn.Module): method apply (line 72) | def apply(self, class MaskedConv (line 89) | class MaskedConv(flax.deprecated.nn.Module): method apply (line 94) | def apply(self, class MaskedTwoLayerConv (line 105) | class MaskedTwoLayerConv(flax.deprecated.nn.Module): method apply (line 110) | def apply(self, class MaskedThreeLayerConvDense (line 127) | class MaskedThreeLayerConvDense(flax.deprecated.nn.Module): method apply (line 132) | def apply(self, class MaskedTwoLayerMixedConvDense (line 155) | class MaskedTwoLayerMixedConvDense(flax.deprecated.nn.Module): method apply (line 160) | def apply(self, class MaskedTest (line 176) | class MaskedTest(parameterized.TestCase): method setUp (line 179) | def setUp(self): method test_fully_masked_layer (line 241) | def test_fully_masked_layer(self): method test_no_mask_masked_layer (line 253) | def test_no_mask_masked_layer(self): method test_empty_mask_masked_layer (line 263) | def test_empty_mask_masked_layer(self): method test_invalid_mask (line 275) | def test_invalid_mask(self): method test_shuffled_mask_invalid_model (line 287) | def test_shuffled_mask_invalid_model(self): method test_shuffled_mask_invalid_sparsity (line 294) | def test_shuffled_mask_invalid_sparsity(self): method test_shuffled_mask_sparsity_full (line 307) | def test_shuffled_mask_sparsity_full(self): method test_shuffled_mask_sparsity_empty (line 328) | def test_shuffled_mask_sparsity_empty(self): method test_shuffled_mask_sparsity_half_full (line 349) | def test_shuffled_mask_sparsity_half_full(self): method test_shuffled_mask_sparsity_full_twolayer (line 359) | def test_shuffled_mask_sparsity_full_twolayer(self): method test_shuffled_mask_sparsity_empty_twolayer (line 390) | def test_shuffled_mask_sparsity_empty_twolayer(self): method test_random_invalid_model (line 416) | def test_random_invalid_model(self): method test_random_invalid_sparsity (line 423) | def test_random_invalid_sparsity(self): method test_random_mask_sparsity_full (line 436) | def test_random_mask_sparsity_full(self): method test_random_mask_sparsity_empty (line 451) | def test_random_mask_sparsity_empty(self): method test_random_mask_sparsity_half_full (line 468) | def test_random_mask_sparsity_half_full(self): method test_simple_mask_one_layer (line 480) | def test_simple_mask_one_layer(self): method test_simple_mask_two_layer (line 500) | def test_simple_mask_two_layer(self): method test_shuffled_mask_neuron_mask_sparsity_empty (line 528) | def test_shuffled_mask_neuron_mask_sparsity_empty(self): method test_shuffled_mask_neuron_mask_sparsity_half_full (line 549) | def test_shuffled_mask_neuron_mask_sparsity_half_full(self): method test_symmetric_mask_sparsity_empty (line 566) | def test_symmetric_mask_sparsity_empty(self): method test_symmetric_mask_sparsity_half_full (line 587) | def test_symmetric_mask_sparsity_half_full(self): method test_propagate_masks_ablated_neurons_one_layer (line 604) | def test_propagate_masks_ablated_neurons_one_layer(self): method test_propagate_masks_ablated_neurons_two_layers (line 625) | def test_propagate_masks_ablated_neurons_two_layers(self): method test_propagate_masks_ablated_neurons_two_layers_nonmasked (line 653) | def test_propagate_masks_ablated_neurons_two_layers_nonmasked(self): method test_propagate_masks_ablated_neurons_one_conv_layer (line 683) | def test_propagate_masks_ablated_neurons_one_conv_layer(self): method test_propagate_masks_ablated_neurons_two_conv_layers (line 704) | def test_propagate_masks_ablated_neurons_two_conv_layers(self): method test_propagate_masks_ablated_neurons_three_conv_fc_layers (line 734) | def test_propagate_masks_ablated_neurons_three_conv_fc_layers(self): method test_propagate_masks_ablated_neurons_mixed_conv_dense_layers (line 774) | def test_propagate_masks_ablated_neurons_mixed_conv_dense_layers(self): method test_mask_layer_sparsity_zero_mask (line 802) | def test_mask_layer_sparsity_zero_mask(self): method test_mask_layer_sparsity_half_mask (line 809) | def test_mask_layer_sparsity_half_mask(self): method test_mask_layer_sparsity_ones_mask (line 816) | def test_mask_layer_sparsity_ones_mask(self): method test_mask_sparsity_zero_mask (line 823) | def test_mask_sparsity_zero_mask(self): method test_mask_sparsity_ones_mask (line 829) | def test_mask_sparsity_ones_mask(self): method test_mask_sparsity_mixed_mask (line 835) | def test_mask_sparsity_mixed_mask(self): method test_generate_model_masks_depth_only (line 873) | def test_generate_model_masks_depth_only(self, depth): method test_generate_model_masks_indices (line 890) | def test_generate_model_masks_indices(self, depth, indices): method test_generate_model_masks_existing_mask (line 909) | def test_generate_model_masks_existing_mask(self, depth, existing_mask, method test_generate_model_masks_invalid_depth_zero (line 931) | def test_generate_model_masks_invalid_depth_zero(self): method test_generate_model_masks_invalid_index_toohigh (line 936) | def test_generate_model_masks_invalid_index_toohigh(self): method test_generate_model_masks_invalid_index_negative (line 941) | def test_generate_model_masks_invalid_index_negative(self): method test_shuffled_neuron_no_input_ablation_mask_invalid_model (line 946) | def test_shuffled_neuron_no_input_ablation_mask_invalid_model(self): method test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity (line 954) | def test_shuffled_neuron_no_input_ablation_mask_invalid_sparsity(self): method test_shuffled_neuron_no_input_ablation_mask_sparsity_full (line 969) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_full(self): method test_shuffled_neuron_no_input_ablation_mask_sparsity_empty (line 994) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty(self): method test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full (line 1016) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_half_full(self): method test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full (line 1033) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_quarter_full(... method test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer (line 1050) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_full_twolayer... method test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolayer (line 1092) | def test_shuffled_neuron_no_input_ablation_mask_sparsity_empty_twolaye... FILE: rigl/experimental/jax/pruning/pruning.py function weight_magnitude (line 26) | def weight_magnitude(weights): function prune (line 31) | def prune( FILE: rigl/experimental/jax/pruning/pruning_test.py class MaskedDense (line 28) | class MaskedDense(flax.deprecated.nn.Module): method apply (line 33) | def apply(self, class MaskedTwoLayerDense (line 45) | class MaskedTwoLayerDense(flax.deprecated.nn.Module): method apply (line 50) | def apply(self, class MaskedConv (line 67) | class MaskedConv(flax.deprecated.nn.Module): method apply (line 72) | def apply(self, class MaskedTwoLayerConv (line 83) | class MaskedTwoLayerConv(flax.deprecated.nn.Module): method apply (line 88) | def apply(self, class PruningTest (line 105) | class PruningTest(absltest.TestCase): method setUp (line 108) | def setUp(self): method test_prune_single_layer_dense_no_mask (line 134) | def test_prune_single_layer_dense_no_mask(self): method test_prune_single_layer_local_pruning (line 145) | def test_prune_single_layer_local_pruning(self): method test_prune_single_layer_dense_with_mask (line 158) | def test_prune_single_layer_dense_with_mask(self): method test_prune_two_layers_dense_no_mask (line 172) | def test_prune_two_layers_dense_no_mask(self): method test_prune_two_layer_local_pruning_rate (line 186) | def test_prune_two_layer_local_pruning_rate(self): method test_prune_one_layer_conv_no_mask (line 206) | def test_prune_one_layer_conv_no_mask(self): method test_prune_one_layer_conv_with_mask (line 217) | def test_prune_one_layer_conv_with_mask(self): method test_prune_two_layer_conv_no_mask (line 231) | def test_prune_two_layer_conv_no_mask(self): FILE: rigl/experimental/jax/pruning/symmetry.py function count_permutations_mask_layer (line 30) | def count_permutations_mask_layer( function count_permutations_mask (line 125) | def count_permutations_mask(mask): function get_mask_stats (line 161) | def get_mask_stats(mask): FILE: rigl/experimental/jax/pruning/symmetry_test.py class MaskedDense (line 33) | class MaskedDense(flax.deprecated.nn.Module): method apply (line 42) | def apply(self, class MaskedConv (line 53) | class MaskedConv(flax.deprecated.nn.Module): method apply (line 62) | def apply(self, class MaskedTwoLayerDense (line 73) | class MaskedTwoLayerDense(flax.deprecated.nn.Module): method apply (line 82) | def apply(self, class SymmetryTest (line 99) | class SymmetryTest(parameterized.TestCase): method setUp (line 102) | def setUp(self): method test_count_permutations_layer_mask_full (line 123) | def test_count_permutations_layer_mask_full(self): method test_count_permutations_layer_mask_empty (line 146) | def test_count_permutations_layer_mask_empty(self): method test_count_permutations_conv_layer_mask_full (line 168) | def test_count_permutations_conv_layer_mask_full(self): method test_count_permutations_conv_layer_mask_empty (line 191) | def test_count_permutations_conv_layer_mask_empty(self): method test_count_permutations_layer_mask_known_perm (line 213) | def test_count_permutations_layer_mask_known_perm(self): method test_count_permutations_layer_mask_known_perm_zeros (line 247) | def test_count_permutations_layer_mask_known_perm_zeros(self): method test_count_permutations_shuffled_full_mask (line 279) | def test_count_permutations_shuffled_full_mask(self): method test_count_permutations_shuffled_empty_mask (line 297) | def test_count_permutations_shuffled_empty_mask(self): method test_count_permutations_mask_layer_twolayer_known_symmetric (line 316) | def test_count_permutations_mask_layer_twolayer_known_symmetric(self): method test_count_permutations_mask_layer_twolayer (line 396) | def test_count_permutations_mask_layer_twolayer(self, mask, unique, method test_count_permutations_mask_full (line 414) | def test_count_permutations_mask_full(self): method test_count_permutations_mask_bn_layer_full (line 433) | def test_count_permutations_mask_bn_layer_full(self): method test_count_permutations_mask_empty (line 452) | def test_count_permutations_mask_empty(self): method test_count_permutations_mask_twolayer_full (line 470) | def test_count_permutations_mask_twolayer_full(self): method test_count_permutations_mask_twolayers_empty (line 494) | def test_count_permutations_mask_twolayers_empty(self): method test_count_permutations_mask_twolayer_known_symmetric (line 515) | def test_count_permutations_mask_twolayer_known_symmetric(self): method test_count_permutations_mask_twolayer_known_non_symmetric (line 542) | def test_count_permutations_mask_twolayer_known_non_symmetric(self): method test_get_mask_stats_keys_values (line 569) | def test_get_mask_stats_keys_values(self): FILE: rigl/experimental/jax/random_mask.py function main (line 177) | def main(argv: List[str]): FILE: rigl/experimental/jax/random_mask_test.py class RandomMaskTest (line 26) | class RandomMaskTest(absltest.TestCase): method test_run_fc (line 28) | def test_run_fc(self): method test_run_conv (line 46) | def test_run_conv(self): method test_run_random (line 64) | def test_run_random(self): method test_run_per_neuron (line 82) | def test_run_per_neuron(self): method test_run_symmetric (line 100) | def test_run_symmetric(self): FILE: rigl/experimental/jax/shuffled_mask.py function main (line 178) | def main(argv: List[str]): FILE: rigl/experimental/jax/shuffled_mask_test.py class ShuffledMaskTest (line 26) | class ShuffledMaskTest(absltest.TestCase): method test_run_fc (line 28) | def test_run_fc(self): method test_run_conv (line 45) | def test_run_conv(self): method test_run_random (line 62) | def test_run_random(self): method test_run_per_neuron (line 79) | def test_run_per_neuron(self): method test_run_symmetric (line 96) | def test_run_symmetric(self): FILE: rigl/experimental/jax/train.py function run_training (line 86) | def run_training(): function main (line 170) | def main(argv): FILE: rigl/experimental/jax/train_test.py class TrainTest (line 26) | class TrainTest(absltest.TestCase): method test_train_driver_run (line 28) | def test_train_driver_run(self): FILE: rigl/experimental/jax/training/training.py function _shard_batch (line 51) | def _shard_batch(xs): function train_step (line 61) | def train_step( class Trainer (line 110) | class Trainer: method __init__ (line 118) | def __init__( FILE: rigl/experimental/jax/training/training_test.py class TrainingTest (line 33) | class TrainingTest(absltest.TestCase): method setUp (line 36) | def setUp(self): method test_train_one_step (line 76) | def test_train_one_step(self): method test_train_one_epoch (line 104) | def test_train_one_epoch(self): method test_train_one_epoch_tensorboard (line 140) | def test_train_one_epoch_tensorboard(self): method test_train_one_epoch_pruning_global_schedule (line 180) | def test_train_one_epoch_pruning_global_schedule(self): method test_train_one_epoch_pruning_local_schedule (line 217) | def test_train_one_epoch_pruning_local_schedule(self): method test_eval_batch (line 254) | def test_eval_batch(self): method test_eval (line 275) | def test_eval(self): FILE: rigl/experimental/jax/utils/utils.py function cross_entropy_loss (line 34) | def cross_entropy_loss(log_softmax_logits, function compute_metrics (line 48) | def compute_metrics(logits, function _np_converter (line 76) | def _np_converter(obj): function dump_dict_json (line 86) | def dump_dict_json(data_dict, path): function count_param (line 100) | def count_param(model, function cosine_similarity (line 120) | def cosine_similarity(a, b): function param_as_array (line 127) | def param_as_array(params): function cosine_similarity_model (line 133) | def cosine_similarity_model(initial_model, function vector_difference_norm_model (line 142) | def vector_difference_norm_model(initial_model, function pairwise_longest (line 154) | def pairwise_longest(iterable): FILE: rigl/experimental/jax/utils/utils_test.py class TwoLayerDense (line 34) | class TwoLayerDense(flax.deprecated.nn.Module): method apply (line 39) | def apply(self, inputs): class UtilsTest (line 47) | class UtilsTest(parameterized.TestCase): method setUp (line 50) | def setUp(self): method _create_logits_labels (line 68) | def _create_logits_labels(self, correct): method test_compute_metrics_correct (line 93) | def test_compute_metrics_correct(self): method test_compute_metrics_incorrect (line 122) | def test_compute_metrics_incorrect(self): method test_compute_metrics_equal_logits (line 151) | def test_compute_metrics_equal_logits(self): method test_dump_dict_json (line 180) | def test_dump_dict_json(self): method test_count_param_two_layer_dense (line 200) | def test_count_param_two_layer_dense(self): method test_count_invalid_param (line 209) | def test_count_invalid_param(self): method test_model_param_as_array (line 215) | def test_model_param_as_array(self): method test_cosine_similarity_random (line 228) | def test_cosine_similarity_random(self): method test_cosine_similarity_same (line 238) | def test_cosine_similarity_same(self): method test_cosine_similarity_same_model (line 247) | def test_cosine_similarity_same_model(self): method test_vector_difference_norm_diff_model (line 253) | def test_vector_difference_norm_diff_model(self): method test_vector_difference_norm_same_model (line 260) | def test_vector_difference_norm_same_model(self): method test_pairwise_longest_list_iterator (line 277) | def test_pairwise_longest_list_iterator( FILE: rigl/imagenet_resnet/imagenet_train_eval.py function set_lr_schedule (line 280) | def set_lr_schedule(): function set_custom_sparsity_map (line 308) | def set_custom_sparsity_map(): function lr_schedule (line 317) | def lr_schedule(current_epoch): function train_function (line 333) | def train_function(training_method, loss, cross_loss, reg_loss, output_dir, function resnet_model_fn_w_pruning (line 478) | def resnet_model_fn_w_pruning(features, labels, mode, params): class ExportModelHook (line 668) | class ExportModelHook(tf.train.SessionRunHook): method __init__ (line 671) | def __init__(self, classifier, export_dir): method begin (line 683) | def begin(self): method after_run (line 686) | def after_run(self, run_context, run_values): function main (line 703) | def main(argv): FILE: rigl/imagenet_resnet/mobilenetv1_model.py function _make_divisible (line 33) | def _make_divisible(v, divisor=8, min_value=None): function depthwise_conv2d_fixed_padding (line 43) | def depthwise_conv2d_fixed_padding(inputs, function conv2d_fixed_padding (line 95) | def conv2d_fixed_padding(inputs, function mbv1_block_ (line 156) | def mbv1_block_(inputs, function mobilenet_v1_generator (line 223) | def mobilenet_v1_generator(num_classes=1000, function mobilenet_v1 (line 345) | def mobilenet_v1(num_classes, FILE: rigl/imagenet_resnet/mobilenetv2_model.py function _make_divisible (line 33) | def _make_divisible(v, divisor=8, min_value=None): function depthwise_conv2d_fixed_padding (line 43) | def depthwise_conv2d_fixed_padding(inputs, function conv2d_fixed_padding (line 95) | def conv2d_fixed_padding(inputs, function inverted_res_block_ (line 156) | def inverted_res_block_(inputs, function mobilenet_v2_generator (line 255) | def mobilenet_v2_generator(num_classes=1000, function mobilenet_v2 (line 401) | def mobilenet_v2(num_classes, FILE: rigl/imagenet_resnet/pruning_layers.py function get_model_variables (line 29) | def get_model_variables(getter, function variable_getter (line 62) | def variable_getter(rename=None): function sparse_conv2d (line 72) | def sparse_conv2d(x, function sparse_fully_connected (line 175) | def sparse_fully_connected(x, FILE: rigl/imagenet_resnet/resnet_model.py function batch_norm_relu (line 41) | def batch_norm_relu(inputs, is_training, relu=True, init_zero=False, function fixed_padding (line 83) | def fixed_padding(inputs, kernel_size, data_format='channels_first'): class RandomSparseInitializer (line 111) | class RandomSparseInitializer(init_ops.Initializer): method __init__ (line 114) | def __init__(self, sparsity, seed=None, dtype=tf.float32): method __call__ (line 123) | def __call__(self, *args, **kwargs): method get_config (line 131) | def get_config(self): class SparseConvVarianceScalingInitializer (line 139) | class SparseConvVarianceScalingInitializer(init_ops.Initializer): method __init__ (line 142) | def __init__(self, sparsity, seed=None, dtype=tf.float32): method __call__ (line 149) | def __call__(self, shape, dtype=None, partition_info=None): method get_config (line 168) | def get_config(self): class SparseFCVarianceScalingInitializer (line 175) | class SparseFCVarianceScalingInitializer(init_ops.Initializer): method __init__ (line 178) | def __init__(self, sparsity, seed=None, dtype=tf.float32): method __call__ (line 185) | def __call__(self, shape, dtype=None, partition_info=None): method get_config (line 207) | def get_config(self): function _pick_initializer (line 214) | def _pick_initializer(kernel_initializer, init_method, pruning_method, function conv2d_fixed_padding (line 234) | def conv2d_fixed_padding(inputs, function residual_block_ (line 306) | def residual_block_(inputs, function bottleneck_block_ (line 396) | def bottleneck_block_(inputs, function block_group (line 504) | def block_group(inputs, function resnet_v1_generator (line 577) | def resnet_v1_generator(block_fn, function resnet_v1_ (line 734) | def resnet_v1_(resnet_depth, FILE: rigl/imagenet_resnet/train_test.py class DataInputTest (line 36) | class DataInputTest(tf.test.TestCase, parameterized.TestCase): method _retrieve_data (line 38) | def _retrieve_data(self, is_training, data_dir): method testTrainingPipeline (line 50) | def testTrainingPipeline(self, training_method): FILE: rigl/imagenet_resnet/utils.py function format_tensors (line 28) | def format_tensors(*dicts): function host_call_fn (line 59) | def host_call_fn(model_dir, **kwargs): function mask_summaries (line 83) | def mask_summaries(masks, with_img=False): function initialize_parameters_from_ckpt (line 93) | def initialize_parameters_from_ckpt(ckpt_path, model_dir, param_suffixes): FILE: rigl/imagenet_resnet/vgg.py function vgg_net (line 64) | def vgg_net(inputs, function vgg (line 203) | def vgg(vgg_type, FILE: rigl/mnist/mnist_train_eval.py function mnist_network_fc (line 112) | def mnist_network_fc(input_batch, reuse=False, model_pruning=False): function get_compressed_fc (line 165) | def get_compressed_fc(masks): function main (line 192) | def main(unused_args): FILE: rigl/mnist/visualize_mask_records.py function main (line 62) | def main(unused_args): FILE: rigl/rigl_tf2/init_utils.py function unit_scaled_init (line 23) | def unit_scaled_init(mask, method='fanavg_uniform', scale=1.0): function layer_scaled_init (line 70) | def layer_scaled_init(mask, method='fanavg_uniform', scale=1.0): function unit_scaled_init_tf1 (line 81) | def unit_scaled_init_tf1(mask, FILE: rigl/rigl_tf2/interpolate.py function test_model (line 61) | def test_model(model, d_test, batch_size=1000): function interpolate (line 80) | def interpolate(model_start, model_end, model_inter, d_set, function main (line 97) | def main(unused_argv): FILE: rigl/rigl_tf2/mask_updaters.py function get_all_layers (line 22) | def get_all_layers(model, filter_fn=lambda _: True): function is_pruned (line 33) | def is_pruned(layer): class MaskUpdater (line 37) | class MaskUpdater(object): method __init__ (line 49) | def __init__(self, model, optimizer, use_stateless=True, method prune_masks (line 58) | def prune_masks(self, prune_fraction): method update_masks (line 67) | def update_masks(self, drop_fraction): method get_all_pruning_layers (line 76) | def get_all_pruning_layers(self): method get_vars_and_masks (line 83) | def get_vars_and_masks(self): method get_drop_scores (line 93) | def get_drop_scores(self, all_vars, all_masks): method get_grow_scores (line 96) | def get_grow_scores(self, all_vars, all_masks): method generic_mask_update (line 99) | def generic_mask_update(self, mask, var, score_drop, score_grow, method reset_momentum (line 156) | def reset_momentum(self, var, new_connections): method _random_uniform (line 164) | def _random_uniform(self, *args, **kwargs): method _random_normal (line 173) | def _random_normal(self, *args, **kwargs): method set_validation_data (line 182) | def set_validation_data(self, val_x, val_y): method _get_gradients (line 185) | def _get_gradients(self, all_vars): class SET (line 195) | class SET(MaskUpdater): method get_drop_scores (line 204) | def get_drop_scores(self, all_vars, all_masks, noise_std=0): method get_grow_scores (line 214) | def get_grow_scores(self, all_vars, all_masks): class RigL (line 219) | class RigL(MaskUpdater): method get_drop_scores (line 225) | def get_drop_scores(self, all_vars, all_masks, noise_std=0): method get_grow_scores (line 235) | def get_grow_scores(self, all_vars, all_masks): class RigLInverted (line 239) | class RigLInverted(RigL): method get_grow_scores (line 245) | def get_grow_scores(self, all_vars, all_masks): class UpdateSchedule (line 251) | class UpdateSchedule(object): method __init__ (line 260) | def __init__(self, mask_updater, init_drop_fraction, update_freq, method get_drop_fraction (line 268) | def get_drop_fraction(self, step): method is_update_iter (line 271) | def is_update_iter(self, step): method update (line 286) | def update(self, step, check_update_iter=True): method prune (line 296) | def prune(self, prune_fraction): method set_validation_data (line 300) | def set_validation_data(self, val_x, val_y): class ConstantUpdateSchedule (line 304) | class ConstantUpdateSchedule(UpdateSchedule): method get_drop_fraction (line 307) | def get_drop_fraction(self, step): class CosineUpdateSchedule (line 311) | class CosineUpdateSchedule(UpdateSchedule): method __init__ (line 314) | def __init__(self, *args, **kwargs): method get_drop_fraction (line 322) | def get_drop_fraction(self, step): class ScaledLRUpdateSchedule (line 326) | class ScaledLRUpdateSchedule(UpdateSchedule): method __init__ (line 329) | def __init__(self, mask_updater, init_drop_fraction, update_freq, method _get_lr (line 336) | def _get_lr(self, step): method get_drop_fraction (line 342) | def get_drop_fraction(self, step): function get_mask_updater (line 359) | def get_mask_updater( FILE: rigl/rigl_tf2/metainit.py class ScaleSGD (line 23) | class ScaleSGD(tf1.train.Optimizer): method __init__ (line 29) | def __init__(self, learning_rate=0.1, momentum=0.9, mindim=3, method _prepare (line 40) | def _prepare(self): method _create_slots (line 44) | def _create_slots(self, var_list): method _resource_apply_dense (line 53) | def _resource_apply_dense(self, grad, handle): method _apply_dense (line 71) | def _apply_dense(self, grad, var): method _apply_sparse (line 74) | def _apply_sparse(self, grad, var): function meta_init (line 78) | def meta_init(model, loss, x_shape, y_shape, n_params, learning_rate=0.001, FILE: rigl/rigl_tf2/networks.py function lenet5 (line 25) | def lenet5(input_shape, function mlp (line 58) | def mlp(input_shape, FILE: rigl/rigl_tf2/train.py function get_rows (line 59) | def get_rows(model, variables, masks, ind_l, indices, x_batch, y_batch, function sparse_hessian_calculator (line 89) | def sparse_hessian_calculator(model, function hessian (line 170) | def hessian(model, function update_prune_step (line 195) | def update_prune_step(model, step): function log_sparsities (line 202) | def log_sparsities(model): function cosine_distance (line 212) | def cosine_distance(x, y): function flatten_list_of_vars (line 219) | def flatten_list_of_vars(var_list): function var_to_img (line 224) | def var_to_img(tensor): function mask_gradients (line 235) | def mask_gradients(model, gradients, variables): function train_model (line 248) | def train_model(model, function test_model (line 445) | def test_model(model, d_test, batch_size=1000): function main (line 461) | def main(unused_argv): FILE: rigl/rigl_tf2/utils.py function get_dataset (line 37) | def get_dataset(): function get_pruning_params (line 51) | def get_pruning_params(mode='prune', function maybe_prune_layer (line 75) | def maybe_prune_layer(layer, params, filter_fn): function get_network (line 82) | def get_network( function get_optimizer (line 182) | def get_optimizer(total_steps, FILE: rigl/rl/dqn_agents.py function flatten_list_of_vars (line 36) | def flatten_list_of_vars(var_list): function _get_bn_layer_name (line 41) | def _get_bn_layer_name(block_id, i): function _get_conv_layer_name (line 45) | def _get_conv_layer_name(block_id, i): class _Stack (line 49) | class _Stack(tf.keras.Model): method __init__ (line 53) | def __init__(self, method call (line 80) | def call(self, conv_out, training=False): class ImpalaNetwork (line 103) | class ImpalaNetwork(tf.keras.Model): method __init__ (line 120) | def __init__(self, method get_features (line 190) | def get_features(self, state, training=True): method call (line 205) | def call(self, state, training=True): class NatureDQNNetwork (line 211) | class NatureDQNNetwork(tf.keras.Model): method __init__ (line 214) | def __init__(self, num_actions, width=1, mode='dense', name=None): method call (line 284) | def call(self, state): class SparseDQNAgent (line 309) | class SparseDQNAgent(dqn_agent.DQNAgent): method __init__ (line 312) | def __init__(self, method _create_network (line 337) | def _create_network(self, name): method _set_additional_ops (line 344) | def _set_additional_ops(self): method _build_train_op (line 370) | def _build_train_op(self): method _create_summary_ops (line 406) | def _create_summary_ops(self, grads_and_vars): method update_prune_step (line 430) | def update_prune_step(self): method maybe_update_and_apply_masks (line 433) | def maybe_update_and_apply_masks(self): method maybe_init_masks (line 436) | def maybe_init_masks(self): method _train_step (line 440) | def _train_step(self): method _build_sync_op (line 459) | def _build_sync_op(self): method _build_networks (line 474) | def _build_networks(self): FILE: rigl/rl/run_experiment.py function create_sparse_agent (line 33) | def create_sparse_agent(sess, num_actions, agent=None, summary_writer=No... class SparseTrainRunner (line 54) | class SparseTrainRunner(run_experiment.Runner): method __init__ (line 57) | def __init__(self, method _run_one_phase_fix_episodes (line 127) | def _run_one_phase_fix_episodes(self, max_episodes, statistics): method _run_eval_phase (line 165) | def _run_eval_phase(self, statistics): method _run_one_step (line 177) | def _run_one_step(self, action): method run_experiment (line 186) | def run_experiment(self): FILE: rigl/rl/sparse_utils.py function get_total_params (line 36) | def get_total_params(model): function get_pruning_sparsities (line 56) | def get_pruning_sparsities( function get_pruning_params (line 86) | def get_pruning_params(mode, function maybe_prune_layer (line 113) | def maybe_prune_layer(layer, params, filter_fn=None): function get_wrap_fn (line 121) | def get_wrap_fn(mode): function update_prune_step (line 139) | def update_prune_step(model, step): function update_prune_masks (line 150) | def update_prune_masks(model): function get_all_layers (line 157) | def get_all_layers(model, filter_fn=lambda _: True): function get_all_variables_and_masks (line 168) | def get_all_variables_and_masks(model): function get_all_pruning_layers (line 179) | def get_all_pruning_layers(model): function log_sparsities (line 185) | def log_sparsities(model): class SparseOptTf2Mixin (line 197) | class SparseOptTf2Mixin: method compute_gradients (line 200) | def compute_gradients(self, *args, **kwargs): method set_model (line 204) | def set_model(self, model): method get_weights (line 207) | def get_weights(self): method get_masks (line 213) | def get_masks(self): method get_masked_weights (line 219) | def get_masked_weights(self): class UpdatedSETOptimizer (line 227) | class UpdatedSETOptimizer(SparseOptTf2Mixin, method _before_apply_gradients (line 230) | def _before_apply_gradients(self, grads_and_vars): class UpdatedRigLOptimizer (line 235) | class UpdatedRigLOptimizer(SparseOptTf2Mixin, method _before_apply_gradients (line 238) | def _before_apply_gradients(self, grads_and_vars): function init_masks (line 245) | def init_masks(model, FILE: rigl/rl/tfagents/dqn_train_eval.py class SparseDqnAgent (line 75) | class SparseDqnAgent(dqn_agent.DqnAgent): method __init__ (line 78) | def __init__(self, *args, **kwargs): method _train (line 95) | def _train(self, experience, weights): function _scale_width (line 151) | def _scale_width(num_units, width): function build_network (line 156) | def build_network( function train_eval (line 200) | def train_eval( function main (line 404) | def main(_): FILE: rigl/rl/tfagents/ppo_train_eval.py function _normalize_advantages (line 99) | def _normalize_advantages(advantages, axes=(0,), variance_epsilon=1e-8): class SparsePPOAgent (line 112) | class SparsePPOAgent(ppo_clip_agent.PPOClipAgent): method __init__ (line 115) | def __init__(self, method _process_experience_weights (line 168) | def _process_experience_weights(self, experience, weights): method _train (line 233) | def _train(self, experience, weights): method get_loss (line 424) | def get_loss(self, method value_estimation_loss (line 541) | def value_estimation_loss(self, method policy_gradient_loss (line 644) | def policy_gradient_loss( method entropy_regularization_loss (line 801) | def entropy_regularization_loss( class ReverbFixedLengthSequenceObserver (line 844) | class ReverbFixedLengthSequenceObserver(reverb_utils.ReverbAddTrajectory... method __call__ (line 857) | def __call__(self, trajectory): function train_eval (line 874) | def train_eval( function main (line 1175) | def main(_): FILE: rigl/rl/tfagents/sac_train_eval.py function create_fc_layers (line 81) | def create_fc_layers(layer_units, width=1.0, weight_decay=0): function create_identity_layer (line 90) | def create_identity_layer(): function create_sequential_critic_network (line 94) | def create_sequential_critic_network(obs_fc_layer_units, class _TanhNormalProjectionNetworkWrapper (line 176) | class _TanhNormalProjectionNetworkWrapper( method __init__ (line 180) | def __init__(self, sample_spec, predefined_outer_rank=1, weight_decay=... method call (line 186) | def call(self, inputs, network_state=(), **kwargs): function create_sequential_actor_network (line 194) | def create_sequential_actor_network(actor_fc_layers, class SparseSacAgent (line 234) | class SparseSacAgent(sac_agent.SacAgent): method __init__ (line 237) | def __init__(self, method _train (line 316) | def _train(self, experience, weights): function train_eval (line 455) | def train_eval( function main (line 698) | def main(_): FILE: rigl/rl/tfagents/sparse_encoding_network.py function _copy_layer (line 46) | def _copy_layer(layer): class EncodingNetwork (line 79) | class EncodingNetwork(network.Network): method __init__ (line 82) | def __init__(self, method call (line 297) | def call(self, observation, step_type=None, network_state=(), training... FILE: rigl/rl/tfagents/sparse_ppo_actor_network.py function tanh_and_scale_to_spec (line 30) | def tanh_and_scale_to_spec(inputs, spec): class PPOActorNetwork (line 38) | class PPOActorNetwork(): method __init__ (line 41) | def __init__(self, method create_sequential_actor_net (line 53) | def create_sequential_actor_net(self, FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network.py function tanh_and_scale_to_spec (line 31) | def tanh_and_scale_to_spec(inputs, spec): class PPODiscreteActorNetwork (line 39) | class PPODiscreteActorNetwork(): method __init__ (line 42) | def __init__(self, seed_stream_class=tfp.util.SeedStream, method create_sequential_actor_net (line 57) | def create_sequential_actor_net(self, FILE: rigl/rl/tfagents/sparse_ppo_discrete_actor_network_test.py class DeterministicSeedStream (line 32) | class DeterministicSeedStream(object): method __init__ (line 35) | def __init__(self, seed, salt=''): method __call__ (line 39) | def __call__(self): class PpoActorNetworkTest (line 43) | class PpoActorNetworkTest(parameterized.TestCase, test_utils.TestCase): method setUp (line 45) | def setUp(self): method tearDown (line 52) | def tearDown(self): method _init_network (line 56) | def _init_network( method test_no_mismatched_shape (line 66) | def test_no_mismatched_shape(self): method test_is_sparse (line 83) | def test_is_sparse(self, is_sparse, sparse_output_layer, expected_laye... method test_width_scaling (line 96) | def test_width_scaling(self): method test_weight_decay (line 115) | def test_weight_decay(self, is_sparse, sparse_output_layer, FILE: rigl/rl/tfagents/sparse_tanh_normal_projection_network.py class SparseTanhNormalProjectionNetwork (line 34) | class SparseTanhNormalProjectionNetwork( method __init__ (line 42) | def __init__(self, FILE: rigl/rl/tfagents/sparse_value_network.py class ValueNetwork (line 42) | class ValueNetwork(network.Network): method __init__ (line 45) | def __init__(self, method call (line 160) | def call(self, observation, step_type=None, network_state=(), training... FILE: rigl/rl/tfagents/tf_sparse_utils.py function log_total_params (line 34) | def log_total_params(networks): function scale_width (line 43) | def scale_width(num_units, width): function wrap_all_layers (line 49) | def wrap_all_layers(layers, function wrap_layer (line 115) | def wrap_layer(layer, function is_valid_layer_to_wrap (line 144) | def is_valid_layer_to_wrap(layer): function log_sparsities (line 153) | def log_sparsities(model, model_name='q_net', log_images=False): function update_prune_step (line 174) | def update_prune_step(model, step): function flatten_list_of_vars (line 180) | def flatten_list_of_vars(var_list): function log_snr (line 186) | def log_snr(tape, loss, step, variables_to_train, freq=1000): FILE: rigl/rl/train.py function create_sparsetrain_runner (line 41) | def create_sparsetrain_runner(base_dir): function main (line 46) | def main(unused_argv): FILE: rigl/sparse_optimizers.py class PruningGetterTf1Mixin (line 46) | class PruningGetterTf1Mixin: method get_weights (line 49) | def get_weights(self): method get_masks (line 52) | def get_masks(self): method get_masked_weights (line 55) | def get_masked_weights(self): class SparseSETOptimizer (line 59) | class SparseSETOptimizer(PruningGetterTf1Mixin, class SparseRigLOptimizer (line 64) | class SparseRigLOptimizer(PruningGetterTf1Mixin, class SparseStaticOptimizer (line 69) | class SparseStaticOptimizer(SparseSETOptimizer): method __init__ (line 86) | def __init__(self, method generic_mask_update (line 109) | def generic_mask_update(self, mask, weights, noise_std=1e-5): class SparseMomentumOptimizer (line 126) | class SparseMomentumOptimizer(SparseSETOptimizer): method __init__ (line 149) | def __init__(self, method set_masked_grads (line 176) | def set_masked_grads(self, grads, weights): method compute_gradients (line 183) | def compute_gradients(self, loss, **kwargs): method _before_apply_gradients (line 195) | def _before_apply_gradients(self, grads_and_vars): method generic_mask_update (line 199) | def generic_mask_update(self, mask, weights, noise_std=1e-5): class SparseSnipOptimizer (line 217) | class SparseSnipOptimizer(tf_optimizer.Optimizer): method __init__ (line 235) | def __init__(self, method compute_gradients (line 254) | def compute_gradients(self, loss, **kwargs): method apply_gradients (line 258) | def apply_gradients(self, grads_and_vars, global_step=None, name=None): class SparseDNWOptimizer (line 340) | class SparseDNWOptimizer(tf_optimizer.Optimizer): method __init__ (line 360) | def __init__(self, method compute_gradients (line 375) | def compute_gradients(self, loss, var_list=None, **kwargs): method replace_with_masked_weights (line 388) | def replace_with_masked_weights(self, var_list): method replace_masked_weights (line 397) | def replace_masked_weights(self, grads_and_vars): method apply_gradients (line 408) | def apply_gradients(self, grads_and_vars, global_step=None, name=None): method get_weights (line 473) | def get_weights(self): method get_masks (line 476) | def get_masks(self): method get_masked_weights (line 479) | def get_masked_weights(self): FILE: rigl/sparse_optimizers_base.py function extract_number (line 45) | def extract_number(token): class SparseSETOptimizerBase (line 62) | class SparseSETOptimizerBase(tf_optimizer.Optimizer): method __init__ (line 87) | def __init__(self, method compute_gradients (line 113) | def compute_gradients(self, loss, **kwargs): method apply_gradients (line 118) | def apply_gradients(self, grads_and_vars, global_step=None, name=None): method _before_apply_gradients (line 148) | def _before_apply_gradients(self, grads_and_vars): method cond_mask_update_op (line 152) | def cond_mask_update_op(self, global_step, false_branch): method get_weights (line 189) | def get_weights(self): method get_masks (line 192) | def get_masks(self): method get_masked_weights (line 195) | def get_masked_weights(self): method is_mask_update_iter (line 198) | def is_mask_update_iter(self, global_step, last_update_step): method get_drop_fraction (line 232) | def get_drop_fraction(self, global_step, is_mask_update_iter_op): method generic_mask_update (line 260) | def generic_mask_update(self, mask, weights, noise_std=1e-5): method _get_update_op (line 276) | def _get_update_op(self, method reset_momentum (line 345) | def reset_momentum(self, weights, new_connections): method get_grow_tensor (line 355) | def get_grow_tensor(self, weights, method): method _random_uniform (line 402) | def _random_uniform(self, *args, **kwargs): method _random_normal (line 411) | def _random_normal(self, *args, **kwargs): class SparseRigLOptimizerBase (line 421) | class SparseRigLOptimizerBase(SparseSETOptimizerBase): method __init__ (line 444) | def __init__(self, method set_masked_grads (line 471) | def set_masked_grads(self, grads, weights): method compute_gradients (line 478) | def compute_gradients(self, loss, **kwargs): method apply_gradients (line 487) | def apply_gradients(self, grads_and_vars, global_step=None, name=None): method generic_mask_update (line 523) | def generic_mask_update(self, mask, weights, noise_std=1e-5): method get_grow_tensor (line 540) | def get_grow_tensor(self, weights, method): method reset_momentum (line 555) | def reset_momentum(self, weights, new_connections): FILE: rigl/sparse_optimizers_test.py class SparseSETOptimizerTest (line 38) | class SparseSETOptimizerTest(tf.test.TestCase, parameterized.TestCase): method _setup_graph (line 40) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, method testMaskNonUpdateIterations (line 72) | def testMaskNonUpdateIterations(self, n_inp, n_out, drop_frac): method testUpdateIterations (line 95) | def testUpdateIterations(self, n_inp, n_out, drop_frac): method testNoDrop (line 121) | def testNoDrop(self, start_iter, end_iter, freq_iter): method testNewConnectionZeroInit (line 141) | def testNewConnectionZeroInit(self): method testShapeOfGetGrowTensor (line 160) | def testShapeOfGetGrowTensor(self, shape, init_type): method testDtypeOfGetGrowTensor (line 172) | def testDtypeOfGetGrowTensor(self, dtype, init_type): method testValueErrorOfGetGrowTensor (line 182) | def testValueErrorOfGetGrowTensor(self, method): class SparseStaticOptimizerTest (line 192) | class SparseStaticOptimizerTest(tf.test.TestCase, parameterized.TestCase): method _setup_graph (line 194) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, method testMaskStatic (line 226) | def testMaskStatic(self, n_inp, n_out, drop_frac): class SparseMomentumOptimizerTest (line 247) | class SparseMomentumOptimizerTest(tf.test.TestCase, parameterized.TestCa... method _setup_graph (line 249) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, method testMomentumUpdate (line 276) | def testMomentumUpdate(self, n_inp, n_out, momentum): class SparseRigLOptimizerTest (line 297) | class SparseRigLOptimizerTest(tf.test.TestCase, parameterized.TestCase): method _setup_graph (line 299) | def _setup_graph(self, n_inp, n_out, drop_frac, start_iter=1, end_iter=4, method testMaskedGradientCalculation (line 331) | def testMaskedGradientCalculation(self, n_inp, n_out): method testApplyGradients (line 353) | def testApplyGradients(self, start_iter, end_iter, freq_iter, is_incre... class SparseSnipOptimizerTest (line 370) | class SparseSnipOptimizerTest(tf.test.TestCase, parameterized.TestCase): method _setup_graph (line 372) | def _setup_graph(self, default_sparsity, mask_init_method, method testSnipSparsity (line 407) | def testSnipSparsity(self, n_inp, n_out, default_sparsity): method testGradientUsed (line 422) | def testGradientUsed(self, n_inp, n_out, default_sparsity): method testInitialMaskIsDense (line 441) | def testInitialMaskIsDense(self, n_inp, n_out, default_sparsity): method testAfterSnipTraining (line 451) | def testAfterSnipTraining(self, n_inp, n_out, default_sparsity): class SparseDNWOptimizerTest (line 471) | class SparseDNWOptimizerTest(tf.test.TestCase, parameterized.TestCase): method _setup_graph (line 473) | def _setup_graph(self, method testDNWSparsity (line 515) | def testDNWSparsity(self, n_inp, n_out, default_sparsity): method testWeightsUsed (line 529) | def testWeightsUsed(self, n_inp, n_out, default_sparsity): method testGradientIsDense (line 548) | def testGradientIsDense(self, n_inp, n_out, default_sparsity): method testDNWUpdates (line 557) | def testDNWUpdates(self, n_inp, n_out, default_sparsity): method testSparsityAfterDNWUpdates (line 574) | def testSparsityAfterDNWUpdates(self, n_inp, n_out, default_sparsity): FILE: rigl/sparse_utils.py function mask_extract_name_fn (line 31) | def mask_extract_name_fn(mask_name): function get_n_zeros (line 35) | def get_n_zeros(size, sparsity): function calculate_sparsity (line 39) | def calculate_sparsity(masks): function get_mask_random_numpy (line 48) | def get_mask_random_numpy(mask_shape, sparsity, random_state=None): function get_mask_random (line 71) | def get_mask_random(mask, sparsity, dtype, random_state=None): function get_sparsities_erdos_renyi (line 90) | def get_sparsities_erdos_renyi(all_masks, function get_sparsities_uniform (line 210) | def get_sparsities_uniform(all_masks, function get_sparsities_str (line 238) | def get_sparsities_str(all_masks, default_sparsity): function get_sparsities (line 258) | def get_sparsities(all_masks, function get_mask_init_fn (line 319) | def get_mask_init_fn(all_masks, function _get_kernel (line 368) | def _get_kernel(layer): function get_stats (line 376) | def get_stats(masked_layers, FILE: rigl/sparse_utils_test.py class GetMaskRandomTest (line 29) | class GetMaskRandomTest(tf.test.TestCase, parameterized.TestCase): method _setup_session (line 31) | def _setup_session(self): method testMaskConnectionDeterminism (line 38) | def testMaskConnectionDeterminism(self, shape, sparsity): method testMaskFraction (line 49) | def testMaskFraction(self, shape, sparsity, expected_ones): method testMaskDtype (line 58) | def testMaskDtype(self, dtype): class GetSparsitiesTest (line 65) | class GetSparsitiesTest(tf.test.TestCase, parameterized.TestCase): method _setup_session (line 67) | def _setup_session(self): method testSparsityDictRandom (line 74) | def testSparsityDictRandom(self, default_sparsity): method testSparsityDictErdosRenyiCustom (line 87) | def testSparsityDictErdosRenyiCustom(self, default_sparsity): method testSparsityDictErdosRenyiError (line 98) | def testSparsityDictErdosRenyiError(self, default_sparsity): method testSparsityDictErdosRenyiSparsitiesScale (line 113) | def testSparsityDictErdosRenyiSparsitiesScale( FILE: rigl/str_sparsities.py function _name_map_str (line 86) | def _name_map_str(k): function read_all (line 109) | def read_all():