From a7f7324505966e1ddf8d28b8d224d3271d045b70 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 3 Dec 2024 07:48:27 +0100 Subject: [PATCH 01/14] start work on inverse operations --- mllam_data_prep/__init__.py | 2 +- mllam_data_prep/create_dataset.py | 52 ++++++++++++++++++++++++++++-- mllam_data_prep/recreate_inputs.py | 0 pdm.lock | 12 ++++++- pyproject.toml | 1 + tests/test_inverse.py | 29 +++++++++++++++++ 6 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 mllam_data_prep/recreate_inputs.py create mode 100644 tests/test_inverse.py diff --git a/mllam_data_prep/__init__.py b/mllam_data_prep/__init__.py index 64bfa91..c34f855 100644 --- a/mllam_data_prep/__init__.py +++ b/mllam_data_prep/__init__.py @@ -7,4 +7,4 @@ # expose the public API from .config import Config, InvalidConfigException # noqa -from .create_dataset import create_dataset, create_dataset_zarr # noqa +from .create_dataset import create_dataset, create_dataset_zarr, recreate_inputs # noqa diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 73996cf..85fb924 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -4,6 +4,7 @@ from pathlib import Path import numpy as np +import parse import xarray as xr from loguru import logger from numcodecs import Blosc @@ -19,6 +20,9 @@ # optional, so we can support both v0.2.0 and v0.5.0 SUPPORTED_CONFIG_VERSIONS = ["v0.2.0", "v0.5.0"] +STATISTICS_VARIABLE_NAME_FORMAT = "{var_name}__{split_name}__{op}" +SOURCE_DATASET_NAME_ATTR = "source_dataset" + def _check_dataset_attributes(ds, expected_attributes, dataset_name): # check that the dataset has the expected attributes with the expected values @@ -172,7 +176,7 @@ def create_dataset(config: Config): f" produce variable {target_output_var} from dataset {dataset_name}" ) from ex - da_target.attrs["source_dataset"] = dataset_name + da_target.attrs[SOURCE_DATASET_NAME_ATTR] = dataset_name # only need to do selection for the coordinates that the input dataset actually has if output_coord_ranges is not None: @@ -218,7 +222,10 @@ def create_dataset(config: Config): ) for op, op_dataarrays in split_stats.items(): for var_name, da in op_dataarrays.items(): - ds[f"{var_name}__{split_name}__{op}"] = da + stat_var_name = STATISTICS_VARIABLE_NAME_FORMAT.format( + var_name=var_name, split_name=split_name, op=op + ) + ds[stat_var_name] = da # add a new variable which contains the start, stop for each split, the coords would then be the split names # and the data would be the start, stop values @@ -278,3 +285,44 @@ def create_dataset_zarr(fp_config, fp_zarr: str = None): logger.info(f"Wrote training-ready dataset to {fp_zarr}") logger.info(ds) + + +def recreate_inputs(config: Config, ds: xr.Dataset): + """ + Recreate the input datasets from a zarr file created by + `create_dataset_zarr` by applying inverse operations to each step. + + Parameters + ---------- + config : Config + The configuration object defining the input datasets and how to map them to the output dataset. + """ + + for input_name, input_config in config.inputs.items(): + dim_mapping = input_config.dim_mapping + da_target = ds[input_config.target_output_variable] + + # find the dim mapping item that is the one where variable names + # are stacked into a feature dimension + feature_dim_name = None + for output_dim, mapping_config in dim_mapping.items(): + if mapping_config.method == "stack_variables_by_var_name": + feature_dim_name = output_dim + source_dims = mapping_config.dims + name_format = mapping_config.name_format + + name_parts = [] + for feature_value in da_target[feature_dim_name].values: + name_parts.append(parse.parse(name_format, feature_value).named) + + # add a variable for each source-dimension name, so that we can unstack with this later + for d in source_dims: + values = [name_part[d] for name_part in name_parts] + da_target[d] = xr.DataArray(values, dims=[feature_dim_name]) + + if len(source_dims) == 1: + da_target = da_target.swap_dims({feature_dim_name: source_dims[0]}) + else: + da_target = da_target.set_index({feature_dim_name: source_dims}).unstack( + feature_dim_name + ) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py new file mode 100644 index 0000000..e69de29 diff --git a/pdm.lock b/pdm.lock index 30a0e20..2d67ff9 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:ed345b0df8664a5ab1aadb77d5c4218ef15f135ddacd674f0d029e5a39f9654d" +content_hash = "sha256:9a4011ddc80d96fad270712b9100d335a91fe347fed94172e2cb0b7c7438c1f2" [[metadata.targets]] requires_python = ">=3.9" @@ -814,6 +814,16 @@ files = [ {file = "pandas-2.2.2.tar.gz", hash = "sha256:9e79019aba43cb4fda9e4d983f8e88ca0373adbb697ae9c6c43093218de28b54"}, ] +[[package]] +name = "parse" +version = "1.20.2" +summary = "parse() is the opposite of format()" +groups = ["default"] +files = [ + {file = "parse-1.20.2-py2.py3-none-any.whl", hash = "sha256:967095588cb802add9177d0c0b6133b5ba33b1ea9007ca800e526f42a85af558"}, + {file = "parse-1.20.2.tar.gz", hash = "sha256:b41d604d16503c79d81af5165155c0b20f6c8d6c559efa66b4b695c3e5a0a0ce"}, +] + [[package]] name = "parso" version = "0.8.4" diff --git a/pyproject.toml b/pyproject.toml index 9c65473..a037c7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "rich>=13.7.1", "dask>=2024.2.1", "psutil>=5.7.2", + "parse>=1.20.2", ] requires-python = ">=3.9" readme = "README.md" diff --git a/tests/test_inverse.py b/tests/test_inverse.py new file mode 100644 index 0000000..0cc2045 --- /dev/null +++ b/tests/test_inverse.py @@ -0,0 +1,29 @@ +from pathlib import Path + +import xarray as xr + +import mllam_data_prep as mdp + + +def test_danra_example_inverse(): + fp_config = Path(__file__).parent.parent / "example.danra.yaml" + config: mdp.Config = mdp.Config.from_yaml_file(fp_config) + + ds_transformed = mdp.create_dataset(config=config) + + input_datasets_inverted = mdp.recreate_inputs(config=config, ds=ds_transformed) + + for input_name, input_config in config.inputs.items(): + ds_input = xr.open_dataset(input_config.path) + ds_input_inverted = input_datasets_inverted[input_name] + + # find coordinate ranges in `ds_input_inverted` and subset `ds_input` to match + for dim, coord in ds_input_inverted.coords.items(): + if dim in ds_input.coords: + ds_input = ds_input.sel({dim: coord}) + + # check that the variables in `ds_input_inverted` are present in `ds_input` + for var in ds_input_inverted.data_vars: + assert var in ds_input.data_vars + # and check that the values are the same + xr.testing.assert_equal(ds_input[var], ds_input_inverted[var]) From dfe0df1a4f3463be498a7bc7d711e91aa75f102e Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Mon, 12 May 2025 13:31:47 +0200 Subject: [PATCH 02/14] first fully working implementation! --- mllam_data_prep/__init__.py | 3 +- mllam_data_prep/create_dataset.py | 46 ----- mllam_data_prep/recreate_inputs.py | 271 +++++++++++++++++++++++++++++ pyproject.toml | 6 + tests/test_inverse.py | 44 +++-- 5 files changed, 309 insertions(+), 61 deletions(-) diff --git a/mllam_data_prep/__init__.py b/mllam_data_prep/__init__.py index c34f855..6fa686f 100644 --- a/mllam_data_prep/__init__.py +++ b/mllam_data_prep/__init__.py @@ -7,4 +7,5 @@ # expose the public API from .config import Config, InvalidConfigException # noqa -from .create_dataset import create_dataset, create_dataset_zarr, recreate_inputs # noqa +from .create_dataset import create_dataset, create_dataset_zarr # noqa +from .recreate_inputs import recreate_inputs # noqa diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 4581e7d..4576d71 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -5,7 +5,6 @@ from typing import Optional import numpy as np -import parse import xarray as xr import yaml import zarr @@ -414,48 +413,3 @@ def create_dataset_zarr( logger.info(f"Wrote training-ready dataset to {fp_zarr}") logger.info(ds) - - -def recreate_inputs(config: Config, ds: xr.Dataset): - """ - Recreate the input datasets from a zarr file created by - `create_dataset_zarr` by applying inverse operations to each step. - - Parameters - ---------- - config : Config - The configuration object defining the input datasets and how to map them to the output dataset. - """ - - for input_name, input_config in config.inputs.items(): - dim_mapping = input_config.dim_mapping - da_target = ds[input_config.target_output_variable] - - # find the dim mapping item that is the one where variable names - # are stacked into a feature dimension - feature_dim_name = None - source_dims: list[str] | None = [] - for output_dim, mapping_config in dim_mapping.items(): - if mapping_config.method == "stack_variables_by_var_name": - feature_dim_name = output_dim - source_dims = mapping_config.dims - name_format = mapping_config.name_format - break - - name_parts = [] - for feature_value in da_target[feature_dim_name].values: - name_parts.append(parse.parse(name_format, feature_value).named) - - if source_dims is not None: - - # add a variable for each source-dimension name, so that we can unstack with this later - for d in source_dims: - values = [name_part[d] for name_part in name_parts] - da_target[d] = xr.DataArray(values, dims=[feature_dim_name]) - - if len(source_dims) == 1: - da_target = da_target.swap_dims({feature_dim_name: source_dims[0]}) - else: - da_target = da_target.set_index( - {feature_dim_name: source_dims} - ).unstack(feature_dim_name) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index e69de29..ca847cf 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -0,0 +1,271 @@ +from typing import Optional + +import parse +import xarray as xr +from loguru import logger + +from .config import Config +from .create_dataset import SOURCE_DATASET_NAME_ATTR + + +def _split_coord_values_as_variables( + da: xr.DataArray, name_format: str, target_dim: str +): + """ + Split the coordinate values of a DataArray into separate variables based on a name format. + + Parameters + ---------- + da : xr.DataArray + The DataArray to split. + name_format : str + The format string used to parse the coordinate values. + target_dim : str + The name of the coordinate dimension to split. + + Returns + ------- + dict[str, xr.DataArray] + A dictionary of new DataArrays, where the keys are the variable names and the values are the DataArrays. + """ + + dataarrays = [] + coord_values = da[target_dim].values + for coord_value in coord_values: + + da_feature = da.sel({target_dim: coord_value}) + name_parts = dict(parse.parse(name_format, coord_value).named) + # the "var_name" part of the coordinate value is the name of the + # variable that that the data came from + var_name = name_parts.pop("var_name") + # the rest are coordinate names and values + coords = name_parts + + da_original = da_feature.copy().squeeze() + da_original.name = var_name + for k, v in coords.items(): + # TODO: in future we should enforce that the format strings contain + # types so that we can parse the values to the correct type + if "." in v: + try: + v = float(v) + except ValueError: + pass + else: + try: + v = int(v) + except ValueError: + pass + + da_original[k] = v + + da_original = da_original.expand_dims(list(coords.keys())) + + var_units = da_feature[f"{target_dim}_units"].load().item() + var_long_name = da_feature[f"{target_dim}_long_name"].load().item() + + da_original.attrs["units"] = var_units + da_original.attrs["long_name"] = var_long_name + + # remove the coords (and aux coords) that represented the feature + # coordinate, the units, long_name and source_dataset + for d in [ + target_dim, + f"{target_dim}_units", + f"{target_dim}_long_name", + f"{target_dim}_{SOURCE_DATASET_NAME_ATTR}", + ]: + da_original = da_original.drop_vars(d) + + dataarrays.append(da_original) + + ds = xr.merge(dataarrays, join="exact") + + return ds + + +def recreate_inputs( + ds: xr.Dataset, config: Optional[Config] = None +) -> dict[str, xr.Dataset]: + """ + Recreate the input datasets from a zarr file created by + `create_dataset_zarr` by applying inverse operations to each step. + + Parameters + ---------- + ds : xr.Dataset + The mllam-data-prep dataset to recreate the input datasets from. + config: Config, optional + The configuration object defining the input datasets and how to map them to the output dataset. + If not provided, the config will be read from the dataset attributes. + + Returns + ------- + dict[str, xr.Dataset] + A dictionary of input datasets, where the keys are the input dataset names + and the values are the recreated input datasets. + """ + input_datasets = {} + if config is None: + config = Config.from_yaml(ds.creation_config) + + for input_name, input_config in config.inputs.items(): + dim_mapping = input_config.dim_mapping + da_target = ds[input_config.target_output_variable] + + # 1. First, we need to split out the coordinate that was used to stack + # multiple variables into. Find the dim mapping item that is the one + # where variable names are stacked into a feature dimension + feature_dim_name = None + for output_dim, mapping_config in dim_mapping.items(): + if mapping_config.method == "stack_variables_by_var_name": + feature_dim_name = output_dim + name_format: str = str(mapping_config.name_format) + break + + if feature_dim_name is None: + raise ValueError( + f"Could not find a feature dimension in the dim_mapping for input dataset {input_name}" + ) + dim_mapping.pop(output_dim) + ds_source = _split_coord_values_as_variables( + da=da_target, + name_format=name_format, + target_dim=feature_dim_name, + ) + + # 2. And then we handle the other mapping of dimensions + for output_dim, mapping_config in dim_mapping.items(): + method_name = mapping_config.method + if method_name == "stack_variables_by_var_name": + raise Exception( + "`stack_variables_by_var_name` should have been handled above" + ) + elif method_name == "rename": + # rename the dimension back again + ds_source = ds_source.rename({output_dim: mapping_config.dim}) + elif method_name == "stack": + # unstack the stacked dimension + ds_source = ds_source.set_index( + {output_dim: mapping_config.dims} + ).unstack(output_dim) + else: + raise NotImplementedError(method_name) + + # 3. Finally, we remove any variables that were derived from the input + # dataset + if input_config.derived_variables is not None: + derived_variables = input_config.derived_variables.keys() + ds_source = ds_source.drop_vars(derived_variables) + + # 4. Remove chunking information so that we can save the dataset with a + # new chunking + for var in ds_source.data_vars: + if "chunks" in ds_source[var].encoding: + del ds_source[var].encoding["chunks"] + + input_datasets[input_name] = ds_source + + return input_datasets + + +def _parse_string_to_dict(input_string, value_type=int): + """ + Parses a comma-separated key-value string into a dictionary. + The format is 'key=value,key2=value2'. Empty values and multiple values for the same key are not allowed. + + Parameters + ---------- + input_string : str + The input string to parse. It should be in the format 'key=value,key2=value2'. + value_type : type + The type to which the values should be converted. Default is int. + + Returns + ------- + dict + A dictionary with keys and values parsed from the input string. + + Raises + ------ + ValueError: If the input string is not in the correct format. + TypeError: If the value cannot be converted to the specified type. + KeyError: If a key appears more than once in the input string. + """ + + result = {} + + for item in input_string.split(","): + key_value_pair = item.strip().split("=") + if len(key_value_pair) != 2: + raise ValueError( + "Invalid format. Each key-value pair must be separated by '=' and the pair must be separated by ','." + ) + + key, value = key_value_pair + if key in result: + raise KeyError("Duplicate keys are not allowed.") + + result[key] = value_type(value) + + return result + + +def main(argv=None): + import argparse + + parser = argparse.ArgumentParser( + description="Recreate the input datasets from a zarr file created by create_dataset_zarr" + ) + parser.add_argument( + "zarr_dataset_path", + type=str, + help="The path to the zarr file to recreate the input datasets from", + ) + parser.add_argument( + "--output-path-format", + default="{input_name}.zarr", + type=str, + help="The format string for the output path. The input name will be replaced with the input dataset name", + ) + parser.add_argument( + "--chunks", + type=_parse_string_to_dict, + default={}, + help="The chunks to use for the output datasets. The format is" + "'key=value,key2=value2'. I.e. to use chunksize 1 along the time" + "dimension and 100 along the x-dimension, use `--chunks time=1,x=100`", + ) + parser.add_argument( + "--only-selected-inputs", + nargs="*", + default=None, + help="If provided, only the input datasets with these names will be recreated. " + "If not provided, all input datasets will be recreated.", + ) + + args = parser.parse_args(argv) + + ds = xr.open_zarr(args.zarr_dataset_path) + input_datasets = recreate_inputs(ds=ds) + if args.only_selected_inputs is not None: + missing_inputs = set(args.only_selected_inputs) - set(input_datasets.keys()) + if missing_inputs: + raise ValueError( + f"The following input datasets were not found in the zarr file: {missing_inputs}. " + f"The available input datasets are: {list(input_datasets.keys())})" + ) + input_datasets = { + k: v for k, v in input_datasets.items() if k in args.only_selected_inputs + } + + for input_name, ds_input in input_datasets.items(): + output_path = args.output_path_format.format(input_name=input_name) + logger.info( + f"Saving input dataset {input_name} to {output_path} with chunks={args.chunks}" + ) + ds_input.chunk(args.chunks).to_zarr(output_path, mode="w", consolidated=True) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 327f216..4b05742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,5 +54,11 @@ dev = [ "ipdb>=0.13.13", "pre-commit>=3.7.1", ] + +[dependency-groups] +dev = [ + "ipdb>=0.13.13", + "pytest>=8.3.5", +] [project.scripts] mllam_data_prep = "mllam_data_prep:cli.call" diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 0cc2045..b5832dd 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -5,25 +5,41 @@ import mllam_data_prep as mdp -def test_danra_example_inverse(): +def _test(): fp_config = Path(__file__).parent.parent / "example.danra.yaml" config: mdp.Config = mdp.Config.from_yaml_file(fp_config) ds_transformed = mdp.create_dataset(config=config) - input_datasets_inverted = mdp.recreate_inputs(config=config, ds=ds_transformed) + for recreation_config in [None, config]: + input_datasets_inverted = mdp.recreate_inputs( + config=recreation_config, ds=ds_transformed + ) + + for input_name, input_config in config.inputs.items(): + ds_input = xr.open_dataset(input_config.path) + ds_input_inverted = input_datasets_inverted[input_name] + + # the config may have performed subsetting (i.e. ds.sel) so we will + # find coordinate ranges in `ds_input_inverted` and subset + # `ds_input` to match. This allows us to check that the values are the same + # for each coordinate in `ds_input_inverted`, check if it is present in `ds_input` + for dim in ds_input_inverted.dims.keys(): + coord_values = ds_input_inverted.coords[dim].values + if dim in ds_input.coords: + ds_input = ds_input.sel({dim: coord_values}) + + # check that the variables in `ds_input_inverted` are present in `ds_input` + for var in ds_input_inverted.data_vars: + assert var in ds_input.data_vars + # and check that the values are the same + da_orig = ds_input[var] + da_inverted = ds_input_inverted[var].transpose(*da_orig.dims) + xr.testing.assert_equal(da_orig.coords, da_inverted.coords) - for input_name, input_config in config.inputs.items(): - ds_input = xr.open_dataset(input_config.path) - ds_input_inverted = input_datasets_inverted[input_name] - # find coordinate ranges in `ds_input_inverted` and subset `ds_input` to match - for dim, coord in ds_input_inverted.coords.items(): - if dim in ds_input.coords: - ds_input = ds_input.sel({dim: coord}) +def test_danra_example_inverse(): + import ipdb - # check that the variables in `ds_input_inverted` are present in `ds_input` - for var in ds_input_inverted.data_vars: - assert var in ds_input.data_vars - # and check that the values are the same - xr.testing.assert_equal(ds_input[var], ds_input_inverted[var]) + with ipdb.launch_ipdb_on_exception(): + _test() From a246f2880664eed82a18fa2e5b2b2fdd549295c7 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 08:57:43 +0200 Subject: [PATCH 03/14] skip missing output targets in dataset to invert --- mllam_data_prep/recreate_inputs.py | 31 +++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index ca847cf..0ce1c8b 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -85,7 +85,9 @@ def _split_coord_values_as_variables( def recreate_inputs( - ds: xr.Dataset, config: Optional[Config] = None + ds: xr.Dataset, + config: Optional[Config] = None, + only_selected_inputs: Optional[list[str]] = None, ) -> dict[str, xr.Dataset]: """ Recreate the input datasets from a zarr file created by @@ -98,6 +100,9 @@ def recreate_inputs( config: Config, optional The configuration object defining the input datasets and how to map them to the output dataset. If not provided, the config will be read from the dataset attributes. + only_selected_inputs : list[str], optional + If provided, only the input datasets with these names will be recreated. + If not provided, all input datasets will be recreated. Returns ------- @@ -109,8 +114,18 @@ def recreate_inputs( if config is None: config = Config.from_yaml(ds.creation_config) - for input_name, input_config in config.inputs.items(): + if only_selected_inputs is None: + only_selected_inputs = list(config.inputs.keys()) + + for input_name in only_selected_inputs: + input_config = config.inputs[input_name] dim_mapping = input_config.dim_mapping + if input_config.target_output_variable not in ds: + logger.warning( + f"Target output variable {input_config.target_output_variable} " + f"for input dataset {input_name} not found in dataset, skipping" + ) + continue da_target = ds[input_config.target_output_variable] # 1. First, we need to split out the coordinate that was used to stack @@ -211,6 +226,7 @@ def _parse_string_to_dict(input_string, value_type=int): return result +@logger.catch(reraise=True) def main(argv=None): import argparse @@ -222,6 +238,13 @@ def main(argv=None): type=str, help="The path to the zarr file to recreate the input datasets from", ) + parser.add_argument( + "--config-path", + type=str, + default=None, + help="The path to the configuration file that was used to create the dataset. " + "If not provided, the config will be read from the dataset attributes.", + ) parser.add_argument( "--output-path-format", default="{input_name}.zarr", @@ -246,8 +269,10 @@ def main(argv=None): args = parser.parse_args(argv) + config = Config.from_yaml_file(args.config_path) if args.config_path else None + ds = xr.open_zarr(args.zarr_dataset_path) - input_datasets = recreate_inputs(ds=ds) + input_datasets = recreate_inputs(ds=ds, config=config) if args.only_selected_inputs is not None: missing_inputs = set(args.only_selected_inputs) - set(input_datasets.keys()) if missing_inputs: From 83c0ecf3a338159ae461a5045caca57060b75be1 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 09:02:55 +0200 Subject: [PATCH 04/14] actually select specific inputs --- mllam_data_prep/recreate_inputs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index 0ce1c8b..5381641 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -272,7 +272,9 @@ def main(argv=None): config = Config.from_yaml_file(args.config_path) if args.config_path else None ds = xr.open_zarr(args.zarr_dataset_path) - input_datasets = recreate_inputs(ds=ds, config=config) + input_datasets = recreate_inputs( + ds=ds, config=config, only_selected_inputs=args.only_selected_inputs + ) if args.only_selected_inputs is not None: missing_inputs = set(args.only_selected_inputs) - set(input_datasets.keys()) if missing_inputs: From d3cc028a29a6fa4d2f9b1c90eb0a55274baf431b Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 13:38:29 +0200 Subject: [PATCH 05/14] use cf_xarray.encoding for MultiIndex instead of dropping --- mllam_data_prep/ops/mapping.py | 9 ++++++++- pyproject.toml | 1 + 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/mllam_data_prep/ops/mapping.py b/mllam_data_prep/ops/mapping.py index 9482ff8..838dcb0 100644 --- a/mllam_data_prep/ops/mapping.py +++ b/mllam_data_prep/ops/mapping.py @@ -1,3 +1,5 @@ +import cf_xarray as cfxr + from .stacking import stack_variables_as_coord_values, stack_variables_by_coord_values @@ -100,7 +102,12 @@ def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims): # in the input dataset that we want to stack to create the architecture # dimension, this is for example used for flatting the spatial dimensions # into a single dimension representing the grid index - ds = ds.stack({arch_dim: source_dims}).reset_index(arch_dim) + ds = ds.stack({arch_dim: source_dims}) + # rather than .reset_index(arch_dim) here to remove the MultiIndex + # (which we previously did, since MultiIndexes can't be serialised + # to netcdf/zarr) we use cf_xarrays cf-compliant encoding/decoding + # here: + ds = cfxr.encode_multi_index_as_compress(ds, idxnames=arch_dim) else: raise NotImplementedError(method) diff --git a/pyproject.toml b/pyproject.toml index cb0bb4a..2de6346 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "packaging>=23.1", "deepdiff>=8.2.0", "parse>=1.20.2", + "cf-xarray>=0.9.4", ] requires-python = ">=3.9" readme = "README.md" From 42aea1f1da8926427235abecc7779cf187962f2c Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 14:19:55 +0200 Subject: [PATCH 06/14] encode MultiIndex in create_dataset --- mllam_data_prep/create_dataset.py | 13 +++++++++++++ mllam_data_prep/ops/mapping.py | 7 ------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 95bd486..3a9a13b 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Optional, Union +import cf_xarray as cfxr import numpy as np import xarray as xr import yaml @@ -294,6 +295,18 @@ def create_dataset(config: Config): ) ds["splits"] = da_splits + # We have to deal with the fact that MultiIndex objects (this would + # commonly before example `grid_index` created by stacking the `x` and `y` + # coordinates) can't be written to netcdf/zarr. In cf_xarray this has been + # handled in a cf-compliant manner using so-called "compression by + # gathering" (see + # https://cf-xarray.readthedocs.io/en/latest/generated/cf_xarray.encode_multi_index_as_compress.html#cf_xarray.encode_multi_index_as_compress). + # which allows us to safely roundtrip MultiIndexes through netcdf/zarr, + # using their encode and decode functions. + for idx in ds.indexes: + if isinstance(ds.indexes[idx], xr.MultiIndex): + ds = cfxr.encode_multi_index_as_compress(ds, idxnames=idx) + ds.attrs = {} ds.attrs["schema_version"] = config.schema_version ds.attrs["dataset_version"] = config.dataset_version diff --git a/mllam_data_prep/ops/mapping.py b/mllam_data_prep/ops/mapping.py index 838dcb0..f453699 100644 --- a/mllam_data_prep/ops/mapping.py +++ b/mllam_data_prep/ops/mapping.py @@ -1,5 +1,3 @@ -import cf_xarray as cfxr - from .stacking import stack_variables_as_coord_values, stack_variables_by_coord_values @@ -103,11 +101,6 @@ def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims): # dimension, this is for example used for flatting the spatial dimensions # into a single dimension representing the grid index ds = ds.stack({arch_dim: source_dims}) - # rather than .reset_index(arch_dim) here to remove the MultiIndex - # (which we previously did, since MultiIndexes can't be serialised - # to netcdf/zarr) we use cf_xarrays cf-compliant encoding/decoding - # here: - ds = cfxr.encode_multi_index_as_compress(ds, idxnames=arch_dim) else: raise NotImplementedError(method) From da1cc2407a2b10b1107f754ea08c0d07d6d85d23 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 14:23:51 +0200 Subject: [PATCH 07/14] use pd.MultiIndex not xr.MultiIndex --- mllam_data_prep/create_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllam_data_prep/create_dataset.py b/mllam_data_prep/create_dataset.py index 3a9a13b..d44570d 100644 --- a/mllam_data_prep/create_dataset.py +++ b/mllam_data_prep/create_dataset.py @@ -6,6 +6,7 @@ import cf_xarray as cfxr import numpy as np +import pandas as pd import xarray as xr import yaml import zarr @@ -304,7 +305,7 @@ def create_dataset(config: Config): # which allows us to safely roundtrip MultiIndexes through netcdf/zarr, # using their encode and decode functions. for idx in ds.indexes: - if isinstance(ds.indexes[idx], xr.MultiIndex): + if isinstance(ds.indexes[idx], pd.MultiIndex): ds = cfxr.encode_multi_index_as_compress(ds, idxnames=idx) ds.attrs = {} From 12a9a9f203dff3657e708f91201f4ce7abd6c25d Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 24 Sep 2025 15:36:36 +0200 Subject: [PATCH 08/14] use cf_xarray decode --- mllam_data_prep/recreate_inputs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index 5381641..7681393 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -1,5 +1,6 @@ from typing import Optional +import cf_xarray as cfxr import parse import xarray as xr from loguru import logger @@ -161,9 +162,9 @@ def recreate_inputs( ds_source = ds_source.rename({output_dim: mapping_config.dim}) elif method_name == "stack": # unstack the stacked dimension - ds_source = ds_source.set_index( - {output_dim: mapping_config.dims} - ).unstack(output_dim) + ds_source = cfxr.decode_multi_index_as_compression( + ds_source, output_dim + ) else: raise NotImplementedError(method_name) From 6c185cf5f77ee8d2f34d7ddf74cee8f365ba2368 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Sep 2025 09:10:34 +0200 Subject: [PATCH 09/14] fix typo --- mllam_data_prep/recreate_inputs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index 7681393..fea723c 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -162,8 +162,8 @@ def recreate_inputs( ds_source = ds_source.rename({output_dim: mapping_config.dim}) elif method_name == "stack": # unstack the stacked dimension - ds_source = cfxr.decode_multi_index_as_compression( - ds_source, output_dim + ds_source = cfxr.decode_compress_to_multi_index( + ds=ds_source, idxnames=output_dim ) else: raise NotImplementedError(method_name) From 4adc28d953772d8454128fd1356198145f026c3d Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Sep 2025 14:13:10 +0200 Subject: [PATCH 10/14] final fixes to MultiIndex decode+unstack --- mllam_data_prep/recreate_inputs.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index fea723c..b5f570a 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -162,9 +162,25 @@ def recreate_inputs( ds_source = ds_source.rename({output_dim: mapping_config.dim}) elif method_name == "stack": # unstack the stacked dimension + # To allow MultiIndex to zarr/netcdf + # mllam_data_prep.create_dataset encodes these using + # cf-compliant "gather compression" (see + # https://cf-xarray.readthedocs.io/en/latest/coding.html). + # To make sure decoding of these MultiIndex is possible we need + # to ensure that the required stacked coordinates (defined + # through the "compress" attribute) are included in the dataset + compress_attr = ds_source[output_dim].attrs["compress"] + required_coords = compress_attr.split(" ") + for coord in required_coords: + if coord not in ds.coords: + raise ValueError( + f"Cannot unstack dimension {output_dim} as the required " + f"coordinate {coord} is not in the dataset" + ) + ds_source[coord] = ds.coords[coord] ds_source = cfxr.decode_compress_to_multi_index( - ds=ds_source, idxnames=output_dim - ) + ds_source, idxnames=output_dim + ).unstack(output_dim) else: raise NotImplementedError(method_name) From 46ef6dbcdd62190cdc7b5b20e80be0d8c0516e67 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Sep 2025 14:50:58 +0200 Subject: [PATCH 11/14] show default args in cli --- mllam_data_prep/recreate_inputs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index b5f570a..512b15a 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -248,7 +248,8 @@ def main(argv=None): import argparse parser = argparse.ArgumentParser( - description="Recreate the input datasets from a zarr file created by create_dataset_zarr" + description="Recreate the input datasets from a zarr file created by create_dataset_zarr", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "zarr_dataset_path", From d3a6221d29b376bb853977c0d4da0f8a79a11bea Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Sep 2025 15:15:25 +0200 Subject: [PATCH 12/14] set attrs for inverted dataset --- mllam_data_prep/recreate_inputs.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index 512b15a..38a3251 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -1,10 +1,13 @@ +import datetime from typing import Optional import cf_xarray as cfxr +import isodate import parse import xarray as xr from loguru import logger +from . import __version__ from .config import Config from .create_dataset import SOURCE_DATASET_NAME_ATTR @@ -305,6 +308,14 @@ def main(argv=None): } for input_name, ds_input in input_datasets.items(): + ds_input.attrs = {} + ds_input.attrs["recreated_from"] = args.zarr_dataset_path + if config is not None: + ds_input.attrs["recreation_config"] = config.to_yaml() + ds_input.attrs["source_dataset_name"] = input_name + ds_input.attrs["created_by"] = "mllam_data_prep.recreate_inputs" + ds_input.attrs["created_on"] = isodate.isoformat(datetime.datetime.utcnow()) + ds_input.attrs["mdp-version"] = __version__ output_path = args.output_path_format.format(input_name=input_name) logger.info( f"Saving input dataset {input_name} to {output_path} with chunks={args.chunks}" From 5979423370cbddd7223e835ce2c0f4c7008b9b0b Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 25 Sep 2025 15:21:24 +0200 Subject: [PATCH 13/14] use utc time --- mllam_data_prep/recreate_inputs.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mllam_data_prep/recreate_inputs.py b/mllam_data_prep/recreate_inputs.py index 38a3251..c9a0d02 100644 --- a/mllam_data_prep/recreate_inputs.py +++ b/mllam_data_prep/recreate_inputs.py @@ -2,7 +2,6 @@ from typing import Optional import cf_xarray as cfxr -import isodate import parse import xarray as xr from loguru import logger @@ -314,7 +313,9 @@ def main(argv=None): ds_input.attrs["recreation_config"] = config.to_yaml() ds_input.attrs["source_dataset_name"] = input_name ds_input.attrs["created_by"] = "mllam_data_prep.recreate_inputs" - ds_input.attrs["created_on"] = isodate.isoformat(datetime.datetime.utcnow()) + ds_input.attrs["created_on"] = datetime.datetime.now( + datetime.timezone.utc + ).isoformat() ds_input.attrs["mdp-version"] = __version__ output_path = args.output_path_format.format(input_name=input_name) logger.info( From 23b2bd1eb5077e1d45203a3fe1065ec9c397b6e3 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 25 Nov 2025 13:36:58 +0100 Subject: [PATCH 14/14] remove debug statement --- tests/test_inverse.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index b5832dd..eb1e8d4 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -36,10 +36,3 @@ def _test(): da_orig = ds_input[var] da_inverted = ds_input_inverted[var].transpose(*da_orig.dims) xr.testing.assert_equal(da_orig.coords, da_inverted.coords) - - -def test_danra_example_inverse(): - import ipdb - - with ipdb.launch_ipdb_on_exception(): - _test()