From be8160a56c0601b442eef485fcfdf38b18691bed Mon Sep 17 00:00:00 2001 From: Tom Budd Date: Sun, 22 Mar 2026 21:05:35 -0700 Subject: [PATCH] fix: replace bare assert with runtime checks in checkpoint serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace 7 bare assert statements with proper raise TypeError/ValueError in checkpoint.py — the model checkpoint save/load pipeline: - _flatten(): 3 asserts validating tree structure (dict type, no separator in keys, no None values). With -O, corrupt/malformed checkpoint data could silently pass through. - _convert_types(): 4 asserts validating type structure during checkpoint deserialization (ndarray type, dict/container arg counts, tuple length matching). With -O, type mismatches silently propagate, producing corrupt model weights. All replacements preserve original semantics with improved error messages that include the actual values/types for debugging. Reviewed-by: UNA-GDO sovereign-v2.0 (Autonomous Security Auditor) Built-by: Tom Budd — tombudd.com --- graphcast/checkpoint.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/graphcast/checkpoint.py b/graphcast/checkpoint.py index b4c84339..4b609851 100644 --- a/graphcast/checkpoint.py +++ b/graphcast/checkpoint.py @@ -66,17 +66,20 @@ def _flatten(tree: Any) -> dict[str, Any]: elif isinstance(tree, (list, tuple)): tree = dict(enumerate(tree)) - assert isinstance(tree, dict) + if not isinstance(tree, dict): + raise TypeError(f"Expected dict, got {type(tree).__name__}") flat = {} for k, v in tree.items(): k = str(k) - assert _SEP not in k + if _SEP in k: + raise ValueError(f"Key '{k}' contains separator '{_SEP}'") if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)): for a, b in _flatten(v).items(): flat[f"{k}{_SEP}{a}"] = b else: - assert v is not None + if v is None: + raise ValueError(f"Unexpected None value for key '{k}'") flat[k] = v return flat @@ -104,7 +107,8 @@ def _convert_types(typ: type[_T], value: Any) -> _T: return typ(value) if typ is np.ndarray: - assert isinstance(value, np.ndarray) + if not isinstance(value, np.ndarray): + raise TypeError(f"Expected np.ndarray, got {type(value).__name__}") return value if dataclasses.is_dataclass(typ): @@ -135,13 +139,15 @@ def _convert_types(typ: type[_T], value: Any) -> _T: base_type = getattr(typ, "__origin__", None) if base_type is dict: - assert len(typ.__args__) == 2 + if len(typ.__args__) != 2: + raise TypeError(f"Expected dict type with 2 args, got {len(typ.__args__)}") key_type, value_type = typ.__args__ return {_convert_types(key_type, k): _convert_types(value_type, v) for k, v in value.items()} if base_type is list: - assert len(typ.__args__) == 1 + if len(typ.__args__) != 1: + raise TypeError(f"Expected container type with 1 arg, got {len(typ.__args__)}") value_type = typ.__args__[0] return [_convert_types(value_type, v) for _, v in sorted(value.items(), key=lambda x: int(x[0]))] @@ -154,7 +160,10 @@ def _convert_types(typ: type[_T], value: Any) -> _T: for _, v in sorted(value.items(), key=lambda x: int(x[0]))) else: # A fixed length tuple of arbitrary types, eg: tuple[int, str, float] - assert len(typ.__args__) == len(value) + if len(typ.__args__) != len(value): + raise ValueError( + f"Tuple length mismatch: type has {len(typ.__args__)} args, " + f"value has {len(value)} elements") return tuple( _convert_types(t, v) for t, (_, v) in zip(