diff --git a/graphcast/checkpoint.py b/graphcast/checkpoint.py index b4c84339..4b609851 100644 --- a/graphcast/checkpoint.py +++ b/graphcast/checkpoint.py @@ -66,17 +66,20 @@ def _flatten(tree: Any) -> dict[str, Any]: elif isinstance(tree, (list, tuple)): tree = dict(enumerate(tree)) - assert isinstance(tree, dict) + if not isinstance(tree, dict): + raise TypeError(f"Expected dict, got {type(tree).__name__}") flat = {} for k, v in tree.items(): k = str(k) - assert _SEP not in k + if _SEP in k: + raise ValueError(f"Key '{k}' contains separator '{_SEP}'") if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): for a, b in _flatten(v).items(): flat[f"{k}{_SEP}{a}"] = b else: - assert v is not None + if v is None: + raise ValueError(f"Unexpected None value for key '{k}'") flat[k] = v return flat @@ -104,7 +107,8 @@ def _convert_types(typ: type[_T], value: Any) -> _T: return typ(value) if typ is np.ndarray: - assert isinstance(value, np.ndarray) + if not isinstance(value, np.ndarray): + raise TypeError(f"Expected np.ndarray, got {type(value).__name__}") return value if dataclasses.is_dataclass(typ): @@ -135,13 +139,15 @@ def _convert_types(typ: type[_T], value: Any) -> _T: base_type = getattr(typ, "__origin__", None) if base_type is dict: - assert len(typ.__args__) == 2 + if len(typ.__args__) != 2: + raise TypeError(f"Expected dict type with 2 args, got {len(typ.__args__)}") key_type, value_type = typ.__args__ return {_convert_types(key_type, k): _convert_types(value_type, v) for k, v in value.items()} if base_type is list: - assert len(typ.__args__) == 1 + if len(typ.__args__) != 1: + raise TypeError(f"Expected container type with 1 arg, got {len(typ.__args__)}") value_type = typ.__args__[0] return [_convert_types(value_type, v) for _, v in sorted(value.items(), key=lambda x: int(x[0]))] @@ -154,7 +160,10 @@ def _convert_types(typ: type[_T], value: Any) -> _T: for _, v in sorted(value.items(), key=lambda x: int(x[0]))) else: # A fixed length tuple of arbitrary types, eg: tuple[int, str, float] - assert len(typ.__args__) == len(value) + if len(typ.__args__) != len(value): + raise ValueError( + f"Tuple length mismatch: type has {len(typ.__args__)} args, " + f"value has {len(value)} elements") return tuple( _convert_types(t, v) for t, (_, v) in zip(