Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixes

- fix bug which adds unwanted dimensions to the dataset [\#60](https://github.com/mllam/mllam-data-prep/pull/60), @ealerskans, @observingClouds
- correct chunk size estimate [\#59](https://github.com/mllam/mllam-data-prep/pull/59), @ealerskans
- fix bug arising when variables provided to derived functions are renamed [\#56](https://github.com/mllam/mllam-data-prep/pull/56), @leifdenby
- ensure config fields defaulting to `None` are typed as `Optional` and fields defaulting to `{}` are given a default-factory so that serialization with default values works correctly [\#63](https://github.com/mllam/mllam-data-prep/pull/63), @leifdenby
Expand Down
5 changes: 2 additions & 3 deletions mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,9 @@ def create_dataset(config: Config):
if input_config.coord_ranges is not None:
ds_input = selection.select_by_kwargs(ds_input, **input_config.coord_ranges)

# Initialize the output dataset and add dimensions
# Initialize the output dataset
ds = xr.Dataset()
ds.attrs.update(ds_input.attrs)
for dim in ds_input.dims:
ds = ds.assign_coords({dim: ds_input.coords[dim]})

if selected_variables:
logger.info(f"Extracting selected variables from dataset {dataset_name}")
Expand All @@ -190,6 +188,7 @@ def create_dataset(config: Config):
ds=ds_input,
derived_variable=derived_variable,
chunking=chunking_config,
target_dims=expected_input_var_dims,
)

_check_dataset_attributes(
Expand Down
7 changes: 4 additions & 3 deletions mllam_data_prep/ops/derive_variable/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
REQUIRED_FIELD_ATTRIBUTES = ["units", "long_name"]


def derive_variable(ds, derived_variable, chunking):
def derive_variable(ds, derived_variable, chunking, target_dims):
"""
Derive a variable using the `function` and `kwargs` of `derived_variable`.

Expand All @@ -33,15 +33,16 @@ def derive_variable(ds, derived_variable, chunking):
chunking: Dict[str, int]
Dictionary with keys as the dimensions to chunk along and values
with the chunk size
target_dims: List[str]
List of dims from ds to broadcast derived variable to,
if not used in calculation

Returns
-------
xr.Dataset
Dataset with derived variables included
"""

target_dims = list(ds.sizes.keys())

function_namespace = derived_variable.function
expected_field_attributes = derived_variable.attrs

Expand Down
280 changes: 280 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
"""Tests for the output dataset created by `mllam-data-prep`."""
import pytest
import yaml

import mllam_data_prep as mdp

with open("example.danra.yaml", "r") as file:
BASE_CONFIG = file.read()

HEIGHT_LEVEL_TEST_SECTION = """\
inputs:
danra_height_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/height_levels.zarr
dims: [time, x, y, altitude]
variables:
u:
altitude:
values: [100, 50,]
units: m
v:
altitude:
values: [100, 50, ]
units: m
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
dims: [altitude]
name_format: "{var_name}{altitude}m"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state
"""

PRESSURE_LEVEL_TEST_SECTION = """\
inputs:
danra_pressure_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/pressure_levels.zarr
dims: [time, x, y, pressure]
variables:
u:
pressure:
values: [1000,]
units: hPa
v:
pressure:
values: [1000, ]
units: hPa
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
dims: [pressure]
name_format: "{var_name}{pressure}m"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state
"""

SINGLE_LEVEL_SELECTED_VARIABLES_TEST_SECTION = """\
inputs:
danra_single_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [time, x, y]
variables:
- t2m
- pres_seasurface
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
name_format: "{var_name}"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state
"""

SINGLE_LEVEL_DERIVED_VARIABLES_TEST_SECTION = """\
inputs:
danra_single_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/single_levels.zarr
dims: [time, x, y]
derived_variables:
# derive variables to be used as forcings
toa_radiation:
kwargs:
time: ds_input.time
lat: ds_input.lat
lon: ds_input.lon
function: mllam_data_prep.ops.derive_variable.physical_field.calculate_toa_radiation
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
name_format: "{var_name}"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state
"""

INVALID_PRESSURE_LEVEL_TEST_SECTION = """\
inputs:
danra_pressure_levels:
path: https://object-store.os-api.cci1.ecmwf.int/mllam-testdata/danra_cropped/v0.2.0/pressure_levels.zarr
dims: [time, x, y, pressure]
variables:
z:
pressure:
values: [1000,]
units: hPa
t:
pressure:
values: [800, ]
units: hPa
dim_mapping:
time:
method: rename
dim: time
state_feature:
method: stack_variables_by_var_name
dims: [pressure]
name_format: "{var_name}{pressure}m"
grid_index:
method: stack
dims: [x, y]
target_output_variable: state
"""


def update_config(config: str, update: str):
"""
Update provided config.

Parameters
----------
config: str
String with config in yaml format
update: str
String with the update in yaml format

Returns
-------
config: Config
Updated config
"""
original_config = mdp.Config.from_yaml(config)
update = yaml.safe_load(update)
modified_config = original_config.to_dict()
modified_config.update(update)
modified_config = mdp.Config.from_dict(modified_config)

return modified_config


@pytest.mark.parametrize(
"base_config, new_inputs_section",
[
(BASE_CONFIG, "{}"), # Does not modify the example config
(BASE_CONFIG, PRESSURE_LEVEL_TEST_SECTION),
(BASE_CONFIG, HEIGHT_LEVEL_TEST_SECTION),
(BASE_CONFIG, SINGLE_LEVEL_SELECTED_VARIABLES_TEST_SECTION),
(BASE_CONFIG, SINGLE_LEVEL_DERIVED_VARIABLES_TEST_SECTION),
],
)
def test_selected_output_variables(base_config, new_inputs_section):
"""
Test that the variables specified in each input dataset are
present in the output dataset.
"""
# Modify the example config
config = update_config(base_config, new_inputs_section)

# Create the dataset
ds = mdp.create_dataset(config=config)

# Check that the output variables are the ones selected
for _, input_config in config.inputs.items():
target_output_variable = input_config.target_output_variable

# Get the expected selected variable names
selected_variables = input_config.variables or []
if isinstance(selected_variables, dict):
selected_var_names = list(selected_variables.keys())
elif isinstance(selected_variables, list):
selected_var_names = selected_variables
else:
pytest.fail(
"Expected either 'list' or 'dict' but got"
f" type {type(selected_variables)} for 'variables'."
)

# Get the expected derived variable names
derived_variables = input_config.derived_variables or []
if isinstance(derived_variables, dict):
derived_var_names = list(derived_variables.keys())
elif isinstance(derived_variables, list):
derived_var_names = derived_variables
else:
pytest.fail(
"Expected either 'list' or 'dict' but got"
f" type {type(derived_variables)} for 'derived_variables'."
)

dim_mapping = input_config.dim_mapping[target_output_variable + "_feature"]
dims = dim_mapping.dims or []
name_format = dim_mapping.name_format

if len(dims) == 0:
selected_vars = selected_var_names
derived_vars = derived_var_names
elif len(dims) == 1:
coord = dims[0]
# Stack the variable names by coordinates, as is done in
# mdp.ops.stacking.stack_variables_by_coord_values
selected_vars = []
for var_name in selected_var_names:
coord_values = selected_variables[var_name][coord].values
formatted_var_names = [
name_format.format(var_name=var_name, **{coord: val})
for val in coord_values
]
selected_vars += formatted_var_names
# We currently do not support stacking of variables by coordinates
# for the derived variables
derived_vars = []

expected_variables = selected_vars + derived_vars
output_variables = ds[target_output_variable + "_feature"].values

if set(expected_variables) != set(output_variables):
# Check if there are missing or extra variable
missing_vars = list(set(expected_variables) - set(output_variables))
extra_vars = list(set(output_variables) - set(expected_variables))

error_message = (
f"Expected {expected_variables}, but got {output_variables}."
)
if missing_vars:
error_message += f"\nMissing variables: {missing_vars}"
if extra_vars:
error_message += f"\nExtra variables: {extra_vars}"

pytest.fail(error_message)


@pytest.mark.parametrize(
"base_config, update, expected_result",
[
(
BASE_CONFIG,
"{}",
False,
), # Do not modify the example config - should return False since we're expecting no nans
(
BASE_CONFIG,
INVALID_PRESSURE_LEVEL_TEST_SECTION,
True,
), # Dataset with nans - should return True
],
)
def test_output_dataset_for_nans(base_config, update, expected_result):
"""
Test that the output dataset does not contain any nan values.
"""
config = update_config(base_config, update)
ds = mdp.create_dataset(config=config)
nan_in_ds = any(ds.isnull().any().compute().to_array())
assert nan_in_ds == expected_result