Repository: openai/chz Branch: main Commit: b01c082aa50b Files: 48 Total size: 431.1 KB Directory structure: gitextract_3v4reuym/ ├── .gitignore ├── CHANGELOG.md ├── LICENSE ├── README.md ├── chz/ │ ├── __init__.py │ ├── blueprint/ │ │ ├── __init__.py │ │ ├── _argmap.py │ │ ├── _argv.py │ │ ├── _blueprint.py │ │ ├── _entrypoint.py │ │ ├── _lazy.py │ │ └── _wildcard.py │ ├── data_model.py │ ├── factories.py │ ├── field.py │ ├── mungers.py │ ├── py.typed │ ├── tiepin.py │ ├── universal.py │ ├── util.py │ └── validators.py ├── docs/ │ ├── 01_quickstart.md │ ├── 02_object_model.md │ ├── 03_validation.md │ ├── 04_command_line.md │ ├── 05_blueprint.md │ ├── 06_serialisation.md │ ├── 21_post_init.md │ ├── 22_field_api.md │ ├── 91_philosophy.md │ ├── 92_alternatives.md │ └── 93_testimonials.md ├── pyproject.toml └── tests/ ├── test_blueprint.py ├── test_blueprint_cast.py ├── test_blueprint_errors.py ├── test_blueprint_meta_factory.py ├── test_blueprint_methods.py ├── test_blueprint_reference.py ├── test_blueprint_root_polymorphism.py ├── test_blueprint_unit.py ├── test_blueprint_variadic.py ├── test_data_model.py ├── test_factories.py ├── test_munge.py ├── test_tiepin.py ├── test_todo.py └── test_validate.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ __pycache__/ *.py[cod] .DS_Store .env .venv env/ venv/ build/ dist/ *.egg-info/ .tox/ .mypy_cache/ ================================================ FILE: CHANGELOG.md ================================================ # Changelog ## November 2025 - fix most tests on Python 3.14 - support cast to `datetime.datetime` - improve `is_subtype` for `TypedDict`s - add `Computed` reference type, thanks sfitzgerald! - support int dict keys in blueprint, thanks hessam! - fix subparam mutation in the "template thing", thanks tz! - improve docs, thanks awei! - require newer `typing-extensions` ## September 2025 - add `dispatch_entrypoint` - several changes to optimise argmap lookups by collapsing and consolidating layers - this is >10x speedup for some use cases - always print additional diagnostics for extraneous args, thanks camillo! - add `skip_default` arg to `beta_to_blueprint_values`, thanks charlieb! - add special casing for tuples in `beta_argv_arg_to_string`, thanks tz! - testing improvements ## August 2025 - mention the value of the closest ancestor for extraneous args to help with polymorphism confusion - improve extraneous arg error message - add `exclude` param to `asdict`, thanks andrey! - fix subtype check in the "template thing", thanks tz! - some cleanup of the "template thing" ## July 2025 - changes to add the "template thing" to blueprint, thanks xintao! - this feature is not available in the open source version and I plan to attempt to remove it from the internal version - differentiate between untyped and zero length tuple in sequence param collection, thanks elwong! - fix `beta_argv_arg_to_string` behaviour for list elements that are strings containing commas ## June 2025 - better error if annotation eval fails, thanks jelle! - add `ge` and `le` validators, thanks cassirer! - special casing to make `beta_argv_arg_to_string` handle dicts, thanks yjiao! ## May 2025 - error for duplicate class when name is ambiguous - fix defaulting special case for nested args - add `chz.traverse`, thanks hessam! - better handling of type variables and meta factory casting - special casing to make `beta_argv_arg_to_string` involving lists more compact - improve `freeze_dict` munger static typing for optionals, thanks camillo! - add `include_type` param to `asdict`, thanks wenda and andrei! - internal refactoring ## March 2025 Improvements: - add "universal CLI" via `python -m chz.universal` - add `shallow` param to `asdict` to prevent deep copying, thanks wenda! - look at `builtins` and `__main__` to find object factories - support `*args` and `**kwargs` collection in blueprint - support type variables in `is_subtype` - fix variadics that match wildcards in more than one literal location - fix blueprint apply to subpath with empty key, thanks hessam! - suppert converter argument in field, thanks camillo! - refactor param collection in blueprint - various docs improvements, thanks andrey, csh, mtli! Error messages: - better error for a value with subparams specified - improve error for blueprint type mismatch - fix bug in `simplistic_type_of_value` - include Python's native suggestions for `AttributeError` in blueprint attribute access, thanks yifan! ## February 2025 Improvements: - revamp the docs - improve casting for callables - blindly trust explicit inheritance from protocol - record `meta_factory_value` for non castable factory - add `__eq__` to `castable` - fix quoting in the `beta_blueprint_to_argv` thing - expose the `beta_argv_arg_to_string` thing Performance: - add an optimisation when constructing large variadics for 6x speedup on some workloads - rewrite the `beta_blueprint_to_argv` thing so it's now 40x faster - make it easier to reuse `MakeResult` to save repeated blueprint make - lru cache `inspect.getmembers_static` to speed up repeated construction - refactoring to make optimisation easier Error messages: - show the full path more often when errors occur during blueprint construction, thanks mlim and gross! - add error for case where you have duplicate classes due to `__main__` confusion - improve error message when constructing ambiguous or mistyped callables - minor improvements to error messages ## January 2025 Improvements: - add basic support for functools.partial in blueprints - allow parametrising the entrypoint in chz blueprints. this allows for a "universal" cli - rewrie `beta_to_blueprint_values` to better support nesting and polymorphism - improve interaction between type checking and munger - ignore self references more consistently when there is a default value available - better error when self references are missing a default value - add `freeze_dict` munger, thanks camillo! - use post init field value in hash, thanks camillo! - prevent parameterisation of enum classes - colourise and improve alignment of `--help` output - various refactoring Typing improvements: - implement callable subtyping (especially useful for substructural typing) - improve `is_subtype_instance` of protocols - improve `is_subtype_instance` of None, thanks tongzhou! - improve `is_subtype` handling of unions, thanks tongzhou! - improve `is_subtype` handling of literals and `types.NoneType` - better signature subtyping - better casting for dict ## December 2024 Improvements: - expand the error for wildcard matching variadic defaults, preventing a footgun - optimise blueprint construction with large variadics, making a use case 2.7x faster - add support for protocol subtyping for vitchyr use case - pass field metadata through blueprint, for use in custom tools - allow custom root for consistency in tree, thanks ignasi! - fix `beta_blueprint_to_argv` with None args, thanks tongzhou! - add test for unspecified `type(None)` trick to avoid instantiating defaulted class - simplify some `meta_factory` logic - fix standard `meta_factory` for `type[specialform]` Error messages: - mention layer name for extraneous arguments, so you know where the arg comes from - reorder logic in cast for better errors - more helpful error message when disallowing `__init__`, `__post_init__`, etc. thanks ebrevdo! - other misc error message improvements - misc internal docs ## November 2024 Two headline features for this month: references and `meta_factory` unification: - references allow for deduplication of parameters and allow introducing indirection where some config is controlled by other teams - `meta_factory` unification makes chz’s polymorphism more consistent and more powerful Features: - core of `meta_factory` unification, change default `meta_factory` - infra for references, expose references - use `X_values` from pre-init in `beta_to_blueprint_values`, thanks guillaume! - give users access to methods_entrypoint blueprint - add strict option to `Blueprint.apply`, thanks menick! - add subpath to apply - add override validators, thanks vineet! - allow default values in nested_entrypoint - make (wildcard) references not self-reference when defaulted - recurse into dict in pretty_format - make `meta_factory` lambda logic more robust - support for python3.12 and 3.13 Typing features: - basic pep 692 support - add typeddict total=False and pep 655 support, thanks alec! - add subtype support for pep 655 / required - parse objects as literals - allow casting to iterable - allow ast eval of tuple for sequence - better casting rules for list - support casting pathlib - add typeddict and callable tests Error messages: - improve two issues with `--help` in polymorphic command lines - better error when we choose not to cast due to subparams - batch errors for invalid ref targets - improve error with reference cycles - improve error message during blueprint evaluation - include previous valid parent for non wildcard extraneous - improve error mentioning closest matching parent for extraneous argument - improve error messages on failure to interpret argument - special case representation of objects from typing module Internal: - many refactoring changes and clean up, including large refactor of blueprint and changes for open source ## October 2024 - finally land support for variadic typeddicts - add ability to attach user metadata to fields - add better support for NewType, LiteralString, NamedTuple and other niche typing features - add some support for PEP 646 unpacking of tuples - add native support for casting fractions - steps towards `meta_factory` unification. these changes make chz's polymorphism more powerful and more consistent - allow disallowing `meta_factory`, useful in niche cases - fix static typing of runtime typing to allow better downstream type checking ## September 2024 - add `blueprint_unspecified` to field, as generalisation of `chz.field(meta_factory=chz.factories.subclass(annot, default_cls=...))`. thanks to vitchyr for helping with this - use `__orig_class__` to type check user defined generics, if possible - add `chz.chz_fields` helper to access `__chz_fields__` attribute - better error if there are no params and extraneous args ## August 2024 - add `check_field_consistency_in_tree` validator, as a way to help ensure your wildcards are doing what you want them to do - use stdout for `--help` - allow parsing empty tuple - add a `const_default` validator for constant fields ## July 2024 - improvements to static types, thanks lmetz and wenda - quick follow ups to `beta_blueprint_to_argv`, thanks hunter and noah - improve `type_repr`, thanks davis - minor error improvements ## June 2024 - support for polymorphic variadic generics - fix some issues with pydantic support - add `x_type` to improve static type checking of mungers - add `beta_blueprint_to_argv`, thanks hunter - fix callable subtyping with future annotations - various improvements to `pretty_repr`, make dunder pure - allow use of chz with abc - add some special casing to avoid false positives with the conservative check against wildcard default factory interaction - improve error message when validating types against a `Literal` - improve error message when hashing chz class with unhashable fields, thanks alexk - improve error message for unparseable type, thanks andmis - fix typo in error message, thanks sean - make various error messages more concise ## May 2024 - show default values in `--help`, includes some fancy logic around lambdas - show values from unspecified_factory in `--help`, to make polymorphic construction easier to understand - add `chz.methods_entrypoint` for easily make cli's from classes - support mapping and sequence variadics - basic support for pydantic validation during runtime type checking, thanks camillo - better handling of runtime contexts for future annotations support - support for nested classes when `meta_factory` turns strings into classes - better support for polymorphism in `beta_blueprint_to_values`, thanks wenda - only error for variadic failure if variadic param specified - more docs, more tests, cleaner help output, cleaner tracebacks ## ??? Established in 2022 ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2024 OpenAI Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # 🪤 chz *(pronounced "चीज़")* `chz` helps you manage configuration, particularly from the command line. `chz` is available on [PyPI](https://pypi.org/project/chz/). To click the links below, please visit [Github](https://github.com/openai/chz). Overview: - [Quickstart](docs/01_quickstart.md) - [Declarative object model](docs/02_object_model.md) - [Immutability](docs/02_object_model.md#immutability) - [Validation](docs/03_validation.md) - [Type checking](docs/03_validation.md#type-checking) - [Command line parsing](docs/04_command_line.md) - [Discoverability](docs/04_command_line.md#discoverability---help-and-errors) - [Partial application](docs/05_blueprint.md) - [Presets or shared configuration](docs/05_blueprint.md#presets-or-shared-configuration) - [Serialisation and deserialisation](docs/06_serialisation.md) More details: - [Post init](docs/21_post_init.md) - [Field API](docs/22_field_api.md) - [Philosophy](docs/91_philosophy.md) - [Alternatives](docs/92_alternatives.md) - [Testimonials](docs/93_testimonials.md) Please let @shantanu know if you have feedback! ================================================ FILE: chz/__init__.py ================================================ from typing import TYPE_CHECKING, Callable, TypeVar, overload from . import blueprint, factories, mungers, tiepin, validators from .blueprint import ( Blueprint, Castable, dispatch_entrypoint, entrypoint, get_nested_target, methods_entrypoint, nested_entrypoint, ) from .data_model import ( asdict, beta_to_blueprint_values, chz_fields, chz_make_class, init_property, is_chz, replace, traverse, ) from .field import field from .validators import validate __all__ = [ "Blueprint", "asdict", "chz", "is_chz", "chz_fields", "entrypoint", "field", "get_nested_target", "init_property", "methods_entrypoint", "nested_entrypoint", "replace", "beta_to_blueprint_values", "traverse", "validate", "validators", "mungers", "Castable", # are the following public? "blueprint", "factories", "tiepin", ] def _chz(cls=None, *, version: str | None = None, typecheck: bool | None = None): if cls is None: return lambda cls: chz_make_class(cls, version=version, typecheck=typecheck) return chz_make_class(cls, version=version, typecheck=typecheck) if TYPE_CHECKING: _TypeT = TypeVar("_TypeT", bound=type) from typing_extensions import dataclass_transform @dataclass_transform(kw_only_default=True, frozen_default=True, field_specifiers=(field,)) @overload def chz(version: str = ..., typecheck: bool = ...) -> Callable[[type], type]: ... @overload def chz(cls: _TypeT, /) -> _TypeT: ... def chz(*a, **k): raise NotImplementedError else: chz = _chz ================================================ FILE: chz/blueprint/__init__.py ================================================ from chz.blueprint._argv import argv_to_blueprint_args as argv_to_blueprint_args from chz.blueprint._argv import beta_argv_arg_to_string as beta_argv_arg_to_string from chz.blueprint._argv import beta_blueprint_to_argv as beta_blueprint_to_argv from chz.blueprint._blueprint import Blueprint as Blueprint from chz.blueprint._blueprint import Castable as Castable from chz.blueprint._blueprint import Reference as Reference from chz.blueprint._entrypoint import ConstructionException as ConstructionException from chz.blueprint._entrypoint import EntrypointHelpException as EntrypointHelpException from chz.blueprint._entrypoint import ExtraneousBlueprintArg as ExtraneousBlueprintArg from chz.blueprint._entrypoint import InvalidBlueprintArg as InvalidBlueprintArg from chz.blueprint._entrypoint import MissingBlueprintArg as MissingBlueprintArg from chz.blueprint._entrypoint import dispatch_entrypoint as dispatch_entrypoint from chz.blueprint._entrypoint import entrypoint as entrypoint from chz.blueprint._entrypoint import exit_on_entrypoint_error as exit_on_entrypoint_error from chz.blueprint._entrypoint import get_nested_target as get_nested_target from chz.blueprint._entrypoint import methods_entrypoint as methods_entrypoint from chz.blueprint._entrypoint import nested_entrypoint as nested_entrypoint ================================================ FILE: chz/blueprint/_argmap.py ================================================ from __future__ import annotations import bisect import re from dataclasses import dataclass from typing import TYPE_CHECKING, AbstractSet, Any, Iterator, Mapping from chz.blueprint._entrypoint import ExtraneousBlueprintArg from chz.blueprint._wildcard import wildcard_key_approx, wildcard_key_to_regex if TYPE_CHECKING: from chz.blueprint._blueprint import _MakeResult class Layer: def __init__(self, args: Mapping[str, Any], layer_name: str | None): self._args = args self.layer_name = layer_name # Computed from the above self.qualified = {} self.wildcard = {} self._to_regex = {} # Match more specific wildcards first for k, v in sorted(args.items(), key=lambda kv: -len(kv[0])): if "..." in k: self.wildcard[k] = v self._to_regex[k] = wildcard_key_to_regex(k) else: self.qualified[k] = v def get_kv(self, exact_key: str) -> tuple[str, Any, str | None] | None: # TODO: remove this method if exact_key in self.qualified: return exact_key, self.qualified[exact_key], self.layer_name for wildcard_key, value in self.wildcard.items(): if self._to_regex[wildcard_key].fullmatch(exact_key): return wildcard_key, value, self.layer_name return None def iter_keys(self) -> Iterator[tuple[str, bool]]: yield from ((k, False) for k in self.qualified) yield from ((k, True) for k in self.wildcard) def nest_subpath(self, subpath: str | None) -> Layer: if subpath is None: return self return Layer( {join_arg_path(subpath, k): v for k, v in self._args.items()}, self.layer_name, ) def __repr__(self) -> str: return f"" @dataclass(frozen=True) class _FoundArgument: key: str value: Any layer_index: int layer_name: str | None def _valid_parent(parts: list[str], param_paths: AbstractSet[str]) -> str | None: for i in reversed(range(1, len(parts))): parent = ".".join(parts[:i]) if parent in param_paths: return parent return None class ArgumentMap: def __init__(self, layers: list[Layer]) -> None: self._layers = layers self.consolidated = False self.consolidated_qualified: dict[str, tuple[Any, int]] = {} self.consolidated_qualified_sorted: list[str] = [] self.consolidated_wildcard: list[tuple[str, re.Pattern[str], Any, int]] = [] def add_layer(self, layer: Layer) -> None: self._layers.append(layer) self.consolidated = False def consolidate(self) -> None: self.consolidated_qualified = {} for i, layer in enumerate(self._layers): for key, value in layer.qualified.items(): self.consolidated_qualified[key] = (value, i) self.consolidated_qualified_sorted = sorted(self.consolidated_qualified.keys()) self.consolidated_wildcard = [] for i, layer in reversed(list(enumerate(self._layers))): for wildcard_key, value in layer.wildcard.items(): self.consolidated_wildcard.append( (wildcard_key, layer._to_regex[wildcard_key], value, i) ) self.consolidated = True def subpaths(self, path: str, strict: bool = False) -> list[str]: """Returns the suffix of arguments this contains that would match a subpath of path. The invariant is that for each element `suffix` in the returned list, `path + suffix` would match an argument in this map. Args: strict: Whether to avoid returning arguments that match path exactly. """ assert self.consolidated, "ArgumentMap must be consolidated before calling subpaths" assert not path.endswith(".") wildcard_literal = path.split(".")[-1] # note path may be the empty string assert path.endswith(wildcard_literal) path_plus_dot = path + "." ret = [] if not strict and path in self.consolidated_qualified: ret.append("") if not path: ret.extend([k for k in self.consolidated_qualified_sorted if k]) index = bisect.bisect_left(self.consolidated_qualified_sorted, path_plus_dot) for i in range(index, len(self.consolidated_qualified_sorted)): key = self.consolidated_qualified_sorted[i] if not key.startswith(path_plus_dot): break ret.append(key.removeprefix(path_plus_dot)) assert key == join_arg_path(path, ret[-1]) for key, pattern, _value, _index in self.consolidated_wildcard: if not path: ret.append(key) continue # If it's not a wildcard, the logic is straightforward. But doing the equivalent # for wildcards is tricky! i = key.rfind(wildcard_literal) if i == -1: continue # The not strict case is not complicated, we just regex match if pattern.fullmatch(path): if not strict: ret.append("") assert pattern.fullmatch(path + ret[-1]) continue # This needs a little thinking about. # Say path is "foo.bar" and key is "...bar...baz" # Then wildcard_literal is "bar" and we check if "...bar" matches "foo.bar" # Since it does, we append "...baz" while i != -1: if ( i + len(wildcard_literal) < len(key) and key[i + len(wildcard_literal)] == "." and wildcard_key_to_regex(key[: i + len(wildcard_literal)]).fullmatch(path) ): assert i == 0 or key[i - 1] == "." suffix = key[i + len(wildcard_literal) :] if not suffix.startswith("..."): suffix = suffix.removeprefix(".") ret.append(suffix) assert pattern.fullmatch(join_arg_path(path, ret[-1])) break i_next = key.rfind(wildcard_literal, 0, i) assert i_next < i, "Infinite loop" i = i_next return ret def get_kv(self, exact_key: str, *, ignore_wildcards: bool = False) -> _FoundArgument | None: assert self.consolidated, "ArgumentMap must be consolidated before calling get_kv" lookup = self.consolidated_qualified.get(exact_key) if not ignore_wildcards: lookup_index = lookup[1] if lookup is not None else -1 for wildcard_key, pattern, value, index in self.consolidated_wildcard: if index <= lookup_index: break if pattern.fullmatch(exact_key): layer_name = self._layers[index].layer_name return _FoundArgument(wildcard_key, value, index, layer_name=layer_name) if lookup is not None: value, lookup_index = lookup layer_name = self._layers[lookup_index].layer_name return _FoundArgument(exact_key, value, lookup_index, layer_name=layer_name) return None def check_extraneous( self, used_args: set[tuple[str, int]], param_paths: AbstractSet[str], make_result: _MakeResult, *, entrypoint_repr: str, ) -> None: for index in range(len(self._layers)): layer = self._layers[index] for key, is_wildcard in layer.iter_keys(): # If something is not in used_args, it means it was either extraneous or it got # clobbered because something in a higher layer matched it if (key, index) in used_args: continue if ( # It's easy to check if a non-wildcard arg was clobbered. We just check if # there was a param with that name (that we should have matched if not for # presumed clobbering) (not is_wildcard and key not in param_paths) # For wildcards, we need to match against all param paths or ( is_wildcard and not any(layer._to_regex[key].fullmatch(p) for p in param_paths) ) ): # Okay, we have an extraneous argument. We're going to error, but we should # helpfully try to figure out what the user wanted extra = "" if layer.layer_name: extra += f" (from {layer.layer_name})" ratios = {p: wildcard_key_approx(key, p) for p in param_paths} if ratios: max_option = max(ratios, key=lambda v: ratios[v][0]) if ratios[max_option][0] > 0.1: extra = f"\nDid you mean {ratios[max_option][1]!r}?" if not is_wildcard: nested_pattern = wildcard_key_to_regex("..." + key) found_key = next( (p for p in param_paths if nested_pattern.fullmatch(p)), None ) if found_key is not None: extra += ( f"\nDid you get the nesting wrong, maybe you meant {found_key!r}?" ) if key.startswith("--"): extra += "\nDid you mean to use allow_hyphens=True in your entrypoint?" if not is_wildcard: parts = key.split(".") if len(parts) >= 2: valid_parent = _valid_parent(parts, param_paths) if valid_parent is None: extra += f"\nNo param found matching {parts[0]!r}" else: from chz.blueprint._blueprint import _found_arg_desc extra += f"\n\nParam {valid_parent!r} is closest valid ancestor" parent_found_arg = self.get_kv(valid_parent) param = make_result.all_params[valid_parent] desc = _found_arg_desc( make_result, parent_found_arg, param_path=valid_parent, param=param, omit_redundant=False, ) invalid_part = ( ".".join(parts).removeprefix(valid_parent + ".").split(".")[0] ) extra += f"\nParam {valid_parent!r} is set to {desc}" extra += f"\nSubparam {invalid_part!r} does not exist on it" raise ExtraneousBlueprintArg( f"Extraneous argument {key!r} to Blueprint for {entrypoint_repr}" + extra + "\nAppend --help to your command to see valid arguments" ) def __repr__(self) -> str: return "ArgumentMap(\n" + "\n".join(" " + repr(layer) for layer in self._layers) + "\n)" def join_arg_path(parent: str, child: str) -> str: if not parent: return child if child.startswith(".") or child == "": return parent + child return parent + "." + child ================================================ FILE: chz/blueprint/_argv.py ================================================ from __future__ import annotations import itertools import types from typing import Any, TypeVar import chz.blueprint from chz.blueprint._argmap import Layer from chz.blueprint._wildcard import wildcard_key_to_regex from chz.tiepin import type_repr _T = TypeVar("_T") def argv_to_blueprint_args( argv: list[str], *, allow_hyphens: bool = False ) -> dict[str, chz.blueprint.Castable | chz.blueprint.Reference]: # TODO: allow stuff like model[family=linear n_layers=1] ret: dict[str, chz.blueprint.Castable | chz.blueprint.Reference] = {} for arg in argv: try: key, value = arg.split("=", 1) except ValueError: raise ValueError( f"Invalid argument {arg!r}. Specify arguments in the form key=value" ) from None if allow_hyphens: key = key.lstrip("-") # parse key@=reference syntax (note =@ would be ambiguous) if key.endswith("@"): ret[key.removesuffix("@")] = chz.blueprint.Reference(value) else: ret[key] = chz.blueprint.Castable(value) return ret def beta_argv_arg_to_string(key: str, value: Any) -> list[str]: if isinstance(value, chz.blueprint.Castable): return [f"{key}={value.value}"] if isinstance(value, chz.blueprint.Reference): return [f"{key}@={value.ref}"] if isinstance(value, (types.FunctionType, type)): return [f"{key}={type_repr(value)}"] if isinstance(value, str): return [f"{key}={value}"] if isinstance(value, (int, float, bool)) or value is None: return [f"{key}={repr(value)}"] if isinstance(value, (list, tuple)): if all(isinstance(e, str) for e in value): if not any("," in e for e in value): return [f"{key}={','.join(value)}"] args_list = [] for i, e in enumerate(value): args_list.extend(beta_argv_arg_to_string(f"{key}.{i}", e)) return args_list elif all(isinstance(e, (int, float, bool)) or e is None for e in value): return [f"{key}={','.join(map(str, value))}"] if isinstance(value, dict): args_list = [] for k, v in value.items(): args_list.extend(beta_argv_arg_to_string(f"{key}.{k}", v)) return args_list # Probably safe to use repr here, but I'm curious to see how people end up using this raise NotImplementedError( f"TODO: beta_blueprint_to_argv does not currently convert {value!r} of " f"type {type(value)} to string" ) def beta_blueprint_to_argv(blueprint: chz.Blueprint[_T]) -> list[str]: """Returns a list of arguments that would recreate the given blueprint. Please do not use this function without asking @shantanu, it is slow and not fully robust, and more importantly, there may well be a better way to accomplish your goal. """ ret = [ arg for key, value in _collapse_layers(blueprint) for arg in beta_argv_arg_to_string(key, value) ] return ret def _collapse_layer( ordered_args: list[tuple[str, Any]], ordered_arg_keys: set[str], layer: Layer ) -> None: """Collapses `layer` into `ordered_args`, overriding any old keys as necessary.""" layer_args: list[tuple[str, Any]] = [] keys_to_remove: set[str] = set() for key, value in itertools.chain(layer.qualified.items(), layer.wildcard.items()): # Remove any previous args that would be overwritten by this one. wildcard = wildcard_key_to_regex(key) if "..." in key else None if wildcard: for prev_key in ordered_arg_keys: # TODO(shantanu): usually this regex is only matched against concrete keys # However, here we're matching against other wildcards if wildcard.fullmatch(prev_key): keys_to_remove.add(prev_key) else: if key in ordered_arg_keys: keys_to_remove.add(key) layer_args.append((key, value)) # Commit the new layer ordered_args[:] = [arg for arg in ordered_args if arg[0] not in keys_to_remove] + layer_args ordered_arg_keys.difference_update(keys_to_remove) ordered_arg_keys.update(key for key, _ in layer_args) def _collapse_layers(blueprint: chz.Blueprint[_T]) -> list[tuple[str, Any]]: """Collapses the layers of a blueprint into a list of key-value pairs. These could be applied as a single layer to a new blueprint to recreate the original. """ ordered_args: list[tuple[str, Any]] = [] ordered_arg_keys: set[str] = set() for layer in blueprint._arg_map._layers: _collapse_layer(ordered_args, ordered_arg_keys, layer) return ordered_args ================================================ FILE: chz/blueprint/_blueprint.py ================================================ from __future__ import annotations import ast import collections.abc import dataclasses import functools import inspect import io import sys import textwrap import typing from dataclasses import dataclass from typing import Any, Callable, Final, Generic, Mapping, Protocol from typing_extensions import TypeVar import chz from chz.blueprint._argmap import ArgumentMap, Layer, _FoundArgument, join_arg_path from chz.blueprint._argv import argv_to_blueprint_args from chz.blueprint._entrypoint import ( ConstructionException, EntrypointHelpException, ExtraneousBlueprintArg, InvalidBlueprintArg, MissingBlueprintArg, ) from chz.blueprint._lazy import ( Evaluatable, ParamRef, Thunk, Value, check_reference_targets, evaluate, ) from chz.field import Field from chz.tiepin import ( CastError, _simplistic_try_cast, _simplistic_type_of_value, eval_in_context, is_kwargs_unpack, is_subtype_instance, is_typed_dict, type_repr, ) from chz.util import MISSING, MISSING_TYPE _T = TypeVar("_T") _T_cov_def = TypeVar("_T_cov_def", covariant=True, default=Any) class SpecialArg: ... class Castable(SpecialArg): """A wrapper class for str if you want a Blueprint value to be magically type aware casted.""" def __init__(self, value: str) -> None: self.value = value def __repr__(self) -> str: return f"Castable({self.value!r})" def __hash__(self) -> int: return hash(self.value) def __eq__(self, other: object) -> bool: if not isinstance(other, Castable): try: return _simplistic_try_cast(self.value, type(other)) == other except CastError: return False return self.value == other.value class Reference(SpecialArg): """A reference to another parameter in a Blueprint.""" def __init__(self, ref: str) -> None: if "..." in ref: raise ValueError("Cannot use wildcard as a reference target") self.ref = ref def __repr__(self) -> str: return f"Reference({self.ref!r})" @dataclass(frozen=True, kw_only=True) class Computed(SpecialArg): """A parameter computed from other parameters in a Blueprint.""" src: dict[str, Reference] compute: Callable[..., Any] def __repr__(self) -> str: arg_str = ", ".join(f"{k}@={v.ref}" for k, v in self.src.items()) return f"Computed({arg_str})" @dataclass(frozen=True) class _MakeResult: # `value_mapping` is a dictionary mapping from parameter paths to Evaluatable values. # This ultimately contains all the kinds of values we will use in instantiation. # See chz.blueprint._lazy.evaluate for an example of using Evaluatable. value_mapping: dict[str, Evaluatable] # `all_params` is a dictionary containing all parameters we discover, mapping from that param # path to the parameter. Note what parameters we discover will depend on polymorphic # construction via meta_factories. We use all_params to provide a useful --help (and various # other things, e.g. detect clobbering when checking for extraneous arguments) all_params: dict[str, _Param] # `used_args` is a set of (key, layer_index) tuples that we use to track which arguments from # arg_map we've used. We use this to check for extraneous arguments. used_args: set[tuple[str, int]] # `meta_factory_value` records what meta_factory we're using. This makes --help more # understandable in the presence of polymorphism, especially when factories come from # blueprint_unspecified. It's conceptually the same information as in Thunk.fn in value_mapping, # but preserves user input for variadics or generics (instead of being a constructor function) meta_factory_value: dict[str, Any] # `missing_params` is a list of parameters we know need are required but haven't been # specified. In theory, this is unnecessary because `__init__` will raise an error if # a required param is missing, but this improves diagnostics. missing_params: list[str] def _entrypoint_caster(o: str) -> object: raise chz.tiepin.CastError("Will not interpret entrypoint as a value") def _found_arg_desc( r: _MakeResult, found_arg: _FoundArgument | None, *, param_path: str, param: _Param, omit_redundant: bool = True, color: bool = False, ) -> str: if found_arg is None: if param_path in r.meta_factory_value: found_arg_str = type_repr(r.meta_factory_value[param_path]) if color: found_arg_str += " \033[90m(meta_factory)\033[0m" else: found_arg_str += " (meta_factory)" elif param.default is not None: found_arg_str = param.default.to_help_str() if color: found_arg_str += " \033[90m(default)\033[0m" else: found_arg_str += " (default)" elif ( param.meta_factory is not None and (factory := param.meta_factory.unspecified_factory()) is not None and (factory is not param.type or not omit_redundant) ): if getattr(factory, "__name__", None) == "": found_arg_str = _lambda_repr(factory) or type_repr(factory) else: found_arg_str = type_repr(factory) if color: found_arg_str += " \033[90m(blueprint_unspecified)\033[0m" else: found_arg_str += " (blueprint_unspecified)" else: found_arg_str = "-" else: if isinstance(found_arg.value, Castable): found_arg_str = repr(found_arg.value.value)[1:-1] elif isinstance(found_arg.value, Reference): found_arg_str = f"@={found_arg.value.ref}" elif isinstance(found_arg.value, Computed): arg_str = ", ".join(f"{k}@={v.ref}" for k, v in found_arg.value.src.items()) found_arg_str = f"f({arg_str})" else: found_arg_str = type_repr(found_arg.value) if found_arg.layer_name: if color: found_arg_str += f" \033[90m(from \033[94m{found_arg.layer_name}\033[90m)\033[0m" else: found_arg_str += f" (from {found_arg.layer_name})" return found_arg_str class Blueprint(Generic[_T_cov_def]): def __init__( self, target: chz.factories.MetaFactory | type[_T_cov_def] | Callable[..., _T_cov_def] ) -> None: """Instantiate a Blueprint. Args: target: The target object or callable we will instantiate or call. """ self.target = target if isinstance(target, chz.factories.MetaFactory): self.meta_factory = target if isinstance(target, chz.factories.standard): entrypoint_type = target.annotation entrypoint_doc = getattr(entrypoint_type, "__doc__", "") else: entrypoint_type = object entrypoint_doc = "" self.entrypoint_repr = type_repr(entrypoint_type) else: self.meta_factory = chz.factories.standard(annotation=target) entrypoint_type = target if self.meta_factory.unspecified_factory() is None: if not callable(target): raise ValueError(f"{target} is not callable") self.meta_factory = chz.factories.standard(annotation=object, unspecified=target) entrypoint_type = object self.entrypoint_repr = type_repr(target) entrypoint_doc = getattr(target, "__doc__", "") self.param = _Param( name="", type=entrypoint_type, meta_factory=self.meta_factory, default=None, doc=entrypoint_doc.strip() if entrypoint_doc else "", blueprint_cast=_entrypoint_caster, metadata={}, ) self._arg_map = ArgumentMap([]) def clone(self) -> Blueprint[_T_cov_def]: """Make a copy of this Blueprint.""" return Blueprint(self.target).apply(self) def apply( self, values: Blueprint[_T_cov_def] | Mapping[str, Any], layer_name: str | None = None, *, subpath: str | None = None, strict: bool = False, ) -> Blueprint[_T_cov_def]: """Modify this Blueprint by partially applying some arguments. Args: values: The values to partially apply to this Blueprint layer_name: The name of the layer to add (allows identification of the source of values) subpath: A subpath to nest the argument names under strict: Whether to eagerly check for extraneous arguments. This may not work well in cases where a polymorphic field is applied later. """ if isinstance(values, Mapping): self._arg_map.add_layer(Layer(values, layer_name).nest_subpath(subpath)) elif isinstance(values, Blueprint): if subpath is None: if values.target is not self.target: raise ValueError( f"Cannot apply Blueprint for {type_repr(values.target)} to Blueprint for " f"{type_repr(self.target)}" ) for layer in values._arg_map._layers: self._arg_map.add_layer(layer.nest_subpath(subpath)) else: raise TypeError(f"Expected dict or Blueprint, got {type(values)}") if strict: r = self._make_lazy() self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), make_result=r, entrypoint_repr=self.entrypoint_repr, ) return self def apply_from_argv( self, argv: list[str], allow_hyphens: bool = False, layer_name: str = "command line" ) -> Blueprint[_T_cov_def]: """Apply arguments from argv to this Blueprint.""" values = argv_to_blueprint_args( [a for a in argv if a != "--help"], allow_hyphens=allow_hyphens ) self.apply(values, layer_name=layer_name) if "--help" in argv: argv.remove("--help") raise EntrypointHelpException(self.get_help(color=sys.stdout.isatty())) return self def _make_lazy(self) -> _MakeResult: all_params: dict[str, _Param] = {} used_args: set[tuple[str, int]] = set() meta_factory_value: dict[str, Any] = {} missing_params: list[str] = [] self._arg_map.consolidate() value_mapping: dict[str, Evaluatable] | ConstructionIssue | None value_mapping = _construct_param( self.param, "", self._arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): raise ConstructionException(value_mapping.issue) if value_mapping is None: # value_mapping is None if _construct_param / _construct_unspecified_param # wants us to fallback to the default or thinks we're missing required arguments # This is sort of equivalent to what happens in _construct_factory unspecified_factory = self.meta_factory.unspecified_factory() if unspecified_factory is None: raise ConstructionException( f"Cannot construct {self.target} because it has no unspecified factory" ) value_mapping = {"": Thunk(unspecified_factory, {})} if "" in missing_params: missing_params.remove("") return _MakeResult( value_mapping=value_mapping, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) def _make_from_make_result(self, r: _MakeResult) -> _T_cov_def: self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), make_result=r, entrypoint_repr=self.entrypoint_repr, ) check_reference_targets(r.value_mapping, r.all_params.keys()) # Note we check for extraneous args first, so we get better errors for typos if r.missing_params: raise MissingBlueprintArg( f"Missing required arguments for parameter(s): {', '.join(r.missing_params)}" ) # __chz_blueprint__ return evaluate(r.value_mapping) def make(self) -> _T_cov_def: """Instantiate or call the target object or callable.""" r = self._make_lazy() return self._make_from_make_result(r) def make_from_argv( self, argv: list[str] | None = None, allow_hyphens: bool = False ) -> _T_cov_def: """Like make, but suitable for command line entrypoints. This will apply arguments from argv to this Blueprint before attempting to make the target. If "--help" is in argv, this will print help text and exit. """ if argv is None: argv = sys.argv[1:] self.apply_from_argv(argv, allow_hyphens=allow_hyphens) return self.make() def get_help(self, *, color: bool = False) -> str: """Get help text for this Blueprint. Note that applied arguments may affect output here, e.g. in case of polymorphically constructed fields. """ r = self._make_lazy() f = io.StringIO() output = functools.partial(print, file=f) try: self._arg_map.check_extraneous( r.used_args, r.all_params.keys(), make_result=r, entrypoint_repr=self.entrypoint_repr, ) except ExtraneousBlueprintArg as e: output(f"WARNING: {e}\n") try: check_reference_targets(r.value_mapping, r.all_params.keys()) except InvalidBlueprintArg as e: output(f"WARNING: {e}\n") if r.missing_params: output( f"WARNING: Missing required arguments for parameter(s): {', '.join(r.missing_params)}\n" ) output(f"Entry point: {self.entrypoint_repr}") output() if self.param.doc: output(textwrap.indent(self.param.doc, " ")) output() param_output = [] for param_path, param in r.all_params.items(): # TODO: cast or meta_factory info, not just type found_arg = self._arg_map.get_kv(param_path) if ( not isinstance(self.target, chz.factories.MetaFactory) and param_path == "" and found_arg is None ): # If we're not using root polymorphism, skip this param continue found_arg_str = _found_arg_desc( r, found_arg, param_path=param_path, param=param, color=color ) param_output.append( (param_path or "", type_repr(param.type), found_arg_str, param.doc) ) clip = 40 lens = tuple(min(clip, max(map(len, column))) for column in zip(*param_output)) output("Arguments:") for p, typ, arg, doc in param_output: col = 0 target_col = 0 line = io.StringIO() add = functools.partial(print, file=line, end="") raw_string = p add(" ") if color: add(f"\033[1m{raw_string}\033[0m") else: add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[0] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = typ add(" ") add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[1] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = arg add(" ") add(raw_string) col += 2 + len(raw_string) target_col += 2 + lens[2] pad = (target_col - col) if col <= target_col else (-len(raw_string)) % 20 add(" " * pad) col += pad raw_string = doc add(" ") if color: add(f"\033[90m{raw_string}\033[0m") else: add(raw_string) output(line.getvalue().rstrip()) return f.getvalue() def _lambda_repr(fn) -> str | None: try: src = inspect.getsource(fn).strip() tree = ast.parse(src) nodes = [node for node in ast.walk(tree) if isinstance(node, ast.Lambda)] if len(nodes) != 1: return None return ast.unparse(nodes[0]) except Exception: return None @dataclass(frozen=True, kw_only=True) class _Default: value: Any | MISSING_TYPE factory: Callable[..., Any] | MISSING_TYPE def to_help_str(self) -> str: if self.factory is not MISSING: if getattr(self.factory, "__name__", None) == "": return f"({_lambda_repr(self.factory)})()" or "" # type_repr works reasonably well for functions too return f"{type_repr(self.factory)}()" ret = repr(self.value) if len(ret) > 40: return "" return ret def instantiate(self) -> Any: if not isinstance(self.factory, MISSING_TYPE): return self.factory() return self.value @classmethod def from_field(cls, field: Field) -> _Default | None: if field._default is MISSING and field._default_factory is MISSING: return None return _Default(value=field._default, factory=field._default_factory) @classmethod def from_inspect_param(cls, sigparam: inspect.Parameter) -> _Default | None: if sigparam.default is sigparam.empty: return None return _Default(value=sigparam.default, factory=MISSING) @dataclass(frozen=True, kw_only=True) class _Param: name: str type: Any meta_factory: chz.factories.MetaFactory | None default: _Default | None doc: str blueprint_cast: Callable[[str], object] | None metadata: dict[str, Any] def cast(self, value: str) -> object: # If we have a field-level cast, always use that! if self.blueprint_cast is not None: return self.blueprint_cast(value) # If we have a meta_factory, route casting through it. This allows user expectations # of types that result from casting to better match their expectations from polymorphic # construction (e.g. using default_cls from chz.factories.subclass) if self.meta_factory is not None: return self.meta_factory.perform_cast(value) # TODO: maybe MetaFactory should have default impl and this should be: # return chz.factories.MetaFactory().perform_cast(value, self.type) return _simplistic_try_cast(value, self.type) def _get_variadic_elements(obj_path: str, arg_map: ArgumentMap) -> set[str]: elements = set() for subpath in arg_map.subpaths(obj_path, strict=True): assert subpath if subpath[0] != ".": element = subpath.split(".")[0] assert element elements.add(element) return elements def _collect_params_from_chz( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: obj_origin = getattr(obj, "__origin__", obj) params = [] for field in chz.chz_fields(obj_origin).values(): params.append( _Param( name=field.logical_name, type=field.x_type, meta_factory=field.meta_factory, default=_Default.from_field(field), doc=field._doc, blueprint_cast=field._blueprint_cast, metadata=(field.metadata or {}), ) ) return params, obj, [] def _collect_params_from_sequence( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: elements = _get_variadic_elements(obj_path, arg_map) max_element = max((int(e) for e in elements), default=-1) obj_origin = getattr(obj, "__origin__", obj) obj_type_construct = obj_origin type_for_index: Callable[[int], type] if obj_origin is list: element_type = getattr(obj, "__args__", [object])[0] type_for_index = lambda i: element_type variadic_types = [element_type] elif obj_origin is collections.abc.Sequence: element_type = getattr(obj, "__args__", [object])[0] type_for_index = lambda i: element_type variadic_types = [element_type] obj_type_construct = tuple elif obj_origin is tuple: args: tuple[Any, ...] | None = getattr(obj, "__args__", None) if args is None: args = (Any, ...) if len(args) == 2 and args[-1] is ...: # homogeneous tuple type_for_index = lambda i: args[0] variadic_types = [args[0]] else: # heterogeneous tuple if max_element >= len(args): raise TypeError( f"Tuple type {obj} for {obj_path!r} must take {len(args)} items; " f"arguments for index {max_element} were specified" + ( f". Homogeneous tuples should be typed as tuple[{type_repr(args[0])}, ...] not tuple[{type_repr(args[0])}]" if len(args) == 1 else "" ) ) type_for_index = lambda i: args[i] variadic_types = list(args) else: raise AssertionError params = [] for i in range(max_element + 1): element_type = type_for_index(i) params.append( _Param( name=str(i), type=element_type, meta_factory=chz.factories.standard(annotation=element_type), default=None, doc="", blueprint_cast=None, metadata={}, ) ) def sequence_constructor(**kwargs): return obj_type_construct(kwargs[str(i)] for i in range(max_element + 1)) obj_constructor = sequence_constructor return params, obj_constructor, variadic_types def _collect_params_from_mapping( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue: elements = _get_variadic_elements(obj_path, arg_map) args: tuple[Any, ...] = getattr(obj, "__args__", ()) if len(args) == 0: element_type = object key_type = str elif len(args) == 2: if args[0] not in (str, int): if elements: raise TypeError( f"Variadic dict type must take str or int keys, not {type_repr(args[0])}" ) return ConstructionIssue( f"Variadic dict type must take str or int keys, not {type_repr(args[0])}" ) key_type = args[0] element_type = args[1] else: raise TypeError(f"Dict type {obj} must take 0 or 2 items") params = [] for element in elements: params.append( _Param( name=element, type=element_type, meta_factory=chz.factories.standard(annotation=element_type), default=None, doc="", blueprint_cast=None, metadata={}, ) ) def _dict(**kwargs) -> dict[int | str, Any]: return {key_type(k): v for k, v in kwargs.items()} return params, _dict, [element_type] def _collect_params_from_typed_dict( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]]: obj_origin = getattr(obj, "__origin__", obj) params = [] variadic_types = [] for key, annotation in typing.get_type_hints(obj_origin).items(): required = key in obj_origin.__required_keys__ params.append( _Param( name=key, type=annotation, meta_factory=chz.factories.standard(annotation=annotation), # Mark the default as NotRequired to improve --help output # We don't actually use the default values in Blueprint since we let # instantiation handle insertion of default values default=(None if required else _Default(value=typing.NotRequired, factory=MISSING)), doc="", blueprint_cast=None, metadata={}, ) ) variadic_types.append(annotation) return params, obj_origin, variadic_types def _collect_params_from_callable( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> tuple[list[_Param], Callable[..., Any], list[Any]] | ConstructionIssue: # Note you probably don't want to call this if obj is a primitive try: signature = inspect.signature(obj) except ValueError: return ConstructionIssue(f"Failed to get signature for {obj.__name__}") obj_constructor = obj has_pos_only = False has_pos_or_kwarg = False elements: set[str] | None = None params = [] for i, (name, sigparam) in enumerate(signature.parameters.items()): param_annot: Any if sigparam.annotation is sigparam.empty: if i == 0 and "." in obj.__qualname__: # potentially first parameter of a method, default the annotation to the class try: cls = getattr(inspect.getmodule(obj), obj.__qualname__.rsplit(".", 1)[0]) param_annot = cls except Exception: param_annot = object else: param_annot = object else: param_annot = sigparam.annotation if isinstance(param_annot, str): try: param_annot = eval_in_context(param_annot, obj) except Exception as e: raise ValueError( f"Failed to evaluate parameter {name}: {param_annot} in signature {signature} of object {obj}" ) from e if sigparam.kind == sigparam.POSITIONAL_ONLY: has_pos_only = True name = str(i) if sigparam.kind == sigparam.POSITIONAL_OR_KEYWORD: has_pos_or_kwarg = True if sigparam.kind == sigparam.VAR_POSITIONAL: if elements is None: elements = _get_variadic_elements(obj_path, arg_map) max_element = max((int(e) for e in elements if e.isdigit()), default=-1) if has_pos_or_kwarg and max_element >= 0: return ConstructionIssue( "Cannot collect parameters with both positional-or-keyword and variadic positional parameters" ) has_pos_only = True for j in range(i, max_element + 1): params.append( _Param( name=str(j), type=param_annot, meta_factory=chz.factories.standard(annotation=param_annot), default=None, doc="", blueprint_cast=None, metadata={}, ) ) continue if sigparam.kind == sigparam.VAR_KEYWORD: if is_kwargs_unpack(param_annot): if len(param_annot.__args__) != 1 or not is_typed_dict(param_annot.__args__[0]): return ConstructionIssue( f"Cannot collect parameters from {obj.__name__}, expected Unpack[TypedDict], not {param_annot}" ) for key, annotation in typing.get_type_hints(param_annot.__args__[0]).items(): # TODO: handle total=False and PEP 655 params.append( _Param( name=key, type=annotation, meta_factory=chz.factories.standard(annotation=annotation), default=None, doc="", blueprint_cast=None, metadata={}, ) ) else: if elements is None: elements = _get_variadic_elements(obj_path, arg_map) for element in elements - {p.name for p in params}: params.append( _Param( name=element, type=param_annot, meta_factory=chz.factories.standard(annotation=param_annot), default=None, doc="", blueprint_cast=None, metadata={}, ) ) continue # It could be interesting to let function defaults be chz.Field :-) # TODO: could be fun to parse function docstring params.append( _Param( name=name, type=param_annot, meta_factory=chz.factories.standard(annotation=param_annot), default=_Default.from_inspect_param(sigparam), doc="", blueprint_cast=None, metadata={}, ) ) if has_pos_only: def positional_constructor(**kwargs): a = [] kw = {} for k, v in kwargs.items(): if k.isdigit(): a.append((int(k), v)) else: kw[k] = v a = [v for _, v in sorted(a)] return obj(*a, **kw) obj_constructor = positional_constructor # Note params may be empty here if obj doesn't take any parameters. # This is usually okay, but has some interaction with fully defaulted subcomponents. # See test_nested_all_defaults and variants return params, obj_constructor, [] def _collect_params( obj: Any, obj_path: str, arg_map: ArgumentMap ) -> ( ConstructionIssue | tuple[ list[_Param], # params discovered Callable[..., Any], # constructor to call list[Any], # vaguely like [p.type for p in params], used only for sanity checking ] ): obj_origin = getattr(obj, "__origin__", obj) if chz.is_chz(obj_origin): return _collect_params_from_chz(obj, obj_path, arg_map) if isinstance(obj, functools.partial) and chz.is_chz(obj.func): if obj.args: return ConstructionIssue( f"Cannot collect parameters from partial function of chz class " f"{type_repr(obj.func)} with positional arguments" ) result = _collect_params(obj.func, obj_path, arg_map) if isinstance(result, ConstructionIssue): return result params, _constructor, variadic_types = result new_params = [] for param in params: if param.name in obj.keywords: # The actual value of the default should only matter for --help output param = dataclasses.replace( param, default=_Default(value=obj.keywords[param.name], factory=MISSING) ) new_params.append(param) return new_params, obj, variadic_types if obj_origin in {list, tuple, collections.abc.Sequence}: return _collect_params_from_sequence(obj, obj_path, arg_map) if obj_origin in {dict, collections.abc.Mapping}: return _collect_params_from_mapping(obj, obj_path, arg_map) if is_typed_dict(obj_origin): return _collect_params_from_typed_dict(obj, obj_path, arg_map) if obj_origin in {bool, int, float, str, bytes, None, type(None)}: return ConstructionIssue("Cannot collect parameters from primitive") if "enum" in sys.modules: import enum if type(obj) is enum.EnumMeta: return ConstructionIssue("Cannot collect parameters from Enum class") if callable(obj): return _collect_params_from_callable(obj, obj_path, arg_map) return ConstructionIssue( f"Could not collect parameters to construct {obj} of type {type_repr(obj)}" ) _K = TypeVar("_K") _V = TypeVar("_V", contravariant=True) class _WriteOnlyMapping(Generic[_K, _V], Protocol): def __setitem__(self, __key: _K, __value: _V, /) -> None: ... def update(self, __m: Mapping[_K, _V], /) -> None: ... class ConstructionIssue: def __init__(self, issue: str) -> None: self.issue = issue def _construct_factory( obj: Callable[..., _T], obj_path: str, arg_map: ArgumentMap, *, # Output parameters, do not use within this function # Typing these as write-only should help prevent accidental unsound use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue: result = _collect_params(obj, obj_path, arg_map) del obj if isinstance(result, ConstructionIssue): return result params, obj_constructor, _ = result # Ideas: # TODO: Allow automatically accessing any attribute on parent for factories? # This eases the responsibility of getting the right structure for the config and could mean # we don't need to support wildcards? Reduces problems of something not getting specified # correctly. # "If you need a value, move it one level up." # TODO: Allow factories to return blueprints? This would allow for better presets, e.g. you # could do model=d4-dev model.n_layers=5 kwargs: dict[str, ParamRef] = {} value_mapping: dict[str, Evaluatable] = {} for param in params: sub_value_mapping = _construct_param( param, obj_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(sub_value_mapping, ConstructionIssue): return sub_value_mapping if sub_value_mapping is None: continue param_path = (obj_path + "." if obj_path else "") + param.name value_mapping.update(sub_value_mapping) kwargs[param.name] = ParamRef(param_path) value_mapping[obj_path] = Thunk(obj_constructor, kwargs) return value_mapping def _construct_unspecified_param( param: _Param, *, param_path: str, arg_map: ArgumentMap, # Output parameters, do not use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue | None: if ( param.meta_factory is not None and (factory := param.meta_factory.unspecified_factory()) is not None ): assert callable(factory) sub_all_params: dict[str, _Param] = {} sub_missing_params: list[str] = [] sub_used_args: set[tuple[str, int]] = set() sub_meta_factory_value: dict[str, Any] = {} value_mapping = _construct_factory( factory, param_path, arg_map, all_params=sub_all_params, used_args=sub_used_args, meta_factory_value=sub_meta_factory_value, missing_params=sub_missing_params, ) all_params.update(sub_all_params) used_args.update(sub_used_args) # TODO: should this be gated by use? meta_factory_value.update(sub_meta_factory_value) if isinstance(value_mapping, ConstructionIssue): if param_path == "" and param.default is None: # This is a little bit special case-y. But if we have a construction issue with # the root param, it's much better to forward it than for the user to get an error # about a missing required root argument. return value_mapping else: thunk = value_mapping[param_path] assert isinstance(thunk, Thunk) # Only do this if we have subcomponents specified (including wildcards) if thunk.kwargs: meta_factory_value[param_path] = factory missing_params.extend(sub_missing_params) return value_mapping # Alternatively, if a) we do not have a default, b) we're making a chz object # c) we know instantiation would always work, that's fine too. # A little special-case-y, but somewhat sane. It turns something that would # error due to lack of default into something reasonable. # See test_nested_all_defaults and variants if ( param.default is None and ( chz.is_chz(factory) or (isinstance(factory, functools.partial) and chz.is_chz(factory.func)) ) and ( all( p.default is not None for path, p in sub_all_params.items() if "." not in path.removeprefix(param_path + ".") ) ) ): assert not sub_missing_params meta_factory_value[param_path] = factory return value_mapping # If we have a default, make sure we don't extend missing_params if param.default is None: if sub_missing_params: missing_params.extend(sub_missing_params) else: # Happens if we collect no params, like non-chz field or variadics missing_params.append(param_path) else: # If we have a default, do some validation about wildcards + variadics _check_for_wildcard_matching_variadic_top_level(factory, param, param_path, arg_map) return None assert not sub_missing_params # If we have no subcomponents specified or we have no factory, we don't add any kwargs # When the object is created, this will be equivalent to: # `attr = default` or `attr = default_factory()` # If there is no default, we will raise MissingBlueprintArg, instead of relying on the # normal Python error during instantiation. We also rely on raising ExtraneousBlueprintArg # if there are arguments that go unused. # (In the case of Blueprint implementation bugs, if we're missing a param, __init__ will # have our back, but the extraneous logic has no backup) if param.default is None: missing_params.append(param_path) return None def _construct_param( param: _Param, obj_path: str, arg_map: ArgumentMap, *, # Output parameters, do not use within this function # See _MakeResult for docs about these parameters all_params: _WriteOnlyMapping[str, _Param], used_args: set[tuple[str, int]], meta_factory_value: _WriteOnlyMapping[str, Any], missing_params: list[str], ) -> dict[str, Evaluatable] | ConstructionIssue | None: # Returns None if we don't need to pass any value. This doesn't mean there's an error, # we might simply want the default or default_factory value. param_path: Final = (obj_path + "." if obj_path else "") + param.name all_params[param_path] = param found_arg = arg_map.get_kv(param_path) # If nothing is specified, check if we have a factory we can feed subcomponents to and if there # are specified subcomponents we could feed to it. Otherwise, if a default or default_factory # exists, we'll use that. if found_arg is None: return _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) used_args.add((found_arg.key, found_arg.layer_index)) spec: object = found_arg.value # Something is specified, so we must either add something to kwargs or error out # If something is specified, and is of the expected type, we just assign it: # `attr = spec` if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type): # (ignore SpecialArg's here, in case param.type is object) if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)): # TODO: deep copy? return {param_path: Value(spec)} # ..or if it's a Reference to some other parameter if isinstance(spec, Reference): if spec.ref == param_path: # If it's a self reference, treat it as if it were unspecified value_mapping = _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping if value_mapping is None and param.default is not None: # See test_blueprint_reference_wildcard_default # TODO: this is the only place we instantiate a default default = param.default.instantiate() return {param_path: Value(default)} return value_mapping return {param_path: ParamRef(spec.ref)} elif isinstance(spec, Computed): # If it inherits from a set of other parameters if param_path in {spec.ref for spec in spec.src.values()}: # Same as the unspecified param case return _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) else: kwargs = {k: ParamRef(v.ref) for k, v in spec.src.items()} return {param_path: Thunk(kwargs=kwargs, fn=spec.compute)} # Otherwise, we see if we can cast it to the expected type: # `attr = trycast(spec.value, param.type)` if isinstance(spec, Castable): # If we have a meta_factory and we have args that are prefixed with the param path, we # will always want to construct that (if we successfully casted here when subcomponents # are specified, we'd just fail later because those subcomponents would be extraneous) if not (param.meta_factory is not None and arg_map.subpaths(param_path, strict=True)): try: casted_spec = param.cast(spec.value) return {param_path: Value(casted_spec)} except CastError: pass # ..or if it's a Reference to some other parameter if isinstance(spec, Reference): if spec.ref == param_path: # If it's a self reference, treat it as if it were unspecified value_mapping = _construct_unspecified_param( param, param_path=param_path, arg_map=arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping if value_mapping is None and param.default is not None: # See test_blueprint_reference_wildcard_default # TODO: this is the only place we instantiate a default default = param.default.instantiate() return {param_path: Value(default)} return value_mapping return {param_path: ParamRef(spec.ref)} # Otherwise, see if it's something that can construct the expected type. For instance, # maybe it's a subclass of param.type, or more generally any `Callable[..., param.type]`, # in which case we do: # `attr = spec(...)` factory: Callable[..., Any] if is_subtype_instance(spec, Callable[..., param.type]): assert callable(spec) factory = spec value_mapping = _construct_factory( factory, param_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping meta_factory_value[param_path] = factory return value_mapping # Otherwise, see if it's something that can be casted into something that can construct # the expected type. For instance, maybe it's a string that's the name of a subclass of # param.type or "module:func" where module.func is a `func: Callable[..., param.type]`. # `attr = trycast(spec, constructor_type)(...)` if isinstance(spec, Castable): if param.meta_factory is not None: try: factory = param.meta_factory.from_string(spec.value) except chz.factories.MetaFromString as e: cast_error = None try: param.cast(spec.value) except CastError as e2: cast_error = str(e2) if cast_error is None: subpaths = arg_map.subpaths(param_path, strict=True) assert subpaths cast_error = f"Not a value, since subparameters were provided (e.g. {join_arg_path(param_path, subpaths[0])!r})" raise InvalidBlueprintArg( f"Could not interpret argument {spec.value!r} provided for param {param_path!r}...\n\n" f"- Failed to interpret it as a value:\n{cast_error}\n\n" f"- Failed to interpret it as a factory for polymorphic construction:\n{e}" ) from None assert callable(factory) value_mapping = _construct_factory( factory, param_path, arg_map, all_params=all_params, used_args=used_args, meta_factory_value=meta_factory_value, missing_params=missing_params, ) if isinstance(value_mapping, ConstructionIssue): return value_mapping meta_factory_value[param_path] = factory return value_mapping # This cast is just to raise the error we caught previously try: param.cast(spec.value) except CastError as e: raise InvalidBlueprintArg( f"Could not cast {spec.value!r} to {type_repr(param.type)}:\n{e}" ) from e # This next line should be unreachable... raise TypeError( f"Expected {param_path!r} to be castable to {type_repr(param.type)}, got {spec.value!r}" ) if not isinstance(spec, SpecialArg) and is_subtype_instance(spec, param.type): if param.meta_factory is not None: subpaths = arg_map.subpaths(param_path, strict=True) if subpaths: raise InvalidBlueprintArg( f"Could not interpret {spec!r} provided for param {param_path!r} " f"as a value, since subparameters were provided " f"(e.g. {join_arg_path(param_path, subpaths[0])!r})" ) raise TypeError( f"Expected {param_path!r} to be {type_repr(param.type)}, got {type_repr(_simplistic_type_of_value(spec))}" ) def _check_for_wildcard_matching_variadic_top_level( obj: object, param: _Param, obj_path: str, arg_map: ArgumentMap ): assert param.default is not None if ( type(param.default.value) is tuple and param.default.value == () ) or param.default.factory in {tuple, list, dict}: return result = _collect_params(obj, obj_path, arg_map) if isinstance(result, ConstructionIssue): return variadic_params, _, variadic_types = result if variadic_params: return if isinstance(param.default.value, (tuple, list)): variadic_types = list( set(variadic_types) | {type(element) for element in param.default.value} ) elif isinstance(param.default.value, dict): variadic_types = list( set(variadic_types) | {type(element) for element in param.default.value.values()} ) if not variadic_types: return # The case we're checking here is if we: # 1) have a param with a default # 1.5) the default is not an empty tuple or list or dict # 2) have a variadic factory for that param # 3) we do not find any variadic params # Then we check if any wildcards would have matched a param if we had one, since it can be # unintuitive that the default will not be affected by the wildcard (default / default_factory # are opaque and have no interaction with wildcards beyond their presence or absence). # See test_variadic_default_wildcard_error for element_type in variadic_types: result = _collect_params(element_type, obj_path + ".__chz_empty_variadic", arg_map) if isinstance(result, ConstructionIssue): continue subparams, _, _ = result for subparam in subparams: param_path = obj_path + ".__chz_empty_variadic." + subparam.name found_arg = arg_map.get_kv(param_path) param_path = obj_path + ".(variadic)." + subparam.name if found_arg is not None: raise ConstructionException( f"\n\nYou've hit an interesting case.\n\n" f'The parameter "{obj_path}" is variadic ({type_repr(obj)}), but no ' "parametrisation was found (either variadic subparameters or a polymorphic " "parametrisation).\n" f'This is fine in theory, because "{obj_path}" has a ' f"default value.\n\n" f'However, you also specified the wildcard "{found_arg.key}" and you may ' f'have expected it to modify the value of "{param_path}".\n' "This is not possible -- default values / default_factory results are " "opaque to chz. " "The only way in which default / default_factory interact with Blueprint " "is presence / absence. So out of caution, here's an error!\n\n" "If this error is a false positive, consider scoping the wildcard more " "narrowly or using exact keys. As always, appending --help to a chz command " "will show you what gets mapped to which param." ) ================================================ FILE: chz/blueprint/_entrypoint.py ================================================ from __future__ import annotations import functools import inspect import io import os import sys from typing import Any, Callable, TypeVar import chz from chz.tiepin import eval_in_context, type_repr _T = TypeVar("_T") _F = TypeVar("_F", bound=Callable[..., Any]) class EntrypointException(Exception): ... class EntrypointHelpException(EntrypointException): ... class ExtraneousBlueprintArg(EntrypointException): ... class InvalidBlueprintArg(EntrypointException): ... class MissingBlueprintArg(EntrypointException): ... class ConstructionException(EntrypointException): ... def exit_on_entrypoint_error(fn: _F) -> _F: @functools.wraps(fn) def inner(*args, **kwargs): try: return fn(*args, **kwargs) except EntrypointException as e: if isinstance(e, EntrypointHelpException): print(e, end="" if e.args[0][-1] == "\n" else "\n") else: print("Error:", file=sys.stderr) print(e, end="" if e.args[0][-1] == "\n" else "\n", file=sys.stderr) if "PYTEST_VERSION" in os.environ: raise sys.exit(1) return inner # type: ignore[return-value] @exit_on_entrypoint_error def entrypoint( target: Callable[..., _T], *, argv: list[str] | None = None, allow_hyphens: bool = False ) -> _T: """Easy way to create a script entrypoint using chz. For example, if you wish to run a function: ``` def do_something(alpha: int, beta: str, gamma: bytes) -> None: ... if __name__ == "__main__": chz.entrypoint(do_something) ``` It also works for instantiating objects: ``` @chz.chz class Run: name: str def launch(self) -> None: ... if __name__ == "__main__": run = chz.entrypoint(Run) run.launch() ``` """ # This function should be easily forkable, so do not make it more complicated return chz.Blueprint(target).make_from_argv(argv, allow_hyphens=allow_hyphens) @exit_on_entrypoint_error def nested_entrypoint( main: Callable[[Any], _T], *, argv: list[str] | None = None, allow_hyphens: bool = False ) -> _T: """Easy way to create a script entrypoint using chz for functions that take a chz object. For example: ``` @chz.chz class Run: name: str def main(run: Run) -> None: ... if __name__ == "__main__": chz.nested_entrypoint(main) ``` Tip: If your `main` function is `async`, you can just do `asyncio.run(chz.nested_entrypoint(main))`. """ # This function should be easily forkable, so do not make it more complicated target = get_nested_target(main) value = chz.Blueprint(target).make_from_argv(argv, allow_hyphens=allow_hyphens) return main(value) @exit_on_entrypoint_error def methods_entrypoint( target: type[_T], *, argv: list[str] | None = None, transform: Callable[[chz.Blueprint[Any], Any, str], chz.Blueprint[Any]] | None = None, ) -> _T: """Easy way to create a script entrypoint using chz for methods on a class. For example, given main.py: ``` @chz.chz class Run: name: str def launch(self, cluster: str): "Launch a job on a cluster" return ("launch", self, cluster) if __name__ == "__main__": print(chz.methods_entrypoint(Run)) ``` Try out the following command line invocations: ``` python main.py launch self.name=job cluster=owl python main.py launch --help python main.py --help ``` Note that you can rename the `self` argument in your method to something else. """ if argv is None: argv = sys.argv[1:] is_help = not argv or argv[0] == "--help" is_valid = not argv or (argv[0].isidentifier() and hasattr(target, argv[0])) if is_help or not is_valid: f = io.StringIO() output = functools.partial(print, file=f) if not is_valid: output(f"WARNING: {argv[0]} is not a valid method") output(f"Entry point: methods of {type_repr(target)}") output() output("Available methods:") for name in dir(target): meth = getattr(target, name) if not name.startswith("_") and callable(meth): meth_doc = getattr(meth, "__doc__", "") or "" meth_doc = meth_doc.strip().split("\n", 1)[0] output(f" {name} {meth_doc}".rstrip()) raise EntrypointHelpException(f.getvalue()) blueprint = chz.Blueprint(getattr(target, argv[0])) if transform is not None: blueprint = transform(blueprint, target, argv[0]) return blueprint.make_from_argv(argv[1:]) @exit_on_entrypoint_error def dispatch_entrypoint( targets: dict[str, Callable[..., _T]], *, argv: list[str] | None = None ) -> _T: """Easy way to create a script entrypoint using chz for dispatching to different functions. Conceptually, this is strictly a subset of the universal `python -m chz.universal` entrypoint. Compared to that, or methods_entrypoint, this basically just lets you flatten args one level. ``` def say_hello(name: str) -> None: print(f"Hello, {name}!") def say_goodbye(name: str) -> None: print(f"Goodbye, {name}!") chz.dispatch_entrypoint({ "hello": say_hello, "goodbye": say_goodbye, }) ``` """ if argv is None: argv = sys.argv[1:] is_help = not argv or argv[0] == "--help" is_valid = not argv or (argv[0].isidentifier() and argv[0] in targets) if is_help or not is_valid: f = io.StringIO() output = functools.partial(print, file=f) if not is_valid: output(f"WARNING: {argv[0]} is not a valid entrypoint") output("Available entrypoints:") for name in targets: meth = targets[name] meth_doc = getattr(meth, "__doc__", "") or "" meth_doc = meth_doc.strip().split("\n", 1)[0] output(f" {name} {meth_doc}".rstrip()) raise EntrypointHelpException(f.getvalue()) return chz.Blueprint(targets[argv[0]]).make_from_argv(argv[1:]) def _resolve_annotation(annotation: Any, func: Any) -> Any: """Resolves a type annotation against the globals of the target function.""" if annotation is inspect.Parameter.empty: return None if isinstance(annotation, str): return eval_in_context(annotation, func) return annotation def get_nested_target(main: Callable[[_T], object]) -> type[_T]: """Returns the type of the first argument of a function. For example: ``` def main(run: Run) -> None: ... assert chz.get_nested_target(main) is Run ``` """ params = list(inspect.signature(main).parameters.values()) if not params or params[0].annotation == inspect.Parameter.empty: raise ValueError("Nested entrypoints must take a type annotated argument") if any(p.default is p.empty for p in params[1:]): raise ValueError("Nested entrypoints must take at most one argument without a default") return _resolve_annotation(params[0].annotation, main) ================================================ FILE: chz/blueprint/_lazy.py ================================================ import collections from typing import AbstractSet, Any, Callable, TypeVar from chz.blueprint._entrypoint import InvalidBlueprintArg from chz.blueprint._wildcard import wildcard_key_approx, wildcard_key_to_regex from chz.tiepin import type_repr T = TypeVar("T") class Evaluatable: ... class Value(Evaluatable): def __init__(self, value: Any) -> None: self.value = value def __repr__(self) -> str: return f"Value({self.value})" class ParamRef(Evaluatable): def __init__(self, ref: str) -> None: self.ref = ref def __repr__(self) -> str: return f"ParamRef({self.ref})" class Thunk(Evaluatable): def __init__(self, fn: Callable[..., Any], kwargs: dict[str, ParamRef]) -> None: self.fn = fn self.kwargs = kwargs def __repr__(self) -> str: return f"Thunk({type_repr(self.fn)}, {self.kwargs})" def evaluate(value_mapping: dict[str, Evaluatable]) -> Any: assert "" in value_mapping refs_in_progress = collections.OrderedDict[str, None]() def inner(ref: str) -> Any: if ref in refs_in_progress: cycle = " -> ".join(list(refs_in_progress.keys())[1:] + [ref]) raise RecursionError(f"Detected cyclic reference: {cycle}") refs_in_progress[ref] = None try: value = value_mapping[ref] assert isinstance(value, Evaluatable) if isinstance(value, Value): return value.value if isinstance(value, ParamRef): try: ret = inner(value.ref) except Exception as e: e.add_note(f" (when dereferencing {ref!r})") raise assert not isinstance(ret, Evaluatable) value_mapping[ref] = Value(ret) return ret if isinstance(value, Thunk): kwargs = {} for k, v in value.kwargs.items(): assert isinstance(v, ParamRef) try: kwargs[k] = inner(v.ref) except Exception as e: e.add_note(f" (when evaluating argument {k!r} for {type_repr(value.fn)})") raise ret = value.fn(**kwargs) return ret finally: item = refs_in_progress.popitem() assert item[0] == ref raise AssertionError return inner("") def check_reference_targets( value_mapping: dict[str, Evaluatable], param_paths: AbstractSet[str] ) -> None: invalid_references: dict[str, list[str]] = {} def record_invalid(ref: str, referrer: str) -> None: if not referrer: return if ref not in param_paths: referrers = invalid_references.setdefault(ref, []) if referrer not in referrers: referrers.append(referrer) def walk(value: Evaluatable, referrer: str) -> None: if isinstance(value, ParamRef): record_invalid(value.ref, referrer) elif isinstance(value, Thunk): for param_ref in value.kwargs.values(): walk(param_ref, referrer) for param_path, value in value_mapping.items(): walk(value, param_path) if invalid_references: errors = [] for reference, referrers in invalid_references.items(): ratios = {p: wildcard_key_approx(reference, p) for p in param_paths} extra = "" if ratios: max_option = max(ratios, key=lambda v: ratios[v][0]) if ratios[max_option][0] > 0.1: extra = f"\nDid you mean {ratios[max_option][1]!r}?" nested_pattern = wildcard_key_to_regex("..." + reference) found_key = next((p for p in param_paths if nested_pattern.fullmatch(p)), None) if found_key is not None: extra += f"\nDid you get the nesting wrong, maybe you meant {found_key!r}?" if len(referrers) > 1: referrers_str = "params " + ", ".join(referrers) else: referrers_str = f"param {referrers[0]}" errors.append(f"Invalid reference target {reference!r} for {referrers_str}" + extra) raise InvalidBlueprintArg("\n\n".join(errors)) ================================================ FILE: chz/blueprint/_wildcard.py ================================================ import re _FUZZY_SIMILARITY = 0.6 def wildcard_key_to_regex_str(key: str) -> str: if key.endswith("..."): raise ValueError("Wildcard not allowed at end of key") pattern = r"(.*\.)?" if key.startswith("...") else "" pattern += r"\.(.*\.)?".join(map(re.escape, key.removeprefix("...").split("..."))) return pattern def wildcard_key_to_regex(key: str) -> re.Pattern[str]: return re.compile(wildcard_key_to_regex_str(key)) def _wildcard_key_match(key: str, target_str: str) -> bool: # This is what the regex does; currently unused (but is tested) if key.endswith("..."): raise ValueError("Wildcard not allowed at end of key") pattern = ["..."] if key.startswith("...") else [] pattern += [x for x in re.split(r"(\.\.\.)|\.", key.removeprefix("...")) if x is not None] target = target_str.split(".") _grid: dict[tuple[int, int], bool] = {} def _match(i: int, j: int) -> bool: if i == len(pattern): return j == len(target) if j == len(target): return False if (i, j) in _grid: return _grid[i, j] if pattern[i] == "...": ret = _match(i, j + 1) or _match(i + 1, j) _grid[i, j] = ret return ret ret = pattern[i] == target[j] and _match(i + 1, j + 1) _grid[i, j] = ret return ret return _match(0, 0) def wildcard_key_approx(key: str, target_str: str) -> tuple[float, str]: """ Returns a score and a string representing what the key should have been to match the target. Currently only used in error messages. """ if key.endswith("..."): raise ValueError("Wildcard not allowed at end of key") pattern = ["..."] if key.startswith("...") else [] pattern += [x for x in re.split(r"(\.\.\.)|\.", key.removeprefix("...")) if x is not None] target = target_str.split(".") import difflib _grid: dict[tuple[int, int], tuple[float, tuple[str, ...]]] = {} def _match(i, j) -> tuple[float, tuple[str, ...]]: if i == len(pattern): return (1, ()) if j == len(target) else (0, ()) if j == len(target): return (0, ()) if (i, j) in _grid: return _grid[i, j] if pattern[i] == "...": with_wildcard = _match(i, j + 1) without_wildcard = _match(i + 1, j) if with_wildcard[0] * _FUZZY_SIMILARITY > without_wildcard[0]: score, value = with_wildcard score *= _FUZZY_SIMILARITY else: score, value = without_wildcard if value and value[0] != "...": value = ("...",) + value ret = (score, value) _grid[i, j] = ret return ret ratio = difflib.SequenceMatcher(a=pattern[i], b=target[j]).ratio() if ratio >= _FUZZY_SIMILARITY: score, value = _match(i + 1, j + 1) score *= ratio if value and value[0] != "...": value = (target[j] + ".",) + value else: value = (target[j],) + value ret = (score, value) _grid[i, j] = ret return ret return 0, () score, value = _match(0, 0) return score, "".join(value) ================================================ FILE: chz/data_model.py ================================================ """ This is the core implementation of the chz class. It's based off of the implementation of dataclasses, but is somewhat simpler. I also fixed a couple minor issues in dataclasses when writing this :-) Some non-exhaustive reasons why chz's feature set isn't built on top of dataclasses / attrs: - dataclasses is a general purpose class replacement, chz isn't. This lets us establish intention, have better defaults, make different tradeoffs, better errors in various places - Ability to have custom logic in chz.field - Clearer handling of type annotation evaluation and scopes - chz needs keyword-only arguments for various reasons (dataclasses acquired this only later) - Cool data model tricks like munging and init_property - Many small things """ import builtins import copy import dataclasses import functools import hashlib import inspect import sys import types import typing from collections.abc import Collection, Mapping from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar import typing_extensions from chz.field import Field from chz.tiepin import type_repr from chz.util import MISSING FrozenInstanceError = dataclasses.FrozenInstanceError _T = TypeVar("_T") _INIT_ALTERNATIVES: str = ( "For validation, see @chz.validate decorators. " "For per-field defaults, see `default` and `default_factory` options in chz.field. " "To perform post-initialization rewrites of field values, use `munger` option in chz.field " "or add an `init_property` to the class.\n" "See the docs for more details." ) def _create_fn( name: str, args: list[str], body: list[str], *, locals: dict[str, Any], globals: dict[str, Any] ): args_str = ",".join(args) body_str = "\n".join(f" {b}" for b in body) # Compute the text of the entire function. txt = f" def {name}({args_str}):\n{body_str}" # Free variables in exec are resolved in the global namespace. # The global namespace we have is user-provided, so we can't modify it for # our purposes. So we put the things we need into locals and introduce a # scope to allow the function we're creating to close over them. local_vars = ", ".join(locals.keys()) txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}" ns: Any = {} exec(txt, globals, ns) return ns["__create_fn__"](**locals) # ============================== # Method synthesis # ============================== def _synthesise_field_init(f: Field, out_vars: dict[str, Any]) -> tuple[str, str]: # This function modifies out_vars var_type = f"__chz_{f.logical_name}" out_vars[var_type] = f._raw_type var_default = f"__chz_dflt_{f.logical_name}" if f._default_factory is not MISSING: out_vars[var_default] = f._default_factory value = f"{var_default}() if {f.logical_name} is __chz_MISSING else {f.logical_name}" dflt_expr = " = __chz_MISSING" elif f._default is not MISSING: out_vars[var_default] = f._default # Is it ever useful to explicitly pass MISSING? # value = f"{var_default} if {f.logical_name} is __chz_MISSING else {f.logical_name}" value = f.logical_name dflt_expr = f" = {var_default}" else: value = f.logical_name dflt_expr = "" arg = f"{f.logical_name}: {var_type}{dflt_expr}" body = f"__chz_builtins.object.__setattr__(self, {f.x_name!r}, {value})" return arg, body def _synthesise_init(fields: Collection[Field], user_globals: dict[str, Any]) -> Callable[..., Any]: varlocals = {"__chz_MISSING": MISSING, "__chz_builtins": builtins} # __chz_args is not strictly necessary, but makes for better errors args = ["self", "*__chz_args"] body = [ "if __chz_args:", " raise __chz_builtins.TypeError(f'{self.__class__.__name__}.__init__ only takes keyword arguments')", "if '__chz_fields__' not in __chz_builtins.type(self).__dict__:", " raise __chz_builtins.TypeError(f'{self.__class__.__name__} is not decorated with @chz.chz')", ] for field in fields: if field.logical_name.startswith("__chz") or field.logical_name == "self": raise ValueError(f"Field name {field.logical_name!r} is reserved") _arg, _body = _synthesise_field_init(field, varlocals) args.append(_arg) body.append(_body) # Note it's important we validate before we check all init_property body.append("self.__chz_validate__()") body.append("self.__chz_init_property__()") return _create_fn("__init__", args, body, locals=varlocals, globals=user_globals) def __setattr__(self, name, value): raise FrozenInstanceError(f"Cannot modify field {name!r}") def __delattr__(self, name): raise FrozenInstanceError(f"Cannot delete field {name!r}") def _recursive_repr(user_function): import threading repr_running = set() @functools.wraps(user_function) def wrapper(self): key = id(self), threading.get_ident() if key in repr_running: return "..." repr_running.add(key) try: result = user_function(self) finally: repr_running.discard(key) return result return wrapper def __repr__(self) -> str: def field_repr(field: Field) -> str: # use x_name so that repr can be copy-pasted to create the same object if callable(field._repr): return field._repr(getattr(self, field.x_name)) assert isinstance(field._repr, bool) if field._repr: return repr(getattr(self, field.x_name)) return "..." contents = ", ".join( f"{field.logical_name}={field_repr(field)}" for field in self.__chz_fields__.values() ) return self.__class__.__qualname__ + f"({contents})" def __eq__(self, other): if self.__class__ is not other.__class__: return NotImplemented return all(getattr(self, name) == getattr(other, name) for name in self.__chz_fields__) def __hash__(self) -> int: try: return hash(tuple((name, getattr(self, name)) for name in self.__chz_fields__)) except TypeError as e: for name in self.__chz_fields__: value = getattr(self, name) try: hash(value) except TypeError: raise TypeError( f"Cannot hash chz field: {type(self).__name__}.{name}={value}" ) from e raise e def __chz_validate__(self) -> None: for field in self.__chz_fields__.values(): if field._munger is None: for validator in field._validator: # So without mungers, we always run validators against the raw value # There is currently code that relies on not running validator against a potential # user-specified init_property # TODO: is it unfortunate that x_name appears in error messages? validator(self, field.x_name) else: # With mungers, we run validators against both the munged and unmunged value # I'm willing to reconsider this, but want to be conservative for now for validator in field._validator: validator(self, field.logical_name) validator(self, field.x_name) for validator in getattr(self, "__chz_validators__", []): validator(self) @functools.lru_cache() def _get_init_properties(cls: type) -> list[str]: return [ name for name, _obj in inspect.getmembers_static(cls, lambda o: isinstance(o, init_property)) ] def __chz_init_property__(self) -> None: for name in _get_init_properties(self.__class__): getattr(self, name) def pretty_format(obj: Any, colored: bool = True) -> str: """Format a chz object for human readability.""" bold = "\033[1m" if colored else "" blue = "\033[34m" if colored else "" grey = "\033[90m" if colored else "" reset = "\033[0m" if colored else "" space = " " * 4 if isinstance(obj, (list, tuple)): if not obj or all(not is_chz(x) for x in obj): return repr(obj) a, b = ("[", "]") if isinstance(obj, list) else ("(", ")") items = [pretty_format(x, colored).replace("\n", "\n" + space) for x in obj] items_str = f",\n{space}".join(items) return f"{a}\n{space}{items_str},\n{b}" if isinstance(obj, dict): if not obj or all(not is_chz(x) for x in obj.values()): return repr(obj) items = [] for k, v in obj.items(): k_str = pretty_format(k, colored).replace("\n", "\n" + space) v_str = pretty_format(v, colored).replace("\n", "\n" + space) items.append(f"{k_str}: {v_str}") items_str = f",\n{space}".join(items) return f"{{\n{space}{items_str},\n}}" if not is_chz(obj): return repr(obj) cls_name = obj.__class__.__qualname__ out = f"{bold}{cls_name}({reset}\n" def field_repr(field: Field) -> str: # use x_name so that repr can be copy-pasted to create the same object if field._repr is False: return "..." if callable(field._repr): r = field._repr else: assert field._repr is True r = lambda o: pretty_format(o, colored=colored) x_val = getattr(obj, field.x_name) val = getattr(obj, field.logical_name) if x_val is val: return r(val) return f"{grey}{r(x_val)} # {reset}{r(val)}{grey} (after init){reset}" field_reprs: dict[bool, list[str]] = {True: [], False: []} for field in sorted(obj.__chz_fields__.values(), key=lambda f: f.logical_name): if field._default is not MISSING: matches_default = field._default is getattr(obj, field.x_name) elif field._default_factory is not MISSING: matches_default = field._default_factory() == getattr(obj, field.x_name) else: matches_default = False val_str = field_repr(field).replace("\n", "\n" + space) field_str = f"{space}{blue}{field.logical_name}={reset}{val_str},\n" field_reprs[matches_default].append(field_str) out += "".join(field_reprs[False]) if field_reprs[True]: out += f"{space}{bold}# Fields where pre-init value matches default:{reset}\n" out += "".join(field_reprs[True]) out += f"{bold}){reset}" return out def _repr_pretty_(self, p, cycle: bool) -> None: # for nice ipython printing p.text(pretty_format(self)) def __chz_pretty__(self, colored: bool = True) -> str: """Print a chz object for human readability.""" return pretty_format(self, colored=colored) # ============================== # Construction # ============================== def _is_classvar_annotation(annot: str | Any) -> bool: if isinstance(annot, str): # TODO: use better dataclass logic? return annot.startswith(("typing.ClassVar[", "ClassVar[")) return annot is typing.ClassVar or ( type(annot) is typing._GenericAlias # type: ignore[attr-defined] and annot.__origin__ is typing.ClassVar ) def _is_property_like(obj: Any) -> bool: # TODO: the semantics implied here could be more crisply defined and maybe generalised to # more descriptors return isinstance(obj, (property, init_property, functools.cached_property)) def chz_make_class(cls, version: str | None, typecheck: bool | None) -> type: if cls.__class__ is not type: if cls.__class__ is typing._ProtocolMeta: if typing_extensions.is_protocol(cls): raise TypeError("chz class cannot itself be a Protocol)") else: import abc if cls.__class__ is not abc.ABCMeta: raise TypeError("Cannot use custom metaclass") user_module = cls.__module__ cls_annotations = typing_extensions.get_annotations(cls) fields: dict[str, Field] = {} # Collect fields from parent classes for b in reversed(cls.__mro__): if hasattr(b, "__dataclass_fields__"): raise ValueError("Cannot mix chz with dataclasses") # Only process classes that have been processed by our decorator base_fields: dict[str, Field] | None = getattr(b, "__chz_fields__", None) if base_fields is None: continue for f in base_fields.values(): if ( f.logical_name in cls.__dict__ and f.logical_name not in cls_annotations and not _is_property_like(getattr(cls, f.logical_name)) ): # Do an LSP check against parent fields (for non-property-like members) raise ValueError( f"Cannot override field {f.logical_name!r} with a non-field member; " f"maybe you're missing a type annotation?" ) else: fields[f.logical_name] = f # Collect fields from the current class for name, annotation in cls_annotations.items(): if _is_classvar_annotation(annotation): continue # Find the field specification from the class __dict__ value = cls.__dict__.get(name, MISSING) if value is MISSING: field = Field(name=name, raw_type=annotation) elif isinstance(value, Field): field = value field._name = name field._raw_type = annotation delattr(cls, name) else: if _is_property_like(value) or ( isinstance(value, types.FunctionType) and value.__name__ != "" and value.__qualname__.startswith(cls.__qualname__) ): # It's problematic to redefine the field in the same class, because it means we # lose any field specification or default value raise ValueError(f"Field {name!r} is clobbered by {type_repr(type(value))}") field = Field(name=name, raw_type=annotation, default=value) delattr(cls, name) field._user_module = user_module # Do a basic LSP check for new fields parent_value = getattr(cls, name, MISSING) # note the delattr above if parent_value is not MISSING and not ( field.logical_name in fields and isinstance(parent_value, init_property) ): raise ValueError( f"Cannot define field {name!r} because it conflicts with something defined on a " f"superclass: {parent_value!r}" ) other_name = field.logical_name if name != field.logical_name else field.x_name parent_value = getattr(cls, other_name, MISSING) if ( parent_value is not MISSING and not (field.logical_name in fields and isinstance(parent_value, init_property)) and other_name not in cls.__dict__ ): raise ValueError( f"Cannot define field {name!r} because it conflicts with something defined on a " f"superclass: {parent_value!r}" ) if ( name == field.logical_name and name not in cls.__dict__ and name in fields and fields[name]._name != name ): raise ValueError( "I'm a little unsure of what the semantics should be here. " "See test_conflicting_superclass_x_field_in_base. " "Please let @shantanu know if you hit this. " f"You can also just rename the field in the subclass to X_{name}." ) # Create a default init_property for the field that accesses the raw X_ field munger: Any = field.get_munger() if munger is not None: if field.logical_name in cls.__dict__: raise ValueError( f"Cannot define {field.logical_name!r} in class when the associated field " f"has a munger" ) munger.__name__ = field.logical_name munger = init_property(munger) munger.__set_name__(cls, field.logical_name) setattr(cls, field.logical_name, munger) if ( # but don't clobber existing definitions... field.logical_name not in cls.__dict__ # ...if something is already there in class and field.logical_name not in fields # ...if a parent has defined the field ): fn: Any = lambda self, x_name=field.x_name: getattr(self, x_name) fn.__name__ = field.logical_name fn = init_property(fn) fn.__set_name__(cls, field.logical_name) setattr(cls, field.logical_name, fn) fields[field.logical_name] = field for name, value in cls.__dict__.items(): if isinstance(value, Field) and name not in cls_annotations: raise TypeError(f"{name!r} has no type annotation") # Mark the class as having been processed by our decorator cls.__chz_fields__ = fields if "__init__" in cls.__dict__: raise ValueError("Cannot define __init__ on a chz class. " + _INIT_ALTERNATIVES) if "__post_init__" in cls.__dict__: raise ValueError("Cannot define __post_init__ on a chz class. " + _INIT_ALTERNATIVES) cls.__init__ = _synthesise_init(fields.values(), sys.modules[user_module].__dict__) cls.__init__.__qualname__ = f"{cls.__qualname__}.__init__" cls.__chz_validate__ = __chz_validate__ cls.__chz_init_property__ = __chz_init_property__ if "__setattr__" in cls.__dict__: raise ValueError("Cannot define __setattr__ on a chz class") cls.__setattr__ = __setattr__ if "__delattr__" in cls.__dict__: raise ValueError("Cannot define __delattr__ on a chz class") cls.__delattr__ = __delattr__ if "__repr__" not in cls.__dict__: cls.__repr__ = __repr__ if "__eq__" not in cls.__dict__: cls.__eq__ = __eq__ if "__hash__" not in cls.__dict__: cls.__hash__ = __hash__ if "_repr_pretty_" not in cls.__dict__: # Special-cased by IPython cls._repr_pretty_ = _repr_pretty_ if "__chz_pretty__" not in cls.__dict__: cls.__chz_pretty__ = __chz_pretty__ if version is not None: import json # Hash all the fields and check the version matches expected_version = version.split("-")[0] key = [f.versioning_key() for f in sorted(fields.values(), key=lambda f: f.x_name)] key_bytes = json.dumps(key, separators=(",", ":")).encode() actual_version = hashlib.sha1(key_bytes).hexdigest()[:8] if actual_version != expected_version: raise ValueError(f"Version {version!r} does not match {actual_version!r}") if typecheck is not None: import chz.validators as chzval if typecheck: chzval._ensure_chz_validators(cls) if chzval._decorator_typecheck not in cls.__chz_validators__: cls.__chz_validators__.append(chzval._decorator_typecheck) else: if chzval._decorator_typecheck in getattr(cls, "__chz_validators__", []): raise ValueError("Cannot disable typecheck; all validators are inherited") return cls # ============================== # is_chz # ============================== def is_chz(c: object) -> bool: """Check if an object is a chz object.""" return hasattr(c, "__chz_fields__") # ============================== # __chz_fields__ # ============================== def chz_fields(c: object) -> dict[str, Field]: return c.__chz_fields__ # type: ignore[attr-defined] # ============================== # replace # ============================== def replace(obj: _T, /, **changes) -> _T: """Return a new object replacing specified fields with new values. Example: ``` @chz.chz class Foo: a: int b: str foo = Foo(a=1, b="hello") assert chz.replace(foo, a=101) == Foo(a=101, b="hello") ``` This just constructs a new object, so for example, the generated `__init__` gets run and validation will work exactly as if you manually constructed the new object. """ if not hasattr(obj, "__chz_fields__"): raise ValueError(f"{obj} is not a chz object") for field in obj.__chz_fields__.values(): if field.logical_name not in changes: changes[field.logical_name] = getattr(obj, field.x_name) return obj.__class__(**changes) # ============================== # asdict # ============================== def asdict( obj: object, shallow: bool = False, include_type: bool = False, exclude: Collection[str] | None = None, ) -> dict[str, Any]: """Recursively convert a chz object to a dict. This works similarly to dataclasses.asdict. Note no computed properties will be included in the output. See also: beta_to_blueprint_values Args: - shallow: if True, only take shallow copies of inner values. Otherwise, deep copies are made. - include_type: If True, include the type of the object in the output dict for each chz object. Useful for debugging and identity. - exclude: Iterable of field names to omit from the resulting dict at the top level. """ exclude_set = set(exclude) if exclude is not None else None def inner(x: Any, current_exclude: Collection[str] | None = None): if hasattr(x, "__chz_fields__"): result = { k: inner(getattr(x, k)) for k in x.__chz_fields__ if not current_exclude or k not in current_exclude } if include_type: result["__chz_type__"] = type_repr(type(x)) return result if isinstance(x, dict): return {k: inner(v) for k, v in x.items()} if isinstance(x, list): return [inner(x) for x in x] if isinstance(x, tuple): return tuple(inner(x) for x in x) if shallow: return x else: return copy.deepcopy(x) if not hasattr(obj, "__chz_fields__"): raise RuntimeError(f"{obj} is not a chz object") result = inner(obj, exclude_set) assert type(result) is dict return result def traverse(obj: Any, obj_path: str = "") -> Iterable[tuple[str, Any]]: """Traverses the chz object and yields (path, value) pairs for all sub attributes recursively.""" assert is_chz(obj) yield obj_path, obj for f in obj.__chz_fields__.values(): value = getattr(obj, f.logical_name) field_path = f"{obj_path}.{f.logical_name}" if obj_path else f.logical_name yield field_path, value if is_chz(value): yield from traverse(value, field_path) if isinstance(value, Mapping): for k, v in value.items(): if is_chz(v): yield from traverse(v, f"{field_path}.{k}") else: yield f"{field_path}.{k}", v elif isinstance(value, (list, tuple)): for i, v in enumerate(value): if is_chz(v): yield from traverse(v, f"{field_path}.{i}") else: yield f"{field_path}.{i}", v # ============================== # beta_to_blueprint_values # ============================== def beta_to_blueprint_values(obj, skip_defaults: bool = False) -> Any: """Return a dict which can be used to recreate the same object via blueprint. Args: - obj: The object to convert to blueprint values. - skip_defaults: If True, fields whose values are equal to their default values will be omitted from the output dict. If False (default), all fields are included. Example: ``` @chz.chz class Foo: a: int b: str foo = Foo(a=1, b="hello") assert chz.Blueprint(Foo).apply(chz.beta_to_blueprint_values(foo)).make() == foo ``` See also: asdict """ blueprint_values = {} def join_arg_path(parent: str, child: str) -> str: if not parent: return child if child.startswith("."): return parent + child return parent + "." + child def inner(obj: Any, path: str): if hasattr(obj, "__chz_fields__"): for field_name, field_info in obj.__chz_fields__.items(): value = getattr(obj, field_info.x_name) if skip_defaults and field_info._default == value: continue param_path = join_arg_path(path, field_name) if field_info.meta_factory is not None and ( type(value) is not typing.get_origin(field_info.meta_factory.unspecified_factory()) ): # Try to detect when we have used polymorphic construction blueprint_values[param_path] = type(value) inner(value, param_path) elif ( isinstance(obj, dict) and all(isinstance(k, str) for k in obj.keys()) and any(is_chz(v) for v in obj.values()) ): for k, v in obj.items(): param_path = join_arg_path(path, k) blueprint_values[param_path] = type(v) # may be overridden if not needed inner(v, param_path) elif isinstance(obj, (list, tuple)) and any(is_chz(v) for v in obj): for i, v in enumerate(obj): param_path = join_arg_path(path, str(i)) blueprint_values[param_path] = type(v) # may be overridden if not needed inner(v, param_path) else: blueprint_values[path] = obj inner(obj, "") return blueprint_values # ============================== # init_property # ============================== if TYPE_CHECKING: init_property = functools.cached_property else: class init_property: # Simplified and pickleable version of Python 3.8's cached_property # It's important that this remains a non-data descriptor def __init__(self, func: Callable[..., Any]) -> None: self.func = func self.name = None def __set_name__(self, owner, name): self.name = name # Basically just validation func_name = self.func.__name__ if ( name != func_name and func_name != "" # TODO: remove figure out why mini needs name mangling and not func_name.endswith("__register_chz_has_state") ): raise ValueError("Are you doing something weird with init_property?") def __get__(self, obj: Any, cls: Any) -> Any: if obj is None: return self ret = self.func(obj) if self.name is not None: obj.__dict__[self.name] = ret return ret ================================================ FILE: chz/factories.py ================================================ import ast import collections import functools import importlib import re import sys import types import typing from typing import Any, Callable import typing_extensions from chz.tiepin import ( CastError, InstantiableType, TypeForm, _simplistic_try_cast, eval_in_context, is_instantiable_type, is_subtype, is_subtype_instance, is_union_type, type_repr, ) from chz.util import MISSING, MISSING_TYPE class MetaFromString(Exception): ... class MetaFactory: """ A metafactory represents a set of possible factories, where a factory is a callable that can give us a value of a given type. This is the heart of polymorphic construction in chz. The idea is that when instantiating Blueprints, you should be able to not only specify the arguments to whatever is being constructed, but also specify what the thing to be constructed is! In other words, when constructing a value, chz lets you specify the factory to produce it, in addition to the arguments to pass to that factory. In other other words, many tools will let you construct an X by specifying `...` to feed to `X(...)`. But chz lets you construct an X by specifying both callee and arguments in `...(...)` This concept is a little tricky, but it's fairly intuitive when you actually use it. Consider looking at the docstring for `subclass` for a concrete example. """ def __init__(self) -> None: # Set by chz.Field self.field_annotation: TypeForm | MISSING_TYPE = MISSING self.field_module: types.ModuleType | str | MISSING_TYPE = MISSING def unspecified_factory(self) -> Callable[..., Any] | None: """The default callable to use to get a value of the expected type. If this returns None, there is no default. In order to construct a value of the expected type, the user must explicitly specify a factory. """ raise NotImplementedError def from_string(self, factory: str) -> Callable[..., Any]: """The callable that best corresponds to `factory`.""" raise NotImplementedError def perform_cast(self, value: str): # TODO: maybe make this default to: # return _simplistic_try_cast(value, default_target) raise NotImplementedError class subclass(MetaFactory): """ ATTN: this is soft deprecated, since `chz.factories.standard` is powerful enough to effectively do a superset of this. Read the docstring for MetaFactory first. ``` @chz.chz class Experiment: model: Model ``` In the above example, we want to construct a value for the model for our experiment. How should we go about making a model? The meta_factory we provide is what is meant to answer this question. And in this case, the answer we want is: we should make a model by attempting to instantiate `Model` or some subclass of `Model`. This is a common enough answer that chz in fact defaults to it. That is, here chz will set the meta_factory to be `subclass(base_cls=Model, default_cls=Model)` for our model field. See the logic in chz.Field. Given `model=Transformer model.n_layers=16 model.d_model=1024` chz will construct `Transformer(n_layers=16, d_model=1024` That is, if the user specifies a factory for the model field, e.g. model="Transformer", then the logic in `subclass.from_string` will attempt to find a subclass of `Model` (the `base_cls`) named `Transformer` and instantiate it. Given `model.n_layers=16 model.d_model=1024` chz will construct `Model(n_layers=Y, d_model=Z)` That is, if the user doesn't specify a factory (maybe they only specify subcomponents, like `model.n_layers=16`), then we will default to trying to instantiate `Model` (the `default_cls`). """ def __init__( self, base_cls: InstantiableType | MISSING_TYPE = MISSING, default_cls: InstantiableType | MISSING_TYPE = MISSING, ) -> None: super().__init__() self._base_cls = base_cls self._default_cls = default_cls def __repr__(self) -> str: return f"subclass(base_cls={self.base_cls!r}, default_cls={self.default_cls!r})" @property def base_cls(self) -> InstantiableType: if isinstance(self._base_cls, MISSING_TYPE): assert not isinstance(self.field_annotation, MISSING_TYPE) if not isinstance(self.field_annotation, InstantiableType): raise RuntimeError( f"Must explicitly specify base_cls since {self.field_annotation!r} " "is not an instantiable type" ) return self.field_annotation return self._base_cls @property def default_cls(self) -> InstantiableType: if isinstance(self._default_cls, MISSING_TYPE): return self.base_cls return self._default_cls def unspecified_factory(self) -> Callable[..., Any]: return self.default_cls # type: ignore[return-value] def from_string(self, factory: str) -> Callable[..., Any]: """ If factory=module:cls, import module and return cls. If factory=cls, do our best to find a subclass of base_cls named cls. """ return _find_subclass(factory, self.base_cls) def perform_cast(self, value: str): try: return _simplistic_try_cast(value, self.default_cls) except CastError: pass return _simplistic_try_cast(value, self.base_cls) class function(MetaFactory): def __init__( self, unspecified: Callable[..., Any] | None = None, *, default_module: str | types.ModuleType | None | MISSING_TYPE = MISSING, ) -> None: """ ATTN: this is soft deprecated, since `chz.factories.standard` is powerful enough to effectively do a superset of this. Read the docstring for `MetaFactory` and `subclass` first. If you specify `function` as your meta_factory, any function can serve as a factory to construct a value of the expected type. ``` def wikipedia_text(seed: int) -> Dataset: ... @chz.chz class Experiment: dataset: Dataset = field(meta_factory=function()) ``` In the above example, we want to construct a value of type `Dataset` for our experiment. The way we want to do this is by calling some function that returns a `Dataset`. Given `dataset=wikipedia_text dataset.seed=217` chz will construct `wikipedia_text(seed=217)`. If you use a fully qualified name like `function=module:fn` it's obvious where to find the function. Otherwise, chz looks for an appropriately named function in the module `default_module` (which defaults to the module in which the chz class was defined). If you love `wikipedia_text` and you don't wish to explicitly specify `dataset=wikipedia_text` every time, set the `unspecified` argument to be `wikipedia_text`. This way, chz will default to trying to call `wikipedia_text` to instantiate a value of type `Dataset`, instead of erroring because it doesn't know what factory to use to produce a Dataset. """ super().__init__() self.unspecified = unspecified self._default_module = default_module def __repr__(self) -> str: return f"function(unspecified={self.unspecified!r}, default_module={self.default_module!r})" @property def default_module(self) -> types.ModuleType | str | None: if isinstance(self._default_module, MISSING_TYPE): assert not isinstance(self.field_module, MISSING_TYPE) return self.field_module return self._default_module def unspecified_factory(self) -> Callable[..., Any] | None: return self.unspecified def from_string(self, factory: str) -> Callable[..., Any]: """ If factory=module:fn, import module and return fn. If factory=fn, look in the default module for a function named fn. """ if ":" not in factory: if self.default_module is None: raise MetaFromString( f"No module specified in {factory!r} and no default module specified" ) if isinstance(self.default_module, str): module = importlib.import_module(self.default_module) else: module = self.default_module var = factory else: module_name, var = factory.split(":", 1) if module_name != "lambda" and not module_name.startswith("lambda "): module = _module_from_name(module_name) else: import ast if isinstance(self.default_module, str): eval_ctx = importlib.import_module(self.default_module) elif self.default_module is not None: eval_ctx = self.default_module else: eval_ctx = None try: # TODO: add docs for this branch if isinstance(ast.parse(factory).body[0].value, ast.Lambda): # type: ignore[attr-defined] return eval_in_context(factory, eval_ctx) except Exception as e: raise MetaFromString( f"Could not interpret {factory!r} as a function: {e}" ) from None raise AssertionError return _module_getattr(module, var) def perform_cast(self, value: str): assert not isinstance(self.field_annotation, MISSING_TYPE) return _simplistic_try_cast(value, self.field_annotation) def _module_from_name(name: str) -> types.ModuleType: try: return importlib.import_module(name) except ImportError as e: raise MetaFromString( f"Could not import module {name!r} ({type(e).__name__}: {e})" ) from None def _module_getattr(mod: types.ModuleType, attr: str) -> Any: try: for a in attr.split("."): mod = getattr(mod, a) return mod except AttributeError as e: raise MetaFromString(str(e)) from None def _find_subclass(spec: str, superclass: TypeForm): module_name = None if ":" in spec: module_name, var = spec.split(":", 1) else: var = spec match = re.fullmatch(r"(?P[^\s\[\]]+)(\[(?P.+)\])?", var) if match is None: raise MetaFromString(f"Failed to parse '{spec}' as a class name") base = match.group("base") generic = match.group("generic") if module_name is None and not base.isidentifier(): if "." in base: # This effectively adds some basic support for module.symbol, not just module:symbol module_name, base = base.rsplit(".", 1) if not base.isidentifier(): raise MetaFromString( f"No subclass of {type_repr(superclass)} named {base!r} (invalid identifier)" ) if module_name is not None: module = _module_from_name(module_name) # TODO: think about this type ignore value = _maybe_generic( _module_getattr(module, base), generic, template=superclass, # type: ignore[arg-type] ) if is_subtype(value, superclass): return value raise MetaFromString( f"Expected a subtype of {type_repr(superclass)}, got {type_repr(value)}" ) superclass_class_origin = getattr(superclass, "__origin__", superclass) if superclass_class_origin in {object, typing.Any, typing_extensions.Any}: try: return _maybe_generic( _module_getattr(_module_from_name("__main__"), base), generic, template=superclass, # type: ignore[arg-type] ) except MetaFromString: pass try: return _maybe_generic( _module_getattr(_module_from_name("builtins"), base), generic, template=superclass, # type: ignore[arg-type] ) except MetaFromString: pass raise MetaFromString( f"Could not find {spec!r}, try a fully qualified name e.g. module_name:{spec}" ) from None if not is_instantiable_type(superclass_class_origin): raise MetaFromString(f"Could not find subclasses of {type_repr(superclass)}") assert superclass_class_origin is not type visited_subclasses = set() all_subclasses = collections.deque(superclass_class_origin.__subclasses__()) all_subclasses.appendleft(superclass) candidates = [] while all_subclasses: cls = all_subclasses.popleft() if cls in visited_subclasses: continue visited_subclasses.add(cls) if cls.__name__ == base: assert module_name is None candidates.append(_maybe_generic(cls, generic, template=superclass)) # type: ignore[arg-type] cls_origin = getattr(cls, "__origin__", cls) assert cls_origin is not type all_subclasses.extend(cls_origin.__subclasses__()) if len(candidates) == 0: raise MetaFromString(f"No subclass of {type_repr(superclass)} named {base!r}") if len(candidates) > 1: raise MetaFromString( f"Multiple subclasses of {type_repr(superclass)} named {base!r}: " f"{', '.join(type_repr(c) for c in candidates)}" ) return candidates[0] def _maybe_generic( cls: type, generic: str | None, template: InstantiableType ) -> Callable[..., Any]: if generic is None: return cls assert isinstance(generic, str) generic_args_str = generic.split(",") args: list[object] = [] for i, arg_str in enumerate(generic_args_str): arg_str = arg_str.strip() if ":" in arg_str: module_name, arg = arg_str.split(":", 1) args.append(_module_getattr(_module_from_name(module_name), arg)) elif arg_str == "...": args.append(...) else: # TODO: note this assumes covariance, also give a better error superclass = template.__args__[i] # type: ignore[union-attr] args.append(_find_subclass(arg_str, superclass)) origin: Any = getattr(cls, "__origin__", cls) return origin[*args] def _return_prospective(obj: Any, annotation: TypeForm, factory: str) -> Any: if annotation not in { object, typing.Any, typing_extensions.Any, } and not isinstance(annotation, typing.TypeVar): if is_subtype_instance(obj, annotation): # Allow things to be instances! # In some sense, this is just working around deficiencies in casting... return lambda: obj elif not callable(obj): assert is_subtype_instance(obj, annotation) # Also allow things to be instances if we would just error on the next line return lambda: obj if not callable(obj): raise MetaFromString(f"Expected {obj} from {factory!r} to be callable") if isinstance(obj, type) and not is_subtype(obj, annotation): extra = "" if getattr(annotation, "__module__", None) == "__main__": if any( hasattr(sys.modules["__main__"], (witness := parent).__name__) for parent in obj.__mro__ ): extra = f" (there may be confusion between {type_repr(witness)} and __main__:{witness.__name__})" raise MetaFromString( f"Expected {type_repr(obj)} from {factory!r} to be a subtype of {type_repr(annotation)}{extra}" ) return obj def get_unspecified_from_annotation(annotation: TypeForm) -> Callable[..., Any] | None: if typing.get_origin(annotation) is type: base_type = typing.get_args(annotation)[0] if is_union_type(base_type): # No unspecified for type[SpecialForm] e.g. type[int | str] # TODO: annotated return None return type[base_type] # type: ignore[return-value] if is_union_type(annotation): type_args = typing.get_args(annotation) if type_args and len(type_args) == 2 and type(None) in type_args: unwrapped_optional = [t for t in type_args if t is not type(None)][0] if callable(unwrapped_optional): return unwrapped_optional return None if is_instantiable_type(annotation): return annotation # type: ignore[return-value] if annotation is None: return lambda: None # Probably a special form return None class standard(MetaFactory): def __init__( self, *, annotation: TypeForm | MISSING_TYPE = MISSING, unspecified: Callable[..., Any] | None = None, default_module: str | types.ModuleType | None | MISSING_TYPE = MISSING, ) -> None: super().__init__() self._annotation = annotation self.original_unspecified = unspecified self._default_module = default_module def __repr__(self) -> str: return f"standard(annotation={self.annotation!r}, unspecified={self.original_unspecified!r}, default_module={self.default_module!r})" @property def annotation(self) -> TypeForm: if isinstance(self._annotation, MISSING_TYPE): assert not isinstance(self.field_annotation, MISSING_TYPE) return self.field_annotation return self._annotation @property def default_module(self) -> types.ModuleType | str | None: if isinstance(self._default_module, MISSING_TYPE): if isinstance(self.field_module, MISSING_TYPE): # TODO: maybe make this assert and make artificial use cases pass a value explicitly return None return self.field_module if isinstance(self._default_module, str): return _module_from_name(self._default_module) return self._default_module @functools.cached_property def computed_unspecified(self) -> Callable[..., Any] | None: return ( get_unspecified_from_annotation(self.annotation) if self.original_unspecified is None else self.original_unspecified ) def unspecified_factory(self) -> Callable[..., Any] | None: if ( self.computed_unspecified is not None and typing.get_origin(self.computed_unspecified) is type and typing.get_args(self.computed_unspecified) ): base_type = typing.get_args(self.computed_unspecified)[0] # TODO: remove special handling here and elsewhere by moving logic to collect_params return lambda: base_type return self.computed_unspecified def from_string(self, factory: str) -> Callable[..., Any]: if ":" in factory: module_name, var = factory.split(":", 1) # fun lambda case # TODO: add docs for fun lambda case if module_name == "lambda" or module_name.startswith("lambda "): default_module = self.default_module if isinstance(default_module, MISSING_TYPE) or default_module is None: eval_ctx = None else: eval_ctx = default_module try: if isinstance(ast.parse(factory).body[0].value, ast.Lambda): # type: ignore[attr-defined] return eval_in_context(factory, eval_ctx) except Exception as e: raise MetaFromString( f"Could not interpret {factory!r} as a function: {e}" ) from None raise AssertionError # we've just got something explicitly specified module = _module_from_name(module_name) match = re.fullmatch(r"(?P[^\s\[\]]+)(\[(?P.+)\])?", var) if match is None: raise MetaFromString(f"Failed to parse {factory!r} as a class name") base = match.group("base") generic = match.group("generic") # TODO: think about this type ignore typ = _maybe_generic(_module_getattr(module, base), generic, template=self.annotation) # type: ignore[arg-type] return _return_prospective(typ, self.annotation, factory=factory) try: if self.annotation in {object, typing.Any, typing_extensions.Any}: return _find_subclass(factory, self.annotation) if typing.get_origin(self.annotation) is type: base_type = typing.get_args(self.annotation)[0] assert isinstance(base_type, type) typ = _find_subclass(factory, base_type) return lambda: typ if is_union_type(self.annotation): if self.original_unspecified is not None: try: if is_instantiable_type(self.original_unspecified): return _find_subclass(factory, self.original_unspecified) except MetaFromString: pass for t in typing.get_args(self.annotation): try: if is_instantiable_type(t): return _find_subclass(factory, t) except MetaFromString: pass if type(None) in typing.get_args(self.annotation) and factory == "None": return lambda: None raise MetaFromString(f"Could not produce a union instance from {factory!r}") if is_instantiable_type(self.annotation): return _find_subclass(factory, self.annotation) if self.annotation is None and factory == "None": return lambda: None except MetaFromString as e: try: default_module = self.default_module if isinstance(default_module, str): default_module = _module_from_name(default_module) if default_module is not None: obj = _module_getattr(default_module, factory) return _return_prospective(obj, self.annotation, factory=factory) except MetaFromString: pass raise e # Probably a special form raise MetaFromString( f"Could not produce a {type_repr(self.annotation)} instance from {factory!r}" ) def perform_cast(self, value: str): if self.original_unspecified is not None: try: return _simplistic_try_cast(value, self.original_unspecified) except CastError: pass return _simplistic_try_cast(value, self.annotation) ================================================ FILE: chz/field.py ================================================ from __future__ import annotations import functools import sys from typing import Any, Callable import chz from chz.mungers import Munger, default_munger from chz.tiepin import TypeForm from chz.util import MISSING, MISSING_TYPE _FieldValidator = Callable[[Any, str], None] def field( *, # default related default: Any | MISSING_TYPE = MISSING, default_factory: Callable[[], Any] | MISSING_TYPE = MISSING, # munger related munger: Munger | Callable[[Any, Any], Any] | None = None, x_type: TypeForm | MISSING_TYPE = MISSING, converter: Callable[[Any], Any] | None = None, # blueprint related meta_factory: chz.factories.MetaFactory | None | MISSING_TYPE = MISSING, blueprint_unspecified: Callable[..., Any] | MISSING_TYPE = MISSING, blueprint_cast: Callable[[str], object] | None = None, # misc validator: _FieldValidator | (list[_FieldValidator] | None) = None, repr: bool | Callable[[Any], str] = True, doc: str = "", metadata: dict[str, Any] | None = None, ) -> Any: """Customise a field in a chz class. Args: default: The default value for the field (if any). default_factory: A function that returns the default value for the field. Useful for mutable types, for instance, `default_factory=list`. This does not interact at all with parametrisation. Perhaps a better name would be lazy_default (but unfortunately, this is not supported by PEP 681, so static type checkers would lose the ability to understand the class). munger: Lets you adjust the value of a field. Essentially works the same as an init_property. x_type: Useful in combination with mungers. This specifies the type before munging that will be used for parsing and type checking. converter: Synonym for munger that works better with static type checkers. It accepts a munger object or a callable that will be called as fn(value, chzself=chzself). meta_factory: Represents the set of possible callables that can give us a value of a given type. blueprint_unspecified: Used to construct the meta_factory, if meta_factory is unspecified. This is the default callable `Blueprint` may attempt to call to get a value of the expected type. See the documentation in chz.factories for more information. In particular, the following two are equivalent: ``` x: Base = field(blueprint_unspecified=Sub) x: Base = field(meta_factory=chz.factories.subclass(Base, default_cls=Sub)) ``` blueprint_cast: A function that takes a str and returns an object. On failure to cast, it should raise `CastError`. Used to achieve custom parsing behaviour from the command line. Takes priority over the `__chz_cast__` dunder method (if present on the target type). validator: A function or list of functions that validate the field. Field validators take two arguments: the instance of the class and the name of the field. repr: Whether to include the field in the `__repr__` of the class. This can also be a callable to customise the repr of the field. doc: The docstring for the field. Used in `--help`. metadata: Arbitrary user-defined metadata to attach to the field. Useful when extending `chz`. """ return Field( name="", raw_type="", default=default, default_factory=default_factory, munger=munger, raw_x_type=x_type, converter=converter, meta_factory=meta_factory, blueprint_unspecified=blueprint_unspecified, blueprint_cast=blueprint_cast, validator=validator, repr=repr, doc=doc, metadata=metadata, ) class Field: def __init__( self, *, name: str, raw_type: TypeForm | str, default: Any = MISSING, default_factory: Callable[[], Any] | MISSING_TYPE = MISSING, munger: Munger | Callable[[Any, Any], Any] | None = None, raw_x_type: TypeForm | MISSING_TYPE = MISSING, converter: Callable[[Any], Any] | None = None, meta_factory: chz.factories.MetaFactory | None | MISSING_TYPE = MISSING, blueprint_unspecified: Callable[..., Any] | MISSING_TYPE = MISSING, blueprint_cast: Callable[[str], object] | None = None, validator: _FieldValidator | (list[_FieldValidator] | None) = None, repr: bool | Callable[[Any], str] = True, doc: str = "", metadata: dict[str, Any] | None = None, ): if default.__class__.__hash__ is None: raise ValueError( f"Mutable default {type(default)} for field " f"{name} is not allowed: use default_factory" ) if ( meta_factory is not MISSING and meta_factory is not None and not isinstance(meta_factory, chz.factories.MetaFactory) ): raise TypeError(f"meta_factory must be a MetaFactory, not {type(meta_factory)}") if blueprint_unspecified is not MISSING: if not callable(blueprint_unspecified): raise TypeError( f"blueprint_unspecified must be callable, not {type(blueprint_unspecified)}" ) if meta_factory is not MISSING: raise ValueError("Cannot specify both meta_factory and blueprint_unspecified") if default_factory is not MISSING: if not callable(default_factory): raise TypeError(f"default_factory must be callable, not {type(default_factory)}") if isinstance(default_factory, chz.factories.MetaFactory): raise TypeError( "default_factory must be a callable that returns a value, " "not a MetaFactory. Note that default_factory must be callable without any " "arguments and does not interact with parametrisation." ) if converter is not None: if munger is not None: raise ValueError("Cannot specify both converter and munger") if not callable(converter): raise TypeError(f"converter must be callable, not {type(converter)}") if isinstance(converter, Munger): munger = converter else: # Note: when the munger arg is a function, it is called as munger(chzself, value), # but converters must be defined with the value as the only positional parameter, # and so they are called as converter(value, chzself=chzself). # TODO: change the signature of functions passed to the `munger` argument to be # compatible with `converter`? c = converter munger = lambda s, v: c(v, chzself=s) # type: ignore if munger is not None and not callable(munger): raise TypeError(f"munger must be callable, not {type(munger)}") if validator is None: validator = [] elif not isinstance(validator, list): validator = [validator] self._name = name self._raw_type = raw_type self._raw_x_type = raw_x_type self._default = default self._default_factory = default_factory self._meta_factory = meta_factory self._blueprint_unspecified = blueprint_unspecified self._munger = munger self._validator: list[_FieldValidator] = validator self._blueprint_cast = blueprint_cast self._repr = repr self._doc = doc self._metadata = metadata # We used to pass the actual globals around, but cloudpickle did not like that # when it tried to pickle chz classes by value in __main__ # Note that this means that if we're using postponed annotations or quoted annotations # in __main__ that self.type will likely fail if this is ever pickled and unpickled self._user_module: str = "" @property def logical_name(self) -> str: for magic_prefix in ("隐", "_X_"): if self._name.startswith(magic_prefix): raise RuntimeError(f"Magic prefix {magic_prefix} no longer supported, use X_") if self._name.startswith("X_"): return self._name.removeprefix("X_") return self._name @property def x_name(self) -> str: return "X_" + self.logical_name @functools.cached_property def final_type(self) -> TypeForm: if not self._name: raise RuntimeError( "Something has gone horribly awry; are you using a chz.Field in a dataclass?" ) # Delay the eval until after the class if isinstance(self._raw_type, str): # TODO: handle forward ref assert self._user_module if self._user_module not in sys.modules: raise RuntimeError( f"Could not find module {self._user_module}; possibly a pickling issue?" ) user_globals = sys.modules[self._user_module].__dict__ return eval(self._raw_type, user_globals) return self._raw_type @functools.cached_property def x_type(self) -> TypeForm: if isinstance(self._raw_x_type, MISSING_TYPE): return self.final_type return self._raw_x_type @property def meta_factory(self) -> chz.factories.MetaFactory | None: if self._meta_factory is None: return None if isinstance(self._meta_factory, MISSING_TYPE): if isinstance(self._blueprint_unspecified, MISSING_TYPE): unspec = None else: unspec = self._blueprint_unspecified import chz.factories ret = chz.factories.standard( annotation=self.x_type, unspecified=unspec, default_module=self._user_module ) ret.field_annotation = self.x_type ret.field_module = self._user_module return ret self._meta_factory.field_annotation = self.x_type self._meta_factory.field_module = self._user_module return self._meta_factory def get_munger(self) -> Callable[[Any], None] | None: if self._munger is None: return None if isinstance(self._munger, Munger): m = self._munger else: assert callable(self._munger) m = default_munger(self._munger) # Must return a new callable every time return lambda chzself: m(getattr(chzself, self.x_name), chzself=chzself, field=self) @property def metadata(self) -> dict[str, Any] | None: return self._metadata def __repr__(self): return f"Field(name={self._name!r}, type={self.final_type!r}, ...)" def versioning_key(self) -> tuple[str, ...]: from chz.tiepin import approx_type_hash raw_type_key = approx_type_hash(self._raw_type) if self._default is MISSING: default_key = "" elif self._default.__repr__ is not object.__repr__: default_key = repr(self._default) else: default_key = self._default_factory.__class__.__name__ if isinstance(self._default_factory, MISSING_TYPE): default_factory_key = "" else: # TODO: support lambdas default_factory_key = ( self._default_factory.__module__ + "." + self._default_factory.__name__ ) return (self._name, raw_type_key, default_key, default_factory_key) ================================================ FILE: chz/mungers.py ================================================ from __future__ import annotations from typing import TYPE_CHECKING, Any, Callable, Mapping, TypeVar, overload if TYPE_CHECKING: from frozendict import frozendict from chz.field import Field _T = TypeVar("_T") _K = TypeVar("_K") _V = TypeVar("_V") class Munger: """Marker class for mungers""" def __call__(self, value: Any, *, chzself: Any = None, field: Field | None = None) -> Any: raise NotImplementedError class if_none(Munger): """If None, munge the field to the result of an arbitrary function of the chz object.""" def __init__(self, replacement: Callable[[Any], Any]): self.replacement = replacement def __call__(self, value: _T | None, *, chzself: Any = None, field: Field | None = None) -> _T: if value is not None: return value return self.replacement(chzself) class attr_if_none(Munger): """If None, munge the field to another attribute of the chz object.""" def __init__(self, replacement_attr: str): self.replacement_attr = replacement_attr def __call__(self, value: _T | None, *, chzself: Any = None, field: Field | None = None) -> _T: if value is not None: return value return getattr(chzself, self.replacement_attr) class default_munger(Munger): def __init__(self, fn: Callable[[Any, Any], Any]): self.fn = fn def __call__(self, value: Any, *, chzself: Any = None, field: Field | None = None) -> Any: # Note: when the munger arg is a function, it is called as munger(chzself, value), # and we keep that calling convention here. See also the comment in Field.__init__. return self.fn(chzself, value) class freeze_dict(Munger): """Freezes a dictionary value so the object is hashable.""" @overload def __call__( self, value: Mapping[_K, _V], *, chzself: Any = None, field: Field | None = None ) -> frozendict[_K, _V]: ... @overload def __call__( self, value: Mapping[_K, _V] | None, *, chzself: Any = None, field: Field | None = None ) -> frozendict[_K, _V] | None: ... def __call__( self, value: Mapping[_K, _V] | None, *, chzself: Any = None, field: Field | None = None ) -> frozendict[_K, _V] | None: from frozendict import frozendict if value is not None and not isinstance(value, frozendict): return frozendict[_K, _V](value) # pyright: ignore[reportUnknownArgumentType] return value ================================================ FILE: chz/py.typed ================================================ ================================================ FILE: chz/tiepin.py ================================================ """ It's a fair question why this module exists, instead of using something third party. There are two things I would have liked to farm out: 1) is_subtype_instance, 2) _simplistic_try_cast. For is_subtype_instance, I would have liked to use `typeguard`. Unfortunately, the `typeguard` version we were on did not support a lot of basic things. We couldn't upgrade either, because the new version had breaking changes and more importantly created ref cycles in places that caused us to hold on to GPU tensors for longer than we should have, causing GPU OOMs. Update: I eventually got this fixed upstream. For _simplistic_try_cast, despite its name, seems to work better than most things out there for our use case. This is also nice to be able to customise for chz's purposes. I also have another motivation, which is by writing my own Python runtime type checker, I'll become a better maintainer of typing.py / typing_extensions.py upstream. """ import ast import collections.abc import functools import hashlib import importlib import inspect import operator import sys import types import typing import typing_extensions def type_repr(typ) -> str: # Similar to typing._type_repr if isinstance(typ, (types.GenericAlias, typing._GenericAlias)): if typ.__origin__.__module__ in {"typing", "typing_extensions", "collections.abc"}: if typ.__origin__ is collections.abc.Callable: return repr(typ).removeprefix("collections.abc.").removeprefix("typing.") # Based on typing._GenericAlias.__repr__ name = typ.__origin__.__name__ if typ.__args__: args = ", ".join([type_repr(a) for a in typ.__args__]) else: args = "()" return f"{name}[{args}]" return repr(typ) if isinstance(typ, (type, types.FunctionType)): module = getattr(typ, "__module__", None) name = getattr(typ, "__qualname__", None) if name is None: name = getattr(typ, "__name__", None) if name is not None: if module == "typing": return f"{module}.{name}" if module is not None and module != "builtins" and module != "__main__": return f"{module}:{name}" return name if typ is ...: return "..." return repr(typ) def _approx_type_to_bytes(t) -> bytes: # This tries to keep the resulting value similar with and without __future__ annotations # As a result, the conversion is approximate. For instance, `builtins.float` and # `class float: ...` will look the same. # If you need something more discerning, maybe just use pickle? Although note that pickle # doesn't work on at least forward refs and non-module level typevars origin = getattr(t, "__origin__", None) args = getattr(t, "__args__", ()) if origin is None: if isinstance(t, type): # don't use t.__module__, so that we're more likely to preserve hashes # with and without future annotations origin_bytes = t.__name__.encode("utf-8") elif isinstance(t, typing._SpecialForm): origin_bytes = t._name.encode("utf-8") elif isinstance(t, typing.TypeVar): origin_bytes = t.__name__.encode("utf-8") elif isinstance(t, typing.ForwardRef): origin_bytes = t.__forward_arg__.encode("utf-8") elif isinstance(t, str): origin_bytes = t.encode("utf-8") elif isinstance(t, (bytes, int, type(...), type(None))): # enums? origin_bytes = repr(t).encode("utf-8") else: raise TypeError(f"Cannot convert {t} of {type(t)} to bytes") else: origin_bytes = _approx_type_to_bytes(origin) arg_bytes = (b"[" + b",".join(_approx_type_to_bytes(a) for a in args) + b"]") if args else b"" return origin_bytes + arg_bytes def approx_type_hash(t) -> str: return hashlib.sha1(_approx_type_to_bytes(t)).hexdigest() def eval_in_context(annot: str, obj: object) -> typing.Any: # Based on inspect.get_annotations if isinstance(obj, type): obj_globals = None module_name = getattr(obj, "__module__", None) if module_name: module = sys.modules.get(module_name, None) if module: obj_globals = getattr(module, "__dict__", None) obj_locals = dict(vars(obj)) unwrap = obj elif isinstance(obj, types.ModuleType): obj_globals = getattr(obj, "__dict__", None) obj_locals = None unwrap = None elif callable(obj): obj_globals = getattr(obj, "__globals__", None) obj_locals = None unwrap = obj elif obj is None: obj_globals = None obj_locals = None unwrap = None else: raise TypeError(f"{obj!r} is not a module, class, or callable.") if unwrap is not None: while True: if hasattr(unwrap, "__wrapped__"): unwrap = unwrap.__wrapped__ continue if isinstance(unwrap, functools.partial): unwrap = unwrap.func continue break if hasattr(unwrap, "__globals__"): obj_globals = unwrap.__globals__ assert isinstance(annot, str) return eval(annot, obj_globals, obj_locals) def maybe_eval_in_context(annot: str, obj: object) -> typing.Any: if isinstance(annot, str): return eval_in_context(annot, obj) if annot is inspect.Parameter.empty: return typing.Any return annot if sys.version_info >= (3, 11): typing_Never = ( typing.NoReturn, typing_extensions.NoReturn, typing_extensions.Never, typing.Never, ) else: typing_Never = (typing.NoReturn, typing_extensions.NoReturn, typing_extensions.Never) TypeForm = object InstantiableType: typing.TypeAlias = type | types.GenericAlias # | typing._GenericAlias def is_instantiable_type(t: TypeForm) -> typing.TypeGuard[InstantiableType]: origin = getattr(t, "__origin__", t) return isinstance(origin, type) and origin is not type def is_union_type(t: TypeForm) -> bool: # This has gotten a little messy with Python 3.14 origin = getattr(t, "__origin__", t) return origin is typing.Union or isinstance(t, types.UnionType) or t is types.UnionType def is_typed_dict(t: TypeForm) -> bool: return isinstance(t, (typing._TypedDictMeta, typing_extensions._TypedDictMeta)) class CastError(Exception): pass def _module_from_name(name: str) -> types.ModuleType: try: return importlib.import_module(name) except ImportError as e: raise CastError(f"Could not import module {name!r} ({type(e).__name__}: {e})") from None def _module_getattr(mod: types.ModuleType, attr: str) -> typing.Any: try: for a in attr.split("."): mod = getattr(mod, a) return mod except AttributeError: raise CastError(f"No attribute named {attr!r} in module {mod.__name__}") from None def _sort_for_union_preference(typs: tuple[TypeForm, ...]): def sort_key(typ): typ = getattr(typ, "__origin__", typ) if typ is str: # sort str to last, because anything can be cast to str return 1 if typ is typing.Literal or typ is typing_extensions.Literal: # sort literals to first, because they exact match return -2 if typ is type(None) or typ is None: # None exact matches as well (like all singletons) return -1 return 0 # note this is a stable sort, so we preserve user ordering return sorted(typs, key=sort_key) def is_args_unpack(t: TypeForm) -> bool: return getattr(t, "__unpacked__", False) or getattr(t, "__origin__", t) in { typing.Unpack, typing_extensions.Unpack, } def is_kwargs_unpack(t: TypeForm) -> bool: return getattr(t, "__origin__", t) in {typing.Unpack, typing_extensions.Unpack} def _unpackable_arg_length(t: TypeForm) -> tuple[int, bool]: item_args = None if getattr(t, "__unpacked__", False): assert t.__origin__ is tuple # TODO item_args = t.__args__ elif getattr(t, "__origin__", t) in {typing.Unpack, typing_extensions.Unpack}: assert len(t.__args__) == 1 assert t.__args__[0].__origin__ is tuple item_args = t.__args__[0].__args__ else: return (1, False) if not item_args or item_args[-1] is ...: assert len(item_args) == 2 return (0, True) min_length = 0 unbounded = False for item_arg in item_args: arg_length, arg_unbounded = _unpackable_arg_length(item_arg) min_length += arg_length unbounded |= arg_unbounded return (min_length, unbounded) def _cast_unpacked_tuples( inst_items: list[str], args: tuple[TypeForm, ...] ) -> tuple[typing.Any, ...]: # Cursed PEP 646 stuff arg_lengths = [_unpackable_arg_length(arg) for arg in args] min_length = sum(arg_length for arg_length, _ in arg_lengths) if len(inst_items) < min_length: raise CastError( f"Could not cast {repr(','.join(inst_items))} to {type_repr(tuple[*args])} " "because of length mismatch" ) ret = [] i = 0 for arg, (arg_length, arg_unbounded) in zip(args, arg_lengths): if is_args_unpack(arg): if arg_unbounded: arg_length += len(inst_items) - min_length min_length = len(inst_items) if getattr(arg, "__origin__", arg) in {typing.Unpack, typing_extensions.Unpack}: assert len(arg.__args__) == 1 assert arg.__args__[0].__origin__ is tuple arg = arg.__args__[0] arg = arg.__args__ if len(arg) == 0: ret.extend(inst_items[i : i + arg_length]) elif len(arg) == 2 and arg[-1] is ...: ret.extend( _cast_unpacked_tuples(inst_items[i : i + arg_length], (arg[0],) * arg_length) ) else: ret.extend(_cast_unpacked_tuples(inst_items[i : i + arg_length], arg)) else: assert arg_length == 1 assert not arg_unbounded ret.append(_simplistic_try_cast(inst_items[i], arg)) i += arg_length return tuple(ret) def _simplistic_try_cast(inst_str: str, typ: TypeForm): origin = getattr(typ, "__origin__", typ) if is_union_type(origin): # sort str to last spot args = _sort_for_union_preference(getattr(typ, "__args__", ())) for arg in args: try: return _simplistic_try_cast(inst_str, arg) except CastError: pass raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin is typing.Any or origin is typing_extensions.Any or origin is object: try: return ast.literal_eval(inst_str) except (ValueError, SyntaxError): pass # Also accept some lowercase spellings if inst_str in {"true", "false"}: return inst_str == "true" if inst_str in {"none", "null", "NULL"}: return None return inst_str if isinstance(origin, typing.TypeVar): if origin.__constraints__: for constraint in origin.__constraints__: try: return _simplistic_try_cast(inst_str, constraint) except CastError: pass raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin.__bound__: return _simplistic_try_cast(inst_str, origin.__bound__) return _simplistic_try_cast(inst_str, object) if origin is typing.Literal or origin is typing_extensions.Literal: values_by_type = {} for arg in getattr(typ, "__args__", ()): values_by_type.setdefault(type(arg), []).append(arg) for literal_typ, literal_values in values_by_type.items(): try: value = _simplistic_try_cast(inst_str, literal_typ) if value in literal_values: return value except CastError: pass raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin is None or origin is type(None): if inst_str == "None": return None raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin is bool: if inst_str in {"t", "true", "True", "1"}: return True if inst_str in {"f", "false", "False", "0"}: return False raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin is str: return inst_str if origin is float: try: return float(inst_str) except ValueError as e: raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e if origin is int: try: return int(inst_str) except ValueError as e: raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e if origin is list or origin is collections.abc.Sequence or origin is collections.abc.Iterable: if not inst_str: return [] args = getattr(typ, "__args__", ()) item_type = args[0] if args else typing.Any if inst_str[0] in {"[", "("}: try: value = ast.literal_eval(inst_str) except (ValueError, SyntaxError): raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from None if is_subtype_instance(value, typ): return value raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") inst_items = inst_str.split(",") if inst_str else [] ret = [_simplistic_try_cast(item, item_type) for item in inst_items] if origin is list: return ret return tuple(ret) if origin is tuple: args = getattr(typ, "__args__", ()) inst_items = inst_str.split(",") if inst_str else [] if len(args) == 0: return tuple(inst_items) if len(args) == 2 and args[-1] is ...: item_type = args[0] return tuple(_simplistic_try_cast(item, item_type) for item in inst_items) num_unpack = sum(is_args_unpack(arg) for arg in args) if num_unpack == 0: # Great, normal heterogenous tuple if len(args) != len(inst_items): raise CastError( f"Could not cast {repr(inst_str)} to {type_repr(typ)} because of length mismatch" + ( f". Homogeneous tuples should be typed as tuple[{type_repr(args[0])}, ...] not tuple[{type_repr(args[0])}]" if len(args) == 1 else "" ) ) return tuple( _simplistic_try_cast(item, item_typ) for item, item_typ in zip(inst_items, args) ) else: # Cursed PEP 646 stuff return _cast_unpacked_tuples(inst_items, args) if origin is dict or origin is collections.abc.Mapping: if not inst_str: return {} if inst_str[0] == "{": try: value = ast.literal_eval(inst_str) except (ValueError, SyntaxError): raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from None if is_subtype_instance(value, typ): return value raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if origin is collections.abc.Callable: # TODO: also support type, callback protocols # TODO: unify with factories.from_string # TODO: needs module context if ":" in inst_str: try: module_name, var = inst_str.split(":", 1) module = _module_from_name(module_name) value = _module_getattr(module, var) if not is_subtype_instance(value, typ): raise CastError(f"{type_repr(value)} is not a subtype of {type_repr(typ)}") except CastError as e: raise CastError( f"Could not cast {repr(inst_str)} to {type_repr(typ)}. {e}" ) from None return value else: raise CastError( f"Could not cast {repr(inst_str)} to {type_repr(typ)}. Try using a fully qualified name, e.g. module_name:{inst_str}" ) if "torch" in sys.modules: import torch if origin is torch.dtype: value = getattr(torch, inst_str, None) if value and isinstance(value, torch.dtype): return value raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if "datetime" in sys.modules: import datetime if origin is datetime.datetime: try: return datetime.datetime.fromisoformat(inst_str) except ValueError: raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from None if "enum" in sys.modules: import enum if isinstance(origin, type) and issubclass(origin, enum.Enum): try: # Look up by name return origin[inst_str] except KeyError: pass # Fallback to looking up by value for member in origin: try: value = _simplistic_try_cast(inst_str, type(member.value)) except CastError: continue if value == member.value: return member raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") if "fractions" in sys.modules: import fractions if origin is fractions.Fraction: try: return fractions.Fraction(inst_str) except ValueError as e: raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") from e if "pathlib" in sys.modules: import pathlib if origin is pathlib.Path: return pathlib.Path(inst_str) if hasattr(origin, "__chz_cast__"): return origin.__chz_cast__(inst_str) if not isinstance(origin, type): raise CastError(f"Unrecognised type object {type_repr(typ)}") raise CastError(f"Could not cast {repr(inst_str)} to {type_repr(typ)}") class _SignatureOf: def __init__(self, fn: typing.Callable, strip_self: bool = False): self.fn = fn self._sig = inspect.signature(fn) self.pos = [] self.kwonly = {} self.varpos = None self.varkw = None for param in self._sig.parameters.values(): if param.kind in {param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY}: self.pos.append(param) elif param.kind is param.KEYWORD_ONLY: self.kwonly[param.name] = param elif param.kind is param.VAR_POSITIONAL and param.name != "__chz_args": self.varpos = param elif param.kind is param.VAR_KEYWORD: self.varkw = param if strip_self: if self.pos[0].name != "self": raise ValueError(f"Cannot strip self from signature of {self.fn}") self.pos = self.pos[1:] self.ret = self._sig.return_annotation if isinstance(self.ret, str): self.ret = eval_in_context(self.ret, self.fn) def is_subtype(left: TypeForm, right: TypeForm) -> bool: left_origin = getattr(left, "__origin__", left) left_args = getattr(left, "__args__", ()) right_origin = getattr(right, "__origin__", right) right_args = getattr(right, "__args__", ()) if left_origin is typing.Any or left_origin is typing_extensions.Any: return True if right_origin is typing.Any or right_origin is typing_extensions.Any: return True if left_origin is None: if right_origin is None or right_origin is type(None): return True if is_union_type(right_origin): if is_union_type(left_origin): possible_left_types = left_args else: possible_left_types = [left] return all( any(is_subtype(possible_left, right_arg) for right_arg in right_args) for possible_left in possible_left_types ) if right_origin is typing.Literal or right_origin is typing_extensions.Literal: if left_origin is typing.Literal or left_origin is typing_extensions.Literal: return all(left_arg in right_args for left_arg in left_args) return False if left_origin is typing.Literal or left_origin is typing_extensions.Literal: return all(is_subtype_instance(left_arg, right) for left_arg in left_args) if isinstance(left_origin, typing.TypeVar): if left_origin == right_origin: return True bound = left_origin.__bound__ if bound is None: bound = object if is_subtype(bound, right_origin): return True if left_origin.__constraints__: return any( is_subtype(left_arg, right_origin) for left_arg in left_origin.__constraints__ ) return False if isinstance(right_origin, typing.TypeVar): if right_origin.__constraints__: return any( is_subtype(left_origin, constraint) for constraint in right_origin.__constraints__ ) if right_origin.__bound__: return is_subtype(left_origin, right_origin.__bound__) return True if typing_extensions.is_protocol(left) and typing_extensions.is_protocol(right): left_attrs = typing_extensions.get_protocol_members(left) right_attrs = typing_extensions.get_protocol_members(right) if not right_attrs.issubset(left_attrs): return False # TODO: this is incorrect return True if typing_extensions.is_protocol(right): if not isinstance(left_origin, type): return False right_attrs = typing_extensions.get_protocol_members(right) if not all(hasattr(left_origin, attr) for attr in right_attrs): return False # TODO: this is incorrect return True if isinstance(left, _SignatureOf) and isinstance(right, _SignatureOf): empty = inspect.Parameter.empty for left_param, right_param in zip(left.pos, right.pos): if right_param.kind is right_param.POSITIONAL_OR_KEYWORD: if right_param.name != left_param.name: return False if left_param.kind is left_param.POSITIONAL_ONLY: return False if right_param.default is not empty and left_param.default is empty: return False left_param_annot = maybe_eval_in_context(left_param.annotation, left.fn) right_param_annot = maybe_eval_in_context(right_param.annotation, right.fn) if not is_subtype(right_param_annot, left_param_annot): return False if len(left.pos) < len(right.pos): # Okay if left has a *args that accepts all the extra args if left.varpos is None: return False left_varpos_annot = maybe_eval_in_context(left.varpos.annotation, left.fn) for i in range(len(left.pos), len(right.pos)): right_param = right.pos[i] right_param_annot = maybe_eval_in_context(right_param.annotation, right.fn) if not is_subtype(right_param_annot, left_varpos_annot): return False if len(left.pos) > len(right.pos): # Must either have a default or correspond to a required keyword-only arg for i in range(len(right.pos), len(left.pos)): left_param = left.pos[i] if left_param.default is not empty: continue if ( left_param.name in right.kwonly and left_param.kind is left_param.POSITIONAL_OR_KEYWORD ): continue return False for name in left.kwonly.keys() & right.kwonly.keys(): right_param = right.kwonly[name] left_param = left.kwonly[name] if right_param.default is not empty and left_param.default is empty: return False left_param_annot = maybe_eval_in_context(left_param.annotation, left.fn) right_param_annot = maybe_eval_in_context(right_param.annotation, right.fn) if not is_subtype(right_param_annot, left_param_annot): return False for name in left.kwonly.keys() - right.kwonly.keys(): # Must either have a default or match a varkwarg left_param = left.kwonly[name] if left_param.default is not empty: continue if right.varkw is not None: left_param_annot = maybe_eval_in_context(left_param.annotation, left.fn) right_varkw_annot = maybe_eval_in_context(right.varkw.annotation, right.fn) if is_subtype(right_varkw_annot, left_param_annot): continue return False right_only_kwonly = right.kwonly.keys() - left.kwonly.keys() if right_only_kwonly: # Must correspond to a positional-or-keyword arg left_pos_or_kw = {p.name: p for p in left.pos if p.kind is p.POSITIONAL_OR_KEYWORD} for name in right_only_kwonly: if name not in left_pos_or_kw: return False left_param = left_pos_or_kw[name] if right.kwonly[name].default is not empty and left_param.default is empty: return False left_param_annot = maybe_eval_in_context(left_param.annotation, left.fn) right_param_annot = maybe_eval_in_context(right.kwonly[name].annotation, right.fn) if not is_subtype(right_param_annot, left_param_annot): return False if right.varkw is not None: if left.varkw is None: return False right_varkw_annot = maybe_eval_in_context(right.varkw.annotation, right.fn) left_varkw_annot = maybe_eval_in_context(left.varkw.annotation, left.fn) if not is_subtype(right_varkw_annot, left_varkw_annot): return False if right.ret is not empty and left.ret is not empty: # TODO: handle Cls.__init__ like below if not is_subtype(left.ret, right.ret): return False return True if left_origin is collections.abc.Callable and right_origin is collections.abc.Callable: *left_params, left_ret = left_args *right_params, right_ret = right_args if len(left_params) != len(right_params): return False if not is_subtype(left_ret, right_ret): return False return all( is_subtype(right_param, left_param) for left_param, right_param in zip(left_params, right_params) ) if is_typed_dict(left_origin) and is_typed_dict(right_origin): if not right_origin.__required_keys__.issubset(left_origin.__required_keys__): return False left_hints = typing_extensions.get_type_hints(left_origin) right_hints = typing_extensions.get_type_hints(right_origin) for k, v in right_hints.items(): if k not in left_hints: return False if not is_subtype(left_hints[k], v): # Technically this should be invariant due to mutability return False return True # TODO: handle other special forms if left_origin is right_origin and left_args == right_args: return True try: if not issubclass(left_origin, right_origin): return False except TypeError: return False # see comments in is_subtype_instance # TODO: add invariance # TODO: think about some of this logic more carefully if hasattr(left_origin, "__class_getitem__") and hasattr(right_origin, "__class_getitem__"): if ( issubclass(right_origin, collections.abc.Mapping) and typing.Generic not in left_origin.__mro__ and typing.Generic not in right_origin.__mro__ ): if left_args: left_key, left_value = left_args else: left_key, left_value = typing.Any, typing.Any if right_args: right_key, right_value = right_args else: right_key, right_value = typing.Any, typing.Any return is_subtype(left_key, right_key) and is_subtype(left_value, right_value) if left_origin is tuple and right_origin is tuple: if not left_args: left_args = (typing.Any, ...) if not right_args: right_args = (typing.Any, ...) if len(right_args) == 2 and right_args[1] is ...: return all(is_subtype(left_arg, right_args[0]) for left_arg in left_args) if len(left_args) == 2 and left_args[1] is ...: return False return len(left_args) == len(right_args) and all( is_subtype(left_arg, right_arg) for left_arg, right_arg in zip(left_args, right_args) ) if ( issubclass(right_origin, collections.abc.Iterable) and typing.Generic not in left_origin.__mro__ and typing.Generic not in right_origin.__mro__ ): if left_args: (left_item,) = left_args else: left_item = typing.Any if right_args: (right_item,) = right_args else: right_item = typing.Any return is_subtype(left_item, right_item) return True def is_subtype_instance(inst: typing.Any, typ: TypeForm) -> bool: if typ is typing.Any or typ is typing_extensions.Any: return True if typ is None and inst is None: return True if isinstance(typ, typing.TypeVar): if typ.__constraints__: # types must match exactly return any( type(inst) is getattr(c, "__origin__", c) and is_subtype_instance(inst, c) for c in typ.__constraints__ ) if typ.__bound__: return is_subtype_instance(inst, typ.__bound__) return True if isinstance(typ, typing.NewType): return isinstance(inst, typ.__supertype__) origin: typing.Any args: typing.Any if sys.version_info >= (3, 10) and isinstance(typ, types.UnionType): origin = typing.Union else: origin = getattr(typ, "__origin__", typ) args = getattr(typ, "__args__", ()) del typ if origin is typing.Union: return any(is_subtype_instance(inst, t) for t in args) if origin is typing.Literal or origin is typing_extensions.Literal: return inst in args if origin is typing.LiteralString: return isinstance(inst, str) if is_typed_dict(origin): if not isinstance(inst, dict): return False for k, v in typing_extensions.get_type_hints(origin).items(): if k in inst: if not is_subtype_instance(inst[k], v): return False elif k in origin.__required_keys__: return False return True # Pydantic implements generics in a special way. Just delegate validation to Pydantic. # Note that all pydantic models have __pydantic_generic_metadata__, even non-generic ones. if hasattr(origin, "__pydantic_generic_metadata__"): from pydantic import ValidationError try: origin.model_validate(inst) return True except ValidationError: return False if typing_extensions.is_protocol(origin): if getattr(origin, "_is_runtime_protocol", False): return isinstance(inst, origin) if origin in type(inst).__mro__: return True annotations = typing_extensions.get_type_hints(origin) for attr in sorted(typing_extensions.get_protocol_members(origin)): if not hasattr(inst, attr): return False if attr in annotations: if not is_subtype_instance(getattr(inst, attr), annotations[attr]): return False elif callable(getattr(origin, attr)): if attr == "__call__" and isinstance(inst, (type, types.FunctionType)): # inst will have a better inspect.signature than inst.__call__ inst_attr = inst else: inst_attr = getattr(inst, attr) if not callable(inst_attr): return False try: signature = _SignatureOf(getattr(origin, attr), strip_self=True) except ValueError: continue if not is_subtype_instance(inst_attr, signature): return False else: raise AssertionError(f"Unexpected protocol member {attr} for {origin}") return True if isinstance(origin, _SignatureOf): try: inst_sig = _SignatureOf(inst) except ValueError: return True return is_subtype(inst_sig, origin) # We're done handling special forms, now just need to handle things like generics if not isinstance(origin, type): # TODO: handle other special forms before exit on this branch return False if not isinstance(inst, origin): # PEP 484 duck type compatibility if origin is complex and isinstance(inst, (int, float)): return True if origin is float and isinstance(inst, int): return True if origin is bytes and isinstance(inst, (bytearray, memoryview)): # TODO: maybe remove bytearray and memoryview ducktyping based on PEP 688 return True if inst in typing_Never: return True if issubclass(type(inst), typing_extensions.Any) or ( sys.version_info >= (3, 11) and issubclass(type(inst), typing.Any) ): return True return False assert isinstance(inst, origin) if not args: return True # TODO: there's some confusion when checking issubclass against a generic collections.abc # base class, since you don't actually know whether the generic args of typ / origin correspond # to the generic args of the base class. So if we detect a user defined generic (i.e. based # on presence of Generic in the mro), we just fall back and don't assume we know the semantics # of what the generic args are. if issubclass(origin, collections.abc.Mapping) and typing.Generic not in origin.__mro__: key_type, value_type = args return all( is_subtype_instance(key, key_type) and is_subtype_instance(value, value_type) for key, value in inst.items() ) if origin is tuple: if len(args) == 2 and args[1] is ...: return all(is_subtype_instance(i, args[0]) for i in inst) if len(inst) != len(args): return False return all(is_subtype_instance(i, t) for i, t in zip(inst, args)) if issubclass(origin, collections.abc.Iterable) and typing.Generic not in origin.__mro__: (item_type,) = args return all(is_subtype_instance(item, item_type) for item in inst) if origin is type: (type_type,) = args return issubclass(inst, type_type) if origin is collections.abc.Callable: try: inst_sig = inspect.signature(inst) except ValueError: return True *params, ret = args if params != [...]: try: bound = inst_sig.bind(*params) except TypeError: return False for param, callable_param_type in bound.arguments.items(): param = inst_sig.parameters[param] param_annot = maybe_eval_in_context(param.annotation, inst) # ooh, contravariance if param.kind is param.VAR_POSITIONAL: if any(not is_subtype(cpt, param_annot) for cpt in callable_param_type): return False elif not is_subtype(callable_param_type, param_annot): return False if inst_sig.return_annotation is not inst_sig.empty: ret_annot = maybe_eval_in_context(inst_sig.return_annotation, inst) # inspect.signature(Cls) will have Cls.__init__, which is annotated as -> None if not (isinstance(inst, type) and ret_annot is None and is_subtype(inst, ret)): if ret_annot is None: ret_annot = type(None) elif ret_annot in typing_Never: ret_annot = ret if ret_annot != ret and not is_subtype(ret_annot, ret): return False return True # We don't really know how to handle user defined generics if hasattr(inst, "__orig_class__"): # If we have an __orig_class__ and the origins match, check the args (assuming that they # are invariant, although maybe covariant is a better guess?) if inst.__orig_class__.__origin__ is origin: return inst.__orig_class__.__args__ == args # Otherwise, fail open return True # TODO: overloads # TODO: paramspec / concatenate # TODO: typeguard # TODO: annotated # TODO: self # TODO: pep 692 unpack # TODO: typevartuple?? def simplified_union(types): if len(types) == 0: return typing.Never if len(types) == 1: return types[0] union_types = [] for typ in types: if getattr(typ, "__args__", None) is None and any( is_subtype(typ, member) for member in union_types ): continue union_types.append(typ) types = union_types union_types = [] for typ in reversed(types): if getattr(typ, "__args__", None) is None and any( is_subtype(typ, member) for member in union_types ): continue union_types.append(typ) return functools.reduce(operator.or_, union_types) def _simplistic_type_of_value(value: object) -> TypeForm: # TODO: maybe remove this? Its current use is in diagnostics (for providing the actual type), # but is_subtype_instance is in a position to provide better diagnostics if hasattr(type(value), "__class_getitem__"): if isinstance(value, collections.abc.Mapping) and typing.Generic not in type(value).__mro__: return type(value)[ simplified_union([_simplistic_type_of_value(k) for k in value.keys()]), simplified_union([_simplistic_type_of_value(v) for v in value.values()]), ] if isinstance(value, tuple): if len(value) <= 10: return type(value)[tuple(_simplistic_type_of_value(v) for v in value)] return type(value)[simplified_union([_simplistic_type_of_value(v) for v in value]), ...] if ( isinstance(value, collections.abc.Iterable) and typing.Generic not in type(value).__mro__ ): return type(value)[simplified_union([_simplistic_type_of_value(v) for v in value])] if isinstance(value, type): return type[value] return type(value) ================================================ FILE: chz/universal.py ================================================ if __name__ == "__main__": import chz chz.entrypoint(object) ================================================ FILE: chz/util.py ================================================ class MISSING_TYPE: def __repr__(self) -> str: return "MISSING" MISSING = MISSING_TYPE() ================================================ FILE: chz/validators.py ================================================ from __future__ import annotations import collections import collections.abc import re from typing import Any, Callable, Literal import chz from chz.field import Field from chz.tiepin import _simplistic_type_of_value, is_subtype_instance, type_repr class validate: def __init__(self, fn: Callable[[Any], None]): self.fn = fn def __set_name__(self, owner: Any, name: str) -> None: _ensure_chz_validators(owner) owner.__chz_validators__.append(self.fn) setattr(owner, name, self.fn) def _ensure_chz_validators(cls: Any) -> None: if "__chz_validators__" not in cls.__dict__: # make a copy of the parent's validators, if any validators: list[Callable[[object], None]] = [] for base in cls.__bases__: validators.extend(getattr(base, "__chz_validators__", [])) cls.__chz_validators__ = validators def for_all_fields(fn: Callable[[Any, str], None]) -> Callable[[Any], None]: def inner(self: Any) -> None: for field in self.__chz_fields__.values(): fn(self, field.x_name) return inner def instancecheck(self: Any, attr: str) -> None: """A good old fashioned isinstance check based on the annotated type of the field.""" typ = self.__chz_fields__[attr.removeprefix("X_")].final_type value = getattr(self, attr) if not isinstance(value, typ): raise TypeError(f"Expected {attr} to be {type_repr(typ)}, got {type_repr(type(value))}") def typecheck(self: Any, attr: str) -> None: """A fancy type check based on the annotated type of the field.""" field = self.__chz_fields__[attr.removeprefix("X_")] typ = field.x_type value = getattr(self, attr) if not is_subtype_instance(value, typ): # TODO: is_subtype_instance is in a much better place to return diagnostics if getattr(typ, "__origin__", None) is Literal: raise TypeError(f"Expected {attr} to be {type_repr(typ)}, got {value!r}") raise TypeError( f"Expected {attr} to be {type_repr(typ)}, got {type_repr(_simplistic_type_of_value(value))}" ) def instance_of(typ: type) -> Callable[[Any, str], None]: """Check the attribute is an instance of the given type.""" def inner(self: Any, attr: str) -> None: value = getattr(self, attr) if not isinstance(value, typ): raise TypeError(f"Expected {attr} to be {type_repr(typ)}, got {type_repr(type(value))}") return inner def gt(base) -> Callable[[Any, str], None]: def inner(self: Any, attr: str) -> None: value = getattr(self, attr) if not value > base: raise ValueError(f"Expected {attr} to be greater than {base}, got {value}") return inner def lt(base) -> Callable[[Any, str], None]: def inner(self: Any, attr: str) -> None: value = getattr(self, attr) if not value < base: raise ValueError(f"Expected {attr} to be less than {base}, got {value}") return inner def ge(base) -> Callable[[Any, str], None]: def inner(self: Any, attr: str) -> None: value = getattr(self, attr) if not value >= base: raise ValueError(f"Expected {attr} to be greater or equal to {base}, got {value}") return inner def le(base) -> Callable[[Any, str], None]: def inner(self: Any, attr: str) -> None: value = getattr(self, attr) if not value <= base: raise ValueError(f"Expected {attr} to be less or equal to {base}, got {value}") return inner def valid_regex(self: Any, attr: str) -> None: """Check the attribute is a valid regex.""" import re value = getattr(self, attr) try: re.compile(value) except re.error as e: raise ValueError(f"Invalid regex in {attr}: {e}") from None def const_default(self: Any, attr: str) -> None: """Check the attribute matches the field's default value.""" from chz.util import MISSING_TYPE field: Field = self.__chz_fields__[attr.removeprefix("X_")] default = field._default if isinstance(default, MISSING_TYPE): raise ValueError( "const_default requires a default value (default_factory is not supported)" ) value = getattr(self, attr) if value != default: raise ValueError(f"Expected {attr} to match the default {default!r}, got {value!r}") def _decorator_typecheck(self: Any) -> None: for field in self.__chz_fields__.values(): typecheck(self, field.x_name) # TODO: typecheck(self, field.logical_name) def check_field_consistency_in_tree(obj: Any, fields: set[str], regex_root: str = "") -> None: """ This isn't itself a validator. See test_validate_field_consistency for example usage. This is effectively a way to paper over a potential missing feature in chz. """ values: dict[tuple[str, str], dict[object, list[str]]] = collections.defaultdict( lambda: collections.defaultdict(list) ) def inner(obj: Any, obj_path: str): assert chz.is_chz(obj) for f in obj.__chz_fields__.values(): value = getattr(obj, f.logical_name) field_path = f"{obj_path}.{f.logical_name}" if obj_path else f.logical_name regex_match = re.search(regex_root, obj_path) if f.logical_name in fields and regex_match: values[(regex_match.group(), f.logical_name)][value].append(field_path) if chz.is_chz(value): inner(value, field_path) if isinstance(value, collections.abc.Mapping): for k, v in value.items(): if chz.is_chz(v): inner(v, f"{field_path}.{k}") elif isinstance(value, collections.abc.Sequence): for i, v in enumerate(value): if chz.is_chz(v): inner(v, f"{field_path}.{i}") inner(obj, "") def paths_repr(paths: list[str]) -> str: if len(paths) <= 3: return ", ".join(paths) return ", ".join(paths[:3]) + f", ... ({len(paths) - 3} more)" for (_, field), value_to_paths in values.items(): if len(value_to_paths) > 1: raise ValueError( f"Field {field!r} has inconsistent values in object tree:\n" + "\n".join( f"{value!r} at {paths_repr(paths)}" for value, paths in value_to_paths.items() ) ) def _find_original_definitions(instance: Any) -> dict[str, tuple[Field, type]]: """Find the original field definitions in the parent classes of the instance.""" assert chz.is_chz(instance) fields = {} for cls in reversed(type(instance).__mro__): if not chz.is_chz(cls): continue for field in chz.chz_fields(cls).values(): if field.logical_name not in fields: fields[field.logical_name] = (field, cls) return fields def is_override( instance: Any, attr: str, *, original_defs: dict[str, tuple[Field, type]] | None = None ) -> None: """ Validator that checks if a field is an override of a field of the same type in a parent class. This validator will error out if either: - the field doesn't exist in any parent - the type of field on the child is not a subtype of the type of the field on the parent This is especially useful in case someone renames a field name in the parent class. You'll get an error message rather than your override being silently ignored. """ if original_defs is None: original_defs = _find_original_definitions(instance) logical_name = attr.removeprefix("X_") assert logical_name in original_defs original_field, original_class = original_defs[logical_name] if original_class is type(instance): raise ValueError( f"Field {logical_name} does not exist in any parent classes of {type_repr(type(instance))}" ) instance_value = getattr(instance, attr) if not chz.tiepin.is_subtype_instance(instance_value, original_field.final_type): raise ValueError( f"{type_repr(type(instance))}.{attr}' must be an instance of " f"{type_repr(original_field.final_type)} to match the type on the original definition " f"in {type_repr(original_class)}" ) class IsOverrideMixin: """A mixin that checks if fields are overrides of fields in parent classes. The following: ``` @chz.chz class Foo: x: int = chz.field(default=1, validator=is_override) y: int = chz.field(default=1, validator=is_override) ``` is equivalent to: ``` @chz.chz class Foo(IsOverrideMixin): x: int = chz.field(default=1) y: int = chz.field(default=1) ``` """ @validate def _check_overrides(self) -> None: fields = getattr(self, "__chz_fields__", None) if fields is None: return original_defs = _find_original_definitions(self) for field in fields.values(): is_override(self, field.x_name, original_defs=original_defs) ================================================ FILE: docs/01_quickstart.md ================================================ ## Quick start Turn any function into a command line tool: ```python import chz def main(name: str, age: int) -> None: print(f"Hello, {name}! You are {age} years old.") if __name__ == "__main__": chz.entrypoint(main) # python script.py name=foo age=21 ``` Or instantiate a class containing your configuration: ```python import chz @chz.chz class PersonConfig: name: str age: int def main(c: PersonConfig) -> None: print(f"Hello, {c.name}! You are {c.age} years old.") if __name__ == "__main__": chz.nested_entrypoint(main) # python script.py name=foo age=21 ``` ### [Next section — Object Model](./02_object_model.md) ================================================ FILE: docs/02_object_model.md ================================================ ## Declarative object model In the beginning there was `attrs`... although people may be more familiar with its stripped down nephew, `dataclasses`. `chz` continues in the same tradition. This should feel familiar: ```python @chz.chz class Experiment: name: str steps: int checkpoint_dir: str = "az://oai/default" ``` A quick comparison to `dataclasses`: - `chz` is not meant as a better `class`, but as a solution for configuration. It is opinionated and specialised in various ways that `dataclasses` is not. - `chz` has exclusively keyword-only fields. This is generally saner and solves various problems with `dataclasses`, especially in situations involving inheritance. - `chz` is immutable only. Configuration should not be mutable. `chz` supports partial application in ways that should hopefully obviate the need for mutable configuration (as we'll see later); you can also `chz.replace` to get a new object. - `chz`'s implementation makes different tradeoffs ## Fields, and specifying defaults `chz` lets you customise the fields of your objects using the `chz.field` function. The following example shows different ways you can specify the default value for a field: ```python @chz.chz class Experiment: name: str steps: int # directly assign a default value, useful for simple, immutable types checkpoint_dir: str = "az://oai/default" # via the `default` argument to `chz.field`, useful if you need to customise your field # (like hiding it from the repr), but still have a default password: str = chz.field(default="hunter2", repr=False) # via the `default_factory` argument to `chz.field`, useful if the default is mutable or # expensive to compute dataset: list[str] = chz.field(default_factory=download_all_of_wikipedia, doc="A dataset!") ``` See [`chz.field` docs](./22_field_api.md) for more details. ## Immutability `chz` objects are immutable. This is a deliberate and non-negotiable design choice: ```python e = Experiment(name="train_job", steps=100) e.checkpoint_dir = "az://this/wont/work" # raises FrozenInstanceError ``` That said, there are a couple patterns that are useful. If you need to compute derived data from existing fields, use `@chz.init_property` (or `@property` or `@functools.cached_property`): ```python @chz.chz class Experiment: name: str steps: int @chz.init_property def log_path(self) -> str: return re.sub(r"[^a-zA-Z]", "", self.name) ``` `chz.init_property` works exactly like `functools.cached_property`, except that it is automatically accessed during initialisation. This surfaces errors more reliably. Think of this as a replacement for `dataclasses.field(init=False)`. For complex initialisation logic, `chz` has a [`Blueprint` mechanism](04_command_line.md#blueprints-and-partial-application) that is really powerful. This allows you to accomplish things like partial application, where you only specify some of your attributes at a time, or type aware parsing. Note that if you already have a `chz` object and you want to replace a field on it, you can use `chz.replace`; this works similarly to `dataclasses.replace`. ## No `__post_init__` Note that `chz` does **not** have a `__post_init__` equivalent. If you wanted a `__post_init__` to do additional validation, `chz` has first-class support for validation. See [validation](./03_validation.md) for details. If you need arbitrary logic to determine a default value, consider using `default_factory`. If you need to munge your field based on the value of other fields, consider using `@property` to do something equivalent, or a `munger`. See [details and examples](./21_post_init.md) for more guidance with this use case. The details document also describes the "magix prefix" mechanism (`X_`) you may see in use with `chz`. ### [Next section — Validation](./03_validation.md) ================================================ FILE: docs/03_validation.md ================================================ ## Validation `chz` supports validation in a manner similar to `attrs`, but slightly nicer for class-level validation. `chz` supports both field-level validation and class-level validation. ```python from chz.validators import typecheck, gt @chz.chz class Fraction: # specify a validator for a given field numerator: int = chz.field(validator=typecheck) # or even multiple validators for a field! denominator: int = chz.field(validator=[typecheck, gt(0)]) # class-level validator that can check multiple fields @chz.validate def _check_reduced(self): if math.gcd(self.numerator, self.denominator) > 1: raise ValueError("Fraction is not reduced") Fraction(numerator="asdf", denominator=4) # raises TypeError: Expected numerator to be int, got str Fraction(numerator=2, denominator=0) # raises ValueError: Expected denominator to be greater than 0, got 0 Fraction(numerator=2, denominator=4) # raises ValueError: Fraction is not reduced Fraction(numerator=1, denominator=2) # works great! ``` Validation happens as part of the generated `__init__`. All `@chz.init_property` defined on your class will also be accessed at `__init__` time, ensuring that any errors raised when computing those properties are surfaced early. ## Type checking `chz` is usable alongside static type checking. It also contains some facilities to do runtime type checking. `chz` does not currently default to doing runtime type checking. The upsides are limited, since: - `chz` has powerful, type-aware command line parsing - `chz` can be understood by static type checkers However, runtime type checking has several downsides: it's slow, it's not actually sound, so cannot be a substitute for a static type checker, it impedes certain kinds of interesting metaprogramming. It's less clear how one would opt-out of runtime type checking than it is to opt-in (just add a validator). `chz` does not do implicit casting, like `pydantic`. I find this to be a huge footgun. Python is a strongly typed language and this is for the better. `chz` does allow for some forms of explicitly opted-in casting, as part of [the `Blueprint` mechanism](04_command_line.md#blueprints-and-partial-application). With all that said, it remains easy to add runtime type checking! We saw an example of this on a per-field basis above, but here's how to easily do this for all fields in a class: ```python @chz.chz(typecheck=True) class TypeCheckedAlphabet: alpha: int beta: str gamma: bytes # This is approximately equivalent to adding the following validator: # @chz.validate # def typecheck_all_fields(self): # from chz.validators import for_all_fields, typecheck # for_all_fields(typecheck)(self) ``` `chz`'s runtime type checking is also quite advanced and better in several respects than other open source libraries. ## Validation and inheritance `chz`'s validation works as expected in the presence of inheritance: both class-level and field-level validators are inherited by the child class. There is one caveat: if you clobber a field in a child class, you will also clobber any field-level validator specified in a parent class for that field, unless you explicitly respecify it. `chz` currently does not allow overriding validators in subclasses. This is because it would represent a Liskov substitution principle violation (and use cases are niche). If you need this, have your validator call some other method which you can then freely override. `chz` has some built-in validation, for instance, ensuring that fields do not clobber methods or properties defined on the parent class, etc. ### [Next section — Command line](./04_command_line.md) ================================================ FILE: docs/04_command_line.md ================================================ ## Command line parsing Type aware CLIs are really great. These let you focus on writing code, with types, and you get a CLI for free. `chz` gives this to you as well: ```python def launch_run(name: str, steps: int, checkpoint_dir: str = "az://oai/default"): ... if __name__ == "__main__": chz.entrypoint(launch_run) # The command line: # name=foo steps=100 # becomes: # launch_run(name="foo", steps=100) ``` `chz` will also let you parse into an object: ```python @chz.chz class Experiment: name: str steps: int checkpoint_dir: str = "az://oai/default" def main(): experiment = chz.entrypoint(Experiment) ... if __name__ == "__main__": main() # The command line: # name=foo steps=100 checkpoint_dir=az://oai/somewhere # becomes: # experiment = Experiment(name="foo", steps=100, checkpoint_dir="az://oai/somewhere") ``` If you have a `main` function that takes a single argument that's a `chz` object, you can have it serve as an entrypoint by using `chz.nested_entrypoint`: ```python def main(experiment: Experiment): ... if __name__ == "__main__": chz.nested_entrypoint(main) ``` All of this is pretty straightforward. Here's a case `chz` handles that's a bit more interesting: handling parsing when some of your fields are nested `chz` objects. ```python @chz.chz class Model: encoding: str @chz.chz class Experiment: name: str steps: int model: Model # name=foo steps=100 model.encoding=gpt2 # becomes # Experiment(name="foo", steps=100, model=Model(encoding="gpt2")) ``` ## Hyphens If you like `--hyphens` for your command line arguments, just use `chz.entrypoint(..., allow_hyphens=True)` or similar. But the zero hyphen life is pretty great once you get used to it. ## Polymorphic construction Here's something only `chz` lets you do... what if you want `Model` to be polymorphic? That is, maybe you want to be able to specify `model=Transformer` or `model=Diffusion`. Or maybe you want to construct a field by calling some arbitrary factory function. ```python import chz def wikipedia_text(seed: int) -> Dataset: """Function that produces a dataset.""" @chz.chz class Model: encoding_name: str = "gpt2" @chz.chz class Transformer(Model): n_layers: int = 1000 d_model: int = 100000 @chz.chz class Experiment: model: Model dataset: Dataset experiment = chz.entrypoint(Experiment) # The command line: # model=Transformer model.n_layers=10 dataset=wikipedia_text dataset.seed=217 # becomes: # Experiment(model=Transformer(n_layers=10), dataset=wikipedia_text(seed=217)) ``` This is really powerful. Recursive polymorphic construction allows you to separate concerns more clearly and helps reduce boilerplate. If you're familiar with the "callable" pattern we sometimes end up with in our frameworks (e.g. "dataset_callable") -- this enables having an interface that isn't a complete kludge. This encourages modularity and allows for easy dependency injection. It is common for other libraries in this space, to end up infecting all of your code; polymorphic construction helps you avoid this. In other other words, many tools will let you construct an `X` by specifying `...` to feed to `X(...)`. But chz lets you construct an `X` by specifying both callee and arguments in `...(...)` Anyway, hopefully this all adds up to fewer 1000 line launch scripts or registries or horrible interfaces for parametrising datasets when using `chz`. This is probably the primary interesting feature in `chz`. ## Wildcards It can be a little tiresome specifying fully qualified paths for every field you want to set. To aid with this, `chz` supports wildcards in your blueprint arguments using "...". For example: ``` model=Transform ...encoding=gpt2 model...activation_fn=gelu ``` This will set `encoding` on all nested objects that take an `encoding` argument, and set `activation_fn` on all nested objects inside of `model` that take an `activation_fn` argument. Note that wildcards can match multiple (potentially nested) fields. Wildcard use is somewhat discouraged, particularly so outside of a command line context. ## Discoverability, `--help`, and errors Programs that use `chz.entrypoint` also get you a reasonable `--help` out of the box. ``` $ python script.py --help WARNING: Missing required arguments for parameter(s): dataset Entry point: Experiment Arguments: model Model Model (meta_factory) model.encoding_name str 'gpt2' (default) dataset Dataset - ``` One important note about `--help`: polymorphic construction means that arguments you specify can change the set of arguments you need to specify. For instance, in the above example, `model=Transformer` will allow you to also specify `model.n_layers` and `model.d_model`. However, passing `--help` to a `chz` script along with arguments, will show you all the arguments you can specify given the arguments you've already specified. That is, passing `model=Transformer --help` will show `model.n_layers` and `model.d_model` in the output. ``` $ python script.py model=Transformer --help WARNING: Missing required arguments for parameter(s): dataset Entry point: Experiment Arguments: model Model Transformer (from command line) model.encoding_name str 'gpt2' (default) model.n_layers int 1000 (default) model.d_model int 100000 (default) dataset Dataset - ``` Note that `--help` will also show you the mapping of arguments you specify to fields: ``` $ python script.py model=Transformer ...n_layers=10 model.encoding_name=cl100k_base --help WARNING: Missing required arguments for parameter(s): dataset Entry point: Experiment Arguments: model Model Transformer (from command line) model.encoding_name str cl100k_base (from command line) model.n_layers int 10 (from command line) model.d_model int 100000 (default) dataset Dataset - ``` If you misspell an argument, `chz` will tell you what you probably meant. This fuzzy detection logic works well even for wildcard arguments. ``` chz.blueprint.ExtraneousBlueprintArg: Extraneous Blueprint argument 'modell' for __main__.Experiment Did you mean 'model'? ``` Finally, while `chz` allows you to clobber arguments, it does not allow arguments to go completely unused. This is important for sanity, but somehow a common bug in some CLI libraries, like `fire`. ## Variadic parameters chz supports polymorphic construction through variadic parameters. This works for lists, tuples, dicts (with str keys) and `TypedDict`s: ```python @chz.chz class Eval: name: str @chz.chz class Experiment: evals: list[Eval] experiment = chz.entrypoint(Experiment) # The command line: # evals.0.name=foo evals.1.name=bar # becomes: # Experiment(evals=[Eval(name="foo"), Eval(name="bar")]) ``` Variadic parameters can also be polymorphic, for instance, you could do: ```python # evals.0=EvalSubclass evals.0.name=foo evals.1.name=bar # becomes: # Experiment(evals=[EvalSubclass(name="foo"), Eval(name="bar")]) ``` ## `Blueprint`s, briefly We'll talk about `Blueprint`s more in the next section. For now, all you need to know is that the `Blueprint` class is the API that powers `chz`'s command line functionality. The `chz.entrypoint` function we saw above is basically doing: ``` def entrypoint(target: Callable[..., _T]) -> _T: return Blueprint(target).make_from_argv(sys.argv[1:]) ``` ## Casting The arguments you provide on the command line are strings. However, `chz` wants to give you your arguments with the correct type. By default, `chz` will try to cast your arguments for you to the correct type. This casting is a process you may wish to customise. The first method is by attaching a `__chz_cast__` classmethod to the target type. ```python @dataclass class Duration: seconds: int @classmethod def __chz_cast__(cls, value: str): try: return Duration(int(value.strip("hms")) * {"h": 3600, "m": 60, "s": 1}[value[-1]]) except Exception as e: raise CastError(f"Could not cast {value!r} to {cls.__name__}") from e @chz.chz class Args: t: Duration assert chz.Blueprint(Args).apply({"t": Castable("1h")}).make() == Args(t=Duration(3600)) ``` In the above, since `Duration` is used in the annotation, `chz` will attempt to use `Duration.__chz_cast__` to cast `Castable("1h")` to the correct type. The second method is by specifying a per-field function to `blueprint_cast` via `chz.field`: ```python def cast_binary(value: str) -> int: try: return int(value, 2) except Exception as e: raise CastError(f"Could not cast {value!r} to binary") from e @chz.chz class Args: binary: int = chz.field(blueprint_cast=cast_binary) assert chz.Blueprint(Args).apply({"binary": Castable("101")}).make() == Args(binary=5) ``` Field level casts will override the `__chz_cast__` method if both are applicable. Casting only applies to `Blueprint` (not `__init__` of your `chz` class), and only if the value passed to the `Blueprint` is a `Castable`. Python is a strongly typed language, this is a good thing, `chz` will not change your types willy nilly. ### CLI from a class `chz` lets you easily create a script entrypoint based on the methods on a class using `chz.methods_entrypoint`. For example, given main.py: ```python import chz @chz.chz class Run: name: str def launch(self, cluster: str): "Launch a job on a cluster" return ("launch", self, cluster) if __name__ == "__main__": print(chz.methods_entrypoint(Run)) ``` Try out the following command line invocations: ``` python main.py launch self.name=job cluster=owl python main.py launch --help python main.py --help ``` Note that you can rename the `self` argument in your method to something else. ### Universal CLI ```python import chz chz.entrypoint(object) ``` This script probably isn't actually directly useful, but just to show you the power of `chz`, it will let you call most functions or create most objects. Try: - `python -m chz.universal '=print' '0=hello' '1=lambda name: name + "!"' '1.name=world'` - `python -m chz.universal '=calendar:Calendar' --help` See e.g. `test_root_polymorphism` for how you might actually want to use this. ### [Next section — Blueprints](./05_blueprint.md) ================================================ FILE: docs/05_blueprint.md ================================================ ## Blueprints and partial application `chz` has a `Blueprint` mechanism that powers the command line functionality. `chz.entrypoint` is just a thin wrapper around `Blueprint`. The `Blueprint` mechanism is a Python interface that allows you to do advanced initialisation of objects. In particular, it enables partial application of arguments to a `Blueprint`. Since `chz` objects are immutable, this can be a good substitute for a complex initialisation procedure that relies on mutability. ```python blueprint = chz.Blueprint(Experiment) # Note that apply modifies the blueprint in place, use blueprint.clone() to make a copy blueprint.apply({"encoding_name": "gpt2", "...n_layers": 100}) blueprint.apply({"model": Transformer}) blueprint.apply({"model.n_layers": 10_000}) blueprint.apply({"model.n_layers": Castable("10_000")}) experiment = blueprint.make() # experiment = Experiment(model=Transformer(n_layers=10_000), encoding_name="gpt2") ``` Partial application is lazy and non-destructive. In particular, if you do something incorrect, you will only get errors when you actually try to instantiate, via `make`. If for some reason you need the type aware casting logic that `chz` you get via the command line, you can opt in to it when using `Blueprint.apply` by wrapping your value in `Castable`, e.g. `blueprint.apply({"n_layers": Castable("100")})`. Note that if you already have a `chz` object and you want to replace a field on it, you can use `chz.replace`; this works similarly to `dataclasses.replace`. ## Blueprint polymorphism recap Roughly, the core idea of polymorphic construction is that instead of only being able to assign values to fields, you can also assign the return values of a call to fields: If `chz` sees `field=value`, this is similar to `X(field=value)`. But if `chz` sees `field=value field.a=1 field.b=2`, this is similar to `X(field=value(a=1, b=2))`. For a full explanation of the Blueprint algorithm, see [Blueprint Algorithm](#blueprint-algorithm). ### Discovery and interpretation of valid polymorphic values When figuring out which class you mean to instantiate when you do `model=Transformer`, `chz` will look at all currently created subclasses of `Model` to find the right one. When calling functions, `chz` will look at all the functions in the module of the relevant config. You can also specify a fully qualified path like `module:ClassName` / `package.module:function` and `chz` will import and find your object. This can let you avoid ambiguity or reliance on import time side effects. This discovery process can also be customised via `meta_factory`. This is an advanced feature, see `chz/factories.py` for more details. TODO: talk about some of the more advanced tricks here (i.e. look at some of the things from `test_factories.py`) ## `blueprint_unspecified` This is easiest to understand by example. ```python @chz.chz class Model: ... @chz.chz class Transformer(Model): ... @chz.chz class Experiment: model: Model = chz.field(blueprint_unspecified=Transformer) ``` Say you have an entrypoint that can run an experiment on an arbitrary model. But in practice, you mostly want to run experiments on `Transformer`s. Rather than force your users to have to specify `model=Transformer` every time, you can use `blueprint_unspecified` to specify what `chz` should attempt to polymorphically construct if there isn't an argument specified. #### Confusion about `blueprint_unspecified` and `default/default_factory` Users of `chz` are commonly confused by the relationship between `blueprint_unspecified` and `default/default_factory`. There is no relationship! `Blueprint` will **never** look at the value of `default/default_factory`. The primary interaction with `Blueprint`s is that their absence or presence will mark an argument as required or not. I recommend when in doubt not using `default/default_factory` for fields you wish to polymorphically construct. (One could ask why `chz` doesn't attempt to infer `blueprint_unspecified` from `default/default_factory`. This is a good question, but has a longer answer than is worth going into here) ## Presets or shared configuration Partial application gives you the ability to add presets. For example, consider a typical experiment command line: ``` ⤹ preset name python main.py small_gpt seed=217 name=just_a_lil_guy ~~~~~~~~~ ``` You could mimic this with something like: ```python @chz.chz class Experiment: ... presets: dict[str, chz.Blueprint] = { "small_gpt": chz.Blueprint(Experiment).apply( {"seed": 0, "model": Transformer, "model.n_layers": 4}, layer_name="small gpt preset", ), ... } def main(): preset, *argv = sys.argv[1:] blueprint = presets[preset].clone() experiment = blueprint.make_from_argv(argv) ``` The layer name is a subtle thing that's quite important, since adding `--help` to any command line will show you exactly where each value being used is coming from: ``` Arguments: model Model Transformer (from small gpt preset) model.encoding_name str 'gpt2' (default) model.n_layers int 4 (from small gpt preset) ... ``` I will some day add built-in support for presets in `chz` in the future. For now, add your own extensions to manipulate `Blueprint`s. ## Custom tooling The `Blueprint` APIs are powerful. At OpenAI, there's a number of interesting custom tools that build on top of the `Blueprint` APIs. In particular, take a look at `Blueprint._make_lazy`. It's also worth familiarising yourself with the `_ArgumentMap` class. Don't be scared by the underscores. Just add tests for the extensions you write. ## Undocumented Blueprint features There are a number of powerful `Blueprint` features that are not yet documented. The good news is they all have tests that demonstrate their usage. I mention this here because if you hit some case you would like to express, it's possible that there is a way to express this. ## Blueprint algorithm The source code is of course the best source of truth. Very very roughly, the algorithm is: 1. Blueprint arguments are "layers" of dicts from arguments (possibly wildcard) to value provided. 2. For a given parameter `foo.bar`, find the latest layer that has an argument matching the `foo.bar` parameter. 3. If there is no matching argument, check to see if we can call something to construct the value. 1. Check to see if there is a callable specified by `blueprint_unspecified` 2. Otherwise, use `chz`'s best guess (if `chz` has one) 3. Attempt to call this function, with recursive discovery of parameters. 4. If this doesn't work out, we'll use `default/default_factory` if it exists, if not, we'll error for missing a required argument. 4. If there is such an argument, we now attempt to use it! 5. Check if it's a valid value for the parameter (or is a `Castable` that can be casted to the correct type). This is done by checking if the value is of the right type and if there are not additional subarguments specified. 6. Otherwise, attempt to use the value as a callable we can call to construct the value (or a `Castable` that can be casted to a callable). ### [Next section — Serialisation](./06_serialisation.md) ================================================ FILE: docs/06_serialisation.md ================================================ ## Serialisation and deserialisation `chz` will one day have a great story for versioned serialisation and deserialisation. The main obstacle is that I'm busy and cowardly. Note also that it's easy to roll your own *un*versioned serialisation and deserialisation. There are two utility functions in `chz` that you may find useful: `chz.beta_to_blueprint_values` and `chz.asdict`. Example: ```python import chz @chz.chz class P: x: float y: float @chz.chz class C: s: str p: P obj = C(s="foo", p=P(x=1.0, y=2.0)) print(chz.beta_to_blueprint_values(obj)) # {'s': 'foo', 'p': , 'p.x': 1.0, 'p.y': 2.0} print(chz.Blueprint(type(obj)).apply(chz.beta_to_blueprint_values(obj)).make()) # C(s='foo', p=P(x=1.0, y=2.0)) print(chz.asdict(obj)) # {'s': 'foo', 'p': {'x': 1.0, 'y': 2.0}} ```
Thoughts on pickle Pickle is actually totally fine here, if you don't need human readability. `chz` is powerful enough that the ability to execute arbitrary code when deserialising is mostly going to be the same as `pickle`'s. The other thing `pickle` doesn't give you is versioning. Here's a dumb hack that allows evolution for basic field additions. ```python import pickle import chz from chz.util import MISSING @chz.chz class A: a: int d = pickle.dumps(A(a=5)) @chz.chz class A: a: int b: bool = True def __setstate__(self, state): for field in self.__chz_fields__.values(): if field.x_name not in state: if field._default is not MISSING: state[field.x_name] = field._default if field._default_factory is not MISSING: state[field.x_name] = field._default_factory() self.__dict__.update(state) return self print(pickle.loads(d)) ```
### [Next section — Post Init](./21_post_init.md) ================================================ FILE: docs/21_post_init.md ================================================ ## No `__post_init__`; details and examples There are a couple reasons why `chz` does not have a `__post_init__` equivalent for munging your fields: - `__post_init__` has bad ergonomics with immutable objects (e.g. you need to use `object.__setattr__` or some wrapper to mutate fields) - `__post_init__` encourages non-local initialisation behaviour that can be hard to reason about - `__post_init__`'s interaction with `super` is easy to mess up - `__post_init__`'s interaction with static type checkers is bad (if you use munging in `__post_init__` to narrow types) ### Caveat!! If I were to rewrite `chz` from scratch, I would do things a little bit differently here (and I may yet change some of this stuff). In a tale as old as time, I ended up where we are today via a several changes in response to several things over a period of several years. I have one more big change planned at some point. Most notably, chz sort of predates PEP 681. E.g. the `X_` stuff that is most incompatible with PEP 681 was made inconvenient (it used to be `隐_`), but these things proved useful and I relented. ### No `__post_init__` Here is some detail about the constraints posed by wanting static type checking to work. Take a look at the following example: ```python # This will error since chz does not have a __post_init__ @chz.chz class Experiment: name: str steps: int wandb_log_name: Optional[str] = None def __post_init__(self): if self.wandb_log_name is None: self.wandb_log_name = self.name raise NotImplementedError("chz does not actually have a __post_init__; this is a hypothetical") ``` Now, anytime you try to use `experiment.wandb_log_name` you'll have to assert that it is not None because the type checker doesn't know that `__post_init__` will always set it. This is a sorry state of affairs. ### Solution that fully works with static type checkers Compare that to doing: ```python # Recommended solution @chz.chz class Experiment: name: str steps: int wandb_log_name: Optional[str] = None @chz.init_property def wandb_log_name_value(self) -> str: return self.wandb_log_name or self.name ``` Now use the `experiment.wandb_log_name_value` attribute instead. And if you mess it up and use the wrong one, the type checker will warn you, since the types are different! It does suck that you have to come up with a different name like `wandb_log_name_value`. If you do this, I recommend using the `_value` suffix for this. ### Alternative "magic prefix" solution (does not fully work with static type checkers) ```python # Alternative "magic prefix" solution @chz.chz class Experiment: name: str steps: int # chz will magically strip the "X_" in the __init__ parameter, and do the equivalent of # `self.X_wandb_log_name = wandb_log_name` in __init__. X_wandb_log_name: Optional[str] = None @chz.init_property def wandb_log_name(self) -> str: return self.X_wandb_log_name or self.name # Now you can instantiate your object using `wandb_log_name` as a parameter # ...but static type checkers will complain about all direct instantiations of your object experiment = Experiment(name="train_job", steps=100, wandb_log_name=None) assert experiment.wandb_log_name == "train_job" ``` One note for why this design: because definitions with the same name in classes clobber each other, the field name needs to be different from the `init_property` name. Otherwise, `chz` is not able to access the `chz.field` spec / default value for the field. ### Alternative "munging" solution (mostly works with static type checkers) ```python # Alternative "munging" solution @chz.chz class Experiment: name: str steps: int # The value passed to the constructor will end up processed by the munger function # wandb_log_name: str = chz.field(munger=lambda self, value: value or self.name) # You can use the combinators in chz.mungers too, for example: # wandb_log_name: Optional[str] = chz.field(munger=attr_if_none("name")) # If the value passed to `__init__` can be of another type (say None), you can use x_type so # that type aware parsing continues to work wandb_log_name: str = chz.field(munger=attr_if_none("name"), x_type=str | None) experiment = Experiment(name="train_job", steps=100, wandb_log_name=None) ``` As you can see, mungers can access any other attribute (which itself may be munged). chz will handle the ordering well. If you do something recursive, you will get an error. The way to handle this is to explicitly access the raw unmunged value, which you can find on `self` with the `X_` prefix. Currently munged values are validated both before and after munging. This allows you to rely on validation during munging and as an invariant. The exact logic here may change in the future. Note that munging is best for defaulting logic. If you wish to simply change command line parsing logic, consider using `blueprint_cast` and `__chz_cast__` instead. ### Mechanics and `X_` prefix Here's what chz is basically doing under the hood. When it sees: ```python @chz.chz class Args: foo: int = chz.field(default=1) ``` It will convert this to: ```python @chz.chz class Args: def __init__(self, foo: int = 1): self.X_foo = foo ... # some other stuff, like validation @chz.init_property def foo(self) -> int: return self.X_foo ``` If there's a munger for the field `foo`, then the `init_property` added will do whatever the munger does, instead of just returning `self.X_foo`. This design has several advantages: - This handles any graph of field reference between munging in attributes well without e.g. forcing ordering constraints on definitions - We preserve a lot of information about intent - We don't have to worry about non-idempotent munging, e.g. when doing `chz.replace` - In fact, we can even detect impure or non-idempotent `init_property` - Similarly, when deserialising, we could detect if `init_property` behaviour has changed - It keeps semantics relatively consistent between all options discussed on this page - It works well when you inherit from another chz class, but wish to override the field with an `init_property` ### [Next section — Field API](./22_field_api.md) ================================================ FILE: docs/22_field_api.md ================================================ ## `chz.field` `chz.field` takes the following parameters: #### `default` Like with `dataclasses`, the default value for the field (if any). #### `default_factory` Like with `dataclasses`, a function that returns the default value for the field. Useful for mutable types, for instance, `default_factory=list`. Note: this does not interact with parametrisation / `Blueprint` / `blueprint_unspecified` / `meta_factory`. The only thing that matters to parametrisation is presence or absence of a `default` or `default_factory`. Perhaps a better name would be `lazy_default` (but unfortunately, this is not supported by PEP 681, so static type checkers would lose the ability to understand the class). #### `validator` A function or list of functions that validate the field. Field validators take two arguments: the instance of the class and the name of the field. See also: [Validation](./03_validation.md) #### `repr` If a boolean, whether or not to include the field in the `__repr__` of the class. If a callable, will be used to construct the `repr` of the field. #### `doc` The docstring for the field. Used in `--help`. #### `metadata` Arbitrary user-defined metadata to attach to the field. Useful when extending `chz`. #### `munger` Lets you adjust the value of a field. Essentially works the same as an init_property. See also: [Alternative "munging" solution](./21_post_init.md) #### `x_type` Useful in combination with mungers. This specifies the type before munging that will be used for parsing and type checking. #### `meta_factory` A metafactory represents the set of possible callables that can give us a value of a given type. Describes the set of callables that are capable of returning a valid value for the field if given a non-zero number of arguments. For instance, the meta factory `chz.factories.subclass(Model)` is a description of the set of callables that are capable of producing a `Model` (e.g. `{Transformer, Diffusion}`). This was more useful in previous versions of `chz`, but now `chz` infers what you want to do more reliably. See also: the docs in `chz/factories.py` #### `blueprint_unspecified` This is the default callable `Blueprint` may attempt to call to get a value of the expected type. See [Blueprint](./05_blueprint.md#blueprint_unspecified) #### `blueprint_cast` A function that takes a str and returns an object. On failure to cast, it should raise `CastError`. Used to achieve custom parsing behaviour from the command line. Takes priority over the `__chz_cast__` dunder method (if present on the target type). See also: [Casting](./04_command_line.md#casting) ### [Next section — Philosophy](./91_philosophy.md) ================================================ FILE: docs/91_philosophy.md ================================================ # Philosophy There are a few different ideas in `chz` and not all of them are equally valuable or well designed. In particular, if you're using `chz` just for the object model, I'm not sure how useful it is compared to other libraries (see [alternatives](./92_alternatives.md) for many options). Configuration may never be easy! How you design your configurations may well be a bigger deal than what configuration system you use. I have seen `chz` used in ways and at a scale that I never envisioned. I hope this library is useful to you, and if not, I hope it encourages you to build the tools you want (because you deserve them!) The rest of this page is just rambling about things. ## Modularity We've had some monolithic configuration at OpenAI that made systems painful to work with. Some ways to avoid this are deeper nested configuration hierarchies (so things can be self-contained / testable / reusable) and polymorphism (to reduce the cartesian product explosion of config space). I hope that the `chz` classes you define can be reused across multiple entrypoints. At some point, trying to do everything within a single entrypoint makes it hard to create a great experience for all users. Sharing `chz` classes but specialising your own entrypoint with partial application seems like a good balance. As a corollary, this has made me hesitant about adding things to `chz` class definitions that primarily affect `Blueprint`s that use those classes (e.g. even `blueprint_unspecified`). One downside of pushing for modularity and self-contained-ness is that it makes situations where you want to access fields of parents or siblings more awkward. I recommend at least using validators to ensure consistency. (There are some undocumented features that help with this, and I'll likely add more things here in the future) Wildcards are a somewhat controversial feature, but they do at least lower the cost of having fairly deep configuration hierarchies. ## Partial application I've seen some users have a bit of a learning curve when first encountering `chz.Blueprint`, but I'm actually fairly convinced repeated partial application and a one-time initialisation and validation is a good pattern in a lot of use cases (even if you don't use the command line stuff). If you find yourself doing `chz.replace` a lot, ask yourself if you should be using `Blueprint`! ## Managing state `chz` objects are immutable in the sense that you cannot reassign a field. This was a design choice informed by scars from a previous system (and one that I think has been quite healthy). Note however that fields on `chz` objects can be mutable objects. It can be convenient to have state on your configuration objects (e.g. this lets you reuse your polymorphic hierarchy). Currently, I recommend patterns like: ```python @chz.chz class Config: state: types.SimpleNamespace def get_state(self): return self.state def set_state(self, state): if state is None: self.state.value = 0 else: self.state.value = state.value def mutate(self): self.state.value += 1 ``` Not coincidentally, this will remind you of the pattern I use in `fiddle` (OpenAI's dataloader). Currently, I leave the details of state management to downstream applications, but let me know if you think I should merge some of these things into `chz`. ### [Next section — Alternatives](./92_alternatives.md) ================================================ FILE: docs/92_alternatives.md ================================================ ## Alternatives The most common question I get when someone first sees `chz` is "...but have you heard about X?" Here are some values for X that I have heard of. A lot of these libraries are great; `chz` builds off of ideas from multiple of them, executes some things differently, and also has some novel features. Anyway, here is a list of things `chz` is not: (data model) - [dataclasses](https://docs.python.org/3/library/dataclasses.html) - [attrs](https://www.attrs.org/en/stable/) - [msgspec](https://jcristharif.com/msgspec/) - [pydantic](https://docs.pydantic.dev/) - hyperparams (internal) (serialisation) - [msgspec](https://jcristharif.com/msgspec/) - [cattrs](https://catt.rs/en/stable/readme.html#features) - [apischema](https://wyfo.github.io/apischema/) - [marshmallow](https://marshmallow.readthedocs.io/en/stable/) - [dacite](https://github.com/konradhalas/dacite) - [dataclasses_json](https://github.com/lidatong/dataclasses-json) - dump (internal) (cli) - [fire](https://github.com/google/python-fire) - [appeal](https://github.com/larryhastings/appeal) - [typer](https://typer.tiangolo.com/) - smokey (internal) - ein (internal) (runtime typing) - [typeguard](https://github.com/agronholm/typeguard) - [trycast](https://github.com/davidfstr/trycast) - [runtype](https://github.com/erezsh/runtype) (config solutions) - [hydra](https://hydra.cc/docs/intro/) - [the other fiddle](https://github.com/google/fiddle) - [gin](https://github.com/google/gin-config) - [hyperstate](https://github.com/cswinter/hyperstate) I don't think there's anything on this list that covers the same set of functionality I'm aiming for here. I also have specific bones to pick with some of these libraries :-) Let me know if you think there's a feature that would be constructive to add! ================================================ FILE: docs/93_testimonials.md ================================================ ## Testimonials Unsolicited feedback from users of `chz`. To be honest, I'm surprised people like it this much: > “pretty much perfectly what I always wanted for configs” szymon > “open-sourcing chz would increase annual world GDP growth > .1%” daniel selsam > “chz was really really good insight :) ty for making a lot of things a lot simpler” hunter > “i really like chz. thank you for your service making it. at previous companies i feel like we struggled to create a similar config service” mostafa > “chz is amazing; i’ve never used it before, and after using it for 30 minutes i’m glad [we've switched to using chz]” alex nichol > “chz has quickly become one of my favorite Python libraries” chris koch > “chz is so good :froge-chefkiss: thanks for building it!!” dmed > “Hey, I just want to mention how much I love chz. Config management is one of the hardest things to get right on large software projects and I've never seen anything as good as chz!” adam lerer ================================================ FILE: pyproject.toml ================================================ [project] name = "chz" version = "0.4.0" description = "chz is a library for managing configuration" readme = "README.md" license = {file = "LICENSE"} authors = [{name = "Shantanu Jain"}, {email = "shantanu@openai.com"}] dependencies = [ "typing-extensions>=4.13", ] requires-python = ">=3.11" [project.urls] homepage = "https://github.com/openai/chz" repository = "https://github.com/openai/chz" changelog = "https://github.com/openai/chz/blob/main/CHANGELOG.md" [build-system] build-backend = "setuptools.build_meta" requires = ["setuptools>=62.4", "wheel"] [tool.setuptools.packages.find] include = ["chz*"] [tool.pytest.ini_options] addopts = [ "--strict-markers", "--strict-config", "-p", "no:pytest_mock_resources", "-p", "no:httpx", "-p", "no:aiohttp", "-p", "no:faker", "-p", "no:ddtrace", "-p", "no:ddtrace.pytest_bdd", "-p", "no:ddtrace.pytest_benchmark", "-p", "no:hypothesispytest", "-p", "no:anyio", "-p", "no:benchmark", "-p", "no:pytest_mock", "-p", "no:typeguard", "-p", "no:asyncio", ] [tool.mypy] strict = true disallow_untyped_decorators = true disallow_any_generics = true disallow_untyped_calls = true disallow_subclassing_any = false disallow_incomplete_defs = false disallow_untyped_defs = false warn_return_any = false warn_unreachable = true [[tool.mypy.overrides]] module = ["chz.tiepin"] ignore_errors = true [tool.coverage.report] exclude_lines = [ "pragma: no cover", "raise AssertionError", "raise NotImplementedError", "if MYPY", "if TYPE_CHECKING", "elif TYPE_CHECKING", ] ================================================ FILE: tests/test_blueprint.py ================================================ import pytest import chz from chz.blueprint import ( Castable, ConstructionException, ExtraneousBlueprintArg, InvalidBlueprintArg, MissingBlueprintArg, ) def test_entrypoint(): def foo(a: int, b: str, c: float = 1.0): return locals() argv = ["a=1", "b=str", "c=5"] assert chz.entrypoint(foo, argv=argv) == foo(1, "str", 5) argv = ["a=1", "b=str"] assert chz.entrypoint(foo, argv=argv) == foo(1, "str", 1) # test allow_hyphens assert chz.entrypoint(foo, argv=argv, allow_hyphens=True) == foo(1, "str", 1) argv = ["--a=1", "--b=str", "c=5"] with pytest.raises( ExtraneousBlueprintArg, match=( r"Extraneous argument '--a' to Blueprint for test_blueprint.test_entrypoint..foo \(from command line\)" "\nDid you mean to use allow_hyphens=True in your entrypoint?" ), ): chz.entrypoint(foo, argv=argv) assert chz.entrypoint(foo, argv=argv, allow_hyphens=True) == foo(1, "str", 5) # test positional argv = ["a", "1", "b", "str"] with pytest.raises( ValueError, match="Invalid argument 'a'. Specify arguments in the form key=value" ): chz.entrypoint(foo, argv=argv) def test_entrypoint_nested(): @chz.chz class X: a: int b: str c: float = 1.0 def main(x: X) -> list[X]: return [x] argv = ["a=1", "b=str", "c=5"] assert chz.nested_entrypoint(main, argv=argv) == [X(a=1, b="str", c=5)] argv = ["a=1", "b=str"] assert chz.nested_entrypoint(main, argv=argv) == [X(a=1, b="str", c=1)] def test_apply_strictness(): """Test strictness of application when configured and non-strictness when not.""" @chz.chz class X: hello: int = 5 # misspelled! No error on application, but error on make. misspelled_bp = chz.Blueprint(X).apply({"hllo": 1}) with pytest.raises(ExtraneousBlueprintArg): misspelled_bp.make() # In strict mode, we get an error on apply. with pytest.raises(ExtraneousBlueprintArg): chz.Blueprint(X).apply({"hllo": 1}, strict=True) def test_basic_function_blueprint(): def foo(a: int, b: int | str, c: bool = False, d: bytes = b""): return locals() # regular assert chz.Blueprint(foo).apply({"a": 1, "b": "str"}).make() == foo(1, "str", False, b"") # default assert chz.Blueprint(foo).apply({"a": 1, "b": "2", "c": True}).make() == foo(1, "2", True, b"") # clobbered assert chz.Blueprint(foo).apply({"a": 1, "b": "str"}).apply({"a": 2}).make() == foo( 2, "str", False, b"" ) # castable assert chz.Blueprint(foo).apply( {"a": Castable("1"), "b": Castable("str"), "c": Castable("True")} ).make() == foo(1, "str", True, b"") with pytest.raises(TypeError, match="Expected 'a' to be int, got str"): chz.Blueprint(foo).apply({"a": "asdf"}).make() with pytest.raises( InvalidBlueprintArg, match=( "- Failed to interpret it as a value:\n" "Could not cast 'asdf' to int\n\n" "- Failed to interpret it as a factory for polymorphic construction:\n" "No subclass of int named 'asdf'" ), ): chz.Blueprint(foo).apply({"a": Castable("asdf")}).make() def test_basic_class_blueprint(): @chz.chz class X: a: int b: str # regular assert chz.Blueprint(X).apply({"a": 1, "b": "str"}).make() == X(a=1, b="str") # clobbered assert chz.Blueprint(X).apply({"a": 1, "b": "str"}).apply({"a": 2}).make() == X(a=2, b="str") # castable assert chz.Blueprint(X).apply({"a": Castable("1"), "b": "str"}).make() == X(a=1, b="str") with pytest.raises(TypeError, match="Expected 'a' to be int, got str"): chz.Blueprint(X).apply({"a": "asdf"}).make() with pytest.raises( InvalidBlueprintArg, match=( "- Failed to interpret it as a value:\n" "Could not cast 'asdf' to int\n\n" "- Failed to interpret it as a factory for polymorphic construction:\n" "No subclass of int named 'asdf'" ), ): chz.Blueprint(X).apply({"a": Castable("asdf")}).make() def test_blueprint_unused(): def foo(alpha: int, beta: str = ""): return locals() with pytest.raises( ExtraneousBlueprintArg, match="Extraneous argument 'missing' to Blueprint for test_blueprint.test_blueprint_unused..foo", ): assert chz.Blueprint(foo).apply({"alpha": 1, "missing": "oops"}).make() @chz.chz class Foo: alpha: int beta: str assert chz.Blueprint(Foo).apply({"alpha": 1, "beta": "str"}).make() == Foo(alpha=1, beta="str") with pytest.raises( ExtraneousBlueprintArg, match="Extraneous argument 'missing' to Blueprint for test_blueprint.test_blueprint_unused..Foo", ): assert chz.Blueprint(Foo).apply({"alpha": 1, "missing": "oops"}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean 'alpha'"): assert chz.Blueprint(Foo).apply({"alpha": 1, "aleph": "oops"}).make() @chz.chz class Bar: foo: Foo gamma: int assert ( chz.Blueprint(Bar).apply({"foo.alpha": 1, "foo.beta": "str", "gamma": 1}).make() ) == Bar(foo=Foo(alpha=1, beta="str"), gamma=1) with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean 'foo\.alpha'"): assert ( chz.Blueprint(Bar) .apply({"foo.alpha": 1, "foo.beta": "str", "foo.aleph": "oops"}) .make() ) with pytest.raises( ExtraneousBlueprintArg, match=r"Did you get the nesting wrong, maybe you meant 'foo.alpha'\?", ): assert ( chz.Blueprint(Bar) .apply({"foo.alpha": 1, "foo.beta": "str", "alpha": "oops", "gamma": 1}) .make() ) assert ( chz.Blueprint(Bar).apply({"...alpha": 1, "...beta": "str", "...gamma": 1}).make() ) == Bar(foo=Foo(alpha=1, beta="str"), gamma=1) with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean '\.\.\.alpha'"): assert chz.Blueprint(Bar).apply({"...alpha": 1, "...aleph": "oops"}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean '\.\.\.foo\.\.\.alpha'"): assert chz.Blueprint(Bar).apply({"...alpha": 1, "...foo...aleph": "oops"}).make() @chz.chz class Baz: bar: Bar delta: int assert ( chz.Blueprint(Baz) .apply({"bar.foo.alpha": 1, "bar.foo.beta": "str", "bar.gamma": 1, "delta": 1}) .make() ) == Baz(bar=Bar(foo=Foo(alpha=1, beta="str"), gamma=1), delta=1) with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean 'bar\.foo\.alpha'"): assert ( chz.Blueprint(Baz) .apply({"bar.foo.alpha": 1, "bar.foo.beta": "str", "bar.foo.aleph": "oops"}) .make() ) with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean 'bar\.foo\.\.\.alpha'"): assert chz.Blueprint(Baz).apply({"...alpha": 1, "bar.foo...aleph": "oops"}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean '\.\.\.bar\.\.\.foo\.alpha'"): assert chz.Blueprint(Baz).apply({"...alpha": 1, "...bar...foo.aleph": "oops"}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean '\.\.\.bar\.\.\.foo\.alpha'"): assert chz.Blueprint(Baz).apply({"...alpha": 1, "...bar...foZ.aleph": "oops"}).make() def test_blueprint_unused_nested_default(): @chz.chz class Sub: alpha: int = 1 beta: str = "str" @chz.chz class Main: sub: Sub gamma: int with pytest.raises(ExtraneousBlueprintArg, match=r"Did you mean 'sub\.beta'"): chz.Blueprint(Main).apply({"sub.bata": "str"}).make() def test_blueprint_missing_args(): @chz.chz class Alpha: alpha: int beta: str with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): alpha, beta" ): chz.Blueprint(Alpha).make() @chz.chz class Main: alpha: Alpha with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): alpha.alpha, alpha.beta", ): chz.Blueprint(Main).make() with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): alpha.beta" ): chz.Blueprint(Main).apply({"alpha.alpha": 1}).make() @chz.chz class MainDefault: alpha: Alpha = chz.field(default=Alpha(alpha=1, beta="str")) other: int with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): other" ): chz.Blueprint(MainDefault).make() def three_item_dataset(first: str = "a", second: str = "b", third: str = "c") -> list[str]: return [first, second, third] @chz.chz class Model: family: str n_layers: int salt: bytes = b"salt" @chz.chz class Experiment: model: Model dataset: list[str] = chz.field(doc="Yummy data!") def test_nested_construction(): expected = Experiment( model=Model(family="linear", n_layers=1, salt=b"0000"), dataset=["a", "b", "c"] ) assert ( chz.Blueprint(Experiment) .apply( { "model.family": "linear", "model.n_layers": 1, "model.salt": b"0000", "dataset": ["a", "b", "c"], } ) .make() == expected ) def test_nested_construction_with_default_value(): expected = Experiment( model=Model(family="linear", n_layers=1, salt=b"salt"), dataset=["a", "b", "c"] ) assert ( chz.Blueprint(Experiment) .apply({"model.family": "linear", "model.n_layers": 1, "dataset": ["a", "b", "c"]}) .make() == expected ) def test_nested_construction_with_factory_dataset(): expected = Experiment( model=Model(family="linear", n_layers=1, salt=b"0000"), dataset=["first", "b", "third"] ) assert ( chz.Blueprint(Experiment) .apply( { "model.family": "linear", "model.n_layers": 1, "model.salt": b"0000", "dataset": three_item_dataset, "dataset.first": "first", "dataset.third": "third", } ) .make() == expected ) def test_nested_construction_with_wildcards(): expected = Experiment( model=Model(family="linear", n_layers=1, salt=b"0000"), dataset=["first", "b", "third"] ) assert ( chz.Blueprint(Experiment) .apply( { "...family": "linear", "...n_layers": 1, "...salt": b"0000", "...dataset": three_item_dataset, "...first": "first", "...third": "third", } ) .make() == expected ) assert ( chz.Blueprint(Experiment) .apply( { "...family": "linear", "...n_layers": 1, "...salt": b"0000", "...dataset": Castable("three_item_dataset"), "dataset...first": "first", "...third": "third", } ) .make() == expected ) assert ( chz.Blueprint(Experiment) .apply( { "...family": "linear", "...n_layers": 1, "...salt": b"0000", "...dataset": Castable("three_item_dataset"), "...dataset...first": "first", # even more wildcard "...third": "third", } ) .make() == expected ) def test_nested_all_defaults(): @chz.chz class X: value: int = 0 @chz.chz class Y: x: X y: int assert chz.Blueprint(Y).apply({"y": 5}).make() == Y(x=X(value=0), y=5) def test_nested_not_all_defaults(): @chz.chz class X: value: int @chz.chz class Y: x: X @chz.chz class Parent: child: Y | None = None assert chz.Blueprint(Parent).make() == Parent(child=None) def test_nested_all_defaults_primitive(): @chz.chz class X: value: int @chz.chz class Y: x: X | None = None assert chz.Blueprint(Y).make() == Y(x=None) def test_nested_all_defaults_unspecified_nested(): @chz.chz class Value: value: int @chz.chz class A: value: Value @chz.chz class B(A): value: Value = Value(value=1) @chz.chz class Main: a: A = chz.field(blueprint_unspecified=B) assert chz.Blueprint(Main).make() == Main(a=B(value=Value(value=1))) def test_nested_construction_with_default_factory(): @chz.chz class ChildV1: required1: int @chz.chz class ChildV2: required2: int @chz.chz class Parent: child_v1: ChildV1 = chz.field(default_factory=lambda: ChildV1(required1=1)) child_v2: ChildV2 = chz.field(default_factory=lambda: ChildV2(required2=2)) assert chz.Blueprint(Parent).apply({}).make() == Parent( child_v1=ChildV1(required1=1), child_v2=ChildV2(required2=2) ) def test_help(): assert ( chz.Blueprint(Experiment).get_help() == """\ WARNING: Missing required arguments for parameter(s): model.family, model.n_layers, dataset Entry point: test_blueprint:Experiment Arguments: model test_blueprint:Model - model.family str - model.n_layers int - model.salt bytes b'salt' (default) dataset list[str] - Yummy data! """ ) assert ( chz.Blueprint(Experiment).apply({"model.family": "gpt"}, layer_name="gpt config").get_help() == """\ WARNING: Missing required arguments for parameter(s): model.n_layers, dataset Entry point: test_blueprint:Experiment Arguments: model test_blueprint:Model test_blueprint:Model (meta_factory) model.family str 'gpt' (from gpt config) model.n_layers int - model.salt bytes b'salt' (default) dataset list[str] - Yummy data! """ ) @chz.chz class Foo: a: int = chz.field(default_factory=lambda: 1 + 1) assert ( chz.Blueprint(Foo).get_help() == """\ Entry point: test_blueprint:test_help..Foo Arguments: a int (lambda: 1 + 1)() (default) """ ) assert ( chz.Blueprint(Foo).apply({"a": Castable("2")}, layer_name="preset").get_help() == """\ Entry point: test_blueprint:test_help..Foo Arguments: a int 2 (from preset) """ ) def test_logical_name_blueprint(): @chz.chz class X: X_seed1: int X_seed2: int @property def seed1(self): return self.X_seed1 + 100 @property def seed2(self): return self.X_seed2 + 100 x = chz.Blueprint(X).apply({"seed1": 1, "seed2": 2}).make() assert x.seed1 == 101 assert x.seed2 == 102 assert x == X(seed1=1, seed2=2) def test_blueprint_unpack_kwargs(): from typing import TypedDict, Unpack class Args(TypedDict): a: int b: str def foo(**kwargs: Unpack[Args]) -> Args: return Args(**kwargs) assert chz.Blueprint(foo).apply( {"a": chz.Castable("1"), "b": chz.Castable("2")} ).make() == Args(a=1, b="2") with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): b" ): chz.Blueprint(foo).apply({"a": 1}).make() def test_blueprint_castable_but_subpaths(): @chz.chz class A: field: str @chz.chz class Main: a: A = chz.field(blueprint_cast=lambda s: A(field=s)) with pytest.raises( InvalidBlueprintArg, match=r"""Could not interpret argument 'works' provided for param 'a'... - Failed to interpret it as a value: Not a value, since subparameters were provided \(e.g. 'a.field'\) - Failed to interpret it as a factory for polymorphic construction: No subclass of test_blueprint:test_blueprint_castable_but_subpaths..A named 'works'""", ): chz.Blueprint(Main).apply({"a": Castable("works"), "a.field": Castable("field")}).make() def test_blueprint_value_but_subpaths(): @chz.chz class A: field: int @chz.chz class Main: a: A | None with pytest.raises( InvalidBlueprintArg, match=r"Could not interpret None provided for param 'a' as a value, " r"since subparameters were provided \(e.g. 'a.field'\)", ): chz.Blueprint(Main).apply({"a": None, "a.field": Castable("field")}).make() def test_blueprint_apply_subpath(): @chz.chz class A: field: int @chz.chz class AA: a: A @chz.chz class Main: a: AA field: int = 0 assert chz.Blueprint(Main).apply({"field": 1}, subpath="a.a").make() == Main( a=AA(a=A(field=1)), field=0 ) assert chz.Blueprint(Main).apply({"...field": 1}, subpath="").make() == Main( a=AA(a=A(field=1)), field=1 ) assert chz.Blueprint(Main).apply({"...field": 1}, subpath="a").make() == Main( a=AA(a=A(field=1)), field=0 ) with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument 'b.field' to Blueprint"): chz.Blueprint(Main).apply({"field": 1}, subpath="b").make() def test_blueprint_enum_all_defaults(): import enum class E(enum.StrEnum): foo = enum.auto() bar = enum.auto() @chz.chz class Inner: e: E = E.foo @chz.chz class Args: e: E = E.foo inner: Inner assert chz.Blueprint(Args).make() == Args(e=E.foo, inner=Inner(e=E.foo)) assert chz.Blueprint(Args).apply({"inner.e": Castable("bar")}).make() == Args( e=E.foo, inner=Inner(e=E.bar) ) def test_blueprint_functools_partial(): import functools def foo(a: int = 1, b: int = 2): return a, b partial_foo = functools.partial(foo, a=3, b=4) assert chz.Blueprint(partial_foo).apply({"a": 5}).make() == (5, 4) @chz.chz class Foo: a: int = 1 b: int = 2 partial_foo = functools.partial(Foo, a=3, b=4) assert chz.Blueprint(partial_foo).apply({"a": 5}).make() == Foo(a=5, b=4) @chz.chz class A: a: str = "a" @chz.chz class B(A): a: str = "b" @chz.chz class Main: a: A = chz.field(blueprint_unspecified=B) field: int = 0 partial_main = functools.partial(Main, field=1) assert chz.Blueprint(partial_main).make() == Main(a=B(a="b"), field=1) def test_blueprint_unspecified_functools_partial(): @chz.chz class A: field: int typ: str = "a" @chz.chz class B(A): typ: str = "b" import functools @chz.chz class Main: a: A = chz.field(blueprint_unspecified=functools.partial(B, field=1)) assert chz.Blueprint(Main).make() == Main(a=B(field=1)) @chz.chz class C(A): missing: int @chz.chz class Main: a: A = chz.field(blueprint_unspecified=functools.partial(C, field=2)) with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): a.missing" ): chz.Blueprint(Main).make() def test_blueprint_positional_only(): def pos_only(a: int = 42, /): return a assert chz.entrypoint(pos_only, argv=[]) == 42 def pos_only_no_default(a: int, /): return a with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): 0" ): chz.entrypoint(pos_only_no_default, argv=[]) assert chz.entrypoint(pos_only_no_default, argv=["0=1"]) == 1 def test_blueprint_args_kwargs(): def args_only(*args: int): return args assert chz.entrypoint(args_only, argv=[]) == () assert chz.entrypoint(args_only, argv=["0=1", "1=2"]) == (1, 2) def kwargs_only(**kwargs: int): return kwargs assert chz.entrypoint(kwargs_only, argv=[]) == {} assert chz.entrypoint(kwargs_only, argv=["a=1", "b=2"]) == {"a": 1, "b": 2} def pos_only_args_kwargs(x: int, /, *args: int, **kwargs: int): return x, args, kwargs with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): 0" ): chz.entrypoint(pos_only_args_kwargs, argv=[]) assert chz.entrypoint(pos_only_args_kwargs, argv=["0=1", "1=2"]) == (1, (2,), {}) assert chz.entrypoint(pos_only_args_kwargs, argv=["0=1", "1=2", "a=3", "b=4"]) == ( 1, (2,), {"a": 3, "b": 4}, ) def poskw_args_kwargs(x: int, *args: int, **kwargs: int): return x, args, kwargs with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): x" ): chz.entrypoint(poskw_args_kwargs, argv=[]) with pytest.raises( ConstructionException, match=r"both positional-or-keyword and variadic positional parameters", ): chz.entrypoint(poskw_args_kwargs, argv=["x=1", "0=2"]) assert chz.entrypoint(poskw_args_kwargs, argv=["x=1"]) == (1, (), {}) ================================================ FILE: tests/test_blueprint_cast.py ================================================ # Note test_blueprint_meta_factory.py also contains tests relevant to casting from typing import Literal import pytest import chz from chz.blueprint import Castable, InvalidBlueprintArg from chz.tiepin import CastError def test_castable(): @chz.chz class A: a: bool | Literal["both"] assert chz.Blueprint(A).apply({"a": Castable("True")}).make() == A(a=True) assert chz.Blueprint(A).apply({"a": Castable("False")}).make() == A(a=False) assert chz.Blueprint(A).apply({"a": Castable("both")}).make() == A(a="both") with pytest.raises( InvalidBlueprintArg, match=r"Could not cast 'maybe' to (Union\[bool, Literal\['both'\]\]|bool \| typing.Literal\['both'\])", ): assert chz.Blueprint(A).apply({"a": Castable("maybe")}).make() @chz.chz class B: b: str | None assert chz.Blueprint(B).apply({"b": Castable("None")}).make() == B(b=None) assert chz.Blueprint(B).apply({"b": Castable("Nona")}).make() == B(b="Nona") @chz.chz class C: a: tuple[int, ...] b: tuple[str, int, str | None] assert chz.Blueprint(C).apply({"a": Castable("1,2,3"), "b": Castable("1,2,None")}).make() == C( a=(1, 2, 3), b=("1", 2, None) ) assert chz.Blueprint(C).apply({"a": Castable("1,2,3"), "b": Castable("1,2,3")}).make() == C( a=(1, 2, 3), b=("1", 2, "3") ) with pytest.raises( InvalidBlueprintArg, match=( "- Failed to interpret it as a value:\n" r"Could not cast '1,2,3,4' to tuple\[str, int, str \| None\] because of length mismatch" "\n\n- Failed to interpret it as a factory for polymorphic construction:\n" r"No subclass of tuple\[str, int, str \| None\] named '1,2,3,4' \(invalid identifier\)" ), ): assert chz.Blueprint(C).apply({"a": (1,), "b": Castable("1,2,3,4")}).make() with pytest.raises( InvalidBlueprintArg, match=r"No subclass of tuple\[int, \.\.\.\] named 'asdf'" ): assert chz.Blueprint(C).apply({"a": Castable("asdf"), "b": ("1", 2, "3")}).make() @chz.chz class D: a: Literal["a", 123] assert chz.Blueprint(D).apply({"a": Castable("a")}).make() == D(a="a") assert chz.Blueprint(D).apply({"a": Castable("123")}).make() == D(a=123) with pytest.raises(InvalidBlueprintArg, match=r"Could not cast 'b' to Literal\['a', 123\]"): assert chz.Blueprint(D).apply({"a": Castable("b")}).make() def test_castable_object_str(): @chz.chz class A: a: object assert chz.Blueprint(A).apply({"a": Castable("1")}).make() == A(a=1) assert chz.Blueprint(A).apply({"a": Castable("1a")}).make() == A(a="1a") def test_meta_factory_cast_unspecified(): @chz.chz class Base: value: int cls: int = 0 @classmethod def __chz_cast__(cls, data: str): # noqa: ANN206 return Base(value=int(data)) @chz.chz class DefaultChild(Base): value: int cls: int = 2 @classmethod def __chz_cast__(cls, data: str): # noqa: ANN206 return DefaultChild(value=int(data)) @chz.chz class X: a: Base = chz.field( meta_factory=chz.factories.subclass(base_cls=Base, default_cls=DefaultChild) ) assert chz.Blueprint(X).apply({"a": Castable("3")}).make() == X(a=DefaultChild(value=3, cls=2)) def test_chz_cast_dunder(): from dataclasses import dataclass @dataclass class Duration: seconds: int @classmethod def __chz_cast__(cls, value: str): # noqa: ANN206 try: return Duration(int(value.strip("hms")) * {"h": 3600, "m": 60, "s": 1}[value[-1]]) except Exception as e: raise CastError(f"Could not cast {value!r} to {cls.__name__}") from e @chz.chz class X: t: Duration assert chz.Blueprint(X).apply({"t": Castable("1h")}).make() == X(t=Duration(3600)) assert chz.Blueprint(X).apply({"t": Castable("1m")}).make() == X(t=Duration(60)) with pytest.raises( InvalidBlueprintArg, match="Failed to interpret it as a value:\nCould not cast 'yikes' to Duration", ): chz.Blueprint(X).apply({"t": Castable("yikes")}).make() @chz.chz class Y: t1: str | Duration t2: Duration | str t3: str | Duration t4: Duration | str assert chz.Blueprint(Y).apply( { "t1": Castable("1h"), "t2": Castable("1m"), "t3": Castable("yikes"), "t4": Castable("yikes"), } ).make() == Y(t1=Duration(3600), t2=Duration(60), t3="yikes", t4="yikes") def test_cast_per_field(): @chz.chz class X: a: str = chz.field(blueprint_cast=lambda x: x[0]) b: str = chz.field(blueprint_cast=lambda x: x[1]) assert chz.Blueprint(X).apply({"a": Castable("abc"), "b": Castable("abc")}).make() == X( a="a", b="b" ) ================================================ FILE: tests/test_blueprint_errors.py ================================================ import pytest import chz from chz.blueprint import ConstructionException, ExtraneousBlueprintArg def test_target_bad_signature(): def bad(a: int, b: str): ... bad.__text_signature__ = "not a signature" with pytest.raises(ConstructionException, match=r"Failed to get signature for bad"): chz.entrypoint(bad, argv=[]) def test_target_just_plain_old_bad(): with pytest.raises(ValueError, match="42 is not callable"): chz.entrypoint(42, argv=[]) def test_target_no_params_extraneous(): def good(): ... with pytest.raises( ExtraneousBlueprintArg, match=r"Extraneous argument 'a' to Blueprint for .*\.good" ): chz.entrypoint(good, argv=["a=42"]) def test_nested_target_default_values(): @chz.chz class Main: a: int def good(m: Main, b="asdf", c=1): return m.a assert chz.nested_entrypoint(good, argv=["a=42"]) == 42 def bad(m: Main, b, c=1): ... with pytest.raises( ValueError, match=r"Nested entrypoints must take at most one argument without a default", ): chz.nested_entrypoint(bad, argv=["a=42"]) def test_blueprint_extraneous_valid_parent(): @chz.chz class C: field: int @chz.chz class B: c: C @chz.chz class A: b: B with pytest.raises( ExtraneousBlueprintArg, match=r"""Extraneous argument 'b.cc.nope' to Blueprint for test_blueprint_errors:test_blueprint_extraneous_valid_parent..A \(from command line\) Param 'b' is closest valid ancestor Param 'b' is set to test_blueprint_errors:test_blueprint_extraneous_valid_parent..B \(blueprint_unspecified\) Subparam 'cc' does not exist on it Append --help to your command to see valid arguments""", ): chz.entrypoint(A, argv=["b.cc.nope=0"]) ================================================ FILE: tests/test_blueprint_meta_factory.py ================================================ import typing from typing import Optional import pytest import chz from chz.blueprint import Castable, InvalidBlueprintArg, MissingBlueprintArg class A: pass class B(A): pass class C(B): pass def test_meta_factory_subclass(): @chz.chz class Main: obj: A = chz.field(meta_factory=chz.factories.subclass(base_cls=A, default_cls=A)) argv = ["obj=A"] ret = chz.entrypoint(Main, argv=argv) assert type(ret.obj) is A # Test basic subclass functionality, ie B -> A argv = ["obj=B"] ret = chz.entrypoint(Main, argv=argv) assert type(ret.obj) is B # Test multiple inheritance, ie C -> B -> A argv = ["obj=C"] ret = chz.entrypoint(Main, argv=argv) assert type(ret.obj) is C def test_meta_factory_subclass_limited(): # Test that a subclass of B is not accepted @chz.chz class Main: obj: A = chz.field(meta_factory=chz.factories.subclass(base_cls=B, default_cls=A)) argv = ["obj=A"] with pytest.raises( InvalidBlueprintArg, match=( "Failed to interpret it as a factory for polymorphic construction:\n" "No subclass of test_blueprint_meta_factory.B named 'A'" ), ): chz.entrypoint(Main, argv=argv) def test_meta_factory_default_subclass(): @chz.chz class Parent: required0: int @chz.chz class Child2(Parent): required2: int @chz.chz class Main: field: Parent = chz.field(meta_factory=chz.factories.subclass(Parent, default_cls=Child2)) assert (chz.Blueprint(Main).apply({"field.required0": 0, "field.required2": 2}).make()) == Main( field=Child2(required0=0, required2=2) ) assert ( chz.Blueprint(Main) .apply({"field.required0": 0, "field": Child2, "field.required2": 2}) .make() ) == Main(field=Child2(required0=0, required2=2)) assert ( chz.Blueprint(Main).apply({"field.required0": 0}).get_help() == """\ WARNING: Missing required arguments for parameter(s): field.required2 Entry point: test_blueprint_meta_factory:test_meta_factory_default_subclass..Main Arguments: field test_blueprint_meta_factory:test_meta_factory_default_subclass..Parent test_blueprint_meta_factory:test_meta_factory_default_subclass..Child2 (meta_factory) field.required0 int 0 field.required2 int - """ ) assert ( chz.Blueprint(Main).apply({"field.required0": 0, "field": Child2}).get_help() == """\ WARNING: Missing required arguments for parameter(s): field.required2 Entry point: test_blueprint_meta_factory:test_meta_factory_default_subclass..Main Arguments: field test_blueprint_meta_factory:test_meta_factory_default_subclass..Parent test_blueprint_meta_factory:test_meta_factory_default_subclass..Child2 field.required0 int 0 field.required2 int - """ ) @chz.chz class Main2: field: Parent | None = chz.field( meta_factory=chz.factories.subclass(Parent, default_cls=Child2), default=None ) assert ( chz.Blueprint(Main2).get_help() == """\ Entry point: test_blueprint_meta_factory:test_meta_factory_default_subclass..Main2 Arguments: field test_blueprint_meta_factory.test_meta_factory_default_subclass..Parent | None None (default) field.required0 int - field.required2 int - """ ) assert ( chz.Blueprint(Main2).apply({"field.required0": 0}).get_help() == """\ WARNING: Missing required arguments for parameter(s): field.required2 Entry point: test_blueprint_meta_factory:test_meta_factory_default_subclass..Main2 Arguments: field test_blueprint_meta_factory.test_meta_factory_default_subclass..Parent | None test_blueprint_meta_factory:test_meta_factory_default_subclass..Child2 (meta_factory) field.required0 int 0 field.required2 int - """ ) def test_meta_factory_blueprint_unspecified(): @chz.chz class Parent: required0: int @chz.chz class Child2(Parent): required2: int @chz.chz class Main: field: Parent = chz.field(blueprint_unspecified=Child2) assert (chz.Blueprint(Main).apply({"field.required0": 0, "field.required2": 2}).make()) == Main( field=Child2(required0=0, required2=2) ) assert ( chz.Blueprint(Main) .apply({"field.required0": 0, "field": Child2, "field.required2": 2}) .make() ) == Main(field=Child2(required0=0, required2=2)) assert ( chz.Blueprint(Main).get_help() == """\ WARNING: Missing required arguments for parameter(s): field.required0, field.required2 Entry point: test_blueprint_meta_factory:test_meta_factory_blueprint_unspecified..Main Arguments: field test_blueprint_meta_factory:test_meta_factory_blueprint_unspecified..Parent test_blueprint_meta_factory:test_meta_factory_blueprint_unspecified..Child2 (blueprint_unspecified) field.required0 int - field.required2 int - """ ) def test_meta_factory_blueprint_unspecified_more(): @chz.chz class Sub: x: int = 1 @chz.chz class Config: sub: Sub sub2: Sub @chz.chz class MySub(Sub): x: int = 2 @chz.chz class MySub2(Sub): x: int = 3 @chz.chz class MyConfig(Config): sub: Sub = chz.field(blueprint_unspecified=MySub) sub2: Sub = chz.field(blueprint_unspecified=MySub2) # Check defaults get overwritten properly config = chz.Blueprint(MyConfig).make() assert config == MyConfig(sub=MySub(x=2), sub2=MySub2(x=3)) # Check you can still set the nested values properly config = chz.Blueprint(MyConfig).apply_from_argv(["sub.x=4", "sub2.x=5"]).make() assert config == MyConfig(sub=MySub(x=4), sub2=MySub2(x=5)) # Ensure that's okay to override a field with the base Sub class config = chz.Blueprint(MyConfig).apply_from_argv(["sub=Sub"]).make() assert config == MyConfig(sub=Sub(x=1), sub2=MySub2(x=3)) # Lastly, check that it's okay to override the fields with custom Sub classes config = chz.Blueprint(MyConfig).apply_from_argv(["sub=MySub2", "sub2=MySub"]).make() assert config == MyConfig(sub=MySub2(x=3), sub2=MySub(x=2)) def test_meta_factory_blueprint_unspecified_all_default_help(): @chz.chz class X: value: int = 0 @chz.chz class Main: field: object = chz.field(blueprint_unspecified=X) assert ( chz.Blueprint(Main).get_help() == """\ Entry point: test_blueprint_meta_factory:test_meta_factory_blueprint_unspecified_all_default_help..Main Arguments: field object test_blueprint_meta_factory:test_meta_factory_blueprint_unspecified_all_default_help..X (meta_factory) field.value int 0 (default) """ ) def test_meta_factory_blueprint_unspecified_optional(): @chz.chz class X: value: int = 0 @chz.chz class Main: value: int = 0 field: X | None = chz.field(blueprint_unspecified=X, default=None) assert chz.Blueprint(Main).apply({"...value": 1}).make() == Main(value=1, field=X(value=1)) @chz.chz class Main: value: int = 0 field: X | None = chz.field(blueprint_unspecified=type(None), default=None) assert chz.Blueprint(Main).apply({"...value": 1}).make() == Main(value=1, field=None) def test_meta_factory_subclass_generic(): T = typing.TypeVar("T") @chz.chz class Base(typing.Generic[T]): pass @chz.chz class Sub(Base[int]): value: int = 0 @chz.chz class Main1: obj: Base argv = ["obj=Base"] ret = chz.entrypoint(Main1, argv=argv) assert type(ret.obj) is Base argv = ["obj=Sub"] ret = chz.entrypoint(Main1, argv=argv) assert type(ret.obj) is Sub @chz.chz class Main2: obj: Base[int] argv = ["obj=Base"] ret = chz.entrypoint(Main2, argv=argv) assert type(ret.obj) is Base argv = ["obj=Sub"] ret = chz.entrypoint(Main2, argv=argv) assert type(ret.obj) is Sub argv = ["obj=Sub", "obj.value=3"] ret = chz.entrypoint(Main2, argv=argv) assert type(ret.obj) is Sub assert ret.obj.value == 3 def test_meta_factory_optional(): @chz.chz class Child2: x: int @chz.chz class Parent: child: Optional[Child2] # noqa: UP045 @chz.chz class Parent2: child: Child2 | None assert chz.Blueprint(Parent).apply({"child.x": 3}).make() == Parent(child=Child2(x=3)) assert chz.Blueprint(Parent2).apply({"child.x": 3}).make() == Parent2(child=Child2(x=3)) def test_meta_factory_union(): from dataclasses import dataclass @dataclass class O1: ... @dataclass class O2: ... @chz.chz class Main: z: O1 | O2 assert chz.Blueprint(Main).apply({"z": Castable("O1")}).make() == Main(z=O1()) assert chz.Blueprint(Main).apply({"z": Castable("O2")}).make() == Main(z=O2()) def test_meta_factory_non_chz(): class Actor: def __init__(self): self.label = "actor" class WakeActor(Actor): def __init__(self): self.label = "wake_actor" @chz.chz class Args: actor: Actor = chz.field(meta_factory=chz.factories.subclass(Actor)) assert chz.Blueprint(Args).apply({"actor": Castable("Actor")}).make().actor.label == "actor" assert ( chz.Blueprint(Args).apply({"actor": Castable("WakeActor")}).make().actor.label == "wake_actor" ) with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): actor" ): chz.Blueprint(Args).make() def test_meta_factory_function_lambda(): import calendar @chz.chz class Main: a: A = chz.field(meta_factory=chz.factories.function(), default=object()) cal: calendar.Calendar = chz.field( meta_factory=chz.factories.function(default_module="calendar"), default=object() ) argv = ["a=lambda: A()", "cal=lambda d: Calendar(int(d))", "cal.d=3"] ret = chz.entrypoint(Main, argv=argv) assert type(ret.a) is A assert type(ret.cal) is calendar.Calendar assert ret.cal.firstweekday == 3 @chz.chz class Main: a: A = chz.field(default=object()) cal: calendar.Calendar = chz.field( meta_factory=chz.factories.standard(default_module="calendar"), default=object() ) argv = ["a=lambda: A()", "cal=lambda d: Calendar(int(d))", "cal.d=3"] ret = chz.entrypoint(Main, argv=argv) assert type(ret.a) is A assert type(ret.cal) is calendar.Calendar assert ret.cal.firstweekday == 3 def test_meta_factory_type_subclass(): @chz.chz class Main: a: type[A] assert chz.entrypoint(Main, argv=["a=A"]).a is A assert chz.entrypoint(Main, argv=["a=B"]).a is B assert chz.entrypoint(Main, argv=["a=C"]).a is C with pytest.raises( InvalidBlueprintArg, match="Could not interpret argument 'int' provided for param 'a'" ): chz.entrypoint(Main, argv=["a=int"]) def test_meta_factory_function_union(): @chz.chz class A: field: str = "a" @chz.chz class B(A): ... def make_tuple(s0: B | None = None, s1: B | None = None): if s0 is None: s0 = B(field="s0default") if s1 is None: s1 = B(field="s1default") return (s0, s1) @chz.chz class Main: specs: tuple[A, ...] = chz.field(blueprint_unspecified=make_tuple) assert chz.entrypoint(Main, argv=["specs.s1=B"]).specs == ( B(field="s0default"), B(field="a"), ) def test_meta_factory_none(): @chz.chz class Main: a: A = chz.field(meta_factory=None) with pytest.raises( InvalidBlueprintArg, match="Could not cast 'A' to test_blueprint_meta_factory:A" ): chz.entrypoint(Main, argv=["a=A"]) ================================================ FILE: tests/test_blueprint_methods.py ================================================ import re import textwrap from unittest.mock import patch import pytest import chz from chz.blueprint import EntrypointHelpException, ExtraneousBlueprintArg @chz.chz class Run1: name: str def launch(self, cluster: str): """Launch a job on a cluster.""" return ("launch", self, cluster) def history(self): return ("history", self) @chz.chz class RunDefault: def launch(self, cluster: str): return ("launch", self, cluster) def test_methods_entrypoint(): assert chz.methods_entrypoint(Run1, argv=["launch", "self.name=job", "cluster=big"]) == ( "launch", Run1(name="job"), "big", ) assert chz.methods_entrypoint(Run1, argv=["history", "self.name=job"]) == ( "history", Run1(name="job"), ) assert chz.methods_entrypoint(RunDefault, argv=["launch", "cluster=big"]) == ( "launch", RunDefault(), "big", ) with pytest.raises(ExtraneousBlueprintArg, match="Extraneous argument 'self.cluster'"): chz.methods_entrypoint(Run1, argv=["launch", "self.name=job", "self.cluster=big"]) def test_methods_entrypoint_help(): with pytest.raises( EntrypointHelpException, match="""\ Entry point: methods of test_blueprint_methods:Run1 Available methods: history launch Launch a job on a cluster. """, ): chz.methods_entrypoint(Run1, argv=[]) orig_get_help = chz.blueprint._blueprint.Blueprint.get_help with ( # Disable color, which messes with the pytest.raises(..., match=...) patch( "chz.blueprint._blueprint.Blueprint.get_help", lambda self, color: orig_get_help(self, color=False), ), pytest.raises( EntrypointHelpException, match=re.escape( textwrap.dedent( """\ WARNING: Missing required arguments for parameter(s): self.name, cluster Entry point: test_blueprint_methods:Run1.launch Launch a job on a cluster. Arguments: self test_blueprint_methods:Run1 - self.name str - cluster str""" ), ), ), ): chz.methods_entrypoint(Run1, argv=["launch", "--help"]) @chz.chz class RunAltSelfParam: name: str def launch(run, cluster: str): return ("launch", run, cluster) def test_methods_entrypoint_self(): assert chz.methods_entrypoint( RunAltSelfParam, argv=["launch", "run.name=job", "cluster=big"] ) == ("launch", RunAltSelfParam(name="job"), "big") with pytest.raises(ExtraneousBlueprintArg, match="Extraneous argument 'self.name'"): chz.methods_entrypoint(RunAltSelfParam, argv=["launch", "self.name=job", "cluster=big"]) with pytest.raises(ExtraneousBlueprintArg, match="Extraneous argument 'run.name'"): chz.methods_entrypoint(Run1, argv=["launch", "run.name=job", "cluster=big"]) @chz.chz class RunDefaultChild(RunDefault): ... def test_methods_entrypoint_polymorphic(): assert chz.methods_entrypoint( RunDefault, argv=["launch", "self=RunDefaultChild", "cluster=big"] ) == ("launch", RunDefaultChild(), "big") def test_methods_entrypoint_transform(): def transform(blueprint, target, method): if method == "launch": return blueprint.apply({"name": "job"}, subpath="self") return blueprint assert chz.methods_entrypoint(Run1, argv=["launch", "cluster=big"], transform=transform) == ( "launch", Run1(name="job"), "big", ) ================================================ FILE: tests/test_blueprint_reference.py ================================================ import pytest import chz from chz.blueprint import InvalidBlueprintArg, MissingBlueprintArg, Reference def test_blueprint_reference(): @chz.chz class Main: a: str b: str obj = chz.Blueprint(Main).apply({"a": "foo", "b": Reference("a")}).make() assert obj == Main(a="foo", b="foo") obj = chz.Blueprint(Main).apply_from_argv(["a=foo", "b@=a"]).make() assert obj == Main(a="foo", b="foo") assert ( chz.Blueprint(Main).apply({"a": "foo", "b": Reference("a")}).get_help() == """\ Entry point: test_blueprint_reference:test_blueprint_reference..Main Arguments: a str 'foo' b str @=a """ ) with pytest.raises(InvalidBlueprintArg, match=r"Invalid reference target 'c' for param b"): chz.Blueprint(Main).apply({"a": "foo", "b": Reference("c")}).make() def test_blueprint_reference_multiple_invalid(): @chz.chz class Main: a: int b: int c: int with pytest.raises( InvalidBlueprintArg, match="""\ Invalid reference target 'x' for params a, b Invalid reference target 'bb' for param c Did you mean 'b'?""", ): chz.Blueprint(Main).apply_from_argv(["a@=x", "b@=x", "c@=bb"]).make() def test_blueprint_reference_nested(): @chz.chz class C: c: int @chz.chz class B: b: int c: C @chz.chz class A: a: int b: B obj = chz.Blueprint(A).apply_from_argv(["a@=b.b", "b.c.c@=a", "b.b=5"]).make() assert obj == A(a=5, b=B(b=5, c=C(c=5))) def test_blueprint_reference_wildcard(): @chz.chz class B: name: str @chz.chz class A: name: str b: B @chz.chz class Main: name: str a: A obj = chz.Blueprint(Main).apply_from_argv(["...name@=name", "name=foo"]).make() assert obj == Main(name="foo", a=A(name="foo", b=B(name="foo"))) obj = chz.Blueprint(Main).apply_from_argv(["...name@=a.b.name", "a.b.name=foo"]).make() assert obj == Main(name="foo", a=A(name="foo", b=B(name="foo"))) def test_blueprint_reference_wildcard_default(): @chz.chz class A: name: str @chz.chz class Main: name: str = "foo" a: A obj = chz.Blueprint(Main).apply_from_argv(["...name@=name"]).make() assert obj == Main(name="foo", a=A(name="foo")) def test_blueprint_reference_wildcard_default_no_default(): @chz.chz class Defaults: a: int @chz.chz class A: defaults: Defaults a: int with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): defaults.a" ): chz.Blueprint(A).apply_from_argv(["...a@=defaults.a"]).make() def test_blueprint_reference_wildcard_default_constructable(): @chz.chz class Object: a: int = 1 @chz.chz class Defaults: obj: Object a: int = 2 @chz.chz class Main: defaults: Defaults obj: Object a: int = 3 assert chz.Blueprint(Main).apply_from_argv( ["...obj@=defaults.obj", "defaults.obj.a=4"] ).make() == Main( defaults=Defaults(obj=Object(a=4), a=2), obj=Object(a=4), a=3, ) assert chz.Blueprint(Main).apply_from_argv(["...obj@=defaults.obj", "...a=4"]).make() == Main( defaults=Defaults(obj=Object(a=4), a=4), obj=Object(a=4), a=4, ) def test_blueprint_reference_cycle(): @chz.chz class Main: a: int b: int with pytest.raises(RecursionError, match="Detected cyclic reference: a -> b -> a"): chz.Blueprint(Main).apply_from_argv(["a@=b", "b@=a"]).make() @chz.chz class Main: a: int with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): a" ): chz.Blueprint(Main).apply_from_argv(["a@=a"]).make() ================================================ FILE: tests/test_blueprint_root_polymorphism.py ================================================ import re import chz def test_root_polymorphism(): @chz.chz class X: a: int b: str = "str" @chz.chz class Y(X): c: float = 1.0 def foo(a: int, b: str = "default", c: float = 3.0): return Y(a=a, b=b, c=c) assert chz.Blueprint(X).apply({"a": 0}).make() == X(a=0, b="str") assert chz.Blueprint(X | None).apply({"a": 0}).make() == X(a=0, b="str") assert chz.Blueprint(X).apply({"": Y, "a": 0}).make() == Y(a=0, b="str", c=1.0) assert chz.Blueprint(X).apply({"": foo, "a": 2}).make() == Y(a=2, b="default", c=3.0) assert chz.Blueprint(object).apply({"": X, "a": 1}).make() == X(a=1, b="str") assert chz.Blueprint(object).apply({"": X, "...a": 1}).make() == X(a=1, b="str") assert chz.Blueprint(object).apply({"": foo, "a": 1}).make() == Y(a=1, b="default", c=3.0) # TODO: make help better if root is object or Any and no arguments are provided assert re.fullmatch( r"""Entry point: object The base class of the class hierarchy.* Arguments: object test_blueprint_root_polymorphism:test_root_polymorphism..X The base class of the class hierarchy.* a int 1 b str 'str' \(default\) """, chz.Blueprint(object).apply({"": X, "a": 1}).get_help(), flags=re.DOTALL, ) assert re.fullmatch( r"""Entry point: object The base class of the class hierarchy.* Arguments: object test_blueprint_root_polymorphism:test_root_polymorphism..foo The base class of the class hierarchy.* a int 1 b str 'default' \(default\) c float 3.0 \(default\) """, chz.Blueprint(object).apply({"": foo, "a": 1}).get_help(), flags=re.DOTALL, ) ================================================ FILE: tests/test_blueprint_unit.py ================================================ import pytest from chz.blueprint import Blueprint, Castable, beta_argv_arg_to_string, beta_blueprint_to_argv from chz.blueprint._argmap import ArgumentMap, Layer, join_arg_path from chz.blueprint._wildcard import _wildcard_key_match, wildcard_key_to_regex def test_beta_argv_arg_to_string(): k = "a" v = { "b": { "c": { "d": 1, }, "e": 2, }, "f": 3, } assert beta_argv_arg_to_string(k, v) == ["a.b.c.d=1", "a.b.e=2", "a.f=3"] assert beta_argv_arg_to_string("nums", [1, 2, 3]) == ["nums=1,2,3"] assert beta_argv_arg_to_string("flags", [True, None, False]) == ["flags=True,None,False"] assert beta_argv_arg_to_string("k", ["1,2,3"]) == ["k.0=1,2,3"] class C: d: int class B: c: C e: int class A: b: B f: int blueprint = Blueprint(A) blueprint.apply({"a.b": {"c": {"d": 1}, "e": 2}, "a.f": 3}) assert beta_blueprint_to_argv(blueprint) == ["a.b.c.d=1", "a.b.e=2", "a.f=3"] def test_wildcard_key_to_regex(): assert wildcard_key_to_regex("a.b.c").pattern == r"a\.b\.c" assert wildcard_key_to_regex("a...c").pattern == r"a\.(.*\.)?c" assert wildcard_key_to_regex("...a").pattern == r"(.*\.)?a" assert wildcard_key_to_regex("...a...c").pattern == r"(.*\.)?a\.(.*\.)?c" with pytest.raises(ValueError, match="Wildcard not allowed at end of key"): wildcard_key_to_regex("...") with pytest.raises(ValueError, match="Wildcard not allowed at end of key"): wildcard_key_to_regex("a...") def test_wildcard_key_match(): assert _wildcard_key_match("a.b.c", "a.b.c") assert _wildcard_key_match("a...c", "a.b.c") assert _wildcard_key_match("...c", "a.b.c") assert _wildcard_key_match("...a...c", "a.b.c") assert _wildcard_key_match("...a.b.c", "a.b.c") assert _wildcard_key_match("...a.b...c", "a.b.c") assert _wildcard_key_match("...a...b.b...a", "a.b.x.b.b.a") assert _wildcard_key_match("...x.y.z", "x.y.z") assert not _wildcard_key_match("a.b.d", "a.b.c") assert not _wildcard_key_match("...a", "a.b.c") assert not _wildcard_key_match("a...b", "a.b.c") assert not _wildcard_key_match("...a", "xxa") assert not _wildcard_key_match("...a...b.b...a", "a.b.x.b.c.a") assert not _wildcard_key_match("...a...y", "a...a...z") assert not _wildcard_key_match("...a...b...a", "a...b...a...b...b...a.z") assert wildcard_key_to_regex("a.b.c").fullmatch("a.b.c") assert wildcard_key_to_regex("a...c").fullmatch("a.b.c") assert wildcard_key_to_regex("...c").fullmatch("a.b.c") assert wildcard_key_to_regex("...a...c").fullmatch("a.b.c") assert wildcard_key_to_regex("...a.b.c").fullmatch("a.b.c") assert wildcard_key_to_regex("...a.b...c").fullmatch("a.b.c") assert wildcard_key_to_regex("...a...b.b...a").fullmatch("a.b.x.b.b.a") assert wildcard_key_to_regex("...x.y.z").fullmatch("x.y.z") assert not wildcard_key_to_regex("a.b.d").fullmatch("a.b.c") assert not wildcard_key_to_regex("...a").fullmatch("a.b.c") assert not wildcard_key_to_regex("a...b").fullmatch("a.b.c") assert not wildcard_key_to_regex("...a").fullmatch("xxa") assert not wildcard_key_to_regex("...a...b.b...a").fullmatch("a.b.x.b.c.a") assert not wildcard_key_to_regex("...a...y").fullmatch("a...a...z") assert not wildcard_key_to_regex("...a...b...a").fullmatch("a...b...a...b...b...a.z") def test_join_arg_path(): assert join_arg_path("parent", "child") == "parent.child" assert join_arg_path("grand.parent", "child") == "grand.parent.child" assert join_arg_path("parent", "...child") == "parent...child" assert join_arg_path("grand...parent", "child") == "grand...parent.child" assert join_arg_path("", "child") == "child" assert join_arg_path("", "...child") == "...child" def test_arg_map(): layer = Layer({"a.b.c": 0, "a.b.c.one": 1, "a.b.c.two": 2}, None) arg_map = ArgumentMap([layer]) arg_map.consolidate() assert arg_map.get_kv("a.b.c.one").key == "a.b.c.one" assert arg_map.get_kv("a.b.c.two").key == "a.b.c.two" assert arg_map.get_kv("a.b") == None assert arg_map.get_kv("a.b.c.zero") == None assert arg_map.get_kv("a.b.d") == None assert arg_map.subpaths("a.b.c") == ["", "one", "two"] assert arg_map.subpaths("a.b.c", strict=True) == ["one", "two"] assert arg_map.subpaths("a.b.c.one") == [""] assert arg_map.subpaths("a.b.c.one", strict=True) == [] layer = Layer({"prefix_suffix": 1}, None) arg_map = ArgumentMap([layer]) arg_map.consolidate() assert arg_map.subpaths("prefix") == [] assert arg_map.subpaths("prefix", strict=True) == [] layer = Layer({"": 1, "a": 2}, None) arg_map = ArgumentMap([layer]) arg_map.consolidate() assert arg_map.subpaths("") == ["", "a"] assert arg_map.subpaths("", strict=True) == ["a"] def test_arg_map_wildcard(): layer_wildcard = Layer({"a...c.one": 1, "a...c.two": 2}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.get_kv("a.b.c.one").key == "a...c.one" assert arg_map.get_kv("a.b.b.b.b.c.one").key == "a...c.one" assert arg_map.subpaths("a.b.c") == ["one", "two"] assert arg_map.subpaths("a.b.c.one") == [""] assert arg_map.subpaths("a.b.c.one", strict=True) == [] layer_wildcard = Layer({"...one": 1}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.subpaths("a.b.c.one") == [""] assert arg_map.subpaths("a.b.c.two") == [] assert arg_map.subpaths("a.b.c.one", strict=True) == [] layer_wildcard = Layer({"a...one...b...one": 1}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.subpaths("a.one.x.one") == ["...b...one"] layer_wildcard = Layer({"...prefix_suffix": 1}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.subpaths("something.prefix") == [] assert arg_map.subpaths("something.prefix", strict=True) == [] layer_wildcard = Layer({"...a.key.key": 1, "...a.key.key...x": 2}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.subpaths("a.key") == ["key...x", "key"] assert arg_map.subpaths("") == ["...a.key.key...x", "...a.key.key"] layer_wildcard = Layer({"...c": 1, "...b.c": 2}, None) arg_map = ArgumentMap([layer_wildcard]) arg_map.consolidate() assert arg_map.get_kv("a.b.c").key == "...b.c" assert arg_map.get_kv("a.c.c").key == "...c" wildcard_layer = Layer({"...bar.delta": "wildcard"}, "wild") qualified_layer = Layer({"foo.bar.alpha": "alpha", "foo.bar.delta": "qualified"}, "qual") arg_map = ArgumentMap([wildcard_layer, qualified_layer]) arg_map.consolidate() assert arg_map.get_kv("another.bar.delta").key == "...bar.delta" assert arg_map.get_kv("foo.bar.delta").key == "foo.bar.delta" arg_map = ArgumentMap([qualified_layer, wildcard_layer]) arg_map.consolidate() assert arg_map.get_kv("another.bar.delta").key == "...bar.delta" assert arg_map.get_kv("foo.bar.delta").key == "...bar.delta" def test_layer(): l = Layer({"...a": 0, "a": 1}, None) assert l.get_kv("a") == ("a", 1, None) l = Layer({"a": 1, "...a": 0}, None) assert l.get_kv("a") == ("a", 1, None) l = Layer({"...z": 1, "...x...y...z": 2, "...y...z": 3}, None) assert l.get_kv("x.y.z") == ("...x...y...z", 2, None) def test_collapse_layers(): class Dummy: pass from chz.blueprint._argv import _collapse_layers b = Blueprint(Dummy) b.apply({"a.b.c.d": 3, "a...d": 4, "...d": 5, "a...e": 7, "...e": 6, "a.b.c.d.f": 9}) b.apply({"...f": 10}) assert set(_collapse_layers(b)) == { ("a.b.c.d", 3), ("a...d", 4), ("...d", 5), ("a...e", 7), ("...e", 6), ("...f", 10), } b = Blueprint(Dummy) b.apply({"a.b.c.d": 3}) b.apply({"a...d": 4}) b.apply({"...d": 5}) b.apply({"a...e": 7}) b.apply({"...e": 6}) b.apply({"a.b.c.d.f": 9}) b.apply({"...f": 10}) assert set(_collapse_layers(b)) == {("...d", 5), ("...e", 6), ("...f", 10)} b = Blueprint(Dummy) b.apply({"...d": 5}) b.apply({"a.b.c.d": 3}) b.apply({"a...d": 4}) b.apply({"...f": 10}) b.apply({"a.b.c.d.f": 9}) assert set(_collapse_layers(b)) == {("...d", 5), ("a...d", 4), ("...f", 10), ("a.b.c.d.f", 9)} b = Blueprint(Dummy) b.apply({"...d": 5, "a.b.c.d": 3, "a...d": 4, "...f": 10, "a.b.c.d.f": 9}) assert set(_collapse_layers(b)) == { ("...d", 5), ("a.b.c.d", 3), ("a...d", 4), ("...f", 10), ("a.b.c.d.f", 9), } def test_collapse_blueprint_to_argv(): class Dummy: pass b = Blueprint(Dummy) b.apply({"...d": 5}) b.apply({"a.b.c.d": 3}) b.apply({"a...d": 4}) b.apply({"...f": 10}) b.apply({"a.b.c.d.f": 9}) b.apply({"a.b.c.d.f.e": None}) assert beta_blueprint_to_argv(b) == [ "...d=5", "a...d=4", "...f=10", "a.b.c.d.f=9", "a.b.c.d.f.e=None", ] def test_apply_from_argv(): class Dummy: pass from chz.blueprint._argv import _collapse_layers b = Blueprint(Dummy) b.apply_from_argv(["...d=5"]) assert beta_blueprint_to_argv(b) == ["...d=5"] assert _collapse_layers(b)[0][1].value == "5" def test_apply_with_types(): class Dummy: pass b = Blueprint(Dummy) b.apply({"a": 1, "b": beta_blueprint_to_argv, "c": Blueprint}) assert beta_blueprint_to_argv(b) == [ "a=1", "b=chz.blueprint._argv:beta_blueprint_to_argv", "c=chz.blueprint._blueprint:Blueprint", ] def test_castable_eq(): assert Castable("None") == Castable("None") assert Castable("None") == None assert Castable("1") == Castable("1") assert Castable("1") == 1 assert Castable("1") != 2 assert Castable("x") != 2 ================================================ FILE: tests/test_blueprint_variadic.py ================================================ import typing import pytest import chz from chz.blueprint import ( Castable, ConstructionException, ExtraneousBlueprintArg, InvalidBlueprintArg, MissingBlueprintArg, ) def test_variadic_list(): @chz.chz class X: a: int @chz.chz class MainList: xs: list[X] assert chz.Blueprint(MainList).apply({"xs.0.a": 1}).make() == MainList(xs=[X(a=1)]) assert chz.Blueprint(MainList).apply( {"xs.0.a": 1, "xs.1.a": 2, "xs.2.a": 3} ).make() == MainList(xs=[X(a=1), X(a=2), X(a=3)]) with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): xs" ): chz.Blueprint(MainList).make() with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): xs.1.a" ): chz.Blueprint(MainList).apply({"xs.0.a": 1, "xs.2.a": 3}).make() @chz.chz class MainListDefault: xs: list[X] = chz.field(default_factory=list) assert chz.Blueprint(MainListDefault).make() == MainListDefault(xs=[]) def test_variadic_wildcard(): @chz.chz class X: a: int b: int @chz.chz class MainList: xs: list[X] with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument '\.\.\.a'"): chz.Blueprint(MainList).apply({"...a": 1}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument '\.\.\.0.a'"): chz.Blueprint(MainList).apply({"...0.a": 1}).make() assert chz.Blueprint(MainList).apply({"xs.0.a": 0, "...0.b": 1}).make() == MainList( xs=[X(a=0, b=1)] ) assert chz.Blueprint(MainList).apply({"xs.0.a": 0, "...b": 1}).make() == MainList( xs=[X(a=0, b=1)] ) with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument '\.\.\.0.a'"): chz.Blueprint(MainList).apply({"...0.a": 0}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument '\.\.\.0'"): chz.Blueprint(MainList).apply({"...0": 0}).make() assert chz.Blueprint(MainList).apply({"...xs.0.a": 0, "...xs.0.b": 0}).make() == MainList( xs=[X(a=0, b=0)] ) with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument 'xs\.\.\.a'"): chz.Blueprint(MainList).apply({"xs...a": 5}).make() def test_variadic_tuple(): @chz.chz class X: a: int @chz.chz class MainHomoTuple: xs: tuple[X, ...] assert chz.Blueprint(MainHomoTuple).apply( {"xs.0.a": 1, "xs.1.a": 2, "xs.2.a": 3} ).make() == MainHomoTuple(xs=(X(a=1), X(a=2), X(a=3))) with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): xs" ): chz.Blueprint(MainHomoTuple).make() with pytest.raises( MissingBlueprintArg, match=r"Missing required arguments for parameter\(s\): xs.1.a" ): chz.Blueprint(MainHomoTuple).apply({"xs.0.a": 1, "xs.2.a": 3}).make() @chz.chz class Y: b: str @chz.chz class MainHeteroTuple: xs: tuple[X, Y, X] assert chz.Blueprint(MainHeteroTuple).apply( {"xs.0.a": 1, "xs.1.b": "str", "xs.2.a": 3} ).make() == MainHeteroTuple(xs=(X(a=1), Y(b="str"), X(a=3))) with pytest.raises( TypeError, match=r"Tuple type tuple\[.*X.*Y.*X\] for 'xs' must take 3 items; arguments for index 9 were specified", ): chz.Blueprint(MainHeteroTuple).apply({"xs.0.a": 1, "xs.9.b": "str"}).make() with pytest.raises(ExtraneousBlueprintArg, match=r"Extraneous argument 'xs.1.a'"): chz.Blueprint(MainHeteroTuple).apply({"xs.0.a": 1, "xs.1.a": 2, "xs.2.a": 3}).make() def test_variadic_dict(): @chz.chz class X: a: int @chz.chz class MainDict: xs: dict[str, X] assert chz.Blueprint(MainDict).apply( {"xs.first.a": 1, "xs.second.a": 2, "xs.3.a": 3} ).make() == MainDict(xs={"first": X(a=1), "second": X(a=2), "3": X(a=3)}) def test_variadic_collections_type(): @chz.chz class X: a: int @chz.chz class Main: seq: typing.Sequence[X] map: typing.Mapping[str, X] assert chz.Blueprint(Main).apply( {"seq.0.a": 1, "seq.1.a": 2, "map.first.a": 3, "map.second.a": 4} ).make() == Main(seq=(X(a=1), X(a=2)), map={"first": X(a=3), "second": X(a=4)}) def test_variadic_dict_non_int_or_str_key(): @chz.chz class MainDict: xs: dict[float, str] with pytest.raises(TypeError, match="Variadic dict type must take str or int keys, not float"): chz.Blueprint(MainDict).apply({"xs.0": "a", "xs.1": "2"}).make() assert chz.Blueprint(MainDict).apply({"xs": {1: "2"}}).make() == MainDict(xs={1: "2"}) @chz.chz class MainDict2: xs: dict[int, str] | None = None assert chz.Blueprint(MainDict2).make() == MainDict2(xs=None) def test_variadic_dict_unannotated(): @chz.chz class MainDict: xs: dict assert chz.Blueprint(MainDict).apply({"xs.0": "a", "xs.first": 123}).make() == MainDict( xs={"0": "a", "first": 123} ) def test_variadic_typed_dict(): class Foo(typing.TypedDict): bar: int baz: str @chz.chz(typecheck=True) class Main: foo: Foo assert chz.Blueprint(Main).apply( {"foo.bar": Castable("3"), "foo.baz": Castable("43")} ).make() == Main(foo={"bar": 3, "baz": "43"}) with pytest.raises( ExtraneousBlueprintArg, match=r"Extraneous argument 'foo.typo' to Blueprint for .*Main" ): chz.Blueprint(Main).apply({"foo.bar": 3, "foo.typo": "baz"}).make() with pytest.raises(TypeError, match=r"Expected 'foo.bar' to be int, got str"): chz.Blueprint(Main).apply({"foo.bar": "bar", "foo.baz": "baz"}).make() with pytest.raises( InvalidBlueprintArg, match=( "Could not interpret argument 'bar' provided for param 'foo.bar'...\n\n" "- Failed to interpret it as a value:\n" "Could not cast 'bar' to int" ), ): chz.Blueprint(Main).apply({"foo.bar": Castable("bar"), "foo.baz": "baz"}).make() def test_variadic_typed_dict_not_required(): class Foo(typing.TypedDict): a: int b: typing.Required[int] c: typing.NotRequired[int] class Bar(Foo, total=False): d: int e: typing.Required[int] f: typing.NotRequired[int] class Baz(Bar): g: int h: typing.Required[int] i: typing.NotRequired[int] @chz.chz class Main: foo: Foo bar: Bar baz: Baz assert chz.Blueprint(Main).apply( { "foo.a": Castable("1"), "foo.b": Castable("2"), "foo.c": Castable("3"), "bar.a": Castable("4"), "bar.b": Castable("5"), "bar.c": Castable("6"), "bar.d": Castable("7"), "bar.e": Castable("8"), "bar.f": Castable("9"), "baz.a": Castable("10"), "baz.b": Castable("11"), "baz.c": Castable("12"), "baz.d": Castable("13"), "baz.e": Castable("14"), "baz.f": Castable("15"), "baz.g": Castable("16"), "baz.h": Castable("17"), "baz.i": Castable("18"), } ).make() == Main( foo={"a": 1, "b": 2, "c": 3}, bar={"a": 4, "b": 5, "c": 6, "d": 7, "e": 8, "f": 9}, baz={"a": 10, "b": 11, "c": 12, "d": 13, "e": 14, "f": 15, "g": 16, "h": 17, "i": 18}, ) # Test that c, d, f, i are not required assert chz.Blueprint(Main).apply( { "foo.a": Castable("1"), "foo.b": Castable("2"), "bar.a": Castable("3"), "bar.b": Castable("4"), "bar.e": Castable("5"), "baz.a": Castable("6"), "baz.b": Castable("7"), "baz.e": Castable("8"), "baz.g": Castable("9"), "baz.h": Castable("10"), } ).make() == Main( foo={"a": 1, "b": 2}, bar={"a": 3, "b": 4, "e": 5}, baz={"a": 6, "b": 7, "e": 8, "g": 9, "h": 10}, ) # Test that a, b, e, g, h are required with pytest.raises( MissingBlueprintArg, match=( r"Missing required arguments for parameter\(s\): " r"foo.a, foo.b, bar.a, bar.b, bar.e, baz.a, baz.b, baz.e, baz.g, baz.h" ), ): chz.Blueprint(Main).make() print(chz.Blueprint(Main).get_help()) assert ( chz.Blueprint(Main).get_help() == """WARNING: Missing required arguments for parameter(s): foo.a, foo.b, bar.a, bar.b, bar.e, baz.a, baz.b, baz.e, baz.g, baz.h Entry point: test_blueprint_variadic:test_variadic_typed_dict_not_required..Main Arguments: foo test_blueprint_variadic:test_variadic_typed_dict_not_required..Foo - foo.a int - foo.b int - foo.c int typing.NotRequired (default) bar test_blueprint_variadic:test_variadic_typed_dict_not_required..Bar - bar.a int - bar.b int - bar.c int typing.NotRequired (default) bar.d int typing.NotRequired (default) bar.e int - bar.f int typing.NotRequired (default) baz test_blueprint_variadic:test_variadic_typed_dict_not_required..Baz - baz.a int - baz.b int - baz.c int typing.NotRequired (default) baz.d int typing.NotRequired (default) baz.e int - baz.f int typing.NotRequired (default) baz.g int - baz.h int - baz.i int typing.NotRequired (default) """ ) def test_variadic_default(): @chz.chz class X: a: int = 0 @chz.chz class MainList: xs: list[X] assert chz.Blueprint(MainList).apply({"xs.3.a": 5}).make() == MainList( xs=[X(a=0), X(a=0), X(a=0), X(a=5)] ) def test_variadic_default_wildcard_error(): @chz.chz class X: a: int @chz.chz class MainList: xs: list[X] = chz.field(default_factory=lambda: [X(a=0)]) a: int # same name as X.a, to prevent unused wildcard error with pytest.raises( ConstructionException, match=( r'The parameter "xs" is variadic(.|\n)*' r'However, you also specified the wildcard "\.\.\.a" and you may ' r'have expected it to modify the value of "xs\.\(variadic\)\.a"' ), ): chz.Blueprint(MainList).apply({"...a": 1}).make() @chz.chz class MainListOk: xs: list[X] = chz.field(default_factory=list) a: int # same name as X.a, to prevent unused wildcard error assert chz.Blueprint(MainListOk).apply({"...a": 1}).make() == MainListOk(xs=[], a=1) def test_variadic_default_wildcard_error_using_types_from_default(): @chz.chz class Clause: def value(self) -> bool: raise NotImplementedError @chz.chz class SimpleClause(Clause): val: bool def value(self) -> bool: return self.val @chz.chz class FalseClause(SimpleClause): val: bool = False @chz.chz class AndClause(Clause): clauses: tuple[Clause, ...] = () def value(self) -> bool: return all(clause.value() for clause in self.clauses) @chz.chz class MyClause(AndClause): # Need to check both Clause and FalseClause for wildcard matches clauses: tuple[Clause, ...] = (FalseClause(), FalseClause()) with pytest.raises( ConstructionException, match=( r'The parameter "clauses.1.clauses" is variadic(.|\n)*' r'However, you also specified the wildcard "\.\.\.val" and you may ' r'have expected it to modify the value of "clauses.1.clauses\.\(variadic\)\.val"' ), ): chz.Blueprint(AndClause).apply_from_argv( ["clauses.0=SimpleClause", "clauses.1=MyClause", "...val=True"] ).make() def test_polymorphic_variadic_generic(): @chz.chz class A: a: int @chz.chz class AA(A): ... @chz.chz class MainList: xs: list[A] assert chz.Blueprint(MainList).apply( {"xs": Castable("list[AA]"), "xs.0.a": 1} ).make() == MainList(xs=[AA(a=1)]) @chz.chz class MainTuple: xs: tuple[A, ...] assert chz.Blueprint(MainTuple).apply( {"xs": Castable("tuple[AA, ...]"), "xs.0.a": 1} ).make() == MainTuple(xs=(AA(a=1),)) @chz.chz class MainListList: xs: list[list[A]] # This is gtting a little silly :-) assert chz.Blueprint(MainListList).apply( {"xs": Castable("list[list[AA]]"), "xs.0.0.a": 1} ).make() == MainListList(xs=[[AA(a=1)]]) ================================================ FILE: tests/test_data_model.py ================================================ # ruff: noqa: F811 import dataclasses import functools import json import re import typing import pytest import chz def test_basic(): @chz.chz class X: a: int @chz.chz() class Y: a: int = 3 assert X(a=1).a == 1 assert Y().a == 3 assert chz.is_chz(X) assert chz.is_chz(X(a=1)) assert not chz.is_chz(1) with_future_annotation = """ from __future__ import annotations try: class _test: _: _test except NameError: import sys if sys.version_info < (3, 14): raise AssertionError("from __future__ import annotations should be imported") from None """ without_future_annotation = """ try: class _test: _: _test except NameError: pass else: import sys if sys.version_info < (3, 14): raise AssertionError("from __future__ import annotations should not be imported") """ basic_definition = """ @chz.chz class X: a: int b: int = chz.field() c: str = "yikes" d: str = chz.field(default="yonks") e: str = chz.field(default_factory=lambda: "zeiks") """ def _test_construct_helper(X): with pytest.raises(TypeError, match="missing 2 required keyword-only arguments: 'a' and 'b'"): X() with pytest.raises(TypeError, match="missing 2 required keyword-only arguments: 'a' and 'b'"): X(1, 2) with pytest.raises(TypeError, match="missing 2 required keyword-only arguments: 'a' and 'b'"): X(c="okay") x = X(a=1, b=2) assert x.a == 1 assert x.b == 2 assert x.c == "yikes" assert x.d == "yonks" assert x.e == "zeiks" x = X(a=3, b=4, c="hijinks", d="iflunks", e="jourks") assert x.a == 3 assert x.b == 4 assert x.c == "hijinks" assert x.d == "iflunks" assert x.e == "jourks" def test_construct_without_future_annotations(): prog = without_future_annotation + basic_definition ns = {} exec(compile(prog, "", "exec", dont_inherit=True), {"chz": chz}, ns) X = ns["X"] _test_construct_helper(X) def test_construct_with_future_annotations(): prog = with_future_annotation + basic_definition ns = {} exec(compile(prog, "", "exec", dont_inherit=True), {"chz": chz}, ns) X = ns["X"] _test_construct_helper(X) def test_inheritance(): @chz.chz class X: a: int b: str c: str = chz.field(default="yikes") @chz.chz class Y(X): d: int e: str = chz.field(default="yonks") value = Y(a=1, b="2", d=3) assert value.a == 1 assert value.b == "2" assert value.c == "yikes" assert value.d == 3 assert value.e == "yonks" value = Y(a=1, b="2", d=3, e="4") assert value.a == 1 assert value.b == "2" assert value.c == "yikes" assert value.d == 3 assert value.e == "4" with pytest.raises(TypeError, match="missing 1 required keyword-only argument: 'd'"): Y(a=1, b="2") # type: ignore with pytest.raises( ValueError, match="Cannot override field 'c' with a non-field member; maybe you're missing a type annotation?", ): @chz.chz class Z(X): c = "asdf" class NonChz(X): pass with pytest.raises(TypeError, match="NonChz is not decorated with @chz.chz"): NonChz(a=1, b="2", c="3") def test_immutability(): @chz.chz class X: a: int x = X(a=1) with pytest.raises(chz.data_model.FrozenInstanceError): x.a = 2 # type: ignore with pytest.raises(chz.data_model.FrozenInstanceError): x.b = 1 # type: ignore @chz.chz class Y: a: int @functools.cached_property def b(self): return self.a @property def c(self): return self.a @c.setter def c(self, value): self.a = value # type: ignore @chz.init_property def d(self): return self.a y = Y(a=1) assert y.b == 1 with pytest.raises(chz.data_model.FrozenInstanceError): y.b = 2 # type: ignore assert y.c == 1 with pytest.raises(chz.data_model.FrozenInstanceError): y.c = 2 # type: ignore with pytest.raises(chz.data_model.FrozenInstanceError): del y.c # type: ignore assert y.d == 1 with pytest.raises(chz.data_model.FrozenInstanceError): y.d = 2 # type: ignore # Here's the loophole object.__setattr__(y, "a", 2) assert y.a == 2 assert y.b == 1 assert y.c == 2 def test_no_post_init(): with pytest.raises(ValueError, match="Cannot define __post_init__"): @chz.chz class X: a: int def __post_init__(self): pass def test_no_annotation(): @chz.chz class X: a = 1 X() with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'a'"): X(a=11) def test_asdict(): @chz.chz class Y: x: int y: bool @chz.chz class X: a: int b: str c: Y d: dict[str, bool] e: list[float] f: tuple[int, ...] x = X(a=1, b="2", c=Y(x=3, y=True), d={"a": True}, e=[1.0, 2.0], f=(1, 2)) assert chz.asdict(x) == { "a": 1, "b": "2", "c": {"x": 3, "y": True}, "d": {"a": True}, "e": [1.0, 2.0], "f": (1, 2), } def test_asdict_computed_properties(): @chz.chz class C: x: float @property def doubled(self): return self.x * 2 @functools.cached_property def tripled(self): return self.x * 3 @chz.init_property def quadrupled(self): return self.x * 4 c = C(x=1.0) assert chz.asdict(c) == {"x": 1.0} # Carefully test cached_property behavior. Cached properties which have # not been accessed are not in the __dict__... assert "tripled" not in c.__dict__ # ...accessing them adds them to the __dict__... _ = c.tripled assert "tripled" in c.__dict__ # ...but they still don't appear in the asdict output assert chz.asdict(c) == {"x": 1.0} def test_asdict_include_type(): @chz.chz class X: a: int b: int @chz.chz class Y: a: int b: int x = X(a=1, b=2) y = Y(a=1, b=2) assert chz.asdict(x) == chz.asdict(y) assert chz.asdict(x, include_type=True) != chz.asdict(y, include_type=True) @chz.chz class Outer: @chz.chz class Config: v: int def test_asdict_include_type_nested_class(): cfg = Outer.Config(v=1) assert chz.asdict(cfg, include_type=True)["__chz_type__"] == "test_data_model:Outer.Config" def test_asdict_exclude(): @chz.chz class Inner: x: int b: int @chz.chz class OuterLocal: a: int b: int inner: Inner obj = OuterLocal(a=1, b=2, inner=Inner(x=3, b=4)) assert chz.asdict(obj, exclude={"b"}) == {"a": 1, "inner": {"x": 3, "b": 4}} def test_replace(): @chz.chz class X: a: int b: int x = X(a=1, b=2) y = chz.replace(x, a=3) assert y is not x assert x.a == 1 assert x.b == 2 assert y.a == 3 assert y.b == 2 z = chz.replace(y, a=4, b=5) assert z is not x assert z is not y assert y.a == 3 assert y.b == 2 assert z.a == 4 assert z.b == 5 @chz.chz class Y: a: int X_e: int @functools.cached_property def b(self): return self.a @property def c(self): return self.a @c.setter def c(self, value): self.a = value # type: ignore @chz.init_property def d(self): return self.a y = Y(a=1, e=11) y = chz.replace(y, a=2) y = chz.replace(y, e=12) with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'b'"): chz.replace(y, b=1) with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'c'"): chz.replace(y, c=1) with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'd'"): chz.replace(y, d=1) with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'X_e'"): chz.replace(y, X_e=1) def test_repr(): @chz.chz class X: a: int b: int assert repr(X(a=1, b=2)) == "test_repr..X(a=1, b=2)" @chz.chz class Y: X_seed1: int @chz.init_property def seed1(self): return self.X_seed1 + 10 assert repr(Y(seed1=1)) == "test_repr..Y(seed1=1)" @chz.chz class Z: a: int = chz.field(repr=False) b: int = chz.field(repr=lambda x: f"?{x}?") assert repr(Z(a=1, b=2)) == "test_repr..Z(a=..., b=?2?)" def test_eq(): @chz.chz class X: a: int b: int x = X(a=1, b=2) y = X(a=1, b=2) z = X(a=1, b=3) assert x == y assert x != z assert x != 1 def test_hash(): @chz.chz class X: a: int b: int @chz.init_property def c(self): return self.b + 1 x = X(a=1, b=2) y = X(a=1, b=2) x2 = chz.replace(x, a=2) z = X(a=1, b=3) assert hash(x) == hash(y) assert hash(x) != hash(z) assert hash(x2) != hash(x) @chz.chz class Q: a: list[int] = chz.field(default_factory=lambda: [1, 2, 3]) q = Q() with pytest.raises(TypeError, match=re.escape("Cannot hash chz field: Q.a=[1, 2, 3]")): hash(q) @chz.chz class R: a: tuple[int, ...] = chz.field(default_factory=lambda: (1, 2, 3)) # Tuples are hashable. hash(R()) value = 0 @chz.chz class S: @chz.init_property def a(self): nonlocal value value += 1 return value # Since init property is not adding a __chz_fields__ # the instances of S result in the same hash value # despite not having the same a value s1 = S() s2 = S() assert hash(s1) == hash(s2) assert s1.a != s2.a @chz.chz class T: X_a: int @chz.init_property def a(self): return [self.X_a] with pytest.raises(TypeError): hash(T(a=1)) @chz.chz class U: X_a: list[int] @chz.init_property def a(self): return tuple(self.X_a) hash(U(a=[1, 2, 3])) def test_blueprint_values(): @chz.chz class Y: c: int @chz.init_property def b(self): return self.c + 1 @chz.chz class X: b: int c: Y = chz.field(default_factory=lambda: Y(c=1)) @chz.init_property def a(self): return self.c.b + self.b x = X(b=2) assert x.a == 4 # (c=1) + 1 + (b=2) assert x.c.b == 2 # (c=1) + 1 values = chz.beta_to_blueprint_values(x) assert values == {"b": 2, "c": Y, "c.c": 1} assert chz.Blueprint(X).apply(values).make() == x x = X(b=2, c=Y(c=3)) assert x.a == 6 # (c=3) + 1 + (b=2) assert x.c.b == 4 # (c=3) + 1 values = chz.beta_to_blueprint_values(x) assert values == {"b": 2, "c": Y, "c.c": 3} assert chz.Blueprint(X).apply(values).make() == x @chz.chz class Q: a: int b: tuple[int, ...] = chz.field(default_factory=lambda: (1, 2, 3)) q = Q(a=1) values = chz.beta_to_blueprint_values(q) assert values == {"a": 1, "b": (1, 2, 3)} assert chz.Blueprint(Q).apply(values).make() == q @chz.chz class R: a: int b: list[int] = chz.field(default_factory=lambda: [1, 2, 3]) r = R(a=1) # Ensure that castable + json mostly works values = chz.beta_to_blueprint_values(r) assert ( chz.Blueprint(R) .apply( {key: chz.Castable(str(value)) for key, value in json.loads(json.dumps(values)).items()} ) .make() == r ) # This test ensures that we handle the following as expected: # - default_factory # - munged values # - castable values @chz.chz class T: default_factory: str = chz.field(default_factory=lambda: "?") default_value: str = chz.field(default="!") munged_value: str = chz.field(default="?", munger=lambda instance, value: value + "!!") t = T(munged_value="Hello") values = chz.beta_to_blueprint_values(t) assert values == {"default_factory": "?", "default_value": "!", "munged_value": "Hello"} assert chz.Blueprint(T).apply(values).make() == t # This test is dedicated to testing that `x_type`/`blueprint_cast` works fine @chz.chz class U: value: int = chz.field(x_type=str, blueprint_cast=int) u = U(value="7") # The type of the blueprint should be the one we are supposed to pass in the blueprint # Not the one after instantiation values = chz.beta_to_blueprint_values(u) assert values == {"value": "7"} assert chz.Blueprint(U).apply(values).make() == u # This test verifies that derived properties aren't serialized and X_ fields are not exposed @chz.chz class W: value: int @chz.init_property def value_2(self) -> int: return self.value * 2 w = W(value=5) values = chz.beta_to_blueprint_values(w) assert values == {"value": 5} assert "value_2" not in values assert "value" in values assert w.value_2 == 10 def test_blueprint_values_polymorphic(): @chz.chz class X: a: int @property def name(self) -> str: return "x" @chz.chz class Y(X): b: int @property def name(self) -> str: return "y" @chz.chz class Z(X): c: int @property def name(self) -> str: return "z" @chz.chz class Y2(Y): d: int @property def name(self) -> str: return "y2" @chz.chz class W: x: X = chz.field(meta_factory=chz.factories.subclass(base_cls=X, default_cls=Y)) w: int w = W(x=Y(a=1, b=2), w=3) values = chz.beta_to_blueprint_values(w) assert values == {"x": Y, "x.a": 1, "x.b": 2, "w": 3} w_new = chz.Blueprint(W).apply(values).make() assert w_new == w assert w_new.x.name == "y" w = W(x=Z(a=1, c=2), w=3) values = chz.beta_to_blueprint_values(w) assert values == {"x": Z, "x.a": 1, "x.c": 2, "w": 3} w_new = chz.Blueprint(W).apply(values).make() assert w_new == w assert w_new.x.name == "z" w = W(x=Y2(a=1, b=2, d=3), w=4) values = chz.beta_to_blueprint_values(w) assert values == {"x": Y2, "x.a": 1, "x.b": 2, "x.d": 3, "w": 4} w_new = chz.Blueprint(W).apply(values).make() assert w_new == w assert w_new.x.name == "y2" @chz.chz class W_Union: x: Y | Z w: int wu = W_Union(x=Y(a=1, b=2), w=3) values = chz.beta_to_blueprint_values(wu) assert values == {"x": Y, "x.a": 1, "x.b": 2, "w": 3} wu_new = chz.Blueprint(W_Union).apply(values).make() assert wu_new == wu assert wu_new.x.name == "y" wu = W_Union(x=Z(a=1, c=2), w=3) values = chz.beta_to_blueprint_values(wu) assert values == {"x": Z, "x.a": 1, "x.c": 2, "w": 3} wu_new = chz.Blueprint(W_Union).apply(values).make() assert wu_new == wu assert wu_new.x.name == "z" def test_blueprint_values_variadic(): @chz.chz class A: a: int @chz.chz class B(A): b: int @chz.chz class Main: list_a: list[A] dict_a: dict[str, A] main = Main(list_a=[A(a=1), B(a=2, b=3)], dict_a={"a": A(a=4), "b": B(a=5, b=6)}) values = chz.beta_to_blueprint_values(main) assert values == { "list_a.0": A, "list_a.0.a": 1, "list_a.1": B, "list_a.1.a": 2, "list_a.1.b": 3, "dict_a.a": A, "dict_a.a.a": 4, "dict_a.b": B, "dict_a.b.a": 5, "dict_a.b.b": 6, } assert chz.Blueprint(Main).apply(values).make() == main @chz.chz class Main: list_a: list[A | int] dict_a: dict[str, A | int] main = Main(list_a=[A(a=1), 2], dict_a={"a": A(a=4), "b": 5}) values = chz.beta_to_blueprint_values(main) assert values == { "list_a.0": A, "list_a.0.a": 1, "list_a.1": 2, "dict_a.a": A, "dict_a.a.a": 4, "dict_a.b": 5, } assert chz.Blueprint(Main).apply(values).make() == main main = Main(list_a=[1, 2], dict_a={"a": 3, "b": 4}) values = chz.beta_to_blueprint_values(main) assert values == { "list_a": [1, 2], "dict_a": {"a": 3, "b": 4}, } assert chz.Blueprint(Main).apply(values).make() == main def test_blueprint_values_skip_defaults(): @chz.chz class Y: c: int d: int = 3 @chz.init_property def b(self): return self.c + 1 @chz.chz class X: b: int c: Y = chz.field(default_factory=lambda: Y(c=1)) e: int = 4 f: int = 5 @chz.init_property def a(self): return self.c.b + self.b x = X(b=2, e=4, f=6) values = chz.beta_to_blueprint_values(x, skip_defaults=True) assert values == {"b": 2, "c": Y, "c.c": 1, "f": 6} assert chz.Blueprint(X).apply(values).make() == x x = X(b=2, c=Y(c=3, d=4), f=5) values = chz.beta_to_blueprint_values(x, skip_defaults=True) assert values == {"b": 2, "c": Y, "c.c": 3, "c.d": 4} assert chz.Blueprint(X).apply(values).make() == x # Does not skip default values generated by default_factory @chz.chz class Q: a: int b: tuple[int, ...] = chz.field(default_factory=lambda: (1, 2, 3)) c: tuple[int, ...] = (4, 5, 6) q = Q(a=1) values = chz.beta_to_blueprint_values(q, skip_defaults=True) assert values == {"a": 1, "b": (1, 2, 3)} assert chz.Blueprint(Q).apply(values).make() == q def test_blueprint_values_unspecified_sequence(): @chz.chz class Element: v: int @chz.chz class TheClass: fixed_tuple: tuple[Element, Element] = chz.field(blueprint_unspecified=lambda: 1 / 0) var_tuple: tuple[Element, ...] = chz.field(blueprint_unspecified=lambda: 1 / 0) fixed_list: list[Element] = chz.field(blueprint_unspecified=lambda: 1 / 0) var_list: list[Element] = chz.field(blueprint_unspecified=lambda: 1 / 0) obj1 = TheClass( fixed_tuple=(Element(v=1), Element(v=2)), var_tuple=(Element(v=3), Element(v=4), Element(v=5)), fixed_list=[Element(v=6), Element(v=7)], var_list=[Element(v=8), Element(v=9), Element(v=10)], ) vals = chz.beta_to_blueprint_values(obj1) obj2 = chz.Blueprint(TheClass).apply(vals).make() assert obj1 == obj2 def test_duplicate_fields(): # There is no way to detect this since __annotations__ is a dictionary @chz.chz class X: a: int a: int # noqa: PIE794 X(a=1) def test_no_type_annotation_on_field(): with pytest.raises(TypeError, match="'a' has no type annotation"): @chz.chz class X: a = chz.field(default=0) def test_logical_name(): @chz.chz class X: X_seed1: int seed2: int @property def seed1(self): return self.X_seed1 + 100 assert len(X.__chz_fields__) == 2 assert X.__chz_fields__["seed1"].logical_name == "seed1" assert X.__chz_fields__["seed2"].logical_name == "seed2" x = X(seed1=1, seed2=2) assert x.seed1 == 101 assert x.X_seed1 == 1 assert x.seed2 == 2 assert x.X_seed2 == 2 def test_init_property(): value = 0 @chz.chz class A: @chz.init_property def a(self): nonlocal value value += 1 return value @chz.chz class B(A): @chz.init_property def b(self): nonlocal value value += 1 return value b1 = B() assert b1.a == 1 assert b1.b == 2 value = 10 b2 = B() assert b2.a == 11 assert b2.b == 12 assert b1.a == 1 assert b1.b == 2 @chz.chz class X: @chz.init_property def a(self): raise RuntimeError with pytest.raises(RuntimeError): X() @chz.chz class Y(X): pass with pytest.raises(RuntimeError): Y() def test_init_property_top_level(): # There isn't really anything special about init_property, you can just use it as a one-liner # if you don't care about static type checking. It's probably better to use a munger though # Note that in this case it is possible for init_property to get called more than once @chz.chz class A: a: int b = chz.init_property(lambda self: self.a + 1) a = A(a=1) assert a.__dict__ == {"X_a": 1, "a": 1, "b": 2} assert a.b == 2 with pytest.raises(ValueError, match="Field 'b' is clobbered by chz.data_model.init_property"): @chz.chz class B: a: int b: int = chz.init_property(lambda self: self.a + 1) # with type annotation def test_default_init_property(): @chz.chz class A: attr: int a = A(attr=1) assert a.__dict__ == {"X_attr": 1, "attr": 1} assert a.attr == 1 assert a.__dict__ == {"X_attr": 1, "attr": 1} def test_init_property_x_field(): @chz.chz class A: X_attr: int @chz.init_property def attr(self): return self.X_attr + 1 a = A(attr=1) assert a.X_attr == 1 assert a.attr == 2 assert a.__dict__ == {"X_attr": 1, "attr": 2} def test_conflicting_superclass_no_fields_in_base(): @chz.chz class BaseA: def method(self): return 1 @property def prop(self): return 1 @chz.init_property def init_prop(self): return 1 with pytest.raises( ValueError, match="Cannot define field 'method' because it conflicts with something defined on a superclass", ): @chz.chz class A1(BaseA): method: int with pytest.raises( ValueError, match="Cannot define field 'X_method' because it conflicts with something defined on a superclass", ): @chz.chz class A1X(BaseA): X_method: int with pytest.raises( ValueError, match="Cannot define field 'prop' because it conflicts with something defined on a superclass", ): @chz.chz class A2(BaseA): prop: int with pytest.raises( ValueError, match="Cannot define field 'X_prop' because it conflicts with something defined on a superclass", ): @chz.chz class A2X(BaseA): X_prop: int with pytest.raises( ValueError, match="Cannot define field 'init_prop' because it conflicts with something defined on a superclass", ): @chz.chz class A3(BaseA): init_prop: int # We could consider allowing this. In which case, you want: # assert A3(init_prop=2).X_init_prop == 2 # assert A3(init_prop=2).init_prop == 2 with pytest.raises( ValueError, match="Cannot define field 'X_init_prop' because it conflicts with something defined on a superclass", ): @chz.chz class A3X(BaseA): X_init_prop: int def test_conflicting_superclass_field_in_base(): @chz.chz class BaseB: field: int = 0 assert BaseB().X_field == 0 assert BaseB().field == 0 @chz.chz class B1(BaseB): X_field: int = 1 assert B1().X_field == 1 assert B1().field == 1 @chz.chz class B2(BaseB): X_field: int = 1 @chz.init_property def field(self): return self.X_field + 10 assert B2().X_field == 1 assert B2().field == 11 @chz.chz class B3(BaseB): @chz.init_property def field(self): return self.X_field + 100 assert B3().X_field == 0 assert B3().field == 100 @chz.chz class B4(BaseB): field: int = 1 assert B4().X_field == 1 assert B4().field == 1 def test_conflicting_superclass_x_field_in_base(): @chz.chz class BaseC: X_field: int = 0 @chz.init_property def field(self): return self.X_field + 10 assert BaseC().X_field == 0 assert BaseC().field == 10 @chz.chz class C1(BaseC): X_field: int = 1 assert C1().X_field == 1 assert C1().field == 11 @chz.chz class C2(BaseC): @chz.init_property def field(self): return self.X_field + 100 assert C2().X_field == 0 assert C2().field == 100 with pytest.raises(ValueError, match="little unsure of what the semantics should be here"): @chz.chz class C3(BaseC): field: int = 1 # assert C3().X_field == 1 # Should this be 11 or 1? # The argument for 11 is that it's exactly the same case as C1 # The argument for 1 is that it matches go-to-definition better # I'm in favour of 11, but I'll stick with a custom error for now... # assert C3().field == 11 def test_field_clobbering_in_same_class(): with pytest.raises(ValueError, match="Field 'a' is clobbered by chz.data_model.init_property"): @chz.chz class X: a: int = 1 @chz.init_property def a(self): return 1 with pytest.raises(ValueError, match="Field 'a' is clobbered by function"): @chz.chz class Y: a: int = 1 def a(self): return 1 @chz.chz class OK1: # lambdas are special cased (since they're more likely to be default values than methods) a: typing.Callable[[], int] = lambda: 1 @chz.chz class OK2: a: typing.Callable[[], None] = test_field_clobbering_in_same_class def test_dataclass_errors(): with pytest.raises( RuntimeError, match=r"Something has gone horribly awry; are you using a chz.Field in a dataclass\?", ): @dataclasses.dataclass class X: a: int = chz.field(default=1) def test_cloudpickle_main(): import cloudpickle # noqa: F401 main = """ import chz from threading import Lock unpickleable = Lock() class Normal: def __repr__(self): return "normal" assert __name__ == "__main__" @chz.chz class X: one: int norm: Normal = chz.field(default_factory=lambda: Normal()) import base64 import cloudpickle print(base64.b64encode(cloudpickle.dumps(X(one=1))).decode()) """ import base64 import pickle import subprocess import sys pickled = subprocess.check_output([sys.executable, "-c", main]) try: unpickled = pickle.loads(base64.b64decode(pickled)) except pickle.UnpicklingError as e: e.add_note("Maybe you forgot to remove a print statement?") raise assert unpickled.one == 1 assert repr(unpickled) == "X(one=1, norm=normal)" def test_protocol(): with pytest.raises(TypeError, match="chz class cannot itself be a Protocol"): @chz.chz class Disallowed(typing.Protocol): a: int class Proto(typing.Protocol): a: int @chz.chz class Allowed(Proto): pass # Protocol fields do not become chz fields automatically Allowed() def test_abc(): import abc @chz.chz class IDontMakeAnyPromisesAboutBehaviourHere(abc.ABC): a: int class Abc(abc.ABC): a: int @chz.chz class Allowed(Abc): pass # ABC fields do not become chz fields automatically Allowed() def test_pretty_format(): from chz.data_model import pretty_format @chz.chz class Child: name: str age: int @chz.chz class Parent: name: str age: int X_nickname: str | None = None child: Child = chz.field(default_factory=lambda: Child(name="bob", age=1)) @chz.init_property def nickname(self) -> str: return self.X_nickname or self.name obj = Parent(name="alice", age=30) expected = f"""{Parent.__qualname__}( age=30, name='alice', # Fields where pre-init value matches default: child={Child.__qualname__}( age=1, name='bob', ), nickname=None # 'alice' (after init), )""" assert pretty_format(obj, colored=False) == expected @chz.chz class Collection: children: list[Child] named_children: dict[str, Child] obj = Collection( children=[Child(name="alice", age=1)], named_children={"bob": Child(name="bob", age=2)}, ) assert ( pretty_format(obj, colored=False) == """test_pretty_format..Collection( children=[ test_pretty_format..Child( age=1, name='alice', ), ], named_children={ 'bob': test_pretty_format..Child( age=2, name='bob', ), }, )""" ) def test_metadata(): @chz.chz class X: a: int = chz.field(metadata={"foo": "bar"}) assert X.__chz_fields__["a"].metadata == {"foo": "bar"} def test_traverse(): @chz.chz class A: a_value: int = 15 @chz.chz class B: a: A = chz.field(default_factory=A) b_value: str = "hi" @chz.chz class C: ba: tuple[A | B, ...] = (A(), B()) c_value: tuple[str, ...] = ("hello", "world") assert list(chz.traverse(C())) == [ ("", C(ba=(A(a_value=15), B(a=A(a_value=15), b_value="hi")), c_value=("hello", "world"))), ("ba", (A(a_value=15), B(a=A(a_value=15), b_value="hi"))), ("ba.0", A(a_value=15)), ("ba.0.a_value", 15), ("ba.1", B(a=A(a_value=15), b_value="hi")), ("ba.1.a", A(a_value=15)), ("ba.1.a", A(a_value=15)), ("ba.1.a.a_value", 15), ("ba.1.b_value", "hi"), ("c_value", ("hello", "world")), ("c_value.0", "hello"), ("c_value.1", "world"), ] def test_int_dict_keys(): @chz.chz class A: int_keyed: dict[int, str] str_keyed: dict[str, str] a = chz.Blueprint(A).make_from_argv( [ "int_keyed.1=one", "int_keyed.2=two", "str_keyed.a=ay", "str_keyed.b=bee", ] ) assert a.int_keyed == {1: "one", 2: "two"} assert a.str_keyed == {"a": "ay", "b": "bee"} ================================================ FILE: tests/test_factories.py ================================================ """ Watch out for some of the extra parentheses in these tests. """ import typing import pytest from chz.factories import MetaFromString, standard class A: ... class B(A): ... B_alias = B class C(B): ... class X: ... def foo(): return A() bar = 0 a = A() def test_standard_subclass(): f = standard(annotation=A) assert f.unspecified_factory() is A assert f.from_string("A") is A assert f.from_string("B") is B assert f.from_string("C") is C with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") with pytest.raises( MetaFromString, match="Expected test_factories:X from 'test_factories:X' to be a subtype of test_factories:A", ): f.from_string(f"{__name__}:X") assert f.from_string(f"{__name__}:A") is A assert f.from_string(f"{__name__}:B") is B assert f.from_string(f"{__name__}:C") is C assert f.from_string(f"{__name__}.A") is A assert f.from_string(f"{__name__}.B") is B assert f.from_string(f"{__name__}.C") is C def test_standard_subclass_unspecified(): f = standard(annotation=A, unspecified=B) assert f.unspecified_factory() is B assert f.from_string("A") is A assert f.from_string("B") is B assert f.from_string("C") is C with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") assert f.from_string(f"{__name__}:A") is A assert f.from_string(f"{__name__}:B") is B assert f.from_string(f"{__name__}:C") is C def test_standard_subclass_module(): f = standard(annotation=A) with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'a'"): f.from_string("a") with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'foo'"): f.from_string("foo") with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'bar'"): f.from_string("bar") with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") assert f.from_string(f"{__name__}:a")() is a assert f.from_string(f"{__name__}:foo") is foo with pytest.raises(MetaFromString, match="Expected 0 from 'test_factories:bar' to be callable"): f.from_string(f"{__name__}:bar") f = standard(annotation=A, default_module=__name__) assert f.from_string("a")() is a assert f.from_string("foo") is foo with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'bar'"): assert f.from_string("bar") is foo with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") assert f.from_string(f"{__name__}:a")() is a assert f.from_string(f"{__name__}:foo") is foo with pytest.raises(MetaFromString, match="Expected 0 from 'test_factories:bar' to be callable"): f.from_string(f"{__name__}:bar") with pytest.raises( MetaFromString, match="Expected test_factories:X from 'test_factories:X' to be a subtype of test_factories:A", ): f.from_string(f"{__name__}:X") def test_standard_subclass_object_any(): import collections.abc for any_object in (object, typing.Any): f = standard(annotation=any_object) with pytest.raises(MetaFromString, match="Could not find 'a', try a fully qualified name"): f.from_string("a") with pytest.raises( MetaFromString, match="Could not find 'foo', try a fully qualified name" ): f.from_string("foo") assert f.from_string(f"{__name__}:a")() is a assert f.from_string(f"{__name__}:foo") is foo assert f.from_string(f"{__name__}:bar")() is bar f = standard(annotation=any_object, default_module=__name__) assert f.from_string("a")() is a assert f.from_string("foo") is foo assert f.from_string("bar")() is bar assert f.from_string("collections.abc.MutableSequence") is collections.abc.MutableSequence f = standard(annotation=any_object, unspecified=type[object]) assert f.unspecified_factory() != type[object] assert f.unspecified_factory()() is object f = standard(annotation=any_object, unspecified=type) assert f.unspecified_factory() is type def test_standard_type_subclass(): f = standard(annotation=type[A]) assert f.unspecified_factory()() is A assert f.from_string("A")() is A assert f.from_string("B")() is B assert f.from_string("C")() is C with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") assert f.from_string(f"{__name__}:A")() is A assert f.from_string(f"{__name__}:B")() is B assert f.from_string(f"{__name__}:C")() is C def test_standard_type_subclass_unspecified(): f = standard(annotation=type[A], unspecified=type[B]) assert f.unspecified_factory()() is B assert f.from_string("A")() is A assert f.from_string("B")() is B assert f.from_string("C")() is C with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'X'"): f.from_string("X") assert f.from_string(f"{__name__}:A")() is A assert f.from_string(f"{__name__}:B")() is B assert f.from_string(f"{__name__}:C")() is C def test_standard_type_subclass_module(): f = standard(annotation=type[A]) with pytest.raises(MetaFromString, match="No subclass of test_factories:A named 'B_alias'"): f.from_string("B_alias") assert f.from_string(f"{__name__}:B_alias")() is B f = standard(annotation=type[A], default_module=__name__) assert f.from_string("B_alias")() is B assert f.from_string(f"{__name__}:B_alias")() is B def test_standard_union(): f = standard(annotation=A | X) assert f.unspecified_factory() is None assert f.from_string("A") is A assert f.from_string("B") is B assert f.from_string("C") is C assert f.from_string("X") is X with pytest.raises(MetaFromString, match="Could not produce a union instance from 'object'"): f.from_string("object") assert f.from_string(f"{__name__}:A") is A assert f.from_string(f"{__name__}:B") is B assert f.from_string(f"{__name__}:C") is C assert f.from_string(f"{__name__}:X") is X def test_standard_union_unspecified(): f = standard(annotation=A | X, unspecified=B) assert f.unspecified_factory() is B assert f.from_string("A") is A assert f.from_string("B") is B assert f.from_string("C") is C assert f.from_string("X") is X with pytest.raises(MetaFromString, match="Could not produce a union instance from 'object'"): f.from_string("object") assert f.from_string(f"{__name__}:A") is A assert f.from_string(f"{__name__}:B") is B assert f.from_string(f"{__name__}:C") is C assert f.from_string(f"{__name__}:X") is X def test_standard_union_optional(): f = standard(annotation=A | None) assert f.unspecified_factory() is A assert f.from_string("A") is A assert f.from_string("None")() is None with pytest.raises(MetaFromString, match="Could not produce a union instance from 'object'"): f.from_string("object") f = standard(annotation=int | None) assert f.perform_cast("123") == 123 assert f.perform_cast("None") is None def test_standard_union_module(): f = standard(annotation=A | X) with pytest.raises(MetaFromString, match="Could not produce a union instance from 'a'"): f.from_string("a") with pytest.raises(MetaFromString, match="Could not produce a union instance from 'foo'"): f.from_string("foo") assert f.from_string(f"{__name__}:a")() is a assert f.from_string(f"{__name__}:foo") is foo f = standard(annotation=A, default_module=__name__) assert f.from_string("a")() is a assert f.from_string("foo") is foo assert f.from_string(f"{__name__}:a")() is a assert f.from_string(f"{__name__}:foo") is foo def test_standard_union_type(): f = standard(annotation=type[A] | type[X]) assert f.unspecified_factory() == None f = standard(annotation=type[A | X]) assert f.unspecified_factory() == None f = standard(annotation=type[A] | type[X], unspecified=type[B]) assert f.unspecified_factory() != type[B] assert f.unspecified_factory()() is B f = standard(annotation=type[A | X], unspecified=type[B]) assert f.unspecified_factory() != type[B] assert f.unspecified_factory()() is B def test_standard_type_generic(): f = standard(annotation=type[list[int]]) assert f.unspecified_factory() is not list assert f.unspecified_factory() != list[int] assert f.unspecified_factory()() == list[int] def test_standard_lambda(): f = standard(annotation=int) assert f.from_string("lambda: 123")() == 123 assert f.from_string("lambda x, y: x + y")(1, 2) == 3 def test_standard_none(): f = standard(annotation=None) assert f.unspecified_factory()() is None assert f.from_string("None")() is None def test_standard_special_forms(): f = standard(annotation=typing.Literal["foo", "bar"]) assert f.unspecified_factory() is None with pytest.raises( MetaFromString, match=r"Could not produce a Literal\['foo', 'bar'\] instance from 'foo'" ): f.from_string("foo") def test_standard_subclass_duplicate(): class Parent: ... ca = type("Child", (Parent,), {}) # noqa: F841 cb = type("Child", (Parent,), {}) # noqa: F841 f = standard(annotation=Parent) with pytest.raises( MetaFromString, match=r"Multiple subclasses of .*Parent named 'Child': .*Child, .*Child" ): f.from_string("Child") ================================================ FILE: tests/test_munge.py ================================================ from typing import Any, Callable, TypedDict, TypeVar import pytest import chz from chz.mungers import attr_if_none, if_none def test_munger(): @chz.chz class A: a: int = chz.field(munger=lambda s, v: s.b) b: int = chz.field(munger=lambda s, v: s.c + 10) c: int = chz.field(munger=lambda s, v: s.X_a + 100) x = A(a=1, b=47, c=94) assert x.X_a == 1 assert x.a == 111 assert x.X_b == 47 assert x.b == 111 assert x.X_c == 94 assert x.c == 101 def test_munger_call_count(): count = 0 def munger(s, v): nonlocal count count += 1 return v * 2 @chz.chz class A: a: int = chz.field(munger=munger) a = A(a=18) assert a.a == 36 assert count == 1 assert a.a == 36 assert count == 1 assert a.a == 36 assert count == 1 def test_munger_conflict(): with pytest.raises( ValueError, match="Cannot define 'a' in class when the associated field has a munger" ): @chz.chz class A: X_a: int = chz.field(munger=lambda s, v: s.X_a) @chz.init_property def a(self): return 1 def test_munge_recursive(): @chz.chz class A: # Since we pass the value, there is little need to do this # And if there is need, it's still better to do s.X_a # We could make this a different kind of error other than RecursionError, but seems fine a: int = chz.field(munger=lambda s, v: s.a) with pytest.raises(RecursionError): A(a=1) def test_munger_combinators(): @chz.chz class A: a: int = chz.field(munger=if_none(lambda self: self.c)) b: int = chz.field(munger=attr_if_none("a")) c: int = 42 a = A(a=None, b=None) assert a.a == 42 assert a.b == 42 a = A(a=1, b=None) assert a.a == 1 assert a.b == 1 a = A(a=None, b=2) assert a.a == 42 assert a.b == 2 a = A(a=1, b=2) assert a.a == 1 assert a.b == 2 def test_munger_x_type(): @chz.chz(typecheck=True) class A: a: int = chz.field(munger=lambda s, v: int(v + "0") + 1, x_type=str) a = A(a="123") assert a.X_a == "123" assert a.a == 1231 a = chz.Blueprint(A).apply({"a": chz.blueprint.Castable("456")}).make() assert a.X_a == "456" assert a.a == 4561 @chz.chz(typecheck=True) class B: b: int = chz.field(munger=lambda s, v: v, x_type=str) with pytest.raises(TypeError, match="Expected X_b to be str, got int"): B(b=0) B(b="0") # TODO: this could raise @chz.chz(typecheck=True) class C: X_c: int @chz.init_property def c(self) -> str: return str(self.X_c) assert C(c=0).c == "0" def test_munger_freeze_dict(): class MyDict(TypedDict): a: int b: int @chz.chz class A: d: dict[str, int] = chz.field(munger=chz.mungers.freeze_dict()) d2: MyDict = chz.field(munger=chz.mungers.freeze_dict()) x = A(d={"a": 1, "b": 2}, d2=MyDict(a=1, b=2)) hash(x) def test_converter(): @chz.chz class A: a: int = chz.field(converter=if_none(lambda self: self.c)) c: int = 42 a = A(a=None) assert a.a == 42 b = A(a=3) assert b.a == 3 def test_converter_and_munger(): with pytest.raises(ValueError, match="Cannot specify both converter and munger"): @chz.chz class A: a: int = chz.field( converter=if_none(lambda self: self.c), munger=if_none(lambda self: self.c) ) c: int = 42 def test_converter_fn(): @chz.chz class A: a: int = chz.field(converter=lambda v, **kwargs: v or 10) c: int = 42 a = A(a=None) assert a.a == 10 _T = TypeVar("_T") def if_none_fn(default: _T) -> Callable[[_T | None], _T]: def convert(value: _T | None, *, chzself: Any = None) -> _T: return value or default return convert def test_converter_fn_typed(): @chz.chz class A: a: int = chz.field(converter=if_none_fn(10)) c: int = 42 a = A(a=None) assert a.a == 10 def test_converter_freeze_dict(): from frozendict import frozendict class MyDict(TypedDict): a: int b: int @chz.chz class A: d: frozendict[str, int] = chz.field(converter=chz.mungers.freeze_dict()) d2: MyDict = chz.field(converter=chz.mungers.freeze_dict()) # type: ignore x = A(d={"a": 1, "b": 2}, d2=MyDict(a=1, b=2)) hash(x) ================================================ FILE: tests/test_tiepin.py ================================================ # ruff: noqa: UP006 # ruff: noqa: UP007 # ruff: noqa: UP045 import collections.abc import enum import fractions import pathlib import re import sys import typing import pytest import typing_extensions from chz.tiepin import ( CastError, _simplistic_try_cast, _simplistic_type_of_value, approx_type_hash, is_subtype, is_subtype_instance, type_repr, ) def test_type_repr(): assert type_repr(int) == "int" assert type_repr(list[int]) == "list[int]" class X: ... assert type_repr(X) == "test_tiepin:test_type_repr..X" assert type_repr(typing.Literal["asdf"]) == "Literal['asdf']" assert type_repr(typing.Union[int, str]) in ("Union[int, str]", "int | str") assert ( type_repr(typing.Callable[[int], X]) == "Callable[[int], test_tiepin.test_type_repr..X]" ) assert type_repr(typing.Sequence[str]) == "Sequence[str]" assert type_repr(collections.abc.Sequence[str]) == "Sequence[str]" def test_is_subtype_instance_basic(): assert is_subtype_instance(None, None) assert is_subtype_instance(1, int) assert is_subtype_instance(True, bool) assert not is_subtype_instance("str", int) assert not is_subtype_instance(1, str) assert not is_subtype_instance(int, int) assert not is_subtype_instance(int, float) class A: ... class B(A): ... assert is_subtype_instance(A(), A) assert is_subtype_instance(B(), B) assert is_subtype_instance(B(), A) assert not is_subtype_instance(A(), B) assert not is_subtype_instance(A, A) assert not is_subtype_instance(B, A) assert not is_subtype_instance(A, B) assert not is_subtype_instance(B, B) def test_is_subtype_instance_user_defined_generic(): T = typing.TypeVar("T") class X(typing.Generic[T]): ... assert is_subtype_instance(X(), X) assert is_subtype_instance(X(), X[int]) assert is_subtype_instance(X(), X[str]) # Note that we do use __orig_class__ where possible assert is_subtype_instance(X[int](), X[int]) assert not is_subtype_instance(X[str](), X[int]) def test_is_subtype_instance_user_defined_generic_abc(): T = typing.TypeVar("T") class X(typing.Generic[T]): def unrelated(self) -> T: ... def __iter__(self): return iter([1, 2, 3]) assert is_subtype_instance(X(), X) assert is_subtype_instance(X(), X[int]) assert is_subtype_instance(X(), X[str]) assert is_subtype_instance(X(), collections.abc.Iterable) assert is_subtype_instance(X(), collections.abc.Iterable[int]) assert not is_subtype_instance(X(), collections.abc.Iterable[str]) class Y(typing.Generic[T]): def __call__(self, x: int) -> str: ... assert is_subtype_instance(Y(), Y) assert is_subtype_instance(Y(), Y[int]) assert is_subtype_instance(Y(), Y[str]) assert is_subtype_instance(Y(), Y[bytes]) assert is_subtype_instance(Y(), typing.Callable) assert is_subtype_instance(Y(), typing.Callable[[int], str]) assert not is_subtype_instance(Y(), typing.Callable[[int], int]) assert not is_subtype_instance(Y(), typing.Callable[[str], int]) assert not is_subtype_instance(Y(), typing.Callable[[str], str]) def test_is_subtype_instance_duck_type(): assert is_subtype_instance(1, int) assert is_subtype_instance(1, float) assert is_subtype_instance(1, complex) assert not is_subtype_instance(1.0, int) assert is_subtype_instance(1.0, float) assert is_subtype_instance(1.0, complex) assert not is_subtype_instance(1 + 1j, int) assert not is_subtype_instance(1 + 1j, float) assert is_subtype_instance(1 + 1j, complex) assert is_subtype_instance(bytearray(), bytes) assert is_subtype_instance(memoryview(b""), bytes) assert not is_subtype_instance(bytes(), bytearray) def test_is_subtype_instance_any(): assert is_subtype_instance(1, typing.Any) assert is_subtype_instance("str", typing.Any) assert is_subtype_instance([], typing.Any) assert is_subtype_instance(int, typing.Any) if sys.version_info >= (3, 11): class Mock(typing.Any): ... assert is_subtype_instance(Mock(), typing.Any) assert is_subtype_instance(Mock(), int) assert is_subtype_instance(Mock(), str) assert not is_subtype_instance(1, Mock) def test_is_subtype_instance_list_abc(): T = typing.TypeVar("T") seq: typing.Any for seq in ( list, typing.List, collections.abc.Sequence, collections.abc.Iterable, typing.Iterable, collections.abc.Collection, list[T], typing.List[T], collections.abc.Sequence[T], collections.abc.Iterable[T], typing.Iterable[T], collections.abc.Collection[T], ): assert is_subtype_instance([], seq) assert is_subtype_instance([], seq[int]) assert is_subtype_instance([], seq[object]) assert is_subtype_instance([], seq[typing.Any]) assert is_subtype_instance([1, 2], seq) assert is_subtype_instance([1, 2], seq[int]) assert is_subtype_instance([1, 2], seq[object]) assert is_subtype_instance([1, 2], seq[typing.Any]) assert not is_subtype_instance([1, 2], seq[str]) assert is_subtype_instance([1, 2, "3"], seq) assert is_subtype_instance([1, 2, "3"], seq[object]) assert is_subtype_instance([1, 2, "3"], seq[typing.Any]) assert not is_subtype_instance([1, 2, "3"], seq[int]) def test_is_subtype_instance_iterable(): assert is_subtype_instance([], typing.Iterable) assert is_subtype_instance([1, 2, 3], typing.Iterable[int]) assert not is_subtype_instance([1, 2, "3"], typing.Iterable[int]) assert is_subtype_instance((), typing.Iterable) assert is_subtype_instance((1, 2, 3), typing.Iterable[int]) assert not is_subtype_instance((1, 2, "3"), typing.Iterable[int]) assert is_subtype_instance({}, typing.Iterable) assert is_subtype_instance({1: 2, 3: 4}, typing.Iterable[int]) assert not is_subtype_instance({1: 2, "3": 4}, typing.Iterable[int]) assert is_subtype_instance(set(), typing.Iterable) assert is_subtype_instance({1, 2, 3}, typing.Iterable[int]) assert not is_subtype_instance({1, 2, "3"}, typing.Iterable[int]) assert is_subtype_instance(frozenset(), typing.Iterable) assert is_subtype_instance(frozenset({1, 2, 3}), typing.Iterable[int]) assert not is_subtype_instance(frozenset({1, 2, "3"}), typing.Iterable[int]) assert is_subtype_instance("", typing.Iterable) assert is_subtype_instance("123", typing.Iterable[str]) assert not is_subtype_instance("123", typing.Iterable[int]) assert is_subtype_instance(b"", typing.Iterable) assert is_subtype_instance(b"123", typing.Iterable[int]) assert not is_subtype_instance(b"123", typing.Iterable[str]) assert is_subtype_instance(bytearray(), typing.Iterable) assert is_subtype_instance(bytearray(b"123"), typing.Iterable[int]) assert not is_subtype_instance(bytearray(b"123"), typing.Iterable[str]) assert is_subtype_instance(memoryview(b""), typing.Iterable) assert is_subtype_instance(memoryview(b"123"), typing.Iterable[int]) assert not is_subtype_instance(memoryview(b"123"), typing.Iterable[str]) assert is_subtype_instance(range(0), typing.Iterable) assert is_subtype_instance(range(3), typing.Iterable[int]) assert not is_subtype_instance(range(3), typing.Iterable[str]) assert is_subtype_instance(map(int, [1, 2, 3]), typing.Iterable) assert is_subtype_instance(map(int, [1, 2, 3]), typing.Iterable[int]) assert not is_subtype_instance(map(str, [1, 2, 3]), typing.Iterable[int]) def test_is_subtype_instance_tuple(): assert is_subtype_instance((1, 2), tuple) assert is_subtype_instance((1, 2), tuple[int, int]) assert is_subtype_instance((1, 2), tuple[int, ...]) assert is_subtype_instance((1, 2), typing.Tuple) assert is_subtype_instance((1, 2), typing.Tuple[int, int]) assert is_subtype_instance((1, 2), typing.Tuple[int, ...]) assert not is_subtype_instance((1, 2), tuple[int, int, int]) assert not is_subtype_instance((1, 2), tuple[int, str]) assert not is_subtype_instance((1, 2), typing.Tuple[int, int, int]) assert not is_subtype_instance((1, 2), typing.Tuple[int, str]) assert is_subtype_instance((1, "str", "bytes"), tuple) assert is_subtype_instance((1, "str", "bytes"), tuple[int, str, str]) assert is_subtype_instance((1, "str", "bytes"), typing.Tuple) assert is_subtype_instance((1, "str", "bytes"), typing.Tuple[int, str, str]) assert not is_subtype_instance((1, "str", "bytes"), tuple[int, ...]) def test_is_subtype_instance_mapping(): K = typing.TypeVar("K") V = typing.TypeVar("V") map: typing.Any for map in ( dict, typing.Dict, collections.abc.MutableMapping, collections.abc.Mapping, dict[K, V], typing.Dict[K, V], collections.abc.MutableMapping[K, V], collections.abc.Mapping[K, V], ): assert is_subtype_instance({}, map) assert is_subtype_instance({}, map[str, int]) assert is_subtype_instance({"a": 1}, map) assert is_subtype_instance({"a": 1}, map[str, int]) assert is_subtype_instance({"a": 1}, map[str, typing.Any]) assert is_subtype_instance({"a": 1}, map[typing.Any, typing.Any]) assert not is_subtype_instance({"a": 1}, map[int, int]) assert not is_subtype_instance({"a": 1}, map[int, typing.Any]) def test_is_subtype_instance_typed_dict(): for t_TypedDict in (typing.TypedDict, typing_extensions.TypedDict): class A(t_TypedDict): a: int b: str class B(A): c: bytes assert not is_subtype_instance({}, A) assert not is_subtype_instance({"a": 1}, A) assert is_subtype_instance({"a": 1, "b": "str"}, A) assert not is_subtype_instance({"a": 1, "b": "str"}, B) assert is_subtype_instance({"a": 1, "b": "str", "c": b"bytes"}, A) assert is_subtype_instance({"a": 1, "b": "str", "c": b"bytes"}, B) assert not is_subtype_instance({"a": 1, "b": 1}, A) assert not is_subtype_instance({"c": b"bytes"}, A) assert not is_subtype_instance({"c": b"bytes"}, B) def test_is_subtype_instance_typed_dict_required(): for t_TypedDict in (typing.TypedDict, typing_extensions.TypedDict): class Foo(t_TypedDict): a: int b: typing.Required[int] c: typing.NotRequired[int] class Bar(Foo, total=False): d: int e: typing.Required[int] f: typing.NotRequired[int] class Baz(Bar): g: int h: typing.Required[int] i: typing.NotRequired[int] class Qux(t_TypedDict, total=False): j: int assert not is_subtype_instance({}, Foo) assert is_subtype_instance({"a": 1, "b": 1}, Foo) assert not is_subtype_instance({"a": 1, "b": "str"}, Foo) assert is_subtype_instance({"a": 1, "b": 1, "c": 1}, Foo) assert not is_subtype_instance({"a": 1, "b": 1, "c": "str"}, Foo) assert is_subtype_instance({"a": 1, "b": 1, "c": 1, "xyz": object()}, Foo) assert not is_subtype_instance({}, Bar) assert not is_subtype_instance({"a": 1, "b": 1}, Bar) assert is_subtype_instance({"a": 1, "b": 1, "e": 1}, Bar) assert is_subtype_instance({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1}, Bar) assert not is_subtype_instance({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": "str"}, Bar) assert not is_subtype_instance({"a": 1, "b": 1, "e": 1}, Baz) assert is_subtype_instance({"a": 1, "b": 1, "e": 1, "g": 1, "h": 1}, Baz) assert is_subtype_instance({}, Qux) assert is_subtype_instance({"j": 1}, Qux) assert not is_subtype_instance({"j": "str"}, Qux) def test_is_subtype_instance_named_tuple(): class A(typing.NamedTuple): a: int b: str class B(A): ... assert is_subtype_instance(A(a=1, b="str"), A) assert is_subtype_instance(A(a=1, b="str"), tuple[int, str]) assert not is_subtype_instance(A(a=1, b="str"), tuple[int, int]) assert not is_subtype_instance(A(a=1, b="str"), B) def test_is_subtype_instance_type_var(): T = typing.TypeVar("T") assert is_subtype_instance(1, T) assert is_subtype_instance("str", T) assert is_subtype_instance([1, 2], list[T]) assert is_subtype_instance((1,), tuple[T]) assert is_subtype_instance("a", typing.AnyStr) assert not is_subtype_instance(1, typing.AnyStr) assert is_subtype_instance(["a"], list[typing.AnyStr]) assert is_subtype_instance(["a"], typing.List[typing.AnyStr]) assert not is_subtype_instance([1], list[typing.AnyStr]) assert is_subtype_instance(["this is", b"not quite right"], list[typing.AnyStr]) C = typing.TypeVar("C", list[int], list[str]) assert is_subtype_instance([], C) assert not is_subtype_instance([b"bytes"], C) B = typing.TypeVar("B", bound=int) assert is_subtype_instance(1, B) assert is_subtype_instance(False, B) assert not is_subtype_instance("str", B) def test_is_subtype_instance_union(): assert is_subtype_instance(1, typing.Union[int, str]) assert is_subtype_instance("str", typing.Union[int, str]) assert not is_subtype_instance(b"bytes", typing.Union[int, str]) assert is_subtype_instance([1], typing.List[typing.Union[int, str]]) assert is_subtype_instance(["str", 1], typing.List[typing.Union[int, str]]) assert is_subtype_instance([1], list[typing.Union[int, str]]) assert is_subtype_instance(["str", 1], list[typing.Union[int, str]]) if sys.version_info >= (3, 10): assert is_subtype_instance(1, int | str) assert is_subtype_instance("str", int | str) assert not is_subtype_instance(b"bytes", int | str) assert is_subtype_instance([1], typing.List[int | str]) assert is_subtype_instance(["str", 1], typing.List[int | str]) assert is_subtype_instance([1], list[int | str]) assert is_subtype_instance(["str", 1], list[int | str]) def test_is_subtype_instance_callable() -> None: # lambda / untyped subtyping assert is_subtype_instance(lambda: None, typing.Callable) assert is_subtype_instance(lambda x: x, typing.Callable) assert is_subtype_instance(lambda: None, typing.Callable[[], None]) assert is_subtype_instance(lambda x: x, typing.Callable[[int], int]) assert is_subtype_instance(lambda: None, typing.Callable[..., None]) assert is_subtype_instance(lambda x: x, typing.Callable[..., int]) assert not is_subtype_instance(lambda: None, typing.Callable[[int], int]) assert not is_subtype_instance(lambda x: x, typing.Callable[[], None]) assert not is_subtype_instance(1, typing.Callable[[int], str]) # typed functions subtyping def foo(x: int, y: str, z: bytes = ...) -> None: ... assert is_subtype_instance(foo, typing.Callable[[int, str], None]) assert is_subtype_instance(foo, typing.Callable[[int, str, bytes], None]) assert not is_subtype_instance(foo, typing.Callable[[str, int, bytes], None]) assert not is_subtype_instance(foo, typing.Callable[[int, str, bytes], str]) def bar(x: object) -> bool: ... assert is_subtype_instance(bar, typing.Callable[[int], int]) assert is_subtype_instance(bar, typing.Callable[[str], bool]) assert is_subtype_instance(bar, typing.Callable[..., bool]) assert not is_subtype_instance(bar, typing.Callable[..., str]) def baz(x, y, z): ... assert is_subtype_instance(baz, typing.Callable[[int, str, bytes], int]) assert is_subtype_instance(baz, typing.Callable[[str, bytes, int], bytes]) # type subtyping class A: def __init__(self, x: int) -> None: ... class B(A): ... assert is_subtype_instance(A, typing.Callable[[int], A]) assert is_subtype_instance(B, typing.Callable[[int], A]) assert not is_subtype_instance(A, typing.Callable[[int], B]) # callable instance subtyping class Call: def __call__(self, x: int) -> None: ... assert is_subtype_instance(Call(), typing.Callable[[int], None]) assert is_subtype_instance(Call(), typing.Callable[..., None]) assert not is_subtype_instance(Call(), typing.Callable[[str], None]) assert not is_subtype_instance(Call(), typing.Callable[[int, int], None]) # function with arguments of generic type def takes_dict(x: dict[int, str]) -> int: ... assert is_subtype_instance(takes_dict, typing.Callable[[dict[int, str]], int]) assert is_subtype_instance(takes_dict, typing.Callable[[dict[int, typing.Any]], int]) assert is_subtype_instance(takes_dict, typing.Callable[[dict[typing.Any, str]], int]) assert is_subtype_instance(takes_dict, typing.Callable[[dict[typing.Any, typing.Any]], int]) assert not is_subtype_instance(takes_dict, typing.Callable[[dict[object, object]], int]) assert not is_subtype_instance(takes_dict, typing.Callable[[dict[int, str]], None]) assert not is_subtype_instance(takes_dict, typing.Callable[[dict[str, int]], int]) # more contravariance tests class C(B): ... def takes_b(x: B) -> None: ... assert is_subtype_instance(takes_b, typing.Callable[[C], None]) assert is_subtype_instance(takes_b, typing.Callable[[B], None]) assert not is_subtype_instance(takes_b, typing.Callable[[A], None]) # varargs def varargs(*args: int) -> None: ... assert is_subtype_instance(varargs, typing.Callable[[int], None]) assert is_subtype_instance(varargs, typing.Callable[[int, int], None]) assert not is_subtype_instance(varargs, typing.Callable[[str], None]) # varkwargs def varkwargs(**kwargs: int) -> None: ... assert is_subtype_instance(varkwargs, typing.Callable[[], None]) assert not is_subtype_instance(varkwargs, typing.Callable[[int], None]) assert not is_subtype_instance(varkwargs, typing.Callable[[int, int], None]) # param spec P = typing.ParamSpec("P") assert not is_subtype_instance(lambda: None, typing.Callable[[P], None]) assert not is_subtype_instance(lambda: foo, typing.Callable[[P], None]) assert not is_subtype_instance(lambda: bar, typing.Callable[[P], None]) assert not is_subtype_instance(lambda: A, typing.Callable[[P], None]) def test_is_subtype_instance_callable_protocol(): class A: ... class B(A): ... class C(B): ... class P1(typing.Protocol): def __call__(self, x: B) -> None: ... def p1(x: B) -> None: ... def p2(x: A) -> None: ... def p3(x: B = ...) -> None: ... def p4(x: B, y: int = ...) -> None: ... assert is_subtype_instance(p1, P1) assert is_subtype_instance(p2, P1) assert is_subtype_instance(p3, P1) assert is_subtype_instance(p4, P1) def p5(x: B, /) -> None: ... def p6(y: B, /) -> None: ... def p7(x: C) -> None: ... def p8(y: B) -> None: ... def p9(x: B, y: int) -> None: ... assert not is_subtype_instance(p5, P1) assert not is_subtype_instance(p6, P1) assert not is_subtype_instance(p7, P1) assert not is_subtype_instance(p8, P1) assert not is_subtype_instance(p9, P1) class P2(typing.Protocol): def __call__(self, x: B, /) -> None: ... assert is_subtype_instance(p1, P2) assert is_subtype_instance(p2, P2) assert is_subtype_instance(p3, P2) assert is_subtype_instance(p4, P2) assert is_subtype_instance(p5, P2) assert is_subtype_instance(p6, P2) class P3(typing.Protocol): def __call__(self, x: B, y: int = ...) -> None: ... assert not is_subtype_instance(p1, P3) assert not is_subtype_instance(p9, P3) def p10(*args: int | A) -> None: ... def p11(*args: int | C) -> None: ... assert is_subtype_instance(p10, P3) assert not is_subtype_instance(p11, P3) class P4(typing.Protocol): def __call__(self, x: B, *, y: int) -> None: ... assert not is_subtype_instance(p1, P4) assert is_subtype_instance(p4, P4) def p12(x: B, *, y: int = ...) -> None: ... def p13(x: B, *, y: int) -> None: ... def p14(x: B, *, y: B = ...) -> None: ... def p15(**kwargs: int | B) -> None: ... def p16(x: B, **kwargs: int) -> None: ... def p17(x: B, *, y: int, z: int = ...) -> None: ... def p18(x: B, y: B) -> None: ... assert is_subtype_instance(p4, P4) assert is_subtype_instance(p9, P4) assert is_subtype_instance(p12, P4) assert is_subtype_instance(p13, P4) assert not is_subtype_instance(p14, P4) assert not is_subtype_instance(p15, P4) assert not is_subtype_instance(p16, P4) assert is_subtype_instance(p17, P4) assert not is_subtype_instance(p18, P4) class P5(typing.Protocol): def __call__(self, x: B, *, y: int = ...) -> None: ... assert is_subtype_instance(p4, P5) assert not is_subtype_instance(p9, P5) assert is_subtype_instance(p12, P5) assert not is_subtype_instance(p13, P5) assert not is_subtype_instance(p14, P5) assert not is_subtype_instance(p15, P5) assert not is_subtype_instance(p16, P5) assert not is_subtype_instance(p17, P5) assert not is_subtype_instance(p18, P5) class P7(typing.Protocol): def __call__(self, x: B, **kwargs: B) -> None: ... assert not is_subtype_instance(p1, P7) assert not is_subtype_instance(p4, P7) assert not is_subtype_instance(p12, P7) def p19(x: B, **kwargs: A) -> None: ... def p20(x: B, **kwargs: C) -> None: ... def p21(x: B, y: B, **kwargs: B) -> None: ... def p22(x: B, *, y: B, **kwargs: B) -> None: ... def p23(x: B, *, y: int, **kwargs: B) -> None: ... assert is_subtype_instance(p19, P7) assert not is_subtype_instance(p20, P7) assert not is_subtype_instance(p21, P7) assert is_subtype_instance(p22, P7) assert not is_subtype_instance(p23, P7) class P8(typing.Protocol): def __call__(self) -> B: ... def p22() -> A: ... def p23() -> C: ... assert not is_subtype_instance(p22, P8) assert is_subtype_instance(p23, P8) def test_is_subtype_instance_protocol_chz_callable(): class P(typing.Protocol): def __call__(self, a: int) -> int: ... import chz @chz.chz class Bad: def __call__(self) -> int: ... @chz.chz class Good: def __call__(self, a: int) -> int: ... assert not is_subtype_instance(Bad(), P) assert not is_subtype_instance(Bad, P) assert is_subtype_instance(Good(), P) assert not is_subtype_instance(Good, P) def test_is_subtype_instance_protocol_attr(): class A: ... class B(A): ... class C(B): ... class FooProto(typing.Protocol): x: B def foo(self) -> int: ... class Foo: def __init__(self, x) -> None: self.x = x def foo(self) -> int: return 1 assert not is_subtype_instance(Foo(A()), FooProto) assert is_subtype_instance(Foo(B()), FooProto) assert is_subtype_instance(Foo(C()), FooProto) b = Foo(B()) b.foo = 1 assert not is_subtype_instance(b, FooProto) assert not is_subtype_instance(object(), FooProto) def test_is_subtype_instance_runtime_protocol(): @typing.runtime_checkable class FooProto(typing.Protocol): def foo(self) -> int: ... class Foo: def foo(self) -> int: return 1 assert is_subtype_instance(Foo(), FooProto) assert not is_subtype_instance(object(), FooProto) def test_is_subtype_instance_literal(): assert is_subtype_instance(1, typing.Literal[1]) assert is_subtype_instance("str", typing.Literal["str"]) assert is_subtype_instance(1, typing.Literal[1, 2]) assert is_subtype_instance("str", typing.Literal[1, "str"]) assert not is_subtype_instance(1, typing.Literal[2]) assert not is_subtype_instance("str", typing.Literal[1, "bytes"]) assert is_subtype_instance(None, typing.Literal[None]) assert is_subtype_instance(None, typing.Literal[1, "bytes", None]) def test_is_subtype_instance_type(): assert is_subtype_instance(int, type) assert is_subtype_instance(str, type) assert is_subtype_instance(type, type) assert is_subtype_instance(int, typing.Type) assert is_subtype_instance(str, typing.Type) assert is_subtype_instance(type, typing.Type) assert is_subtype_instance(int, type[int]) assert is_subtype_instance(str, type[str]) assert is_subtype_instance(type, type[type]) assert is_subtype_instance(int, typing.Type[int]) assert is_subtype_instance(str, typing.Type[str]) assert is_subtype_instance(type, typing.Type[type]) assert not is_subtype_instance(int, type[str]) assert not is_subtype_instance(str, type[int]) assert not is_subtype_instance(type, type[int]) assert not is_subtype_instance(int, typing.Type[str]) assert not is_subtype_instance(str, typing.Type[int]) assert not is_subtype_instance(type, typing.Type[int]) def test_is_subtype_instance_enum(): class Color(enum.Enum): RED = 1 GREEN = 2 assert is_subtype_instance(Color.RED, enum.Enum) assert is_subtype_instance(Color.RED, Color) assert not is_subtype_instance(1, Color) assert not is_subtype_instance("RED", Color) assert is_subtype_instance(Color.RED, typing.Literal[Color.RED]) assert is_subtype_instance(Color.RED, typing.Literal[Color.RED, Color.GREEN]) assert not is_subtype_instance(Color.RED, typing.Literal[Color.GREEN]) def test_is_subtype_instance_new_type(): N = typing.NewType("N", int) assert is_subtype_instance(1, N) assert not is_subtype_instance("1", N) def test_is_subtype_instance_literal_string(): assert is_subtype_instance("str", typing.LiteralString) assert not is_subtype_instance(1, typing.LiteralString) def test_is_subtype_instance_explicit_protocol_lsp_violation(): class P(typing.Protocol): def makes_int(self) -> int: ... class Implicit: def makes_int(self) -> str: ... class Explicit(P): def makes_int(self) -> str: ... assert not is_subtype_instance(Implicit(), P) assert is_subtype_instance(Explicit(), P) def test_is_subtype_instance_pydantic() -> None: import pydantic T = typing.TypeVar("T") class Thing(pydantic.BaseModel, typing.Generic[T]): type: str = "thing" x: T assert is_subtype_instance(Thing[int](x=5), Thing) def test_is_subtype_instance_pydantic_utils() -> None: import pydantic import pydantic_core try: from pydantic_utils import get_polymorphic_generic_model_schema except ImportError: pytest.skip("pydantic_utils not installed") T = typing.TypeVar("T") class Foo(pydantic.BaseModel, typing.Generic[T]): type: str = "foo" x: T @classmethod def __get_pydantic_core_schema__( cls, source: typing.Type[pydantic.BaseModel], # noqa: UP006 handler: pydantic.GetCoreSchemaHandler, ) -> pydantic_core.core_schema.CoreSchema: return get_polymorphic_generic_model_schema( cls, __class__, source, handler, # type:ignore[name-defined] ) class Bar(Foo[T], typing.Generic[T]): type: str = "bar" y: T assert is_subtype_instance(Foo[int](x=5), Foo) assert Foo[typing.Any].model_validate(Foo[int](x=5)) assert is_subtype_instance(Foo[int](x=5), Foo[typing.Any]) assert is_subtype_instance(Bar[int](x=5, y=2), Foo[int]) # This is currently broken in pydantic_utils. # assert not is_subtype_instance(Bar[str](x="a", y="b"), Foo[int]) assert is_subtype_instance(Bar(x=5, y=2), Foo[int]) def test_is_subtype(): assert is_subtype(int, int) assert not is_subtype(int, str) assert is_subtype(list, list) assert is_subtype(list[int], list) assert is_subtype(list, list[int]) assert is_subtype(list[int], list[typing.Any]) assert is_subtype(list[typing.Any], list[int]) assert not is_subtype(list[int], list[str]) class A: ... class B(A): ... assert is_subtype(tuple, tuple) assert is_subtype(tuple[int], tuple[int, ...]) assert is_subtype(tuple[int, int], tuple[int, ...]) assert not is_subtype(tuple[int, ...], tuple[int, int]) assert not is_subtype(tuple[int, str], tuple[int, ...]) assert is_subtype(tuple[B, B], tuple[A, B]) assert not is_subtype(tuple[B, A], tuple[B, B]) assert not is_subtype(tuple[B, B], tuple[A, B, object]) assert is_subtype(dict, dict) assert is_subtype(dict[typing.Any, B], dict[str, A]) assert not is_subtype(dict[A, A], dict[A, B]) assert is_subtype(int, int | str) assert is_subtype(str, int | str) assert not is_subtype(bytes, int | str) assert is_subtype(int | str, str | int) assert is_subtype(int | str, int | bytes | str) assert not is_subtype(int | str | bytes, int | str) assert is_subtype(int, typing.Union[int, str]) assert is_subtype(str, typing.Union[int, str]) assert not is_subtype(bytes, typing.Union[int, str]) assert is_subtype(typing.Union[int, str], typing.Union[str, int]) assert is_subtype(typing.Union[int, str], typing.Union[int, bytes, str]) assert not is_subtype(typing.Union[int, str, bytes], typing.Union[int, str]) assert is_subtype(None, None) assert is_subtype(None, int | None) assert is_subtype(None, typing.Optional[int]) assert is_subtype(typing.Literal[1, 2], typing.Literal[3, 2, 1]) assert not is_subtype(typing.Literal[1, 2, 4], typing.Literal[3, 2, 1]) assert is_subtype(typing.Literal[1, 2], str | int) assert is_subtype(typing.Literal[1, 2], str | typing.Literal[1, 2, 3]) assert is_subtype(typing.Callable[[int], str], typing.Callable[[int], str]) assert not is_subtype(typing.Callable[[int], str], typing.Callable[[int, int], str]) assert is_subtype(typing.Callable[[int], str], typing.Callable[[int], str | None]) assert not is_subtype(typing.Callable[[int], str | None], typing.Callable[[int], str]) assert is_subtype(typing.Callable[[int | None], None], typing.Callable[[int], None]) assert not is_subtype(typing.Callable[[int], None], typing.Callable[[int | None], None]) assert not is_subtype(typing.Callable[[int], None], typing.Callable[[str], None]) def test_is_subtype_protocol(): class P1(typing.Protocol): def foo(self) -> int: ... class P2(typing.Protocol): def foo(self) -> int: ... def bar(self) -> int: ... class Good: def foo(self) -> int: ... class Bad: def bar(self) -> int: ... assert not is_subtype(P1, P2) assert is_subtype(P2, P1) assert is_subtype(Good, P1) assert not is_subtype(Bad, P1) assert not is_subtype(Good, P2) def a() -> P1: ... assert is_subtype_instance(a, typing.Callable[..., P1]) assert not is_subtype_instance(a, typing.Callable[..., P2]) def test_is_subtype_typed_dict(): class A(typing.TypedDict): a: int b: str assert is_subtype(A, typing.Mapping[str, typing.Any]) assert not is_subtype(typing.Mapping[str, typing.Any], A) assert is_subtype(A, dict[str, int | str]) assert is_subtype(A, dict[str, int]) class B(A): c: bytes assert is_subtype(B, A) assert not is_subtype(A, B) class B_alt(typing.TypedDict): a: int b: str c: bytes assert is_subtype(B_alt, B) assert is_subtype(B, B_alt) assert is_subtype(B_alt, A) assert not is_subtype(A, B_alt) class A_not_total(typing.TypedDict, total=False): a: int b: str assert is_subtype(A, A_not_total) assert not is_subtype(A_not_total, A) def test_is_subtype_typevar() -> None: T_int = typing.TypeVar("T_int", bound=int) assert is_subtype(T_int, int) assert is_subtype(T_int, object) assert not is_subtype(T_int, str) assert is_subtype(int, T_int) assert not is_subtype(str, T_int) T = typing.TypeVar("T") assert is_subtype(T, object) assert not is_subtype(T, str) assert is_subtype(T, T) assert is_subtype(object, T) assert is_subtype(str, T) T_constrained = typing.TypeVar("T_constrained", int, str) assert is_subtype(T_constrained, int | str) assert is_subtype(T_constrained, object) assert not is_subtype(T_constrained, bytes) assert is_subtype(T_constrained, T_constrained) assert is_subtype(int, T_constrained) assert is_subtype(str, T_constrained) assert not is_subtype(object, T_constrained) def test_no_return(): assert is_subtype_instance(typing.NoReturn, int) assert is_subtype_instance(typing.NoReturn, str) def foo() -> typing.NoReturn: ... assert is_subtype_instance(foo, typing.Callable[[], None]) assert is_subtype_instance(foo, typing.Callable[[], str]) if sys.version_info >= (3, 11): assert is_subtype_instance(typing.Never, int) assert is_subtype_instance(typing.Never, str) def foo() -> typing.Never: ... assert is_subtype_instance(foo, typing.Callable[[], None]) assert is_subtype_instance(foo, typing.Callable[[], str]) def test_try_cast_object_any(): for obj_any in (object, typing.Any, typing_extensions.Any): assert _simplistic_try_cast("1", obj_any) == 1 assert _simplistic_try_cast("1a", obj_any) == "1a" assert _simplistic_try_cast("1j", obj_any) == 1j assert _simplistic_try_cast("{1: ('2', [3])}", obj_any) == {1: ("2", [3])} assert _simplistic_try_cast("null", obj_any) is None assert _simplistic_try_cast("none", obj_any) is None assert _simplistic_try_cast("true", obj_any) is True assert _simplistic_try_cast("false", obj_any) is False def test_try_cast_tuple(): assert _simplistic_try_cast("1", tuple) == ("1",) assert _simplistic_try_cast("1,2", tuple) == ("1", "2") assert _simplistic_try_cast("1,2", tuple[str, str]) == ("1", "2") assert _simplistic_try_cast("1,2", tuple[typing.Any, ...]) == (1, 2) assert _simplistic_try_cast("1a,2", tuple[typing.Any, ...]) == ("1a", 2) assert _simplistic_try_cast("1,2", tuple[int, ...]) == (1, 2) assert _simplistic_try_cast("", tuple) == () # can't distinguish untyped tuple from zero length tuple assert _simplistic_try_cast("1", tuple[()]) == ("1",) assert _simplistic_try_cast("1,2", tuple[()]) == ("1", "2") def test_try_cast_list(): assert _simplistic_try_cast("1", list) == [1] assert _simplistic_try_cast("1,str", list) == [1, "str"] assert _simplistic_try_cast("1,2", list[int]) == [1, 2] assert _simplistic_try_cast("1,2", list[str]) == ["1", "2"] assert _simplistic_try_cast("[1]", list) == [1] assert _simplistic_try_cast("[1,'str']", list) == [1, "str"] assert _simplistic_try_cast("[1,2]", list[int]) == [1, 2] with pytest.raises(CastError, match=r"Could not cast '\[1,2\]' to list\[str\]"): _simplistic_try_cast("[1,2]", list[str]) with pytest.raises(CastError, match=r"Could not cast '\[1,str\]' to list"): _simplistic_try_cast("[1,str]", list) def test_try_cast_sequence_iterable(): for origin in { collections.abc.Sequence, collections.abc.Iterable, typing.Sequence, typing.Iterable, }: assert _simplistic_try_cast("1", origin) == (1,) assert _simplistic_try_cast("1,str", origin) == (1, "str") assert _simplistic_try_cast("1,2", origin[int]) == (1, 2) assert _simplistic_try_cast("1,2", origin[str]) == ("1", "2") assert _simplistic_try_cast("[1]", origin) == [1] assert _simplistic_try_cast("[1,'str']", origin) == [1, "str"] assert _simplistic_try_cast("[1,2]", origin[int]) == [1, 2] assert _simplistic_try_cast("(1,)", origin) == (1,) assert _simplistic_try_cast("(1,'str')", origin) == (1, "str") assert _simplistic_try_cast("(1,2)", origin[int]) == (1, 2) with pytest.raises(CastError, match=r"Could not cast '\(1\)' to \w+"): _simplistic_try_cast("(1)", origin) with pytest.raises(CastError, match=r"Could not cast '\[1,2\]' to \w+\[str\]"): _simplistic_try_cast("[1,2]", origin[str]) with pytest.raises(CastError, match=r"Could not cast '\[1,str\]' to \w+"): _simplistic_try_cast("[1,str]", origin) def test_try_cast_dict(): assert _simplistic_try_cast("{1: 2}", dict) == {1: 2} assert _simplistic_try_cast("{1: 2}", dict[int, int]) == {1: 2} assert _simplistic_try_cast("{1: '2'}", dict[int, str]) == {1: "2"} assert _simplistic_try_cast("{1: '2'}", dict[int, typing.Any]) == {1: "2"} with pytest.raises(CastError, match=r"""Could not cast "\{1: '2'\}" to dict\[int, int\]"""): _simplistic_try_cast("{1: '2'}", dict[int, int]) with pytest.raises(CastError, match=r"""Could not cast "\{1: '2'\}" to dict\[str, str\]"""): _simplistic_try_cast("{1: '2'}", dict[str, str]) with pytest.raises(CastError, match=r"Could not cast '\{str: str\}' to dict"): _simplistic_try_cast("{str: str}", dict) def test_try_cast_callable(): assert ( _simplistic_try_cast( "chz.tiepin:_simplistic_try_cast", typing.Callable[[str, typing.Any], typing.Any] ) is _simplistic_try_cast ) assert ( _simplistic_try_cast("chz.tiepin:_simplistic_try_cast", typing.Callable[..., typing.Any]) is _simplistic_try_cast ) with pytest.raises( CastError, match=r"Could not cast 'chz.tiepin:_simplistic_try_cast' to Callable\[\[int\], int\]", ): _simplistic_try_cast("chz.tiepin:_simplistic_try_cast", typing.Callable[[int], int]) with pytest.raises( CastError, match=r"Could not cast 'does_not_exist:penguin' to Callable\[\[int\], int\]\. Could not import module.*ModuleNotFoundError", ): _simplistic_try_cast("does_not_exist:penguin", typing.Callable[[int], int]) def test_try_cast_tuple_unpack(): # fmt: off assert _simplistic_try_cast("1,2,3", tuple[int, *tuple[str, int]]) == (1, "2", 3) assert _simplistic_try_cast("1,2,3", tuple[int, typing.Unpack[tuple[str, int]]]) == (1, "2", 3) assert _simplistic_try_cast("1,2,3", tuple[int, *tuple[str, ...]]) == (1, "2", "3") assert _simplistic_try_cast("1,2,3", tuple[int, typing.Unpack[tuple[str, ...]]]) == (1, "2", "3") assert _simplistic_try_cast("1,2,3,4", tuple[int, *tuple[str, ...]]) == (1, "2", "3", "4") assert _simplistic_try_cast("1,2,3,4", tuple[int, typing.Unpack[tuple[str, ...]]]) == (1, "2", "3", "4") assert _simplistic_try_cast("1,2,3,4", tuple[int, *tuple[str, ...], int, int]) == (1, "2", 3, 4) assert _simplistic_try_cast("1,2,3,4", tuple[int, typing.Unpack[tuple[str, ...]], int, int]) == (1, "2", 3, 4) assert _simplistic_try_cast("1,2,3,4", tuple[int, *tuple[str, *tuple[int, int]]]) == (1, "2", 3, 4) assert _simplistic_try_cast("1,2,3,4", tuple[int, typing.Unpack[tuple[str, typing.Unpack[tuple[int, int]]]]]) == (1, "2", 3, 4) # fmt: on with pytest.raises( CastError, match=re.escape( "Could not cast '1,2' to tuple[int, *tuple[str, int]] because of length mismatch" ), ): _simplistic_try_cast("1,2", tuple[int, *tuple[str, int]]) def test_try_cast_union_overlap(): assert _simplistic_try_cast("1", str | int) == 1 assert _simplistic_try_cast("1", int | str) == 1 assert _simplistic_try_cast("None", str | None) == None assert _simplistic_try_cast("None", None | str) == None assert _simplistic_try_cast("all", tuple[str, ...] | typing.Literal["all"]) == "all" assert _simplistic_try_cast("all", typing.Literal["all"] | tuple[str, ...]) == "all" assert _simplistic_try_cast("None", None | typing.Literal["None"]) == "None" assert _simplistic_try_cast("None", typing.Literal["None"] | None) == "None" assert _simplistic_try_cast("None", None | tuple[str, ...]) == None assert _simplistic_try_cast("None", tuple[str, ...] | None) == None assert _simplistic_try_cast("", tuple[str, ...] | str) == () assert _simplistic_try_cast("", tuple[str, ...] | typing.Literal[""]) == "" def test_try_cast_enum(): class Color(enum.Enum): RED = 1 GREEN = 2 assert _simplistic_try_cast("RED", Color) == Color.RED assert _simplistic_try_cast("GREEN", Color) == Color.GREEN assert _simplistic_try_cast("1", Color) == Color.RED assert _simplistic_try_cast("2", Color) == Color.GREEN with pytest.raises(CastError, match="Could not cast 'BLUE' to .*Color"): _simplistic_try_cast("BLUE", Color) with pytest.raises(CastError, match="Could not cast '3' to .*Color"): _simplistic_try_cast("3", Color) def test_try_cast_fractions(): assert _simplistic_try_cast("1/2", fractions.Fraction) == fractions.Fraction(1, 2) assert _simplistic_try_cast("1", fractions.Fraction) == fractions.Fraction(1) assert _simplistic_try_cast("1.5", fractions.Fraction) == fractions.Fraction(3, 2) def test_try_cast_pathlib(): assert _simplistic_try_cast("foo", pathlib.Path) == pathlib.Path("foo") def test_try_cast_typevar(): assert _simplistic_try_cast("foo", typing.TypeVar("T")) == "foo" assert _simplistic_try_cast("foo", typing.TypeVar("T", int, str)) == "foo" assert _simplistic_try_cast("5", typing.TypeVar("T", bound=int)) == 5 assert _simplistic_try_cast("5", typing.TypeVar("T", bound=str | int)) == 5 with pytest.raises(CastError, match="Could not cast 'foo' to ~T"): assert _simplistic_try_cast("foo", typing.TypeVar("T", int, float)) == "foo" with pytest.raises(CastError, match="Could not cast 'five' to int"): assert _simplistic_try_cast("five", typing.TypeVar("T", bound=int)) == 5 def test_approx_type_hash(): import builtins from typing import Callable, Literal, TypeVar, Union _T = TypeVar("_T") assert approx_type_hash(int)[:8] == "46f8ab7c" assert approx_type_hash(str)[:8] == "3442496b" assert approx_type_hash("str")[:8] == "3442496b" assert approx_type_hash(builtins.str)[:8] == "3442496b" class float: ... assert approx_type_hash(builtins.float)[:8] == "685e8036" assert approx_type_hash(float)[:8] == "685e8036" # can't tell the difference... assert approx_type_hash("float")[:8] == "685e8036" assert approx_type_hash(list[int])[:8] == "e4c2cba0" assert approx_type_hash("list[int]")[:8] == "e4c2cba0" assert approx_type_hash(list["int"])[:8] == "e4c2cba0" assert approx_type_hash(list[_T])[:8] == "c6eb1529" assert approx_type_hash(Union[int, str])[:8] == "c1729268" assert approx_type_hash(Union[str, int])[:8] == "d811461d" assert approx_type_hash(Union[str, "int"])[:8] == "d811461d" assert approx_type_hash(Callable[[int], str])[:8] == "0dc453ef" assert approx_type_hash(Literal[1, "asdf", False])[:8] == "ee5b7e0f" def test_simplistic_type_of_value(): tov = _simplistic_type_of_value assert tov(1) is int assert tov("foo") is str assert tov([1, 2, 3]) == list[int] assert tov([1, 2, 3.0]) == list[int | float] assert tov([1, 2, "3"]) == list[int | str] assert tov((1, 2, 3)) == tuple[int, int, int] assert tov((1, 2, 3.0)) == tuple[int, int, float] assert tov((1, "2", 3.0)) == tuple[int, str, float] assert tov(tuple(i for i in range(12))) == tuple[int, ...] assert tov([(1, 2), (3, 4)]) == list[tuple[int, int]] assert tov([(1, 2), (3, 4, 5)]) == list[tuple[int, int] | tuple[int, int, int]] assert tov({1: "a", "b": 2}) == dict[int | str, str | int] assert tov(int) == type[int] class A: ... class B(A): ... class C(A): ... assert tov([A(), B()]) == list[A] assert tov([B(), A()]) == list[A] assert tov([B(), C()]) == list[B | C] ================================================ FILE: tests/test_todo.py ================================================ import pytest import chz # TODO: test inheritance, setattr, repr def test_version(): @chz.chz(version="b4d37d6e") class X1: a: int @chz.chz(version="b4d37d6e-3") class X2: a: int with pytest.raises(ValueError, match="Version 'b4d37d6e' does not match '3902ee27'"): @chz.chz(version="b4d37d6e") class X3: a: int b: int ================================================ FILE: tests/test_validate.py ================================================ import math import re from typing import Generic, TypeVar import pytest import chz T = TypeVar("T") def test_validate_readme(): @chz.chz class Fraction: numerator: int = chz.field(validator=chz.validators.typecheck) denominator: int = chz.field(validator=[chz.validators.typecheck, chz.validators.gt(0)]) @chz.validate def _check_reduced(self): if math.gcd(self.numerator, self.denominator) > 1: raise ValueError("Fraction is not reduced") Fraction(numerator=1, denominator=2) Fraction(numerator=2, denominator=1) with pytest.raises(ValueError, match=r"Fraction is not reduced"): Fraction(numerator=2, denominator=4) def test_validate(): @chz.chz class X: attr: int = chz.field(validator=chz.validators.instancecheck) X(attr=1) with pytest.raises(TypeError, match="Expected X_attr to be int, got str"): X(attr="1") # type: ignore @chz.chz class Y: attr: int = chz.field(validator=chz.validators.instance_of(int)) @chz.validate def _attr_validator(self): if self.attr < 0: raise ValueError("attr must be non-negative") Y(attr=1) with pytest.raises(TypeError, match="Expected X_attr to be int, got str"): Y(attr="1") # type: ignore with pytest.raises(ValueError, match="attr must be non-negative"): Y(attr=-1) @chz.chz class Z: attr: int | str = chz.field(validator=chz.validators.typecheck) Z(attr=1) Z(attr="asdf") with pytest.raises(TypeError, match=r"int \| str, got bytes"): Z(attr=b"fdsa") # type: ignore def test_validate_replace(): @chz.chz class X: attr: int = chz.field(validator=chz.validators.typecheck) x = X(attr=1) x = chz.replace(x, attr=2) with pytest.raises(TypeError, match="Expected X_attr to be int, got str"): chz.replace(x, attr="3") def test_for_all_fields(): @chz.chz class X: a: str b: int @chz.validate def typecheck_all_fields(self): chz.validators.for_all_fields(chz.validators.typecheck)(self) X(a="asdf", b=1) with pytest.raises(TypeError, match="Expected X_a to be str, got int"): X(a=1, b=1) with pytest.raises(TypeError, match="Expected X_b to be int, got str"): X(a="asdf", b="asdf") with pytest.raises(TypeError, match="Expected X_a to be str, got int"): X(a=1, b="asdf") def test_validate_inheritance_field_level(): @chz.chz class X: a: str = chz.field(validator=chz.validators.typecheck) @chz.chz class Y(X): b: int with pytest.raises(TypeError, match="Expected X_a to be str, got int"): Y(a=1, b=1) @chz.chz class A: x: X = chz.field(validator=chz.validators.typecheck) @chz.chz class B(A): x: Y A(x=X(a="asdf")) A(x=Y(a="asdf", b=1)) B(x=Y(a="asdf", b=1)) # But note that if you clobber an attribute, the field-level validator also gets clobbered B(x=X(a="asdf")) def test_validate_init_property(): @chz.chz class A1: X_attr: str = chz.field(validator=chz.validators.instancecheck) @chz.init_property def attr(self) -> str: return str(self.X_attr) A1(attr="attr") with pytest.raises(TypeError, match="Expected X_attr to be str, got int"): A1(attr=1) @chz.chz class A2: X_attr: int = chz.field(validator=chz.validators.instancecheck) @chz.init_property def attr(self) -> str: # changes type return str(self.X_attr) A2(attr=1) with pytest.raises(TypeError, match="Expected X_attr to be int, got str"): A2(attr="attr") def test_validate_init_property_order(): @chz.chz class A: value: int = chz.field(validator=chz.validators.gt(0)) @chz.init_property def reciprocal(self): return 1 / self.value with pytest.raises(ValueError, match="Expected X_value to be greater than 0, got 0"): A(value=0) def test_validate_munger(): # See comments in __chz_validate__ @chz.chz class A: a: int = chz.field(munger=lambda s, v: 100, validator=chz.validators.gt(10)) with pytest.raises(ValueError, match="Expected X_a to be greater than 10, got 1"): A(a=1) @chz.chz class A: a: int = chz.field(munger=lambda s, v: 100, validator=chz.validators.lt(10)) with pytest.raises(ValueError, match="Expected a to be less than 10, got 100"): A(a=1) def test_validate_ge_le() -> None: @chz.chz class A: value: int = chz.field(validator=chz.validators.ge(0)) A(value=0) with pytest.raises(ValueError, match="Expected X_value to be greater or equal to 0, got -1"): A(value=-1) @chz.chz class B: value: int = chz.field(validator=chz.validators.le(0)) B(value=0) with pytest.raises(ValueError, match="Expected X_value to be less or equal to 0, got 1"): B(value=1) def test_validate_inheritance_class_level(): @chz.chz class X: a: str @chz.validate def check_a_is_banana(self): if self.a != "banana": raise ValueError("Banana only") @chz.chz class Y(X): b: int with pytest.raises(ValueError, match="Banana only"): Y(a="nana", b=1) @chz.chz class Z(Y): c: bytes @chz.validate def check_c_is_not_empty(self): if not self.c: raise ValueError("c must not be empty") @chz.validate def check_b_is_positive(self): if self.b < 0: raise ValueError("b must be positive") X(a="banana") Y(a="banana", b=1) Y(a="banana", b=-1) with pytest.raises(ValueError, match="Banana only"): Z(a="nana", b=1, c=b"asdf") with pytest.raises(ValueError, match="b must be positive"): Z(a="banana", b=-1, c=b"asdf") Z(a="banana", b=1, c=b"asdf") assert len(X.__chz_validators__) == 1 assert len(Y.__chz_validators__) == 1 assert len(Z.__chz_validators__) == 3 def test_validate_decorator_option(): @chz.chz(typecheck=True) class X: a: str X(a="asdf") with pytest.raises(TypeError, match="Expected X_a to be str, got int"): X(a=1) @chz.chz class Y(X): b: int Y(a="asdf", b=1) with pytest.raises(TypeError, match="Expected X_a to be str, got int"): Y(a=1, b=1) with pytest.raises(TypeError, match="Expected X_b to be int, got str"): Y(a="asdf", b="asdf") @chz.chz(typecheck=True) class Z(X): c: bytes assert len(Z.__chz_validators__) == 1 with pytest.raises(ValueError, match="Cannot disable typecheck; all validators are inherited"): @chz.chz(typecheck=False) class A(X): pass def test_validate_mixins(): results = set() @chz.chz class M1: @chz.validate def v1(self): results.add("v1") class M2NonChz: @chz.validate def v2(self): results.add("v2") @chz.chz class M3: @chz.validate def v3(self): results.add("v3") @chz.chz class Main(M1, M2NonChz, M3): @chz.validate def v4(self): results.add("v4") Main() assert results == {"v1", "v2", "v3", "v4"} def test_validate_valid_regex(): @chz.chz class A: attr: str = chz.field(validator=chz.validators.valid_regex) A(attr=".*") with pytest.raises( ValueError, match="Invalid regex in X_attr: nothing to repeat at position 0" ): A(attr="*") def test_validate_literal(): from typing import Literal @chz.chz(typecheck=True) class A: attr: Literal["a", "b"] A(attr="a") A(attr="b") with pytest.raises(TypeError, match=r"Expected X_attr to be Literal\['a', 'b'\], got 'c'"): A(attr="c") def test_validate_const_default(): @chz.chz class Image: encoding: str @chz.chz class PNG(Image): encoding: str = chz.field(default="png", validator=chz.validators.const_default) PNG() PNG(encoding="png") with pytest.raises( ValueError, match="Expected X_encoding to match the default 'png', got 'jpg'" ): PNG(encoding="jpg") def test_validate_field_consistency(): @chz.chz class D: const: int @chz.chz class C: map: dict[str, D] seq: list[D] @chz.chz class B: const: int c: C @chz.chz class A: const: int b: B @chz.validate def field_consistency(self): chz.validators.check_field_consistency_in_tree(self, {"const"}) with pytest.raises( ValueError, match=re.escape( """\ Field 'const' has inconsistent values in object tree: 1 at const 2 at b.const 3 at b.c.map.a.const 4 at b.c.seq.0.const, b.c.seq.1.const, b.c.seq.2.const, ... (1 more)""" ), ): A( const=1, b=B( const=2, c=C(map={"a": D(const=3)}, seq=[D(const=4), D(const=4), D(const=4), D(const=4)]), ), ) @chz.chz class F: const: int @chz.chz class E: seq: list[F] @chz.chz class D: const: int e: E @chz.validate def field_consistency(self): chz.validators.check_field_consistency_in_tree(self, {"const"}, regex_root=r"e\.seq") # This should not raise an error because the check is only done on the `e.seq` field assert D(const=1, e=E(seq=[F(const=3), F(const=3)])).e.seq[0].const == 3 with pytest.raises( ValueError, match=re.escape( """\ Field 'const' has inconsistent values in object tree: 3 at e.seq.0.const 4 at e.seq.1.const""" ), ): D(const=1, e=E(seq=[F(const=3), F(const=4)])) def test_is_override_catches_non_overriding() -> None: @chz.chz class HasBase: x_different_name: int = 0 @chz.chz class MyHasBase(HasBase): x: int = chz.field(default=1, validator=chz.validators.is_override) with pytest.raises( ValueError, match="Field x does not exist in any parent classes of test_validate:.*MyHasBase", ): MyHasBase() def test_is_override_catches_bad_types() -> None: @chz.chz class Base: x: int = 1 my_tuple: tuple[int, str] = (0, "good") my_homogenous_tuple: tuple[int, ...] = (0, 1) my_dict: dict[int, str] = chz.field(default_factory=lambda: {0: "hi"}) Base() # OK @chz.chz class GoodOverride(Base): x: int = chz.field(default=5, validator=chz.validators.is_override) my_tuple: tuple[int, str] = chz.field( default=(1, "hi"), validator=chz.validators.is_override ) my_homogenous_tuple: tuple[int, ...] = chz.field( default=(1, 2), validator=chz.validators.is_override ) my_dict: dict[int, str] = chz.field( default_factory=lambda: {0: "hi"}, validator=chz.validators.is_override ) result = GoodOverride() assert result.x == 5 assert result.my_tuple == (1, "hi") assert result.my_homogenous_tuple == (1, 2) assert result.my_dict == {0: "hi"} @chz.chz class BadInt(Base): x: int = chz.field(default="oops", validator=chz.validators.is_override) @chz.chz class BadTuple(Base): my_tuple: tuple[int, str] = chz.field(default=(1, 0), validator=chz.validators.is_override) @chz.chz class BadHomogenousTuple(Base): my_homogenous_tuple: tuple[int, ...] = chz.field( default=(1, "oops"), validator=chz.validators.is_override ) @chz.chz class BadDict(Base): my_dict: dict[int, str] = chz.field( default_factory=lambda: {"oops": "foo"}, validator=chz.validators.is_override ) for cls in [BadInt, BadTuple, BadHomogenousTuple, BadDict]: with pytest.raises( ValueError, match=r"test_validate:.*Bad.+\.X_.+' must be an instance of .+? to match the type on the original definition in test_validate:.*Base", ): cls() def test_is_override_mixin_catches_bad_types() -> None: @chz.chz class Base: x: int = 1 my_tuple: tuple[int, str] = (0, "good") my_homogenous_tuple: tuple[int, ...] = (0, 1) my_dict: dict[int, str] = chz.field(default_factory=lambda: {0: "hi"}) Base() # OK @chz.chz class GoodOverride(Base, chz.validators.IsOverrideMixin): x: int = chz.field(default=5) my_tuple: tuple[int, str] = chz.field(default=(1, "hi")) my_homogenous_tuple: tuple[int, ...] = chz.field(default=(1, 2)) my_dict: dict[int, str] = chz.field(default_factory=lambda: {0: "hi"}) result = GoodOverride() assert result.x == 5 assert result.my_tuple == (1, "hi") assert result.my_homogenous_tuple == (1, 2) assert result.my_dict == {0: "hi"} @chz.chz class BadInt(Base, chz.validators.IsOverrideMixin): x: int = chz.field(default="oops") @chz.chz class BadTuple(Base, chz.validators.IsOverrideMixin): my_tuple: tuple[int, str] = chz.field(default=(1, 0)) @chz.chz class BadHomogenousTuple(Base, chz.validators.IsOverrideMixin): my_homogenous_tuple: tuple[int, ...] = chz.field(default=(1, "oops")) @chz.chz class BadDict(Base, chz.validators.IsOverrideMixin): my_dict: dict[int, str] = chz.field(default_factory=lambda: {"oops": "foo"}) for cls in [BadInt, BadTuple, BadHomogenousTuple, BadDict]: with pytest.raises( ValueError, match=r"test_validate:.*Bad.+\.X_.+' must be an instance of .+? to match the type on the original definition in test_validate:.*Base", ): cls() def test_is_override_catches_bad_generic_default_factory() -> None: class Box(Generic[T]): def __init__(self, value: T): self.value = value @chz.chz class Atom(Generic[T]): box: Box[str] # Check that normal overriding words @chz.chz class MyGoodAtom(Atom, Generic[T]): box: Box[str] = chz.field( default_factory=lambda: Box[str]("hi"), validator=chz.validators.is_override ) assert MyGoodAtom().box.value == "hi" @chz.chz class MyBadAtom(Atom, Generic[T]): box: Box[str] = chz.field( default_factory=lambda: Box[int](5), validator=chz.validators.is_override ) with pytest.raises( ValueError, match=r"test_validate:.*MyBadAtom.X_box' must be an instance of .*Box\[str\] to match the type on the original definition in test_validate:.*\.Atom", ): MyBadAtom() def test_is_override_works_with_default_factory() -> None: class Base: def __init__(self) -> None: self.value = 1 @chz.chz class HasBases: bases: tuple[Base, ...] def my_bad_factory() -> tuple[Base, ...]: return Base(), "oop", Base() # type: ignore @chz.chz class MyBadHasBases(HasBases): bases: tuple[Base, ...] = chz.field( default_factory=my_bad_factory, validator=chz.validators.is_override ) with pytest.raises( ValueError, match=r".*MyBadHasBases.X_bases' must be an instance of tuple\[.*Base, \.\.\.\] to match the type on the original definition in .*\.HasBases", ): MyBadHasBases() def test_is_override_mixin_catches_bad_types_in_subclasses() -> None: @chz.chz class Atom: x: int = 1 @chz.chz class MyBadAtom(Atom, chz.validators.IsOverrideMixin): x: int = chz.field(default="foo") @chz.chz class Container: atom: Atom @chz.chz class MyBadContainer(Container): atom: Atom = chz.field(default_factory=MyBadAtom, blueprint_unspecified=MyBadAtom) with pytest.raises( ValueError, match=".*MyBadAtom.X_x' must be an instance of int to match the type on the original definition in .*Atom", ): MyBadContainer() with pytest.raises( ValueError, match=".*MyBadAtom.X_x' must be an instance of int to match the type on the original definition in .*Atom", ): chz.Blueprint(MyBadContainer).make() def test_is_override_mixin_works_on_field_default() -> None: @chz.chz class Base: x: int = 1 @chz.chz class BaseSub(Base, chz.validators.IsOverrideMixin): x: int = "foo" # type: ignore # that's the point of this test! with pytest.raises( ValueError, match=".*BaseSub.X_x' must be an instance of int to match the type on the original definition in .*Base", ): BaseSub() @chz.chz class BadIntermediate(Base): x: str = "sneaky intermediate class trying to mess things up" # type: ignore @chz.chz class BadFinal(BadIntermediate, chz.validators.IsOverrideMixin): pass with pytest.raises( ValueError, match=".*BadFinal.X_x' must be an instance of int to match the type on the original definition in .*Base", ): BadFinal() @chz.chz class BadFinalThatMatchesIntermediate(BadIntermediate, chz.validators.IsOverrideMixin): x: str = "strings are bad here because it doesn't match the Base definition!" # type: ignore with pytest.raises( ValueError, match=".*BadFinalThatMatchesIntermediate.X_x' must be an instance of int to match the type on the original definition in .*Base", ): BadFinalThatMatchesIntermediate() def test_is_override_mixin_works_with_x_fields() -> None: @chz.chz class Base: X_value: str = "foo" @chz.init_property def value(self) -> tuple[str, ...]: return tuple(self.X_value.split(",")) @chz.chz class SomeOverride(Base, chz.validators.IsOverrideMixin): X_value: str = chz.field(default="a,b") instance = chz.Blueprint(SomeOverride).make() assert instance.value == ("a", "b") @chz.chz class BadOverride(Base, chz.validators.IsOverrideMixin): X_value: tuple[str, ...] = chz.field( default=("look at me eagerly create", "a tuple of strings") ) # type: ignore @chz.chz class BadOverride2(Base, chz.validators.IsOverrideMixin): X_value: str = chz.field(default=("type signature is good", "but default is bad")) with pytest.raises( ValueError, match=r".*BadOverride.X_value' must be an instance of str to match the type on the original definition in .*Base", ): BadOverride() with pytest.raises( ValueError, match=r".*BadOverride2.X_value' must be an instance of str to match the type on the original definition in .*Base", ): BadOverride2()