Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- consider full time information for derived calculation of TOA radiation [\#84](https://github.com/mllam/mllam-data-prep/pull/84) @observingClouds

### Fixes
- fix bug where `calc_stats` mutates dataset in loop causing incorrect `diff_std` computation [\#102](https://github.com/mllam/mllam-data-prep/issues/102) @RajdeepKushwaha5
- fix bug where coordinate selection of an unshared dimension isn't applied to subsequent ouput variables when an output variable without this dimension is processed before the others [\#90](https://github.com/mllam/mllam-data-prep/pull/90) @zweihuehner & @leifdenby

## [v0.7.0](https://github.com/mllam/mllam-data-prep/release/tag/v0.7.0)
Expand Down
5 changes: 3 additions & 2 deletions mllam_data_prep/ops/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def calc_stats(
"""
stats = {}
for op_split in statistics_config.ops:
ds_op = ds
try:
pre_op, op = op_split.split("_")
except ValueError:
Expand All @@ -42,11 +43,11 @@ def calc_stats(
if pre_op == "diff":
# subset to select only the variable which have the splitting_dim
vars_to_keep = [v for v in ds.data_vars if splitting_dim in ds[v].dims]
ds = ds[vars_to_keep].diff(dim=splitting_dim)
ds_op = ds[vars_to_keep].diff(dim=splitting_dim)
else:
raise NotImplementedError(pre_op)

fn = getattr(ds, op)
fn = getattr(ds_op, op)
stats[op_split] = fn(dim=statistics_config.dims)

return stats
57 changes: 57 additions & 0 deletions tests/test_calc_stats_fix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Test that calc_stats does not mutate the input dataset across loop iterations.

Regression test for https://github.com/mllam/mllam-data-prep/issues/102
"""

import numpy as np
import xarray as xr

from mllam_data_prep.config import Statistics
from mllam_data_prep.ops.statistics import calc_stats


def test_calc_stats_no_mutation_with_multiple_diff_ops():
"""
When ops=[mean, std, diff_mean, diff_std], diff_std should compute
std(diff(ds)) not std(diff(diff(ds))). The original dataset must not
be mutated between iterations.
"""
data = np.array([1.0, 3.0, 6.0, 10.0, 15.0])
ds = xr.Dataset({"temp": xr.DataArray(data, dims=["time"])})

stats_config = Statistics(
ops=["mean", "std", "diff_mean", "diff_std"], dims=["time"]
)
result = calc_stats(ds, stats_config, splitting_dim="time")

# expected from the original dataset
expected_mean = float(ds["temp"].mean())
expected_std = float(ds["temp"].std())

# expected from a single diff of the original dataset
ds_diff = ds.diff(dim="time") # [2, 3, 4, 5]
expected_diff_mean = float(ds_diff["temp"].mean())
expected_diff_std = float(ds_diff["temp"].std())

assert np.isclose(float(result["mean"]["temp"]), expected_mean)
assert np.isclose(float(result["std"]["temp"]), expected_std)
assert np.isclose(float(result["diff_mean"]["temp"]), expected_diff_mean)
assert np.isclose(float(result["diff_std"]["temp"]), expected_diff_std)


def test_calc_stats_preserves_original_dataset():
"""
The original dataset passed to calc_stats must remain unchanged after
the function returns.
"""
data = np.array([1.0, 3.0, 6.0, 10.0, 15.0])
ds = xr.Dataset({"temp": xr.DataArray(data, dims=["time"])})
ds_copy = ds.copy(deep=True)

stats_config = Statistics(
ops=["mean", "std", "diff_mean", "diff_std"], dims=["time"]
)
calc_stats(ds, stats_config, splitting_dim="time")

xr.testing.assert_identical(ds, ds_copy)
Loading