SYMBOL INDEX (1523 symbols across 122 files) FILE: docs/conf.py function linkcode_resolve (line 130) | def linkcode_resolve(domain, info): FILE: docs/ext/link_tf_api.py function tf_role_fn (line 44) | def tf_role_fn( function tf_doc_url (line 86) | def tf_doc_url(text): function setup (line 122) | def setup(app): FILE: docs/ext/link_tf_api_test.py class LinkTfApiTest (line 23) | class LinkTfApiTest(absltest.TestCase): method test_non_existent (line 25) | def test_non_existent(self): method test_link_to_top_level (line 29) | def test_link_to_top_level(self): method test_link_to_nested_package (line 34) | def test_link_to_nested_package(self): method test_link_to_method_of_exported_class (line 39) | def test_link_to_method_of_exported_class(self): method test_link_to_non_existent_method_of_exported_class (line 44) | def test_link_to_non_existent_method_of_exported_class(self): FILE: examples/functional_mlp_mnist.py function main (line 26) | def main(unused_argv): FILE: examples/simple_mnist.py function mnist (line 25) | def mnist(split: str, batch_size: int) -> tf.data.Dataset: function train_step (line 52) | def train_step( function train_epoch (line 71) | def train_epoch( function test_accuracy (line 83) | def test_accuracy( function main (line 97) | def main(unused_argv): FILE: examples/simple_mnist_test.py class SimpleMnistTest (line 23) | class SimpleMnistTest(test_utils.TestCase): method setUp (line 25) | def setUp(self): method test_train_epoch (line 29) | def test_train_epoch(self): method test_test_accuracy (line 46) | def test_test_accuracy(self): FILE: setup.py function _get_sonnet_version (line 7) | def _get_sonnet_version(): function _parse_requirements (line 17) | def _parse_requirements(requirements_txt_path): FILE: sonnet/src/axis_norm.py class LayerNorm (line 27) | class LayerNorm(base.Module): method __init__ (line 64) | def __init__(self, method __call__ (line 129) | def __call__(self, method _initialize (line 181) | def _initialize(self, inputs: tf.Tensor): class InstanceNorm (line 210) | class InstanceNorm(LayerNorm): method __init__ (line 222) | def __init__(self, FILE: sonnet/src/axis_norm_test.py class LayerNormTest (line 25) | class LayerNormTest(test_utils.TestCase, parameterized.TestCase): method testSimpleCase (line 27) | def testSimpleCase(self): method testSimpleCaseVar (line 35) | def testSimpleCaseVar(self): method testSimpleCaseNCHWVar (line 48) | def testSimpleCaseNCHWVar(self): method testDataFormatAgnosticVar (line 62) | def testDataFormatAgnosticVar(self): method testSimpleCaseTensor (line 80) | def testSimpleCaseTensor(self): method testSimpleCaseNCHWTensor (line 91) | def testSimpleCaseNCHWTensor(self): method testDataFormatAgnosticTensor (line 105) | def testDataFormatAgnosticTensor(self): method testInvalidDataFormat (line 128) | def testInvalidDataFormat(self, data_format): method testValidDataFormatChannelsFirst (line 136) | def testValidDataFormatChannelsFirst(self, data_format): method testValidDataFormatChannelsLast (line 143) | def testValidDataFormatChannelsLast(self, data_format): method testInvalidAxis (line 150) | def testInvalidAxis(self, axis): method testNoScaleAndInitProvided (line 155) | def testNoScaleAndInitProvided(self): method testNoOffsetBetaInitProvided (line 164) | def testNoOffsetBetaInitProvided(self): method testCreateScaleAndScaleProvided (line 173) | def testCreateScaleAndScaleProvided(self): method testCreateOffsetAndOffsetProvided (line 180) | def testCreateOffsetAndOffsetProvided(self): method testSliceAxis (line 188) | def testSliceAxis(self): method testRankChanges (line 204) | def testRankChanges(self): method testWorksWithFunction (line 218) | def testWorksWithFunction(self): method testShapeAgnostic (line 231) | def testShapeAgnostic(self): method test5DDataFormatAgnostic (line 254) | def test5DDataFormatAgnostic(self): method test3DDataFormatAgnostic (line 277) | def test3DDataFormatAgnostic(self): method testInstanceNormCorrectAxis (line 300) | def testInstanceNormCorrectAxis(self): method testInstanceNormCorrectNCW (line 308) | def testInstanceNormCorrectNCW(self): FILE: sonnet/src/base.py function no_name_scope (line 36) | def no_name_scope(method: T) -> T: class ModuleMetaclass (line 58) | class ModuleMetaclass(abc.ABCMeta): method __new__ (line 61) | def __new__( method __call__ (line 108) | def __call__(cls: Type[T], *args, **kwargs) -> T: function safe_compare (line 147) | def safe_compare(a, b) -> bool: function auto_repr (line 158) | def auto_repr(cls: Type[Any], *args, **kwargs) -> str: function fancy_repr (line 214) | def fancy_repr(name: str, value: Any) -> str: function indent (line 221) | def indent(amount: int, s: str) -> str: function wrap_with_name_scope (line 228) | def wrap_with_name_scope( function wrap_with_name_scope_no_exception (line 266) | def wrap_with_name_scope_no_exception( function with_name_scope (line 284) | def with_name_scope(method: T) -> T: function allow_empty_variables (line 337) | def allow_empty_variables(module_or_cls: T) -> T: function assert_tf2 (line 359) | def assert_tf2(): class Module (line 368) | class Module(tf.Module, metaclass=ModuleMetaclass): method __init__ (line 390) | def __init__(self, name: Optional[str] = None): method variables (line 413) | def variables(self): method trainable_variables (line 443) | def trainable_variables(self): class Optimizer (line 474) | class Optimizer(Module): method apply (line 478) | def apply(self, updates: Sequence[types.ParameterUpdate], FILE: sonnet/src/base_test.py class BaseTest (line 27) | class BaseTest(test_utils.TestCase): method test_basic (line 29) | def test_basic(self): method testWrappedMethod (line 33) | def testWrappedMethod(self): method testControlFlow (line 39) | def testControlFlow(self): class TestModuleNaming (line 46) | class TestModuleNaming(tf.test.TestCase): method test_single_name (line 48) | def test_single_name(self): method test_construct_in_scope (line 53) | def test_construct_in_scope(self): method test_enters_name_scope_in_call (line 59) | def test_enters_name_scope_in_call(self): method test_enters_name_scope_in_other_method (line 64) | def test_enters_name_scope_in_other_method(self): method test_subclassed_module (line 69) | def test_subclassed_module(self): method test_submodule_created_late (line 76) | def test_submodule_created_late(self): method test_does_not_evaluate_property_methods (line 84) | def test_does_not_evaluate_property_methods(self): method test_overridden_name_scope (line 89) | def test_overridden_name_scope(self): method test_patched_callable (line 94) | def test_patched_callable(self): method test_property (line 101) | def test_property(self): method test_property_no_name_scope (line 108) | def test_property_no_name_scope(self): method test_ctor_no_name_scope (line 115) | def test_ctor_no_name_scope(self): method test_ctor_no_name_scope_no_super (line 120) | def test_ctor_no_name_scope_no_super(self): method test_invalid_name (line 126) | def test_invalid_name(self): method test_modules_not_numbered_in_eager (line 131) | def test_modules_not_numbered_in_eager(self): method test_module_numbering_in_graph (line 140) | def test_module_numbering_in_graph(self): method test_ctor_error_closes_name_scope (line 150) | def test_ctor_error_closes_name_scope(self): method test_ctor_error_handles_ctor_not_opening_name_scope (line 159) | def test_ctor_error_handles_ctor_not_opening_name_scope(self): method test_forward_method_closes_name_scope (line 168) | def test_forward_method_closes_name_scope(self): method test_get_attr_doesnt_enter_name_scope (line 175) | def test_get_attr_doesnt_enter_name_scope(self): method test_get_attribute_doesnt_enter_name_scope (line 189) | def test_get_attribute_doesnt_enter_name_scope(self): class VariableNamingTest (line 204) | class VariableNamingTest(tf.test.TestCase): method test_variable_names (line 206) | def test_variable_names(self): class AutoReprTest (line 213) | class AutoReprTest(tf.test.TestCase): method test_order_matches_argspec (line 215) | def test_order_matches_argspec(self): method test_defaults_ignored (line 219) | def test_defaults_ignored(self): method test_does_not_fail_with_hostile_input (line 223) | def test_does_not_fail_with_hostile_input(self): method test_args_are_repred (line 230) | def test_args_are_repred(self): method test_long_repr_multi_line (line 236) | def test_long_repr_multi_line(self): method test_repr_wildcard (line 251) | def test_repr_wildcard(self): method test_repr_non_bool_equality (line 259) | def test_repr_non_bool_equality(self): class ForwardMethodsTest (line 274) | class ForwardMethodsTest(tf.test.TestCase): method testFunctionType (line 276) | def testFunctionType(self): method testEntersNameScope_call (line 281) | def testEntersNameScope_call(self): method testEntersNameScope_concreteFunction (line 289) | def testEntersNameScope_concreteFunction(self): class AbcTest (line 298) | class AbcTest(tf.test.TestCase): method testAbstract (line 300) | def testAbstract(self): method testConcrete (line 305) | def testConcrete(self): method testCallMethodsOnParent (line 312) | def testCallMethodsOnParent(self): class CustomGradientTest (line 317) | class CustomGradientTest(test_utils.TestCase): method test_custom_gradient (line 319) | def test_custom_gradient(self): class ZeroGradModule (line 331) | class ZeroGradModule(base.Module): method __call__ (line 334) | def __call__(self, x): class LambdaModule (line 349) | class LambdaModule(base.Module): method __call__ (line 351) | def __call__(self, x): function get_name_scope (line 355) | def get_name_scope(): function wrapt_decorator (line 361) | def wrapt_decorator(method, instance, args, kwargs): class WraptModule (line 369) | class WraptModule(base.Module): method __call__ (line 372) | def __call__(self, x): class ControlFlowModule (line 376) | class ControlFlowModule(base.Module): method __call__ (line 378) | def __call__(self, x): class ErrorModuleError (line 385) | class ErrorModuleError(Exception): class ErrorModule (line 389) | class ErrorModule(base.Module): method __init__ (line 391) | def __init__(self, call_super, raise_in_constructor=True): method __call__ (line 397) | def __call__(self): class RecursiveModule (line 401) | class RecursiveModule(base.Module): method __init__ (line 403) | def __init__(self, depth, trainable=True): class AbstractModule (line 411) | class AbstractModule(base.Module, metaclass=abc.ABCMeta): method __call__ (line 414) | def __call__(self, x): method foo (line 417) | def foo(self): class ConcreteModule (line 421) | class ConcreteModule(AbstractModule): method __call__ (line 423) | def __call__(self, x): class TreeModule (line 427) | class TreeModule(base.Module): method __init__ (line 429) | def __init__(self, name=None): method new_leaf (line 433) | def new_leaf(self, name=None): class ReturnsNameScopeModule (line 439) | class ReturnsNameScopeModule(base.Module): method alternative_forward (line 441) | def alternative_forward(self): method __call__ (line 444) | def __call__(self): class SubclassedReturnsNameScopeModule (line 448) | class SubclassedReturnsNameScopeModule(ReturnsNameScopeModule): method alternative_alternative_forward (line 450) | def alternative_alternative_forward(self): class PropertyThrowsWhenCalledModule (line 454) | class PropertyThrowsWhenCalledModule(base.Module): method raise_assertion_error (line 457) | def raise_assertion_error(self): class ModuleOverridingNameScope (line 461) | class ModuleOverridingNameScope(ReturnsNameScopeModule): method name_scope (line 464) | def name_scope(self): class CommonErrorsTest (line 468) | class CommonErrorsTest(test_utils.TestCase, parameterized.TestCase): method test_not_calling_super_constructor (line 470) | def test_not_calling_super_constructor(self): method test_calls_method_before_super (line 476) | def test_calls_method_before_super(self): method test_annotated_method_is_allowed (line 481) | def test_annotated_method_is_allowed(self): method test_requests_variables_before_they_exist (line 486) | def test_requests_variables_before_they_exist(self, property_name): method test_allow_empty_variables_instance (line 496) | def test_allow_empty_variables_instance(self, property_name): method test_allow_empty_variables_class (line 502) | def test_allow_empty_variables_class(self, property_name): class NoopModule (line 507) | class NoopModule(base.Module): method __init__ (line 509) | def __init__(self, a=None): class RaisesOnEquality (line 514) | class RaisesOnEquality: method __repr__ (line 518) | def __repr__(self): method __eq__ (line 521) | def __eq__(self, other): method __ne__ (line 525) | def __ne__(self, other): class NeverCreatesVariables (line 531) | class NeverCreatesVariables(base.Module): class ModuleWithFunctionAnnotatedCall (line 535) | class ModuleWithFunctionAnnotatedCall(base.Module): method forward (line 538) | def forward(self): method forward_ag (line 542) | def forward_ag(self): class CtorNoNameScope (line 546) | class CtorNoNameScope(base.Module): method __init__ (line 549) | def __init__(self): class CtorNoNameScopeNoSuper (line 555) | class CtorNoNameScopeNoSuper(base.Module): method __init__ (line 558) | def __init__(self): class PropertyModule (line 562) | class PropertyModule(base.Module): method __init__ (line 564) | def __init__(self): method some_property (line 569) | def some_property(self): method some_property (line 574) | def some_property(self, my_property): method no_name_scope_property (line 579) | def no_name_scope_property(self): method no_name_scope_property (line 585) | def no_name_scope_property(self, my_property): class DoesNotCallSuperConstructorModule (line 589) | class DoesNotCallSuperConstructorModule(base.Module): method __init__ (line 591) | def __init__(self): class CallsMethodBeforeSuperConstructorModule (line 596) | class CallsMethodBeforeSuperConstructorModule(base.Module): method __init__ (line 598) | def __init__(self, allowed_method): method no_name_scope (line 606) | def no_name_scope(self): method with_name_scope (line 609) | def with_name_scope(self): class CustomMetaclass (line 613) | class CustomMetaclass(type): method __new__ (line 617) | def __new__(cls, name, bases, clsdict): class CombiningMetaclass (line 623) | class CombiningMetaclass(base.ModuleMetaclass, CustomMetaclass): method __new__ (line 627) | def __new__(cls, name, bases, clsdict): class ModuleWithCustomMetaclass (line 633) | class ModuleWithCustomMetaclass(base.Module, metaclass=CombiningMetaclass): method __init__ (line 635) | def __init__(self): class CustomMetaclassTest (line 640) | class CustomMetaclassTest(tf.test.TestCase): method testSupportsCustomMetaclass (line 642) | def testSupportsCustomMetaclass(self): class TakesSubmodules (line 649) | class TakesSubmodules(base.Module): method __init__ (line 651) | def __init__(self, submodules, name=None): class WildcardInit (line 655) | class WildcardInit(base.Module): method __init__ (line 657) | def __init__(self, a, b, *args, **kwargs): FILE: sonnet/src/batch_apply.py class BatchApply (line 25) | class BatchApply(base.Module): method __init__ (line 45) | def __init__(self, method __call__ (line 53) | def __call__(self, *args, **kwargs): function first_leaf (line 76) | def first_leaf(args, kwargs) -> Optional[Any]: function split_leading_dim (line 86) | def split_leading_dim( function maybe_prod (line 128) | def maybe_prod(s: Sequence[Union[int, None]]) -> Optional[int]: function merge_leading_dims (line 136) | def merge_leading_dims( FILE: sonnet/src/batch_apply_test.py class BatchApplyTest (line 32) | class BatchApplyTest(test_utils.TestCase): method test_simple (line 34) | def test_simple(self): method test_no_output (line 40) | def test_no_output(self): method test_kwargs (line 45) | def test_kwargs(self): class MergeLeadingDimsTest (line 51) | class MergeLeadingDimsTest(test_utils.TestCase, parameterized.TestCase): method test_x_not_tensor (line 54) | def test_x_not_tensor(self, x): method test_static_shape (line 58) | def test_static_shape(self, x_shape, num_dims): method test_dynamic_shape (line 65) | def test_dynamic_shape(self, x_shape, num_dims): method test_dynamic_shape_has_static_info_in_graph (line 79) | def test_dynamic_shape_has_static_info_in_graph(self, x_shape, num_dims): class SplitLeadingDimTest (line 99) | class SplitLeadingDimTest(test_utils.TestCase, parameterized.TestCase): method test_x_not_tensor (line 102) | def test_x_not_tensor(self, x): method test_static_shape (line 106) | def test_static_shape(self, i_shape, num_dims): method test_dynamic_shape (line 115) | def test_dynamic_shape(self, i_shape, num_dims): method test_dynamic_shape_has_static_info_in_graph (line 134) | def test_dynamic_shape_has_static_info_in_graph(self, i_shape, num_dims): class NoOutputModule (line 154) | class NoOutputModule(base.Module): method __call__ (line 156) | def __call__(self, x): class KwargsModule (line 160) | class KwargsModule(base.Module): method __call__ (line 162) | def __call__(self, x, is_training=None): class AddOne (line 167) | class AddOne(base.Module): method __call__ (line 169) | def __call__(self, x): FILE: sonnet/src/batch_norm.py class BaseBatchNorm (line 29) | class BaseBatchNorm(base.Module): method __init__ (line 74) | def __init__(self, method __call__ (line 127) | def __call__(self, method _initialize (line 200) | def _initialize(self, inputs: tf.Tensor): method _moments (line 239) | def _moments(self, inputs: tf.Tensor, method _update_statistics (line 257) | def _update_statistics(self, mean, variance): class BatchNorm (line 265) | class BatchNorm(BaseBatchNorm): method __init__ (line 277) | def __init__(self, FILE: sonnet/src/batch_norm_test.py class BaseBatchNormTest (line 27) | class BaseBatchNormTest(test_utils.TestCase, parameterized.TestCase): method testSimpleTraining (line 29) | def testSimpleTraining(self): method testSimpleTrainingNCHW (line 44) | def testSimpleTrainingNCHW(self): method testSimpleTraining3D (line 60) | def testSimpleTraining3D(self): method testSimpleTraining3DNCDHW (line 75) | def testSimpleTraining3DNCDHW(self): method testNoScaleAndOffset (line 91) | def testNoScaleAndOffset(self): method testSingleBatchInference (line 103) | def testSingleBatchInference(self): method testWithTfFunction (line 114) | def testWithTfFunction(self, autograph): method testWithTfFunctionTfArgs (line 143) | def testWithTfFunctionTfArgs(self, autograph): method testUsingTestStats (line 166) | def testUsingTestStats(self): method testIsTrainingFalseFirstCall (line 183) | def testIsTrainingFalseFirstCall(self): method testInvalidDataFormat (line 194) | def testInvalidDataFormat(self, data_format): method testValidDataFormatChannelsFirst (line 206) | def testValidDataFormatChannelsFirst(self, data_format): method testValidDataFormatChannelsLast (line 217) | def testValidDataFormatChannelsLast(self, data_format): method testNoScaleAndInitProvided (line 227) | def testNoScaleAndInitProvided(self): method testNoOffsetBetaInitProvided (line 237) | def testNoOffsetBetaInitProvided(self): class BatchNormTest (line 248) | class BatchNormTest(test_utils.TestCase, parameterized.TestCase): method testSimple (line 250) | def testSimple(self): class TestMetric (line 261) | class TestMetric: method __init__ (line 263) | def __init__(self): method update (line 267) | def update(self, x): method value (line 275) | def value(self): method initialize (line 278) | def initialize(self, x): FILE: sonnet/src/bias.py class Bias (line 27) | class Bias(base.Module): method __init__ (line 67) | def __init__(self, method _initialize (line 91) | def _initialize(self, inputs): method __call__ (line 106) | def __call__(self, inputs: tf.Tensor, multiplier: types.FloatLike = No... function calculate_bias_shape (line 127) | def calculate_bias_shape(input_shape: types.ShapeLike, FILE: sonnet/src/bias_test.py class BiasTest (line 22) | class BiasTest(test_utils.TestCase): method test_output_shape (line 24) | def test_output_shape(self): method test_output_size_valid (line 29) | def test_output_size_valid(self): method test_bias_dims_scalar (line 33) | def test_bias_dims_scalar(self): method test_bias_dims_custom (line 38) | def test_bias_dims_custom(self): method test_bias_dims_negative_out_of_order (line 45) | def test_bias_dims_negative_out_of_order(self): method test_bias_dims_invalid (line 50) | def test_bias_dims_invalid(self): method test_b_init_defaults_to_zeros (line 56) | def test_b_init_defaults_to_zeros(self): method test_b_init_custom (line 61) | def test_b_init_custom(self): method test_name (line 67) | def test_name(self): method test_multiplier (line 73) | def test_multiplier(self): FILE: sonnet/src/build.py function _int_or_none (line 23) | def _int_or_none(o): function _promote_shapes (line 27) | def _promote_shapes(o): function _maybe_tensor_spec (line 34) | def _maybe_tensor_spec(shape, dtype): function build (line 39) | def build( FILE: sonnet/src/build_test.py class BuildTest (line 22) | class BuildTest(test_utils.TestCase): method test_call_with_shape_lke_object (line 24) | def test_call_with_shape_lke_object(self): method test_output_spec (line 28) | def test_output_spec(self): method test_does_not_trigger_sideeffects (line 35) | def test_does_not_trigger_sideeffects(self): function tensor_identity (line 42) | def tensor_identity(x): class IncrementsCounter (line 47) | class IncrementsCounter(tf.Module): method __call__ (line 49) | def __call__(self): FILE: sonnet/src/conformance/api_test.py class PublicSymbolsTest (line 24) | class PublicSymbolsTest(test_utils.TestCase): method test_src_not_exported (line 26) | def test_src_not_exported(self): method test_supports_reload (line 29) | def test_supports_reload(self): FILE: sonnet/src/conformance/build_test.py function if_present (line 28) | def if_present(f): class BuildTest (line 32) | class BuildTest(test_utils.TestCase, parameterized.TestCase): method test_build (line 35) | def test_build(self, module_fn, input_shape, dtype): method assertCompatible (line 44) | def assertCompatible(self, a: tf.TensorSpec, b: tf.TensorSpec): FILE: sonnet/src/conformance/checkpoint_test.py class TestCheckpoint (line 30) | class TestCheckpoint: method __init__ (line 33) | def __init__(self, golden=None, **kwargs): method save (line 43) | def save(self): method restore_latest (line 46) | def restore_latest(self, assert_consumed): function with_soft_placement (line 55) | def with_soft_placement(f): class GoldenCheckpointsTest (line 65) | class GoldenCheckpointsTest(test_utils.TestCase, parameterized.TestCase): method test_save_load (line 69) | def test_save_load(self, golden): method test_save_then_load_new_instance (line 102) | def test_save_then_load_new_instance(self, golden): method test_restore_on_create (line 132) | def test_restore_on_create(self, golden): method test_restore_golden (line 159) | def test_restore_golden(self, golden): class ReplicatorCheckpointTest (line 174) | class ReplicatorCheckpointTest(test_utils.TestCase, parameterized.TestCa... method replicator_or_skip (line 176) | def replicator_or_skip(self, replicator_fn, use_function): method test_save_restore (line 186) | def test_save_restore(self, golden, replicator_fn, use_function): method test_restore_from_golden (line 240) | def test_restore_from_golden(self, golden, replicator_fn): method test_restore_from_non_distributed (line 257) | def test_restore_from_non_distributed(self, golden, replicator_fn, method test_restore_on_create (line 304) | def test_restore_on_create(self, golden, replicator_fn): method test_restore_on_create_in_replica_context (line 332) | def test_restore_on_create_in_replica_context(self, golden, replicator... function setUpModule (line 375) | def setUpModule(): FILE: sonnet/src/conformance/checkpoints/generate.py function safe_mkdir (line 37) | def safe_mkdir(directory): function safe_unlink (line 45) | def safe_unlink(path): function main (line 53) | def main(unused_argv): FILE: sonnet/src/conformance/copy_test.py class CopyTest (line 27) | class CopyTest(test_utils.TestCase, parameterized.TestCase): method test_copy (line 30) | def test_copy(self, golden): FILE: sonnet/src/conformance/descriptors.py class Wrapped (line 24) | class Wrapped(snt.Module): method __init__ (line 27) | def __init__(self, wrapped: snt.Module): class Training (line 32) | class Training(Wrapped): method __call__ (line 35) | def __call__(self, x: tf.Tensor): class Recurrent (line 39) | class Recurrent(Wrapped): method __init__ (line 42) | def __init__(self, method __call__ (line 49) | def __call__(self, x: tf.Tensor): function unwrap (line 61) | def unwrap(module: snt.Module) -> snt.Module: function recurrent_factory (line 205) | def recurrent_factory( function unroll_descriptors (line 212) | def unroll_descriptors(descriptors, unroller=None): FILE: sonnet/src/conformance/descriptors_test.py class DescriptorsTest (line 28) | class DescriptorsTest(test_utils.TestCase): method test_coverage (line 30) | def test_coverage(self): FILE: sonnet/src/conformance/distribute_test.py class TpuReplicatorTest (line 29) | class TpuReplicatorTest(test_utils.TestCase, parameterized.TestCase): method test_variable_creation_in_replica_context (line 33) | def test_variable_creation_in_replica_context(self, golden, replicator... method assertSameValuePerReplica (line 54) | def assertSameValuePerReplica(self, replicator, per_replica): method test_unroll (line 63) | def test_unroll( FILE: sonnet/src/conformance/doctest_test.py class DoctestTest (line 27) | class DoctestTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 32) | def setUp(self): method test_doctest (line 41) | def test_doctest(self, module): FILE: sonnet/src/conformance/function_test.py class FunctionTest (line 32) | class FunctionTest(test_utils.TestCase, parameterized.TestCase): method test_trace (line 36) | def test_trace( method test_create_variables_eagerly (line 49) | def test_create_variables_eagerly( method test_trace_batch_agnostic (line 63) | def test_trace_batch_agnostic( method test_trace_batch_apply_batch_agnostic (line 78) | def test_trace_batch_apply_batch_agnostic( method test_optimizer_dense (line 99) | def test_optimizer_dense( method test_optimizer_sparse (line 117) | def test_optimizer_sparse( FILE: sonnet/src/conformance/goldens.py function named_goldens (line 28) | def named_goldens() -> Sequence[Tuple[str, "Golden"]]: function all_goldens (line 32) | def all_goldens(test_method): function _register_golden (line 36) | def _register_golden(module_cls, golden_name): function list_goldens (line 46) | def list_goldens(): function range_like (line 50) | def range_like(t, start=0): class Golden (line 77) | class Golden(abc.ABC): method create_module (line 81) | def create_module(self): method create_all_variables (line 86) | def create_all_variables(self, module): method forward (line 91) | def forward(self, module, x=None): class AbstractGolden (line 96) | class AbstractGolden(Golden): method input_spec (line 108) | def input_spec(self): method num_variables (line 112) | def num_variables(self): method forward (line 115) | def forward(self, module, x=None): method create_all_variables (line 120) | def create_all_variables(self, module): class Linear1x1Test (line 131) | class Linear1x1Test(AbstractGolden): class LinearNoBias1x1 (line 138) | class LinearNoBias1x1(AbstractGolden): class Conv1D (line 145) | class Conv1D(AbstractGolden): class Conv2D (line 152) | class Conv2D(AbstractGolden): class Conv3D (line 159) | class Conv3D(AbstractGolden): class Conv1DTranspose (line 166) | class Conv1DTranspose(AbstractGolden): class Conv2DTranspose (line 174) | class Conv2DTranspose(AbstractGolden): class Conv3DTranspose (line 182) | class Conv3DTranspose(AbstractGolden): class DepthwiseConv2D (line 190) | class DepthwiseConv2D(AbstractGolden): class MLP (line 197) | class MLP(AbstractGolden): class MLPNoBias (line 204) | class MLPNoBias(AbstractGolden): class Cifar10ConvNet (line 211) | class Cifar10ConvNet(AbstractGolden): method forward (line 217) | def forward(self, module, x=None): class LayerNorm (line 224) | class LayerNorm(AbstractGolden): class Instance (line 232) | class Instance(AbstractGolden): class GroupNorm (line 240) | class GroupNorm(AbstractGolden): class BaseBatchNorm (line 248) | class BaseBatchNorm(AbstractGolden): method forward (line 254) | def forward(self, module, x=None): class BaseBatchNormScaleOffset (line 261) | class BaseBatchNormScaleOffset(AbstractGolden): method forward (line 267) | def forward(self, module, x=None): class BatchNorm (line 274) | class BatchNorm(AbstractGolden): method forward (line 279) | def forward(self, module, x=None): class BatchNormScaleOffset (line 286) | class BatchNormScaleOffset(AbstractGolden): method forward (line 291) | def forward(self, module, x=None): class ExponentialMovingAverage (line 298) | class ExponentialMovingAverage(AbstractGolden): method forward (line 304) | def forward(self, module, x=None): class BatchNormTraining (line 311) | class BatchNormTraining(AbstractGolden): method forward (line 317) | def forward(self, module, x=None): class CrossReplicaBatchNorm (line 325) | class CrossReplicaBatchNorm(AbstractGolden): method forward (line 331) | def forward(self, module, x=None): class DropoutVariableRate (line 338) | class DropoutVariableRate(AbstractGolden): method forward (line 344) | def forward(self, module, x=None): class AbstractRNNGolden (line 351) | class AbstractRNNGolden(AbstractGolden): method forward (line 353) | def forward(self, module, x=None): class Conv1DLSTM (line 365) | class Conv1DLSTM(AbstractRNNGolden): method create_module (line 369) | def create_module(self): class Conv2DLSTM (line 377) | class Conv2DLSTM(AbstractRNNGolden): method create_module (line 381) | def create_module(self): class Conv3DLSTM (line 389) | class Conv3DLSTM(AbstractRNNGolden): method create_module (line 393) | def create_module(self): class GRU (line 401) | class GRU(AbstractRNNGolden): class LSTM (line 408) | class LSTM(AbstractRNNGolden): class LSTMWithProjection (line 415) | class LSTMWithProjection(AbstractRNNGolden): class UnrolledLSTM (line 422) | class UnrolledLSTM(AbstractRNNGolden): class VanillaRNN (line 429) | class VanillaRNN(AbstractRNNGolden): class TrainableState (line 436) | class TrainableState(AbstractGolden): class BiasTest (line 443) | class BiasTest(AbstractGolden): class EmbedTest (line 450) | class EmbedTest(AbstractGolden): class MeanTest (line 457) | class MeanTest(AbstractGolden): class SumTest (line 465) | class SumTest(AbstractGolden): class ResNet (line 473) | class ResNet(AbstractGolden): method forward (line 479) | def forward(self, module, x=None): class VectorQuantizerTest (line 486) | class VectorQuantizerTest(AbstractGolden): method create_module (line 488) | def create_module(self): method forward (line 495) | def forward(self, module, x=None): class VectorQuantizerEMATrainTest (line 507) | class VectorQuantizerEMATrainTest(AbstractGolden): method create_module (line 509) | def create_module(self): method forward (line 516) | def forward(self, module, x=None): class VectorQuantizerEMAEvalTest (line 529) | class VectorQuantizerEMAEvalTest(AbstractGolden): method create_module (line 531) | def create_module(self): method forward (line 538) | def forward(self, module, x=None): class FooMetric (line 553) | class FooMetric(snt.Metric): method initialize (line 556) | def initialize(self, x): method reset (line 559) | def reset(self): method update (line 562) | def update(self, x): FILE: sonnet/src/conformance/goldens_test.py class CoverageTest (line 25) | class CoverageTest(test_utils.TestCase): method test_all_modules_covered (line 27) | def test_all_modules_covered(self): FILE: sonnet/src/conformance/keras_test.py class KerasTest (line 30) | class KerasTest(test_utils.TestCase, parameterized.TestCase): method test_build_without_batch (line 33) | def test_build_without_batch(self, module_fn, input_shape, dtype): method test_sonnet_module_as_layer (line 55) | def test_sonnet_module_as_layer(self, module_fn, input_shape, dtype): method test_build_with_updating_module (line 78) | def test_build_with_updating_module(self): method test_layer_with_model (line 89) | def test_layer_with_model(self): method test_symbolic_model (line 107) | def test_symbolic_model(self, module_fn, input_shape, dtype): method test_layer_adapter_custom_method (line 123) | def test_layer_adapter_custom_method(self): method test_keras_layer_inside_sonnet_module (line 134) | def test_keras_layer_inside_sonnet_module(self): method test_to_config (line 146) | def test_to_config(self): method test_from_config (line 151) | def test_from_config(self): class LayerAdapter (line 157) | class LayerAdapter(tf.keras.layers.Layer): method __init__ (line 177) | def __init__(self, module, method="__call__", dtype=tf.float32): method from_config (line 184) | def from_config(cls, config): method to_config (line 187) | def to_config(self): method _trace_and_initialize (line 190) | def _trace_and_initialize(self, input_shape): method compute_output_shape (line 198) | def compute_output_shape(self, input_shape): method build (line 202) | def build(self, input_shape): method call (line 216) | def call(self, inputs): class ModuleWithLayer (line 220) | class ModuleWithLayer(snt.Module): method __init__ (line 222) | def __init__(self): method __call__ (line 226) | def __call__(self, x): class ModuleWithUpdateInCall (line 230) | class ModuleWithUpdateInCall(snt.Module): method _init (line 233) | def _init(self, x): method __call__ (line 236) | def __call__(self, x): class ModuleWithCustomForward (line 242) | class ModuleWithCustomForward(snt.Module): method _init (line 245) | def _init(self, x): method forward (line 248) | def forward(self, x): FILE: sonnet/src/conformance/optimizer_test.py class OptimizerConformanceTest (line 26) | class OptimizerConformanceTest(test_utils.TestCase, parameterized.TestCa... method test_variable_order_is_constant (line 32) | def test_variable_order_is_constant(self, module_fn, input_shape, dtype, FILE: sonnet/src/conformance/pickle_test.py class PickleTest (line 26) | class PickleTest(test_utils.TestCase, parameterized.TestCase): method test_pickle (line 31) | def test_pickle(self, golden): FILE: sonnet/src/conformance/saved_model_test.py class SavedModelTest (line 28) | class SavedModelTest(test_utils.TestCase, parameterized.TestCase): method test_save_restore_cycle (line 31) | def test_save_restore_cycle(self, golden): FILE: sonnet/src/conformance/tensorflow1_test.py class TensorFlow1Test (line 23) | class TensorFlow1Test(test_utils.TestCase): method test_requires_tf2 (line 25) | def test_requires_tf2(self): FILE: sonnet/src/conformance/xla_test.py class XLATest (line 26) | class XLATest(test_utils.TestCase, parameterized.TestCase): method test_compile (line 29) | def test_compile(self, golden): method test_jit_scope (line 57) | def test_jit_scope(self, golden): FILE: sonnet/src/conv.py class ConvND (line 28) | class ConvND(base.Module): method __init__ (line 31) | def __init__(self, method __call__ (line 100) | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: method _initialize (line 129) | def _initialize(self, inputs: tf.Tensor): method _make_w (line 150) | def _make_w(self): class Conv1D (line 164) | class Conv1D(ConvND): method __init__ (line 167) | def __init__(self, class Conv2D (line 218) | class Conv2D(ConvND): method __init__ (line 221) | def __init__(self, class Conv3D (line 273) | class Conv3D(ConvND): method __init__ (line 276) | def __init__(self, FILE: sonnet/src/conv_test.py function create_constant_initializers (line 25) | def create_constant_initializers(w, b, with_bias): class ConvTest (line 35) | class ConvTest(test_utils.TestCase, parameterized.TestCase): method testPaddingFunctionReached (line 37) | def testPaddingFunctionReached(self): method testIncorrectN (line 60) | def testIncorrectN(self, n): method testInitializerKeysInvalidWithoutBias (line 70) | def testInitializerKeysInvalidWithoutBias(self): method testIncorrectRankInput (line 80) | def testIncorrectRankInput(self): method testDefaultInitializers (line 90) | def testDefaultInitializers(self, dtype): method testFunction (line 122) | def testFunction(self, with_bias, padding): method testUnknownBatchSizeNHWC (line 152) | def testUnknownBatchSizeNHWC(self): method testUnknownBatchSizeNCHW (line 168) | def testUnknownBatchSizeNCHW(self): method testUnknownChannels (line 187) | def testUnknownChannels(self, autograph): method testUnknownSpatialDims (line 201) | def testUnknownSpatialDims(self): class Conv2DTest (line 222) | class Conv2DTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 225) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 247) | def testComputationPaddingValid(self, with_bias): class Conv1DTest (line 268) | class Conv1DTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 271) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 292) | def testComputationPaddingValid(self, with_bias): class Conv3DTest (line 313) | class Conv3DTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 316) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 344) | def testComputationPaddingValid(self, with_bias): FILE: sonnet/src/conv_transpose.py function smart_concat (line 28) | def smart_concat(v1, v2): function smart_lambda (line 35) | def smart_lambda(func, v1, v2): class ConvNDTranspose (line 42) | class ConvNDTranspose(base.Module): method __init__ (line 52) | def __init__(self, method __call__ (line 122) | def __call__(self, inputs): method _initialize (line 149) | def _initialize(self, inputs): method _make_w (line 171) | def _make_w(self): method _get_output_shape (line 185) | def _get_output_shape(self, inputs): class Conv1DTranspose (line 210) | class Conv1DTranspose(ConvNDTranspose): method __init__ (line 213) | def __init__(self, class Conv2DTranspose (line 267) | class Conv2DTranspose(ConvNDTranspose): method __init__ (line 270) | def __init__(self, class Conv3DTranspose (line 324) | class Conv3DTranspose(ConvNDTranspose): method __init__ (line 327) | def __init__(self, FILE: sonnet/src/conv_transpose_test.py function create_constant_initializers (line 27) | def create_constant_initializers(w, b, with_bias): class ConvTransposeTest (line 37) | class ConvTransposeTest(test_utils.TestCase, parameterized.TestCase): method testIncorrectN (line 40) | def testIncorrectN(self, n): method testIncorrectPadding (line 51) | def testIncorrectPadding(self): method testBiasInitNoBias (line 58) | def testBiasInitNoBias(self): method testIncorrectOutputShape (line 65) | def testIncorrectOutputShape(self): method testGraphConv (line 79) | def testGraphConv(self, with_bias, padding): method testUnknownBatchSizeNHWC (line 111) | def testUnknownBatchSizeNHWC(self): method testUnknownBatchSizeNCHW (line 127) | def testUnknownBatchSizeNCHW(self): method testUnknownShapeDims (line 146) | def testUnknownShapeDims(self): method testGivenOutputShape (line 162) | def testGivenOutputShape(self): method testUnknownChannels (line 174) | def testUnknownChannels(self, autograph): method testInitializerVariance (line 192) | def testInitializerVariance(self, num_spatial_dims, kernel_shape, class Conv2DTransposeTest (line 213) | class Conv2DTransposeTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 216) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 238) | def testComputationPaddingValid(self, with_bias): method testShapeDilated (line 260) | def testShapeDilated(self): class Conv1DTransposeTest (line 274) | class Conv1DTransposeTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 277) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 299) | def testComputationPaddingValid(self, with_bias): class Conv3DTransposeTest (line 321) | class Conv3DTransposeTest(test_utils.TestCase, parameterized.TestCase): method testComputationPaddingSame (line 324) | def testComputationPaddingSame(self, with_bias): method testComputationPaddingValid (line 348) | def testComputationPaddingValid(self, with_bias): FILE: sonnet/src/custom_getter.py function _patch_getattribute (line 28) | def _patch_getattribute(cls, new_getattribute): function _custom_getter (line 37) | def _custom_getter( function custom_variable_getter (line 102) | def custom_variable_getter( function _is_variable (line 159) | def _is_variable(x): FILE: sonnet/src/custom_getter_test.py class CustomVariableGetterTest (line 25) | class CustomVariableGetterTest(test_utils.TestCase): method testDoesNotModifyNonVariables (line 27) | def testDoesNotModifyNonVariables(self): class DoctestTest (line 44) | class DoctestTest(test_utils.TestCase): method testDoctest (line 46) | def testDoctest(self): FILE: sonnet/src/deferred.py class Deferred (line 20) | class Deferred(base.Module): method __init__ (line 47) | def __init__(self, constructor, call_methods=("__call__",), name=None): method target (line 77) | def target(self): method __call__ (line 92) | def __call__(self, *args, **kwargs): method __str__ (line 95) | def __str__(self): method __repr__ (line 98) | def __repr__(self): method __getattr__ (line 101) | def __getattr__(self, name): method __setattr__ (line 109) | def __setattr__(self, name, value): method __delattr__ (line 117) | def __delattr__(self, name): function _materialize_then_call (line 125) | def _materialize_then_call(module, method_name): FILE: sonnet/src/deferred_test.py class DeferredTest (line 23) | class DeferredTest(test_utils.TestCase): method test_target (line 25) | def test_target(self): method test_only_computes_target_once (line 30) | def test_only_computes_target_once(self): method test_attr_forwarding_fails_before_construction (line 39) | def test_attr_forwarding_fails_before_construction(self): method test_getattr (line 44) | def test_getattr(self): method test_setattr (line 49) | def test_setattr(self): method test_setattr_on_target (line 57) | def test_setattr_on_target(self): method test_delattr (line 67) | def test_delattr(self): method test_alternative_forward (line 74) | def test_alternative_forward(self): method test_alternative_forward_call_type_error (line 78) | def test_alternative_forward_call_type_error(self): method test_name_scope (line 84) | def test_name_scope(self): method test_str (line 90) | def test_str(self): method test_repr (line 95) | def test_repr(self): class ExampleModule (line 101) | class ExampleModule(base.Module): method __init__ (line 103) | def __init__(self): method __str__ (line 107) | def __str__(self): method __repr__ (line 110) | def __repr__(self): method __call__ (line 113) | def __call__(self): class AlternativeForwardModule (line 117) | class AlternativeForwardModule(base.Module): method forward (line 119) | def forward(self): FILE: sonnet/src/depthwise_conv.py class DepthwiseConv2D (line 28) | class DepthwiseConv2D(base.Module): method __init__ (line 35) | def __init__(self, method __call__ (line 97) | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: method _initialize (line 113) | def _initialize(self, inputs: tf.Tensor): FILE: sonnet/src/depthwise_conv_test.py function create_constant_initializers (line 26) | def create_constant_initializers(w, b, with_bias): class DepthwiseConvTest (line 36) | class DepthwiseConvTest(test_utils.TestCase, parameterized.TestCase): method testInitializerKeysInvalidWithoutBias (line 38) | def testInitializerKeysInvalidWithoutBias(self): method testDefaultInitializers (line 48) | def testDefaultInitializers(self, dtype): method testFunction (line 69) | def testFunction(self, with_bias, padding): method testUnknownBatchSizeNHWC (line 97) | def testUnknownBatchSizeNHWC(self): method testUnknownBatchSizeNCHW (line 110) | def testUnknownBatchSizeNCHW(self): method testUnknownSpatialDims (line 125) | def testUnknownSpatialDims(self): method testUnknownChannels (line 143) | def testUnknownChannels(self, autograph): method testComputationSame (line 155) | def testComputationSame(self, with_bias): method testComputationValid (line 175) | def testComputationValid(self, with_bias): method testComputationValidMultiChannel (line 193) | def testComputationValidMultiChannel(self, with_bias): method testSharing (line 210) | def testSharing(self, with_bias): FILE: sonnet/src/distribute/distributed_batch_norm.py class CrossReplicaBatchNorm (line 28) | class CrossReplicaBatchNorm(batch_norm.BaseBatchNorm): method __init__ (line 43) | def __init__(self, method _initialize (line 91) | def _initialize(self, inputs: tf.Tensor): method _moments (line 98) | def _moments(self, inputs: tf.Tensor, FILE: sonnet/src/distribute/distributed_batch_norm_test.py class CrossReplicaBatchNormTest (line 26) | class CrossReplicaBatchNormTest(test_utils.TestCase, parameterized.TestC... method testDefaultReplicaContext (line 30) | def testDefaultReplicaContext(self): method testWithMultipleDevicesMirrored (line 41) | def testWithMultipleDevicesMirrored(self): method testWithTpuStrategy (line 80) | def testWithTpuStrategy(self): class TestMetric (line 119) | class TestMetric: method __init__ (line 121) | def __init__(self): method update (line 125) | def update(self, x): method value (line 133) | def value(self): method initialize (line 136) | def initialize(self, x): function setUpModule (line 141) | def setUpModule(): FILE: sonnet/src/distribute/replicator.py function replica_local_creator (line 27) | def replica_local_creator(next_creator, **kwargs) -> tf.Variable: class Replicator (line 38) | class Replicator(tf.distribute.MirroredStrategy): method scope (line 79) | def scope(self): class TpuReplicator (line 92) | class TpuReplicator(TPUStrategy): method scope (line 138) | def scope(self): function create_variables_eagerly (line 145) | def create_variables_eagerly(f: Callable[..., T]) -> Callable[..., T]: function _eager_variable_creator (line 181) | def _eager_variable_creator(getter, initial_value, **kwargs): function _eager_initial_values (line 199) | def _eager_initial_values(): FILE: sonnet/src/distribute/replicator_test.py function _create_variable_in_cross_replica_context (line 28) | def _create_variable_in_cross_replica_context(replicator): class TrainableVariable (line 34) | class TrainableVariable: method __call__ (line 36) | def __call__(self): function _create_variable_in_replica_context (line 42) | def _create_variable_in_replica_context(replicator): function all_variable_creators (line 56) | def all_variable_creators(): class ReplicatorTest (line 61) | class ReplicatorTest(test_utils.TestCase, parameterized.TestCase): method test_variable_synchronization_default (line 68) | def test_variable_synchronization_default(self, replicator_fn, create_... method test_variable_aggregation_default (line 78) | def test_variable_aggregation_default(self, replicator_fn, create_var): method test_variable_trainable_default (line 87) | def test_variable_trainable_default(self, replicator_fn, create_var): method test_variable_trainable (line 96) | def test_variable_trainable(self, replicator_fn, trainable): method test_assign (line 109) | def test_assign(self, replicator_fn, method_name, value, cross_replica): method test_read_value (line 130) | def test_read_value(self, replicator_fn, cross_replica): method test_falls_back_to_graph (line 152) | def test_falls_back_to_graph(self, autograph): method test_requires_eager (line 163) | def test_requires_eager(self, autograph): method test_eager_variable_creator (line 171) | def test_eager_variable_creator(self, autograph): class MyOnesInitializer (line 193) | class MyOnesInitializer(initializers.Initializer): method __call__ (line 195) | def __call__(self, shape, dtype): class FailsInEagerMode (line 200) | class FailsInEagerMode(initializers.Initializer): method __call__ (line 202) | def __call__(self, shape, dtype): function setUpModule (line 208) | def setUpModule(): FILE: sonnet/src/distribute/replicator_test_utils.py function _replicator_primary_device (line 25) | def _replicator_primary_device() -> snt_replicator.Replicator: function _tpu_replicator_or_skip_test (line 40) | def _tpu_replicator_or_skip_test() -> snt_replicator.TpuReplicator: function named_replicators (line 52) | def named_replicators() -> Sequence[Tuple[str, Callable[[], Strategy]]]: FILE: sonnet/src/dropout.py class Dropout (line 26) | class Dropout(base.Module): method __init__ (line 35) | def __init__(self, method __call__ (line 58) | def __call__(self, x: tf.Tensor, is_training: types.BoolLike) -> tf.Te... FILE: sonnet/src/dropout_test.py class DropoutTest (line 24) | class DropoutTest(test_utils.TestCase, parameterized.TestCase): method test_sum_close (line 27) | def test_sum_close(self, rate): method test_dropout_rate (line 37) | def test_dropout_rate(self, rate): method test_dropout_is_actually_random (line 49) | def test_dropout_is_actually_random(self): method test_with_tf_function_with_booleans (line 58) | def test_with_tf_function_with_booleans(self, autograph): method test_with_tf_function_with_variables (line 72) | def test_with_tf_function_with_variables(self, autograph): FILE: sonnet/src/embed.py class Embed (line 27) | class Embed(base.Module): method __init__ (line 30) | def __init__(self, method __call__ (line 97) | def __call__(self, inputs): function embedding_dim (line 105) | def embedding_dim(vocab_size: int): function dense_gradient (line 125) | def dense_gradient(x: tf.Tensor): FILE: sonnet/src/embed_test.py class EmbedTest (line 24) | class EmbedTest(test_utils.TestCase, parameterized.TestCase): method test_vocab_size (line 27) | def test_vocab_size(self, vocab_size): method test_embed_dim (line 33) | def test_embed_dim(self, embed_dim): method test_existing_vocab (line 39) | def test_existing_vocab(self, vocab_size, embed_dim): method test_densify_gradients (line 47) | def test_densify_gradients(self, densify_gradients): method test_initializer (line 57) | def test_initializer(self): method test_pinned_to_cpu (line 61) | def test_pinned_to_cpu(self): method test_trainable (line 68) | def test_trainable(self, trainable): method test_dtype (line 73) | def test_dtype(self, dtype): method test_name (line 79) | def test_name(self): FILE: sonnet/src/functional/haiku.py class TensorVariableCallbacks (line 33) | class TensorVariableCallbacks(threading.local): method __init__ (line 38) | def __init__(self): method notify (line 43) | def notify(self, variable): method __call__ (line 50) | def __call__(self, callback): function notify (line 63) | def notify(f): function defer_property (line 72) | def defer_property(name): function safe_read_tensor_value (line 76) | def safe_read_tensor_value(variable): function defer_read (line 105) | def defer_read(): function defer_raise_notimplemented (line 110) | def defer_raise_notimplemented(): function defer_indexed (line 117) | def defer_indexed(f): function defer_assign (line 121) | def defer_assign(map_fn=None): class TensorVariable (line 136) | class TensorVariable(tf.Variable): method __init__ (line 139) | def __init__(self, value, trainable, name=None): method __repr__ (line 200) | def __repr__(self): function tv_to_tensor (line 215) | def tv_to_tensor(value, dtype=None, name=None, as_ref=None): function create_tensor_variables (line 227) | def create_tensor_variables(): function track_tensor_variables (line 253) | def track_tensor_variables(): function track_new_variables (line 260) | def track_new_variables(): function track_initial_state (line 272) | def track_initial_state(): function initial_value_by_ref (line 283) | def initial_value_by_ref(tf_variables): function final_value_by_ref (line 288) | def final_value_by_ref(tf_variables): function transform (line 293) | def transform(f) -> Transformed: function transform_with_state (line 349) | def transform_with_state(f) -> TransformedWithState: function without_state (line 444) | def without_state(with_state: TransformedWithState) -> Transformed: FILE: sonnet/src/functional/haiku_test.py class TensorVariableTest (line 25) | class TensorVariableTest(test_utils.TestCase, parameterized.TestCase): method test_initial_value (line 27) | def test_initial_value(self): method test_trainable (line 36) | def test_trainable(self, trainable): method test_name (line 44) | def test_name(self): method test_name_with_scope (line 49) | def test_name_with_scope(self): method test_shape (line 55) | def test_shape(self, shape): method test_dtype (line 61) | def test_dtype(self, dtype): method test_attributes_do_not_notify (line 66) | def test_attributes_do_not_notify(self): method test_read_captured_variables_included (line 88) | def test_read_captured_variables_included(self): method test_captured_variable_from_other_function_raises (line 99) | def test_captured_variable_from_other_function_raises(self): method test_assign (line 116) | def test_assign(self): method test_assign_add (line 124) | def test_assign_add(self): method test_assign_sub (line 132) | def test_assign_sub(self): class NetworkTest (line 141) | class NetworkTest(test_utils.TestCase, parameterized.TestCase): method test_transform (line 143) | def test_transform(self): method test_initial_values_preserved (line 163) | def test_initial_values_preserved(self): method test_variables_in_transform_set_to_none (line 180) | def test_variables_in_transform_set_to_none(self): method test_disallows_variables_in_apply (line 192) | def test_disallows_variables_in_apply(self): method test_state_returns_initial_value (line 198) | def test_state_returns_initial_value(self): method test_state_counter (line 213) | def test_state_counter(self): method test_state_ema (line 225) | def test_state_ema(self): FILE: sonnet/src/functional/jax.py function device_put (line 24) | def device_put(t, device=None): function device_get (line 28) | def device_get(t): function jit (line 33) | def jit(f, device=None): function grad (line 40) | def grad(f, argnums=0, has_aux=False): function value_and_grad (line 54) | def value_and_grad(f, argnums=0, has_aux=False): FILE: sonnet/src/functional/jax_test.py class JaxTest (line 23) | class JaxTest(test_utils.TestCase, parameterized.TestCase): method test_jit_copies_to_device (line 25) | def test_jit_copies_to_device(self): method test_device_put (line 39) | def test_device_put(self): class GradTest (line 52) | class GradTest(test_utils.TestCase, parameterized.TestCase): method test_grad (line 54) | def test_grad(self): method test_argnums (line 60) | def test_argnums(self): method test_has_aux (line 69) | def test_has_aux(self): function get_accelerators (line 78) | def get_accelerators(): FILE: sonnet/src/functional/optimizers.py function optimizer (line 30) | def optimizer(cls: Type[base.Optimizer]) -> Callable[..., TransformedOpt... function _split_on_trainable (line 102) | def _split_on_trainable(opt_state): function _merge (line 113) | def _merge(a, b): function _wrap_optimizer (line 120) | def _wrap_optimizer(opt: base.Optimizer) -> TransformedOptimizer: FILE: sonnet/src/functional/optimizers_test.py class OptimizersTest (line 29) | class OptimizersTest(test_utils.TestCase, parameterized.TestCase): method test_sgd (line 31) | def test_sgd(self): method test_adam (line 43) | def test_adam(self): method test_adam_with_variable_lr (line 54) | def test_adam_with_variable_lr(self, trainable_lr): FILE: sonnet/src/functional/utils.py function get_first_accelerator (line 24) | def get_first_accelerator(): function run_on_device (line 33) | def run_on_device(f, device): function get_name_scope (line 47) | def get_name_scope(): function first_non_none (line 52) | def first_non_none(*args): function compose (line 56) | def compose(f0, *fs): FILE: sonnet/src/group_norm.py class GroupNorm (line 27) | class GroupNorm(base.Module): method __init__ (line 66) | def __init__(self, method __call__ (line 132) | def __call__(self, method _initialize (line 188) | def _initialize(self, inputs: tf.Tensor): FILE: sonnet/src/group_norm_test.py class GroupNormTest (line 25) | class GroupNormTest(test_utils.TestCase, parameterized.TestCase): method testSimpleCase (line 27) | def testSimpleCase(self): method testSimpleCaseVar (line 36) | def testSimpleCaseVar(self): method testSimpleCaseNCHWVar (line 50) | def testSimpleCaseNCHWVar(self): method testDataFormatAgnosticVar (line 65) | def testDataFormatAgnosticVar(self): method testSimpleCaseTensor (line 80) | def testSimpleCaseTensor(self): method testSimpleCaseNCHWTensor (line 92) | def testSimpleCaseNCHWTensor(self): method testDataFormatAgnosticTensor (line 104) | def testDataFormatAgnosticTensor(self): method testInvalidDataFormat (line 124) | def testInvalidDataFormat(self, data_format): method testValidDataFormatChannelsFirst (line 135) | def testValidDataFormatChannelsFirst(self, data_format): method testValidDataFormatChannelsLast (line 145) | def testValidDataFormatChannelsLast(self, data_format): method testInvalidAxis (line 155) | def testInvalidAxis(self, axis): method testNoScaleAndInitProvided (line 161) | def testNoScaleAndInitProvided(self): method testNoOffsetBetaInitProvided (line 170) | def testNoOffsetBetaInitProvided(self): method testCreateScaleAndScaleProvided (line 179) | def testCreateScaleAndScaleProvided(self): method testCreateOffsetAndOffsetProvided (line 187) | def testCreateOffsetAndOffsetProvided(self): method testSliceAxis (line 196) | def testSliceAxis(self): method testRankChanges (line 211) | def testRankChanges(self): method testIncompatibleGroupsAndTensor (line 227) | def testIncompatibleGroupsAndTensor(self, shape): method testWorksWithFunction (line 238) | def testWorksWithFunction(self): method testBatchSizeAgnostic (line 252) | def testBatchSizeAgnostic(self): method test5DDataFormatAgnostic (line 276) | def test5DDataFormatAgnostic(self): method test3DDataFormatAgnostic (line 296) | def test3DDataFormatAgnostic(self): FILE: sonnet/src/initializers.py class Initializer (line 25) | class Initializer(abc.ABC): method __call__ (line 29) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class Zeros (line 34) | class Zeros(Initializer): method __call__ (line 37) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class Ones (line 42) | class Ones(Initializer): method __call__ (line 45) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class Constant (line 50) | class Constant(Initializer): method __init__ (line 53) | def __init__(self, value: Union[float, int]): method __call__ (line 59) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class RandomUniform (line 65) | class RandomUniform(Initializer): method __init__ (line 72) | def __init__(self, method __call__ (line 89) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType): class RandomNormal (line 99) | class RandomNormal(Initializer): method __init__ (line 102) | def __init__(self, method __call__ (line 119) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class TruncatedNormal (line 129) | class TruncatedNormal(Initializer): method __init__ (line 137) | def __init__(self, method __call__ (line 154) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType): class Identity (line 164) | class Identity(Initializer): method __init__ (line 170) | def __init__(self, gain: float = 1.0): method __call__ (line 178) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class Orthogonal (line 195) | class Orthogonal(Initializer): method __init__ (line 214) | def __init__(self, gain: float = 1.0, seed: Optional[int] = None): method __call__ (line 224) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: class VarianceScaling (line 252) | class VarianceScaling(Initializer): method __init__ (line 287) | def __init__(self, method __call__ (line 317) | def __call__(self, shape: types.ShapeLike, dtype: tf.DType) -> tf.Tensor: function check_initializers (line 345) | def check_initializers(initializers: Mapping[str, Initializer], function _compute_fans (line 364) | def _compute_fans(shape: types.ShapeLike): function _as_floating_dtype (line 391) | def _as_floating_dtype(dtype: tf.DType) -> tf.DType: function _as_numerical_dtype (line 398) | def _as_numerical_dtype(dtype: tf.DType) -> tf.DType: FILE: sonnet/src/initializers_test.py class InitializersTest (line 26) | class InitializersTest(test_utils.TestCase, parameterized.TestCase): method assertDifferentInitializerValues (line 28) | def assertDifferentInitializerValues(self, method assertRange (line 40) | def assertRange(self, class ConstantInitializersTest (line 61) | class ConstantInitializersTest(InitializersTest): method testZeros (line 64) | def testZeros(self, dtype): method testOnes (line 73) | def testOnes(self, dtype): method testConstantInvalidValue (line 85) | def testConstantInvalidValue(self, value, value_type): method testConstantValidValue (line 92) | def testConstantValidValue(self, value, dtype): method testInvalidDataType (line 101) | def testInvalidDataType(self, initializer): method testInvalidDataTypeConstant (line 107) | def testInvalidDataTypeConstant(self): method testTFFunction (line 113) | def testTFFunction(self): method testBatchAgnostic (line 121) | def testBatchAgnostic(self): class RandomUniformInitializerTest (line 132) | class RandomUniformInitializerTest(InitializersTest): method testRangeInitializer (line 134) | def testRangeInitializer(self): method testDifferentInitializer (line 144) | def testDifferentInitializer(self, dtype): method testInvalidDataType (line 148) | def testInvalidDataType(self): method testTFFunction (line 154) | def testTFFunction(self): method testBatchAgnostic (line 164) | def testBatchAgnostic(self): class RandomNormalInitializerTest (line 177) | class RandomNormalInitializerTest(InitializersTest): method testRangeInitializer (line 179) | def testRangeInitializer(self): method testDifferentInitializer (line 186) | def testDifferentInitializer(self): method testInvalidDataType (line 191) | def testInvalidDataType(self, dtype): method testTFFunction (line 197) | def testTFFunction(self): method testBatchAgnostic (line 207) | def testBatchAgnostic(self): class TruncatedNormalInitializerTest (line 220) | class TruncatedNormalInitializerTest(InitializersTest): method testRangeInitializer (line 222) | def testRangeInitializer(self): method testDifferentInitializer (line 230) | def testDifferentInitializer(self): method testInvalidDataType (line 235) | def testInvalidDataType(self, dtype): method testTFFunction (line 241) | def testTFFunction(self): method testBatchAgnostic (line 251) | def testBatchAgnostic(self): class IdentityInitializerTest (line 264) | class IdentityInitializerTest(InitializersTest): method testRange (line 269) | def testRange(self, shape, gain, dtype): method testInvalidDataType (line 280) | def testInvalidDataType(self): method testInvalidShape (line 287) | def testInvalidShape(self, dtype): method testTFFunction (line 294) | def testTFFunction(self): method testTFFunction4D (line 302) | def testTFFunction4D(self): method testBatchAgnostic (line 310) | def testBatchAgnostic(self): class OrthogonalInitializerTest (line 321) | class OrthogonalInitializerTest(InitializersTest): method testRangeInitializer (line 323) | def testRangeInitializer(self): method testDuplicatedInitializer (line 327) | def testDuplicatedInitializer(self): method testInvalidDataType (line 332) | def testInvalidDataType(self, dtype): method testInvalidShape (line 338) | def testInvalidShape(self): method testShapesValues (line 349) | def testShapesValues(self, shape): method testTFFunctionSimple (line 364) | def testTFFunctionSimple(self): method testTFFunction (line 371) | def testTFFunction(self): method testBatchAgnostic (line 382) | def testBatchAgnostic(self): class VarianceScalingInitializerTest (line 396) | class VarianceScalingInitializerTest(InitializersTest): method testTruncatedNormalDistribution (line 398) | def testTruncatedNormalDistribution(self): method testNormalDistribution (line 405) | def testNormalDistribution(self): method testUniformDistribution (line 412) | def testUniformDistribution(self): method testGlorotUniform (line 419) | def testGlorotUniform(self): method test_GlorotNormal (line 430) | def test_GlorotNormal(self): method testLecunUniform (line 444) | def testLecunUniform(self): method testLecunNormal (line 455) | def testLecunNormal(self): method testHeUniform (line 467) | def testHeUniform(self): method testHeNormal (line 478) | def testHeNormal(self): method testMixedShape (line 493) | def testMixedShape(self, mode, distribution): method testWithTFFunction (line 506) | def testWithTFFunction(self, mode, distribution): method testBatchAgnostic (line 519) | def testBatchAgnostic(self, mode, distribution): method testInvalidDataType (line 533) | def testInvalidDataType(self, dtype): method testCheckInitializersInvalidType (line 539) | def testCheckInitializersInvalidType(self): method testCheckInitalizersEmpty (line 544) | def testCheckInitalizersEmpty(self): method testCheckInitalizersValid (line 550) | def testCheckInitalizersValid(self, keys): method testCheckInitalizersInvalid (line 556) | def testCheckInitalizersInvalid(self): FILE: sonnet/src/leaky_clip_by_value.py function leaky_clip_by_value (line 23) | def leaky_clip_by_value(t: tf.Tensor, FILE: sonnet/src/leaky_clip_by_value_test.py class LeakyClipByValueTest (line 23) | class LeakyClipByValueTest(test_utils.TestCase, parameterized.TestCase): method test_leaky_clip_by_value_forward (line 25) | def test_leaky_clip_by_value_forward(self): method test_leaky_clip_by_value_backward (line 46) | def test_leaky_clip_by_value_backward(self, init, fn, expected_grad): FILE: sonnet/src/linear.py class Linear (line 27) | class Linear(base.Module): method __init__ (line 30) | def __init__(self, method _initialize (line 59) | def _initialize(self, inputs: tf.Tensor): method __call__ (line 82) | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: FILE: sonnet/src/linear_test.py class LinearTest (line 24) | class LinearTest(test_utils.TestCase, parameterized.TestCase): method testInitW (line 26) | def testInitW(self): method testInitB (line 31) | def testInitB(self): method testInitializerKeysInvalidWithoutBias (line 36) | def testInitializerKeysInvalidWithoutBias(self): method testParametersCreatedOnce (line 40) | def testParametersCreatedOnce(self): method testParameterShape (line 48) | def testParameterShape(self): method testParameterDtype (line 58) | def testParameterDtype(self, dtype): method testBiasZeroInitialized (line 70) | def testBiasZeroInitialized(self): method testCall (line 75) | def testCall(self): method testCallMultiBatch (line 95) | def testCallMultiBatch(self): method testFunction (line 109) | def testFunction(self, with_bias): method testUnknownBatchSize (line 125) | def testUnknownBatchSize(self): method testUnknownInputSize (line 141) | def testUnknownInputSize(self): method testMultiBatchOutputDimensions (line 151) | def testMultiBatchOutputDimensions(self): method testIncorrectDims (line 168) | def testIncorrectDims(self, shape): method testInputSize (line 173) | def testInputSize(self): method testOutputSize (line 181) | def testOutputSize(self): FILE: sonnet/src/metrics.py class Metric (line 25) | class Metric(base.Module, metaclass=abc.ABCMeta): method initialize (line 29) | def initialize(self, value): method update (line 33) | def update(self, value): method value (line 37) | def value(self): method reset (line 41) | def reset(self): method __call__ (line 44) | def __call__(self, value): class Sum (line 50) | class Sum(Metric): method __init__ (line 53) | def __init__(self, name: Optional[str] = None): method initialize (line 58) | def initialize(self, value: tf.Tensor): method update (line 62) | def update(self, value: tf.Tensor): method _checked_sum (line 68) | def _checked_sum(self): method value (line 74) | def value(self) -> tf.Tensor: method reset (line 78) | def reset(self): class Mean (line 85) | class Mean(Metric): method __init__ (line 88) | def __init__(self, name: Optional[str] = None): method initialize (line 94) | def initialize(self, value: tf.Tensor): method update (line 98) | def update(self, value: tf.Tensor): method _checked_sum (line 105) | def _checked_sum(self) -> tf.Variable: method value (line 111) | def value(self) -> tf.Tensor: method reset (line 118) | def reset(self): FILE: sonnet/src/metrics_test.py class SumTest (line 22) | class SumTest(test_utils.TestCase): method testSimple (line 24) | def testSimple(self): method testInitialize (line 29) | def testInitialize(self): method testReset (line 34) | def testReset(self): class MeanTest (line 42) | class MeanTest(test_utils.TestCase): method testSimple (line 44) | def testSimple(self): method testInitialize (line 49) | def testInitialize(self): method testReset (line 54) | def testReset(self): FILE: sonnet/src/mixed_precision.py function enable (line 30) | def enable(dtype): function disable (line 40) | def disable(): function _get_mixed_precision_mode (line 45) | def _get_mixed_precision_mode(): function _maybe_cast_element (line 50) | def _maybe_cast_element(x, dtype): function _maybe_cast_structure (line 56) | def _maybe_cast_structure(x, dtype: tf.DType): function _cast_call (line 60) | def _cast_call(f, new_dtype, args, kwargs): function modes (line 73) | def modes(valid_types): function scope (line 141) | def scope(dtype: tf.DType): FILE: sonnet/src/mixed_precision_test.py class DummyVar (line 25) | class DummyVar(base.Module, test_utils.TestCase): method __init__ (line 27) | def __init__(self, x): method check_type (line 32) | def check_type(self, _, dtype): method check_type_structure (line 38) | def check_type_structure(self, _, dtype): method runTest (line 43) | def runTest(self): class DummyInput (line 47) | class DummyInput(test_utils.TestCase): method __init__ (line 49) | def __init__(self, _): method check_type (line 53) | def check_type(self, x, dtype): method check_type_structure (line 57) | def check_type_structure(self, x, dtype): method runTest (line 61) | def runTest(self): class MixedPrecisionClassTest (line 66) | class MixedPrecisionClassTest(test_utils.TestCase): method test_float16_mode_variable_eligible_class (line 68) | def test_float16_mode_variable_eligible_class(self, test_class): method test_float16_mode_disable_class (line 81) | def test_float16_mode_disable_class(self, test_class): method test_float16_mode_nested_eligible_class (line 94) | def test_float16_mode_nested_eligible_class(self, test_class): method test_float16_mode_eligible_multiple_instances_class (line 117) | def test_float16_mode_eligible_multiple_instances_class(self, test_cla... method test_float16_mode_ineligible_multiple_instances_class (line 134) | def test_float16_mode_ineligible_multiple_instances_class(self, test_c... method test_float16_mode_multiple_instances_different_eligibility_class (line 152) | def test_float16_mode_multiple_instances_different_eligibility_class( method test_bfloat16_input_float16_mode_eligible_class (line 171) | def test_bfloat16_input_float16_mode_eligible_class(self, test_class): method test_float16_input_float32_mode_eligible_class (line 182) | def test_float16_input_float32_mode_eligible_class(self, test_class): method test_function_create_module_eligible (line 195) | def test_function_create_module_eligible(self, test_class): method test_function_create_module_ineligible (line 210) | def test_function_create_module_ineligible(self, test_class): method test_function_create_module_not_decorated (line 225) | def test_function_create_module_not_decorated(self, test_class): method test_scoping_option (line 237) | def test_scoping_option(self, test_class): method test_scoping_disable (line 250) | def test_scoping_disable(self, test_class): method test_nested_scoping (line 264) | def test_nested_scoping(self, test_class): class MixedPrecisionTest (line 282) | class MixedPrecisionTest(test_utils.TestCase): method test_float16_mode_eligible_func (line 284) | def test_float16_mode_eligible_func(self): method test_float32_mode_eligible_func (line 300) | def test_float32_mode_eligible_func(self): method test_float16_mode_ineligible_func (line 314) | def test_float16_mode_ineligible_func(self): method test_dont_cast_non_floats_func (line 329) | def test_dont_cast_non_floats_func(self): method test_non_tensor_variable_input_no_cast_func (line 344) | def test_non_tensor_variable_input_no_cast_func(self): method test_float16_mode_enabled_call_function (line 359) | def test_float16_mode_enabled_call_function(self): method test_float16_mode_tensor_eligible_class (line 389) | def test_float16_mode_tensor_eligible_class(self): FILE: sonnet/src/moving_averages.py class ExponentialMovingAverage (line 25) | class ExponentialMovingAverage(metrics.Metric): method __init__ (line 50) | def __init__(self, decay: types.FloatLike, name: Optional[str] = None): method update (line 67) | def update(self, value: tf.Tensor): method value (line 78) | def value(self) -> tf.Tensor: method reset (line 82) | def reset(self): method initialize (line 91) | def initialize(self, value: tf.Tensor): FILE: sonnet/src/moving_averages_test.py class ExponentialMovingAverageTest (line 23) | class ExponentialMovingAverageTest(test_utils.TestCase, parameterized.Te... method testCall (line 25) | def testCall(self): method testUpdateAndValue (line 31) | def testUpdateAndValue(self): method testReset (line 39) | def testReset(self): method testResetVector (line 49) | def testResetVector(self): method testValueEqualsLatestUpdate (line 58) | def testValueEqualsLatestUpdate(self): method testWithTFFunction (line 68) | def testWithTFFunction(self, autograph): method testResetWithTFFunction (line 79) | def testResetWithTFFunction(self, autograph): method testAlternativeShape (line 90) | def testAlternativeShape(self, shape): FILE: sonnet/src/nets/cifar10_convnet.py class Cifar10ConvNet (line 28) | class Cifar10ConvNet(base.Module): method __init__ (line 36) | def __init__(self, method __call__ (line 91) | def __call__( FILE: sonnet/src/nets/cifar10_convnet_test.py class ModelTest (line 24) | class ModelTest(parameterized.TestCase, test_utils.TestCase): method testModelCreation (line 26) | def testModelCreation(self): method testFailedModelCreation (line 31) | def testFailedModelCreation(self): method testModelForwards (line 39) | def testModelForwards(self, batch_size): method testModelForwardsFunction (line 53) | def testModelForwardsFunction(self, batch_size): method testDifferentSizedImages (line 66) | def testDifferentSizedImages(self): method testDefunBackProp (line 80) | def testDefunBackProp(self): FILE: sonnet/src/nets/dnc/control.py function get_controller_ctor (line 29) | def get_controller_ctor(controller_name): class FeedForward (line 40) | class FeedForward(recurrent.RNNCore): method __init__ (line 51) | def __init__(self, method __call__ (line 69) | def __call__(self, inputs, prev_state): method initial_state (line 86) | def initial_state(self, batch_size): function deep_core (line 90) | def deep_core(control_name, FILE: sonnet/src/nets/dnc/control_test.py class CoreTest (line 26) | class CoreTest(test_utils.TestCase, parameterized.TestCase): method testShape (line 30) | def testShape(self, constructor): class FeedForwardTest (line 44) | class FeedForwardTest(test_utils.TestCase): method testShape (line 46) | def testShape(self): method testValues (line 60) | def testValues(self): class DeepCore (line 79) | class DeepCore(test_utils.TestCase, parameterized.TestCase): method testShape (line 94) | def testShape(self, control_name, num_layers): FILE: sonnet/src/nets/dnc/read.py function read (line 20) | def read(memory, FILE: sonnet/src/nets/dnc/read_test.py class ReadTest (line 23) | class ReadTest(test_utils.TestCase): method testShape (line 25) | def testShape(self): method testValues (line 37) | def testValues(self): FILE: sonnet/src/nets/dnc/util.py function segment_dim (line 22) | def segment_dim(inputs, dim, shapes): function batch_invert_permutation (line 81) | def batch_invert_permutation(permutations): function batch_gather (line 90) | def batch_gather(values, indices): function one_hot (line 97) | def one_hot(length, index): function apply_linear (line 104) | def apply_linear(inputs, linear_modules, activation=tf.identity): function apply_split_linear (line 132) | def apply_split_linear(lin_module_1, FILE: sonnet/src/nets/dnc/util_test.py class SegmentDimTest (line 26) | class SegmentDimTest(test_utils.TestCase, parameterized.TestCase): method testShape (line 30) | def testShape(self, initial_shape, final_shape): method testShapeNegative (line 50) | def testShapeNegative(self, initial_shape, final_shape): method testValues (line 68) | def testValues(self): method testInvalidDims (line 77) | def testInvalidDims(self): class BatchInvertPermutationTest (line 84) | class BatchInvertPermutationTest(test_utils.TestCase): method testCorrectOutput (line 86) | def testCorrectOutput(self): class BatchGatherTest (line 104) | class BatchGatherTest(test_utils.TestCase): method testCorrectOutput (line 106) | def testCorrectOutput(self): class LinearTest (line 114) | class LinearTest(test_utils.TestCase, parameterized.TestCase): method testLinearOutputOneModule (line 116) | def testLinearOutputOneModule(self): method testLinearOutputTwoModules (line 128) | def testLinearOutputTwoModules(self): method testDifferentOutputSizeBreaks (line 144) | def testDifferentOutputSizeBreaks(self): method testNonMatchingStructureBreaks (line 168) | def testNonMatchingStructureBreaks(self, input_sizes, module_hidden_si... method testListMustBeLengthTwo (line 188) | def testListMustBeLengthTwo(self, input_sizes, module_hidden_sizes): FILE: sonnet/src/nets/dnc/write.py function additive_write (line 20) | def additive_write(memory, address, values): function erase (line 38) | def erase(memory, address, reset_weights): function erase_rows (line 65) | def erase_rows(memory, address, reset_row_weights): function erase_and_write (line 89) | def erase_and_write(memory, address, reset_weights, values): FILE: sonnet/src/nets/dnc/write_test.py class EraseRowsTest (line 23) | class EraseRowsTest(test_utils.TestCase): method testShape (line 25) | def testShape(self): method testValues (line 38) | def testValues(self): class EraseTest (line 79) | class EraseTest(test_utils.TestCase): method testShape (line 81) | def testShape(self): method testValues (line 94) | def testValues(self): class EraseAndWriteTest (line 132) | class EraseAndWriteTest(test_utils.TestCase): method testShape (line 134) | def testShape(self): method testValues (line 148) | def testValues(self): class AdditiveWriteTest (line 173) | class AdditiveWriteTest(test_utils.TestCase): method testShape (line 175) | def testShape(self): method testValues (line 188) | def testValues(self): FILE: sonnet/src/nets/mlp.py class MLP (line 25) | class MLP(base.Module): method __init__ (line 28) | def __init__(self, method __call__ (line 75) | def __call__(self, inputs: tf.Tensor, is_training=None) -> tf.Tensor: method reverse (line 106) | def reverse(self, FILE: sonnet/src/nets/mlp_test.py class MLPTest (line 24) | class MLPTest(test_utils.TestCase, parameterized.TestCase): method test_b_init_when_with_bias_false (line 26) | def test_b_init_when_with_bias_false(self): method test_submodules (line 31) | def test_submodules(self, num_layers, dropout_rate): method test_applies_activation (line 36) | def test_applies_activation(self, num_layers): method test_activate_final (line 43) | def test_activate_final(self, num_layers): method test_adds_index_to_layer_names (line 50) | def test_adds_index_to_layer_names(self, num_layers): method test_passes_with_bias_to_layers (line 56) | def test_passes_with_bias_to_layers(self, with_bias): method test_repeat_initializer (line 61) | def test_repeat_initializer(self): method test_default_name (line 69) | def test_default_name(self): method test_custom_name (line 73) | def test_custom_name(self): method test_reverse_default_name (line 77) | def test_reverse_default_name(self): method test_reverse_custom_name (line 81) | def test_reverse_custom_name(self): method test_reverse_override_name (line 85) | def test_reverse_override_name(self): method test_reverse (line 91) | def test_reverse(self): method test_reverse_passed_with_bias (line 96) | def test_reverse_passed_with_bias(self, with_bias): method test_reverse_w_init (line 101) | def test_reverse_w_init(self): method test_reverse_b_init (line 107) | def test_reverse_b_init(self): method test_reverse_activation (line 113) | def test_reverse_activation(self): method test_dropout_requires_is_training (line 120) | def test_dropout_requires_is_training(self): method test_no_dropout_rejects_is_training (line 126) | def test_no_dropout_rejects_is_training(self, is_training): method test_reverse_activate_final (line 132) | def test_reverse_activate_final(self, activate_final): method test_applies_activation_with_dropout (line 140) | def test_applies_activation_with_dropout(self, use_dropout, is_training): method test_repr (line 148) | def test_repr(self): function reversed_mlp (line 156) | def reversed_mlp(**kwargs): class CountingActivation (line 162) | class CountingActivation: method __init__ (line 164) | def __init__(self): method __call__ (line 167) | def __call__(self, x): class CountingInitializer (line 172) | class CountingInitializer: method __init__ (line 174) | def __init__(self): method __call__ (line 177) | def __call__(self, shape, dtype=tf.float32): FILE: sonnet/src/nets/resnet.py class BottleNeckBlockV1 (line 28) | class BottleNeckBlockV1(base.Module): method __init__ (line 31) | def __init__(self, method __call__ (line 91) | def __call__(self, inputs, is_training): class BottleNeckBlockV2 (line 107) | class BottleNeckBlockV2(base.Module): method __init__ (line 110) | def __init__(self, method __call__ (line 166) | def __call__(self, inputs, is_training): class BlockGroup (line 182) | class BlockGroup(base.Module): method __init__ (line 185) | def __init__(self, method __call__ (line 213) | def __call__(self, inputs, is_training): class ResNet (line 220) | class ResNet(base.Module): method __init__ (line 223) | def __init__(self, method __call__ (line 301) | def __call__(self, inputs, is_training): class ResNet50 (line 321) | class ResNet50(ResNet): method __init__ (line 324) | def __init__(self, FILE: sonnet/src/nets/resnet_test.py class ResnetTest (line 23) | class ResnetTest(test_utils.TestCase, parameterized.TestCase): method test_simple (line 26) | def test_simple(self, resnet_v2): method test_tf_function (line 35) | def test_tf_function(self, resnet_v2): method test_error_incorrect_args_block_list (line 50) | def test_error_incorrect_args_block_list(self, list_length): method test_error_incorrect_args_channel_list (line 58) | def test_error_incorrect_args_channel_list(self, list_length): FILE: sonnet/src/nets/vqvae.py class VectorQuantizer (line 25) | class VectorQuantizer(base.Module): method __init__ (line 50) | def __init__(self, method __call__ (line 77) | def __call__(self, inputs, is_training): method quantize (line 134) | def quantize(self, encoding_indices): class VectorQuantizerEMA (line 142) | class VectorQuantizerEMA(base.Module): method __init__ (line 176) | def __init__(self, method __call__ (line 221) | def __call__(self, inputs, is_training): method quantize (line 297) | def quantize(self, encoding_indices): FILE: sonnet/src/nets/vqvae_test.py class VqvaeTest (line 26) | class VqvaeTest(parameterized.TestCase, test_utils.TestCase): method testConstruct (line 38) | def testConstruct(self, constructor, kwargs): method testShapeChecking (line 84) | def testShapeChecking(self, constructor, kwargs): method testNoneBatch (line 102) | def testNoneBatch(self, constructor, kwargs): method testEmaUpdating (line 112) | def testEmaUpdating(self, use_tf_function, dtype): method testEmbeddingsNotTrainable (line 147) | def testEmbeddingsNotTrainable(self): FILE: sonnet/src/once.py function _check_no_output (line 24) | def _check_no_output(output): function once (line 29) | def once(f): FILE: sonnet/src/once_test.py class OnceTest (line 24) | class OnceTest(parameterized.TestCase): method test_runs_once (line 26) | def test_runs_once(self): method test_always_returns_none (line 38) | def test_always_returns_none(self): method test_does_not_cache_on_error (line 43) | def test_does_not_cache_on_error(self): method test_method (line 54) | def test_method(self): method test_method_does_not_cache_on_error (line 64) | def test_method_does_not_cache_on_error(self): method test_pickle_method_before_evaluation (line 78) | def test_pickle_method_before_evaluation(self): method test_pickle_method_already_evaluated (line 88) | def test_pickle_method_already_evaluated(self): method test_inline (line 97) | def test_inline(self): method test_adds_property (line 109) | def test_adds_property(self, factory): function nop (line 114) | def nop(): class NoOpCallable (line 118) | class NoOpCallable: method nop (line 120) | def nop(self): method __call__ (line 123) | def __call__(self): class Counter (line 127) | class Counter: method increment (line 131) | def increment(self): FILE: sonnet/src/optimizers/adam.py function adam_update (line 27) | def adam_update(g, alpha, beta_1, beta_2, epsilon, t, m, v): class Adam (line 37) | class Adam(base.Optimizer): method __init__ (line 56) | def __init__(self, method _initialize (line 82) | def _initialize(self, parameters: Sequence[tf.Variable]): method apply (line 90) | def apply(self, updates: Sequence[types.ParameterUpdate], FILE: sonnet/src/optimizers/adam_test.py class ComparisonTest (line 29) | class ComparisonTest(optimizer_tests.AbstractFuzzTest): method _make_tf (line 32) | def _make_tf(self, learning_rate, beta_1, beta_2, epsilon): method _make_snt (line 39) | def _make_snt(self, learning_rate, beta_1, beta_2, epsilon): method testComparingSonnetAndTensorFlow (line 47) | def testComparingSonnetAndTensorFlow(self, config): class AdamTest (line 52) | class AdamTest(optimizer_tests.OptimizerTestBase): method make_optimizer (line 54) | def make_optimizer(self, **kwargs): method testDense (line 59) | def testDense(self): method testSparse (line 76) | def testSparse(self): method testVariableHyperParams (line 117) | def testVariableHyperParams(self): method testHyperParamDTypeConversion (line 132) | def testHyperParamDTypeConversion(self): method testAuxVariablesColocatedWithOriginal (line 147) | def testAuxVariablesColocatedWithOriginal(self): class ReferenceAdamTest (line 156) | class ReferenceAdamTest(optimizer_tests.OptimizerTestBase): method make_optimizer (line 158) | def make_optimizer(self, **kwargs): FILE: sonnet/src/optimizers/momentum.py function momentum_update (line 27) | def momentum_update(update, learning_rate, mu, momentum, use_nesterov): class Momentum (line 37) | class Momentum(base.Optimizer): method __init__ (line 47) | def __init__(self, method _initialize (line 67) | def _initialize(self, parameters): method apply (line 72) | def apply(self, updates: Sequence[types.ParameterUpdate], FILE: sonnet/src/optimizers/momentum_test.py class ComparisonTest (line 28) | class ComparisonTest(optimizer_tests.AbstractFuzzTest): method _make_tf (line 31) | def _make_tf(self, learning_rate, momentum, use_nesterov): method _make_snt (line 37) | def _make_snt(self, learning_rate, momentum, use_nesterov): method testComparingSonnetAndTensorFlow (line 44) | def testComparingSonnetAndTensorFlow(self, config): class MomentumTest (line 49) | class MomentumTest(optimizer_tests.OptimizerTestBase): method make_optimizer (line 51) | def make_optimizer(self, **kwargs): method testDense (line 58) | def testDense(self): method testDenseNesterov (line 75) | def testDenseNesterov(self): method testSparse (line 93) | def testSparse(self): method testSparseNesterov (line 121) | def testSparseNesterov(self): method testVariableHyperParams (line 150) | def testVariableHyperParams(self): method testHyperParamDTypeConversion (line 171) | def testHyperParamDTypeConversion(self): method testAuxVariablesColocatedWithOriginal (line 183) | def testAuxVariablesColocatedWithOriginal(self): class ReferenceMomentumTest (line 194) | class ReferenceMomentumTest(MomentumTest): method make_optimizer (line 196) | def make_optimizer(self, **kwargs): FILE: sonnet/src/optimizers/optimizer_tests.py class WrappedTFOptimizer (line 27) | class WrappedTFOptimizer(base.Optimizer): method __init__ (line 32) | def __init__(self, optimizer: tf.optimizers.Optimizer): method __getattr__ (line 36) | def __getattr__(self, name): method apply (line 39) | def apply(self, updates, params): function is_tf_optimizer (line 43) | def is_tf_optimizer(optimizer): class OptimizerTestBase (line 47) | class OptimizerTestBase(test_utils.TestCase): method make_optimizer (line 50) | def make_optimizer(self, *args, **kwargs): method testNoneUpdate (line 53) | def testNoneUpdate(self): method testDifferentLengthUpdatesParams (line 60) | def testDifferentLengthUpdatesParams(self): method testEmptyParams (line 71) | def testEmptyParams(self): method testAllUpdatesNone (line 79) | def testAllUpdatesNone(self): method testInconsistentDTypes (line 90) | def testInconsistentDTypes(self): method testUnsuppportedStrategyError (line 101) | def testUnsuppportedStrategyError(self): class AbstractFuzzTest (line 116) | class AbstractFuzzTest(test_utils.TestCase, parameterized.TestCase): method _make_tf (line 119) | def _make_tf(self, learning_rate, momentum, use_nesterov): method _make_snt (line 122) | def _make_snt(self, learning_rate, momentum, use_nesterov): method assertParametersRemainClose (line 125) | def assertParametersRemainClose(self, seed, config, num_steps=100, ato... function _generate_dense_data (line 143) | def _generate_dense_data(seed, num_steps): function _apply_optimizer (line 162) | def _apply_optimizer(data, apply_fn): function named_product (line 174) | def named_product(**config): FILE: sonnet/src/optimizers/optimizer_utils.py function check_distribution_strategy (line 34) | def check_distribution_strategy(): function check_updates_parameters (line 44) | def check_updates_parameters(updates: Sequence[types.ParameterUpdate], function check_same_dtype (line 54) | def check_same_dtype(update: types.ParameterUpdate, parameter: tf.Variab... function deduplicate_indexed_slices (line 61) | def deduplicate_indexed_slices(indexed_slice: tf.IndexedSlices): FILE: sonnet/src/optimizers/rmsprop.py function rmsprop_update (line 28) | def rmsprop_update(update, decay, learning_rate, epsilon, mu, mom, ms, mg): class RMSProp (line 40) | class RMSProp(base.Optimizer): method __init__ (line 73) | def __init__(self, method _initialize (line 104) | def _initialize(self, parameters: Sequence[tf.Variable]): method apply (line 114) | def apply(self, updates: Sequence[types.ParameterUpdate], FILE: sonnet/src/optimizers/rmsprop_test.py class ComparisonTest (line 30) | class ComparisonTest(optimizer_tests.AbstractFuzzTest): method _make_tf (line 33) | def _make_tf(self, learning_rate, decay, momentum, epsilon, centered): method _make_snt (line 41) | def _make_snt(self, learning_rate, decay, momentum, epsilon, centered): method testComparingSonnetAndTensorFlow (line 50) | def testComparingSonnetAndTensorFlow(self, config): class RMSPropTest (line 55) | class RMSPropTest(optimizer_tests.OptimizerTestBase): method make_optimizer (line 57) | def make_optimizer(self, **kwargs): method testDense (line 62) | def testDense(self): method testDenseCentered (line 79) | def testDenseCentered(self): method testSparse (line 96) | def testSparse(self): method testSparseCentered (line 124) | def testSparseCentered(self): method testVariableHyperParams (line 152) | def testVariableHyperParams(self): method testHyperParamDTypeConversion (line 166) | def testHyperParamDTypeConversion(self): method testAuxVariablesColocatedWithOriginal (line 186) | def testAuxVariablesColocatedWithOriginal(self): class ReferenceRMSPropTest (line 198) | class ReferenceRMSPropTest(RMSPropTest): method make_optimizer (line 200) | def make_optimizer(self, **kwargs): FILE: sonnet/src/optimizers/sgd.py class SGD (line 25) | class SGD(base.Optimizer): method __init__ (line 32) | def __init__(self, method apply (line 44) | def apply(self, updates: Sequence[types.ParameterUpdate], FILE: sonnet/src/optimizers/sgd_test.py class SGDTest (line 22) | class SGDTest(optimizer_tests.OptimizerTestBase): method make_optimizer (line 24) | def make_optimizer(self, *args, **kwargs): method testDense (line 29) | def testDense(self): method testSparse (line 37) | def testSparse(self): method testVariableLearningRate (line 55) | def testVariableLearningRate(self): method testLearningRateDTypeConversion (line 69) | def testLearningRateDTypeConversion(self): class ReferenceSGDTest (line 80) | class ReferenceSGDTest(SGDTest): method make_optimizer (line 82) | def make_optimizer(self, *args, **kwargs): FILE: sonnet/src/pad.py function valid (line 25) | def valid(effective_kernel_size: int): # pylint: disable=unused-argument function same (line 30) | def same(effective_kernel_size: int): function full (line 35) | def full(effective_kernel_size: int): function causal (line 40) | def causal(effective_kernel_size: int): function reverse_causal (line 45) | def reverse_causal(effective_kernel_size: int): function create (line 50) | def create( FILE: sonnet/src/pad_test.py class PadTest (line 23) | class PadTest(test_utils.TestCase, parameterized.TestCase): method test_padding_2d (line 25) | def test_padding_2d(self): method test_padding_1d (line 29) | def test_padding_1d(self): method test_padding_3d (line 33) | def test_padding_3d(self): method test_padding_incorrect_input (line 39) | def test_padding_incorrect_input(self, kernel_size, rate): method test_padding_valid (line 45) | def test_padding_valid(self): method test_padding_same (line 49) | def test_padding_same(self): method test_padding_full (line 53) | def test_padding_full(self): method test_padding_causal (line 57) | def test_padding_causal(self): method test_padding_reverse_causal (line 61) | def test_padding_reverse_causal(self): method test_same_padding (line 67) | def test_same_padding(self, kernel_size, stride, rate): method test_valid_padding (line 80) | def test_valid_padding(self, kernel_size, stride, rate): FILE: sonnet/src/parallel_linear.py class ParallelLinears (line 27) | class ParallelLinears(base.Module): method __init__ (line 39) | def __init__(self, method _initialize (line 68) | def _initialize(self, inputs: tf.Tensor): method __call__ (line 95) | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: FILE: sonnet/src/parallel_linear_test.py class ParallelLinearTest (line 23) | class ParallelLinearTest(test_utils.TestCase): method test_output_size_correct (line 25) | def test_output_size_correct(self): method test_behaves_same_as_stacked_linears (line 31) | def test_behaves_same_as_stacked_linears(self): FILE: sonnet/src/recurrent.py class RNNCore (line 40) | class RNNCore(base.Module, metaclass=abc.ABCMeta): method __call__ (line 54) | def __call__(self, inputs: types.TensorNest, prev_state): method initial_state (line 72) | def initial_state(self, batch_size: types.IntegerLike, **kwargs): class UnrolledRNN (line 84) | class UnrolledRNN(base.Module, metaclass=abc.ABCMeta): method __call__ (line 92) | def __call__(self, input_sequence: types.TensorNest, method initial_state (line 111) | def initial_state(self, batch_size: types.IntegerLike, **kwargs): class TrainableState (line 123) | class TrainableState(base.Module): method for_core (line 137) | def for_core(cls, method __init__ (line 157) | def __init__(self, method __call__ (line 191) | def __call__(self, batch_size: int) -> types.TensorNest: function static_unroll (line 198) | def static_unroll( class _ListWrapper (line 279) | class _ListWrapper: method __init__ (line 288) | def __init__(self, data): function dynamic_unroll (line 295) | def dynamic_unroll( function _unstack_input_sequence (line 386) | def _unstack_input_sequence(input_sequence): function _safe_where (line 427) | def _safe_where(condition, x, y): # pylint: disable=g-doc-args function _rnn_step (line 437) | def _rnn_step(core, input_tas, sequence_length, t, prev_outputs, prev_st... class VanillaRNN (line 457) | class VanillaRNN(RNNCore): method __init__ (line 475) | def __init__(self, method input_to_hidden (line 512) | def input_to_hidden(self) -> tf.Variable: method hidden_to_hidden (line 516) | def hidden_to_hidden(self) -> tf.Variable: method __call__ (line 519) | def __call__(self, inputs: types.TensorNest, method initial_state (line 531) | def initial_state(self, batch_size: int) -> tf.Tensor: method _initialize (line 536) | def _initialize(self, inputs: tf.Tensor): class _LegacyDeepRNN (line 541) | class _LegacyDeepRNN(RNNCore): method __init__ (line 548) | def __init__(self, method __call__ (line 566) | def __call__(self, inputs, prev_state): method initial_state (line 595) | def initial_state(self, batch_size, **kwargs): class DeepRNN (line 603) | class DeepRNN(_LegacyDeepRNN): method __init__ (line 626) | def __init__(self, layers, name: Optional[str] = None): function deep_rnn_with_skip_connections (line 630) | def deep_rnn_with_skip_connections( class _ResidualWrapper (line 672) | class _ResidualWrapper(RNNCore): method __init__ (line 679) | def __init__(self, base_core: RNNCore): method __call__ (line 683) | def __call__(self, inputs: types.TensorNest, prev_state: types.TensorN... method initial_state (line 689) | def initial_state(self, batch_size, **kwargs): function deep_rnn_with_residual_connections (line 693) | def deep_rnn_with_residual_connections( class LSTM (line 738) | class LSTM(RNNCore): method __init__ (line 781) | def __init__(self, method __call__ (line 829) | def __call__(self, inputs, prev_state): method initial_state (line 835) | def initial_state(self, batch_size: int) -> LSTMState: method input_to_hidden (line 842) | def input_to_hidden(self): method hidden_to_hidden (line 846) | def hidden_to_hidden(self): method _initialize (line 850) | def _initialize(self, inputs): function _lstm_fn (line 882) | def _lstm_fn(inputs, prev_state, w_i, w_h, b, projection=None): class UnrolledLSTM (line 901) | class UnrolledLSTM(UnrolledRNN): method __init__ (line 909) | def __init__(self, method __call__ (line 943) | def __call__(self, input_sequence, initial_state): method initial_state (line 949) | def initial_state(self, batch_size): method input_to_hidden (line 956) | def input_to_hidden(self): method hidden_to_hidden (line 960) | def hidden_to_hidden(self): method _initialize (line 964) | def _initialize(self, input_sequence): function _specialize_per_device (line 985) | def _specialize_per_device(api_name, specializations, default): function _fallback_unrolled_lstm (line 1045) | def _fallback_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): function _block_unrolled_lstm (line 1052) | def _block_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): function _cudnn_unrolled_lstm (line 1070) | def _cudnn_unrolled_lstm(input_sequence, initial_state, w_i, w_h, b): class _RecurrentDropoutWrapper (line 1111) | class _RecurrentDropoutWrapper(RNNCore): method __init__ (line 1123) | def __init__(self, base_core: RNNCore, rates, seed: Optional[int] = No... method __call__ (line 1137) | def __call__(self, inputs, prev_state): method initial_state (line 1147) | def initial_state(self, batch_size, **kwargs): function lstm_with_recurrent_dropout (line 1161) | def lstm_with_recurrent_dropout(hidden_size, dropout=0.5, seed=None, **k... class _ConvNDLSTM (line 1203) | class _ConvNDLSTM(RNNCore): method __init__ (line 1247) | def __init__(self, method __call__ (line 1314) | def __call__(self, inputs, prev_state): method input_to_hidden (line 1332) | def input_to_hidden(self): method hidden_to_hidden (line 1336) | def hidden_to_hidden(self): method initial_state (line 1339) | def initial_state(self, batch_size): method _initialize (line 1349) | def _initialize(self, inputs): class Conv1DLSTM (line 1357) | class Conv1DLSTM(_ConvNDLSTM): # pylint: disable=missing-docstring,empt... method __init__ (line 1360) | def __init__(self, class Conv2DLSTM (line 1408) | class Conv2DLSTM(_ConvNDLSTM): # pylint: disable=missing-docstring,empt... method __init__ (line 1411) | def __init__(self, class Conv3DLSTM (line 1459) | class Conv3DLSTM(_ConvNDLSTM): # pylint: disable=missing-docstring,empt... method __init__ (line 1462) | def __init__(self, class GRU (line 1510) | class GRU(RNNCore): method __init__ (line 1538) | def __init__(self, method __call__ (line 1568) | def __call__(self, inputs, prev_state): method initial_state (line 1587) | def initial_state(self, batch_size): method input_to_hidden (line 1592) | def input_to_hidden(self): method hidden_to_hidden (line 1596) | def hidden_to_hidden(self): method _initialize (line 1600) | def _initialize(self, inputs): class CuDNNGRU (line 1613) | class CuDNNGRU(RNNCore): method __init__ (line 1627) | def __init__(self, method __call__ (line 1657) | def __call__(self, inputs, prev_state): method input_to_hidden (line 1706) | def input_to_hidden(self): method hidden_to_hidden (line 1710) | def hidden_to_hidden(self): method initial_state (line 1713) | def initial_state(self, batch_size): method _initialize (line 1718) | def _initialize(self, inputs): function _check_inputs_dtype (line 1730) | def _check_inputs_dtype(inputs, expected_dtype): FILE: sonnet/src/recurrent_test.py class VanillaRNNTest (line 29) | class VanillaRNNTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 31) | def setUp(self): method testComputationAgainstNumPy (line 38) | def testComputationAgainstNumPy(self, use_tf_function): method testDtypeMismatch (line 57) | def testDtypeMismatch(self): method testInitialization (line 66) | def testInitialization(self): class DeepRNNTest (line 80) | class DeepRNNTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 82) | def setUp(self): method testComputationAgainstNumPy (line 89) | def testComputationAgainstNumPy(self, use_tf_function): method testComputationAgainstNumPyWithCallables (line 111) | def testComputationAgainstNumPyWithCallables(self, use_tf_function): method testInitialState (line 123) | def testInitialState(self): method testWithSkipConnectionsOutputs (line 132) | def testWithSkipConnectionsOutputs(self, use_tf_function): method testWithConnectionsValidation (line 148) | def testWithConnectionsValidation(self): class LSTMTest (line 155) | class LSTMTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 157) | def setUp(self): method testComputationAgainstNumPy (line 165) | def testComputationAgainstNumPy(self, use_tf_function, projection_size, method testDtypeMismatch (line 197) | def testDtypeMismatch(self): method testInitialization (line 207) | def testInitialization(self): method testRecurrentDropout (line 225) | def testRecurrentDropout(self, rate): method testRecurrentDropoutInvalid (line 248) | def testRecurrentDropoutInvalid(self): class UnrolledLSTMTest (line 254) | class UnrolledLSTMTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 256) | def setUp(self): method testComputationAgainstLSTM (line 264) | def testComputationAgainstLSTM(self, num_steps, use_tf_function): method testNumStepsPolymorphism (line 300) | def testNumStepsPolymorphism(self, use_tf_function): method testDtypeMismatch (line 322) | def testDtypeMismatch(self): method testInitialization (line 333) | def testInitialization(self): class ConvNDLSTMTest (line 348) | class ConvNDLSTMTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 350) | def setUp(self): method testComputationAgainstNumPy (line 362) | def testComputationAgainstNumPy(self, use_tf_function, core_cls): method testDtypeMismatch (line 403) | def testDtypeMismatch(self): method testInitialization (line 421) | def testInitialization(self): class GRUTest (line 442) | class GRUTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 444) | def setUp(self): method testComputationAgainstNumPy (line 451) | def testComputationAgainstNumPy(self, use_tf_function): method testDtypeMismatch (line 473) | def testDtypeMismatch(self): method testInitialization (line 482) | def testInitialization(self): function expit (line 496) | def expit(x): class CuDNNGRUTest (line 500) | class CuDNNGRUTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 502) | def setUp(self): method testComputationAgainstTF (line 513) | def testComputationAgainstTF(self, num_steps): method testDtypeMismatch (line 550) | def testDtypeMismatch(self): method testInitialization (line 559) | def testInitialization(self): class Counter (line 573) | class Counter(recurrent.RNNCore): method __init__ (line 583) | def __init__(self, hidden_size, name=None): method __call__ (line 588) | def __call__(self, inputs, prev_state): method initial_state (line 598) | def initial_state(self, batch_size): class Replicate (line 602) | class Replicate(recurrent.RNNCore): method __init__ (line 605) | def __init__(self, base_core, n, name=None): method __call__ (line 610) | def __call__(self, inputs, prev_state): method initial_state (line 614) | def initial_state(self, batch_size, **kwargs): class TrainableStateTest (line 618) | class TrainableStateTest(test_utils.TestCase, parameterized.TestCase): method testUnmasked (line 631) | def testUnmasked(self, initial_values_shape): method testMasked (line 644) | def testMasked(self): method testForCore (line 659) | def testForCore(self): class UnrollTest (line 684) | class UnrollTest(test_utils.TestCase, parameterized.TestCase): method setUp (line 686) | def setUp(self): method testFlat (line 694) | def testFlat(self, use_tf_function, unroll_fn): method testNestedInputs (line 708) | def testNestedInputs(self, use_tf_function, unroll_fn): method testNestedOutputs (line 725) | def testNestedOutputs(self, use_tf_function, unroll_fn): method testEmptyOutputs (line 742) | def testEmptyOutputs(self, use_tf_function, unroll_fn): method testZeroSteps (line 755) | def testZeroSteps(self, use_tf_function, unroll_fn): method testInconsistentSteps (line 766) | def testInconsistentSteps(self, use_tf_function, unroll_fn): method testVariableLengthOneZeroLength (line 778) | def testVariableLengthOneZeroLength(self, use_tf_function, unroll_fn): method testVariableLengthRange (line 794) | def testVariableLengthRange(self, use_tf_function, unroll_fn): method assertConsistentWithLength (line 809) | def assertConsistentWithLength(self, output_sequence, sequence_length): method testVariableLengthAllFull (line 819) | def testVariableLengthAllFull(self, use_tf_function, unroll_fn): method testVariableLengthAllEmpty (line 835) | def testVariableLengthAllEmpty(self, use_tf_function, unroll_fn): class UnknownStepsUnrollTest (line 852) | class UnknownStepsUnrollTest(test_utils.TestCase): method setUp (line 854) | def setUp(self): method testStaticUnroll (line 862) | def testStaticUnroll(self): method testDynamicUnroll (line 873) | def testDynamicUnroll(self): method testDynamicUnrollInconsistentSteps (line 886) | def testDynamicUnrollInconsistentSteps(self): FILE: sonnet/src/regularizers.py class Regularizer (line 24) | class Regularizer(abc.ABC): method __call__ (line 28) | def __call__(self, tensors: Sequence[tf.Tensor]) -> tf.Tensor: class L1 (line 39) | class L1(Regularizer): method __init__ (line 47) | def __init__(self, scale: types.FloatLike): method __repr__ (line 59) | def __repr__(self): method __call__ (line 65) | def __call__(self, tensors: Sequence[tf.Tensor]) -> tf.Tensor: class L2 (line 73) | class L2(Regularizer): method __init__ (line 81) | def __init__(self, scale: types.FloatLike): method __repr__ (line 93) | def __repr__(self): method __call__ (line 99) | def __call__(self, tensors: Sequence[tf.Tensor]) -> tf.Tensor: class OffDiagonalOrthogonal (line 107) | class OffDiagonalOrthogonal(Regularizer): method __init__ (line 136) | def __init__(self, scale: types.FloatLike): method __repr__ (line 147) | def __repr__(self): method __call__ (line 153) | def __call__(self, tensors: Sequence[tf.Tensor]) -> tf.Tensor: function _check_scale (line 168) | def _check_scale(scale: types.FloatLike) -> types.FloatLike: FILE: sonnet/src/regularizers_test.py class L1Test (line 23) | class L1Test(test_utils.TestCase): method testAgainstNumPy (line 25) | def testAgainstNumPy(self): method testNegativeScale (line 36) | def testNegativeScale(self): method testEmpty (line 40) | def testEmpty(self): class L2Test (line 44) | class L2Test(test_utils.TestCase): method testAgainstNumPy (line 46) | def testAgainstNumPy(self): method testNegativeScale (line 57) | def testNegativeScale(self): method testEmpty (line 61) | def testEmpty(self): class OffDiagonalOrthogonalTest (line 65) | class OffDiagonalOrthogonalTest(test_utils.TestCase): method testAgainstNumPy (line 67) | def testAgainstNumPy(self): method testNegativeScale (line 81) | def testNegativeScale(self): method testEmpty (line 85) | def testEmpty(self): FILE: sonnet/src/reshape.py function reshape (line 26) | def reshape(inputs: tf.Tensor, function flatten (line 34) | def flatten(inputs: tf.Tensor, name: str = "flatten") -> tf.Tensor: function _infer_shape (line 39) | def _infer_shape(output_shape: types.ShapeLike, dimensions: Sequence[int]): class Reshape (line 61) | class Reshape(base.Module): method __init__ (line 88) | def __init__(self, method _initialize (line 115) | def _initialize(self, inputs: tf.Tensor): method __call__ (line 123) | def __call__(self, inputs: tf.Tensor) -> tf.Tensor: method reversed (line 162) | def reversed(self, name: Optional[str] = None) -> "Reshape": class Flatten (line 173) | class Flatten(Reshape): method __init__ (line 183) | def __init__(self, preserve_dims: int = 1, name: Optional[str] = None): FILE: sonnet/src/reshape_test.py class ReshapeTest (line 26) | class ReshapeTest(test_utils.TestCase, parameterized.TestCase): method testReshape (line 34) | def testReshape(self, preserve_dims, expected_output_shape): method testInvalid_multipleWildcard (line 39) | def testInvalid_multipleWildcard(self): method testInvalid_negativeSize (line 44) | def testInvalid_negativeSize(self): method testInvalid_type (line 50) | def testInvalid_type(self): method testIncompatibleShape (line 55) | def testIncompatibleShape(self): method testInferShape (line 65) | def testInferShape(self): method testAddDimensions (line 72) | def testAddDimensions(self): method testFlatten (line 85) | def testFlatten(self): method testUnknownBatchSize (line 92) | def testUnknownBatchSize(self): method testReverse (line 99) | def testReverse(self): method testReverse_name (line 120) | def testReverse_name(self): method testInvalidPreserveDimsError (line 126) | def testInvalidPreserveDimsError(self): method testBuildDimError (line 130) | def testBuildDimError(self): method testPreserve (line 144) | def testPreserve(self, preserve): method testRun (line 166) | def testRun(self, preserve, trailing_in, trailing_out): class FlattenTest (line 185) | class FlattenTest(test_utils.TestCase, parameterized.TestCase): method testFlatten (line 188) | def testFlatten(self, batch_size): method testFlatten_unknownBatchSize (line 196) | def testFlatten_unknownBatchSize(self): method testFlatten_unknownNonBatchSize (line 205) | def testFlatten_unknownNonBatchSize(self): method testPreserveDimsOk (line 215) | def testPreserveDimsOk(self, preserve_dims): method testPreserveDimsError (line 226) | def testPreserveDimsError(self, preserve_dims): method testFlattenWithZeroDim (line 233) | def testFlattenWithZeroDim(self): method testInvalidFlattenFromError (line 238) | def testInvalidFlattenFromError(self): method testBuildDimError (line 242) | def testBuildDimError(self): method testReverse (line 249) | def testReverse(self, batch_size): FILE: sonnet/src/scale_gradient.py function scale_gradient (line 24) | def scale_gradient( FILE: sonnet/src/scale_gradient_test.py class ScaleGradientTest (line 25) | class ScaleGradientTest(test_utils.TestCase, parameterized.TestCase): method test_scale (line 29) | def test_scale(self, t_, scale): FILE: sonnet/src/sequential.py class Sequential (line 22) | class Sequential(base.Module): method __init__ (line 57) | def __init__(self, method __call__ (line 63) | def __call__(self, inputs, *args, **kwargs): FILE: sonnet/src/sequential_test.py class SequentialTest (line 26) | class SequentialTest(test_utils.TestCase, parameterized.TestCase): method test_empty (line 29) | def test_empty(self, value): method test_empty_drops_varargs_varkwargs (line 34) | def test_empty_drops_varargs_varkwargs(self, value): method test_identity_chain (line 39) | def test_identity_chain(self, value): method test_call (line 43) | def test_call(self): method test_varargs_varkwargs_to_call (line 47) | def test_varargs_varkwargs_to_call(self): function identity (line 54) | def identity(v): function append_character (line 58) | def append_character(c): FILE: sonnet/src/test_utils.py class TestCase (line 35) | class TestCase(tf.test.TestCase): method setUp (line 40) | def setUp(self): method tearDown (line 63) | def tearDown(self): method primary_device (line 70) | def primary_device(self): method device_types (line 79) | def device_types(self): method get_atol (line 82) | def get_atol(self): function find_all_sonnet_modules (line 100) | def find_all_sonnet_modules( function find_sonnet_python_modules (line 114) | def find_sonnet_python_modules( function combined_named_parameters (line 135) | def combined_named_parameters(*parameters): function named_bools (line 162) | def named_bools(name) -> Sequence[Tuple[str, bool]]: FILE: sonnet/src/utils.py function replicate (line 32) | def replicate( function _is_object (line 49) | def _is_object(f: Any) -> bool: function decorator (line 54) | def decorator( function get_channel_index (line 103) | def get_channel_index(data_format: str) -> int: function assert_rank (line 133) | def assert_rank(inputs, rank: int): function assert_minimum_rank (line 141) | def assert_minimum_rank(inputs, rank: int): function _synchronized (line 149) | def _synchronized(f: Callable[..., T]) -> Callable[..., T]: function smart_autograph (line 159) | def smart_autograph(f: Callable[..., T]) -> Callable[..., T]: function variable_like (line 204) | def variable_like(inputs: Union[tf.Tensor, tf.Variable], function _render_spec (line 218) | def _render_spec(shape: tf.TensorShape, dtype: tf.DType) -> str: function _simple_device (line 248) | def _simple_device(var: tf.Variable) -> str: function _name_scope_then_rank (line 258) | def _name_scope_then_rank(var: tf.Variable): function format_variables (line 264) | def format_variables(variables: Sequence[tf.Variable], function log_variables (line 280) | def log_variables(variables: Sequence[tf.Variable]): class CompareById (line 295) | class CompareById(Generic[T]): method __init__ (line 298) | def __init__(self, wrapped: T): method __hash__ (line 301) | def __hash__(self): method __eq__ (line 309) | def __eq__(self, other): method __lt__ (line 314) | def __lt__(self, other): FILE: sonnet/src/utils_test.py class ReplicateTest (line 34) | class ReplicateTest(test_utils.TestCase, parameterized.TestCase): method testSingleValue (line 37) | def testSingleValue(self, value): method testListLengthOne (line 44) | def testListLengthOne(self, value): method testTupleLengthN (line 51) | def testTupleLengthN(self, value): method testListLengthN (line 59) | def testListLengthN(self, value): method testIncorrectLength (line 65) | def testIncorrectLength(self): class DecoratorTest (line 73) | class DecoratorTest(test_utils.TestCase): method test_callable_object (line 75) | def test_callable_object(self): method test_function (line 91) | def test_function(self): method test_unbound_method (line 101) | def test_unbound_method(self): method test_bound_method (line 117) | def test_bound_method(self): class ChannelIndexTest (line 133) | class ChannelIndexTest(test_utils.TestCase, parameterized.TestCase): method test_returns_index_channels_first (line 136) | def test_returns_index_channels_first(self, data_format): method test_returns_index_channels_last (line 140) | def test_returns_index_channels_last(self, data_format): method test_invalid_strings (line 144) | def test_invalid_strings(self, data_format): class AssertRankTest (line 151) | class AssertRankTest(test_utils.TestCase, parameterized.TestCase): method test_valid_rank (line 158) | def test_valid_rank(self, input_fn): method test_invalid_rank (line 165) | def test_invalid_rank(self, rank): class SmartAutographTest (line 180) | class SmartAutographTest(test_utils.TestCase): method test_smart_ag (line 182) | def test_smart_ag(self): class VariableLikeTest (line 204) | class VariableLikeTest(test_utils.TestCase, parameterized.TestCase): method test_copies_shape (line 208) | def test_copies_shape(self, a): method test_copies_dtype (line 217) | def test_copies_dtype(self, a): method test_copies_device (line 223) | def test_copies_device(self, a): method test_default_initializer_is_zero (line 229) | def test_default_initializer_is_zero(self): method test_override_initializer (line 234) | def test_override_initializer(self): method test_copies_variable_trainable (line 240) | def test_copies_variable_trainable(self, trainable): method test_default_trainable_for_tensor (line 245) | def test_default_trainable_for_tensor(self): method test_override_trainable (line 251) | def test_override_trainable(self, trainable): method test_copies_variable_name (line 256) | def test_copies_variable_name(self): method test_default_name_for_tensor (line 261) | def test_default_name_for_tensor(self): method test_override_name (line 267) | def test_override_name(self, a): class FormatVariablesTest (line 273) | class FormatVariablesTest(test_utils.TestCase): method test_format_variables (line 275) | def test_format_variables(self): method test_log_variables (line 285) | def test_log_variables(self): class NotHashable (line 295) | class NotHashable: method __hash__ (line 297) | def __hash__(self): class CompareByIdTest (line 301) | class CompareByIdTest(test_utils.TestCase): method test_access (line 303) | def test_access(self): method test_hash (line 308) | def test_hash(self): method test_eq (line 313) | def test_eq(self):