diff --git a/kerchunk/grib2.py b/kerchunk/grib2.py index 2b84ceb7..a8126b4c 100644 --- a/kerchunk/grib2.py +++ b/kerchunk/grib2.py @@ -1,12 +1,15 @@ import base64 import copy import io +import os import logging from collections import defaultdict import warnings -from typing import Iterable, List, Dict, Set, TYPE_CHECKING, Optional, Callable +from enum import unique, Enum +from typing import Iterable, List, Dict, Set, TYPE_CHECKING, Optional, Callable, Any import ujson import itertools +import re if TYPE_CHECKING: import pandas as pd @@ -21,6 +24,25 @@ from kerchunk.codecs import GRIBCodec from kerchunk.combine import MultiZarrToZarr, drop +COORD_DIM_MAPPING: dict[str, str] = dict( + time="run_times", + valid_time="valid_times", + step="model_horizons", +) + + +@unique +class AggregationType(Enum): + """ + ENUM for aggregation types + TODO is this useful elsewhere? + """ + + HORIZON = "horizon" + VALID_TIME = "valid_time" + RUN_TIME = "run_time" + BEST_AVAILABLE = "best_available" + try: import cfgrib @@ -654,6 +676,144 @@ def parse_grib_idx( return result.set_index("idx") +def repeat_steps(step_index: "pd.TimedeltaIndex", to_length: int) -> np.array: + return np.tile(step_index.to_numpy(), int(np.ceil(to_length / len(step_index))))[ + :to_length + ] + + +def create_steps(steps_index: "pd.Index", to_length) -> np.array: + return np.vstack([repeat_steps(si, to_length) for si in steps_index]) + + +def store_coord_var(key: str, zstore: dict, coords: tuple[str, ...], data: np.array): + if np.isnan(data).any(): + if f"{key}/.zarray" not in zstore: + logger.debug("Skipping nan coordinate with no variable %s", key) + return + else: + logger.info("Trying to add coordinate var %s with nan value!", key) + + zattrs = ujson.loads(zstore[f"{key}/.zattrs"]) + zarray = ujson.loads(zstore[f"{key}/.zarray"]) + # Use list not tuple + zarray["chunks"] = [*data.shape] + zarray["shape"] = [*data.shape] + zattrs["_ARRAY_DIMENSIONS"] = [ + COORD_DIM_MAPPING[v] if v in COORD_DIM_MAPPING else v for v in coords + ] + + zstore[f"{key}/.zarray"] = ujson.dumps(zarray) + zstore[f"{key}/.zattrs"] = ujson.dumps(zattrs) + + vkey = ".".join(["0" for _ in coords]) + data_bytes = data.tobytes() + try: + enocded_val = data_bytes.decode("ascii") + except UnicodeDecodeError: + enocded_val = (b"base64:" + base64.b64encode(data_bytes)).decode("ascii") + zstore[f"{key}/{vkey}"] = enocded_val + + +def store_data_var( + key: str, + zstore: dict, + dims: dict[str, int], + coords: dict[str, tuple[str, ...]], + data: "pd.DataFrame", + steps: np.array, + times: np.array, + lvals: Optional[np.array], +): + import pandas as pd + + zattrs = ujson.loads(zstore[f"{key}/.zattrs"]) + zarray = ujson.loads(zstore[f"{key}/.zarray"]) + + dcoords = coords["datavar"] + + # The lat/lon y/x coordinates are always the last two + lat_lon_dims = { + k: v for k, v in zip(zattrs["_ARRAY_DIMENSIONS"][-2:], zarray["shape"][-2:]) + } + full_coords = dcoords + tuple(lat_lon_dims.keys()) + full_dims = dict(**dims, **lat_lon_dims) + + # all chunk dimensions are 1 except for lat/lon or x/y + zarray["chunks"] = [ + 1 if c not in lat_lon_dims else lat_lon_dims[c] for c in full_coords + ] + zarray["shape"] = [full_dims[k] for k in full_coords] + if zarray["fill_value"] is None: + # Check dtype first? + zarray["fill_value"] = np.nan + + zattrs["_ARRAY_DIMENSIONS"] = [ + COORD_DIM_MAPPING[v] if v in COORD_DIM_MAPPING else v for v in full_coords + ] + + zstore[f"{key}/.zarray"] = ujson.dumps(zarray) + zstore[f"{key}/.zattrs"] = ujson.dumps(zattrs) + + idata = data.set_index(["time", "step", "level"]).sort_index() + + for idx in itertools.product(*[range(dims[k]) for k in dcoords]): + # Build an iterator over each of the single dimension chunks + # TODO Replace this with a reindex operation and iterate the result + # if the .loc call is slow inside the loop + dim_idx = {k: v for k, v in zip(dcoords, idx)} + + iloc: tuple[Any, ...] = ( + times[tuple([dim_idx[k] for k in coords["time"]])], + steps[tuple([dim_idx[k] for k in coords["step"]])], + ) + if lvals is not None: + iloc = iloc + (lvals[idx[-1]],) # type:ignore[assignment] + + try: + # Squeeze if needed to get a series. Noop if already a series Df has multiple rows + dval = idata.loc[iloc].squeeze() + except KeyError: + logger.info(f"Error getting vals {iloc} for in path {key}") + continue + + assert isinstance( + dval, pd.Series + ), f"Got multiple values for iloc {iloc} in key {key}: {dval}" + + if pd.isna(dval.inline_value): + # List of [URI(Str), offset(Int), length(Int)] using python (not numpy) types. + record = [dval.uri, dval.offset.item(), dval.length.item()] + else: + record = dval.inline_value + # lat/lon y/x have only the zero chunk + vkey = ".".join([str(v) for v in (idx + (0, 0))]) + zstore[f"{key}/{vkey}"] = record + + +def strip_datavar_chunks( + kerchunk_store: dict, keep_list: tuple[str, ...] = ("latitude", "longitude") +) -> None: + """ + Modify in place a kerchunk reference store to strip the kerchunk references + for variables not in the keep list. + + :param kerchunk_store: a kerchunk ref spec store + :param keep_list: the list of variables to keep references + """ + zarr_store = kerchunk_store["refs"] + + zchunk_matcher = re.compile(r"^(?P.*)\/(?P\d+[\.\d+]*)$") + for key in list(zarr_store.keys()): + matched = zchunk_matcher.match(key) + if matched: + logger.debug("Matched! %s", matched) + if any([matched.group("name").endswith(keeper) for keeper in keep_list]): + logger.debug("Skipping key %s", matched.group("name")) + continue + del zarr_store[key] + + def build_path(path: Iterable[str | None], suffix: Optional[str] = None) -> str: """ Returns the path to access the values in a zarr store without a leading "/" @@ -975,3 +1135,257 @@ def build_idx_grib_mapping( ) return result + + +def map_from_index( + run_time: "pd.Timestamp", + mapping: "pd.DataFrame", + idxdf: "pd.DataFrame", + raw_merged: bool = False, +) -> "pd.DataFrame": + """ + Main method used for building index dataframes from parsed IDX files + merged with the correct mapping for the horizon + + Parameters + ---------- + + run_time : pd.Timestamp + the run time timestamp of the idx data + mapping : pd.DataFrame + the mapping data derived from comparing the idx attributes to the + CFGrib attributes for a given horizon + idxdf : pd.DataFrame + the dataframe of offsets and lengths for each grib message and its + attributes derived from an idx file + raw_merged : bool + Used for debugging to see all the columns in the merge. By default, + it returns the index columns with the corrected time values plus + the index metadata + + Returns + ------- + + pd.Dataframe : the index dataframe that will be used to read variable data from the grib file + """ + + idxdf = idxdf.reset_index().set_index("attrs") + mapping = mapping.reset_index().set_index("attrs") + mapping.drop(columns="uri", inplace=True) # Drop the URI column from the mapping + + if not idxdf.index.is_unique: + raise ValueError("Parsed idx data must have unique attrs to merge on!") + + if not mapping.index.is_unique: + raise ValueError("Mapping data must have unique attrs to merge on!") + + # Merge the offset and length from the idx file with the varname, step and level from the mapping + + result = idxdf.merge(mapping, on="attrs", how="left", suffixes=("", "_mapping")) + + if raw_merged: + return result + else: + # Get the grib_uri column from the idxdf and ignore the uri column from the mapping + # We want the offset, length and uri of the index file with the varname, step and level of the mapping + selected_results = result.rename(columns=dict(grib_uri="uri"))[ + [ + "varname", + "typeOfLevel", + "stepType", + "name", + "step", + "level", + "time", + "valid_time", + "uri", + "offset", + "length", + "inline_value", + ] + ] + # Drop the inline values from the mapping data + selected_results.loc[:, "inline_value"] = None + selected_results.loc[:, "time"] = run_time + selected_results.loc[:, "valid_time"] = ( + selected_results.time + selected_results.step + ) + logger.info("Dropping %d nan varnames", selected_results.varname.isna().sum()) + selected_results = selected_results.loc[~selected_results.varname.isna(), :] + return selected_results.reset_index(drop=True) + + +def reinflate_grib_store( + axes: list["pd.Index"], + aggregation_type: AggregationType, + chunk_index: "pd.DataFrame", + zarr_ref_store: dict, +) -> dict: + """ + Given a zarr_store hierarchy, pull out the variables present in the + chunks dataframe and reinflate the zarr variables adding any needed + dimensions. This is a select operation - based on the time axis provided. + Assumes everything is stored in hours per grib convention. + # TODO finish & validate valid_time, run_time & best_available aggregation modes + + :param axes: a list of new axes for aggregation + :param aggregation_type: the type of fmrc aggregation + :param chunk_index: a dataframe containing the kerchunk index + :param zarr_ref_store: the deflated (chunks removed) zarr store + :return: the inflated zarr store + """ + # Make a deep copy so we don't modify the input + zstore = copy.deepcopy(zarr_ref_store["refs"]) + + axes_by_name: dict[str, pd.Index] = {pdi.name: pdi for pdi in axes} + # Validate axis names + time_dims: dict[str, int] = {} + time_coords: dict[str, tuple[str, ...]] = {} + # TODO: add a data class or other method of typing and validating the variables created in this if block + if aggregation_type == AggregationType.HORIZON: + # Use index length horizons containing timedelta ranges for the set of steps + time_dims["step"] = len(axes_by_name["step"]) + time_dims["valid_time"] = len(axes_by_name["valid_time"]) + + time_coords["step"] = ("step", "valid_time") + time_coords["valid_time"] = ("step", "valid_time") + time_coords["time"] = ("step", "valid_time") + time_coords["datavar"] = ("step", "valid_time") + + steps = create_steps(axes_by_name["step"], time_dims["valid_time"]) + valid_times = np.tile( + axes_by_name["valid_time"].to_numpy(), (time_dims["step"], 1) + ) + times = valid_times - steps + + elif aggregation_type == AggregationType.VALID_TIME: + # Provide an index of steps and an index of valid times + time_dims["step"] = len(axes_by_name["step"]) + time_dims["valid_time"] = len(axes_by_name["valid_time"]) + + time_coords["step"] = ("step",) + time_coords["valid_time"] = ("valid_time",) + time_coords["time"] = ("valid_time", "step") + time_coords["datavar"] = ("valid_time", "step") + + steps = axes_by_name["step"].to_numpy() + valid_times = axes_by_name["valid_time"].to_numpy() + + steps2d = np.tile(axes_by_name["step"], (time_dims["valid_time"], 1)) + valid_times2d = np.tile( + np.reshape(axes_by_name["valid_time"], (-1, 1)), (1, time_dims["step"]) + ) + times = valid_times2d - steps2d + + elif aggregation_type == AggregationType.RUN_TIME: + # Provide an index of steps and an index of run times. + time_dims["step"] = len(axes_by_name["step"]) + time_dims["time"] = len(axes_by_name["time"]) + + time_coords["step"] = ("step",) + time_coords["valid_time"] = ("time", "step") + time_coords["time"] = ("time",) + time_coords["datavar"] = ("time", "step") + + steps = axes_by_name["step"].to_numpy() + times = axes_by_name["time"].to_numpy() + + # The valid times will be runtimes by steps + steps2d = np.tile(axes_by_name["step"], (time_dims["time"], 1)) + times2d = np.tile( + np.reshape(axes_by_name["time"], (-1, 1)), (1, time_dims["step"]) + ) + valid_times = times2d + steps2d + + elif aggregation_type == AggregationType.BEST_AVAILABLE: + time_dims["valid_time"] = len(axes_by_name["valid_time"]) + assert ( + len(axes_by_name["time"]) == 1 + ), "The time axes must describe a single 'as of' date for best available" + reference_time = axes_by_name["time"].to_numpy()[0] + + time_coords["step"] = ("valid_time",) + time_coords["valid_time"] = ("valid_time",) + time_coords["time"] = ("valid_time",) + time_coords["datavar"] = ("valid_time",) + + valid_times = axes_by_name["valid_time"].to_numpy() + times = np.where(valid_times <= reference_time, valid_times, reference_time) + steps = valid_times - times + else: + raise RuntimeError(f"Invalid aggregation_type argument: {aggregation_type}") + + # Copy all the groups that contain variables in the chunk dataset + unique_groups = chunk_index.set_index( + ["varname", "stepType", "typeOfLevel"] + ).index.unique() + + # Drop keys not in the unique groups + for key in list(zstore.keys()): + # Separate the key as a path keeping only: varname, stepType and typeOfLevel + # Treat root keys like ".zgroup" as special and return an empty tuple + lookup = tuple( + [val for val in os.path.dirname(key).split("/")[:3] if val != ""] + ) + if lookup not in unique_groups: + del zstore[key] + + # Now update the zstore for each variable. + for key, group in chunk_index.groupby(["varname", "stepType", "typeOfLevel"]): + base_path = "/".join(key) + lvals = group.level.unique() + dims = time_dims.copy() + coords = time_coords.copy() + if len(lvals) == 1: + lvals = lvals.squeeze() + dims[key[2]] = 0 + elif len(lvals) > 1: + lvals = np.sort(lvals) + # multipel levels + dims[key[2]] = len(lvals) + coords["datavar"] += (key[2],) + else: + raise ValueError("") + + # Convert to floating point seconds + # td.astype("timedelta64[s]").astype(float) / 3600 # Convert to floating point hours + store_coord_var( + key=f"{base_path}/time", + zstore=zstore, + coords=time_coords["time"], + data=times.astype("datetime64[s]"), + ) + + store_coord_var( + key=f"{base_path}/valid_time", + zstore=zstore, + coords=time_coords["valid_time"], + data=valid_times.astype("datetime64[s]"), + ) + + store_coord_var( + key=f"{base_path}/step", + zstore=zstore, + coords=time_coords["step"], + data=steps.astype("timedelta64[s]").astype("float64") / 3600.0, + ) + + store_coord_var( + key=f"{base_path}/{key[2]}", + zstore=zstore, + coords=(key[2],) if lvals.shape else (), + data=lvals, # all grib levels are floats + ) + + store_data_var( + key=f"{base_path}/{key[0]}", + zstore=zstore, + dims=dims, + coords=coords, + data=group, + steps=steps, + times=times, + lvals=lvals if lvals.shape else None, + ) + + return dict(refs=zstore, version=1)