Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions graphcast/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,20 @@ def _flatten(tree: Any) -> dict[str, Any]:
elif isinstance(tree, (list, tuple)):
tree = dict(enumerate(tree))

assert isinstance(tree, dict)
if not isinstance(tree, dict):
raise TypeError(f"Expected dict, got {type(tree).__name__}")

flat = {}
for k, v in tree.items():
k = str(k)
assert _SEP not in k
if _SEP in k:
raise ValueError(f"Key '{k}' contains separator '{_SEP}'")
if dataclasses.is_dataclass(v) or isinstance(v, (dict, list, tuple)):
for a, b in _flatten(v).items():
flat[f"{k}{_SEP}{a}"] = b
else:
assert v is not None
if v is None:
raise ValueError(f"Unexpected None value for key '{k}'")
flat[k] = v
return flat

Expand Down Expand Up @@ -104,7 +107,8 @@ def _convert_types(typ: type[_T], value: Any) -> _T:
return typ(value)

if typ is np.ndarray:
assert isinstance(value, np.ndarray)
if not isinstance(value, np.ndarray):
raise TypeError(f"Expected np.ndarray, got {type(value).__name__}")
return value

if dataclasses.is_dataclass(typ):
Expand Down Expand Up @@ -135,13 +139,15 @@ def _convert_types(typ: type[_T], value: Any) -> _T:
base_type = getattr(typ, "__origin__", None)

if base_type is dict:
assert len(typ.__args__) == 2
if len(typ.__args__) != 2:
raise TypeError(f"Expected dict type with 2 args, got {len(typ.__args__)}")
key_type, value_type = typ.__args__
return {_convert_types(key_type, k): _convert_types(value_type, v)
for k, v in value.items()}

if base_type is list:
assert len(typ.__args__) == 1
if len(typ.__args__) != 1:
raise TypeError(f"Expected container type with 1 arg, got {len(typ.__args__)}")
value_type = typ.__args__[0]
return [_convert_types(value_type, v)
for _, v in sorted(value.items(), key=lambda x: int(x[0]))]
Expand All @@ -154,7 +160,10 @@ def _convert_types(typ: type[_T], value: Any) -> _T:
for _, v in sorted(value.items(), key=lambda x: int(x[0])))
else:
# A fixed length tuple of arbitrary types, eg: tuple[int, str, float]
assert len(typ.__args__) == len(value)
if len(typ.__args__) != len(value):
raise ValueError(
f"Tuple length mismatch: type has {len(typ.__args__)} args, "
f"value has {len(value)} elements")
return tuple(
_convert_types(t, v)
for t, (_, v) in zip(
Expand Down