diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py index 0045efde..dae15c0c 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py @@ -15,6 +15,7 @@ # type: ignore[reportPrivateImportUsage] +from itertools import accumulate from typing import Optional, Any import dask.array as da @@ -26,14 +27,12 @@ class Stack(Joiner, DaskOperation): """ Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] _numpy_counterpart = "join.Stack" - def __init__(self, axis: Optional[int] = None): + def __init__(self, axis: Optional[int] = 0): super().__init__() self.record_initialisation() self.axis = axis @@ -43,14 +42,14 @@ def join(self, sample: tuple[Any, ...]) -> da.Array: return da.stack(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + """Unstacks a stacked sample""" + # move the stacked axis to the zeroth axis and convert to tuple + return tuple(da.moveaxis(sample, self.axis, 0)) class VStack(Joiner, DaskOperation): """ Vertically Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -59,22 +58,24 @@ class VStack(Joiner, DaskOperation): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1])) return da.vstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[0],) + return tuple(sample[start:end] for start, end in zip(start, ends, strict=True)) class HStack(Joiner, DaskOperation): """ Horizontally Stack a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -83,22 +84,24 @@ class HStack(Joiner, DaskOperation): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1])) return da.hstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[1],) + return tuple(sample[:, start:end, ...] for start, end in zip(start, ends, strict=True)) class Concatenate(Joiner, DaskOperation): """ Concatenate a tuple of da.Array's - - Currently cannot undo this operation """ _override_interface = ["Serial"] @@ -108,10 +111,14 @@ def __init__(self, axis: Optional[int] = None): super().__init__() self.record_initialisation() self.axis = axis + self.offsets = None def join(self, sample: tuple[Any, ...]) -> da.Array: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1])) return da.concatenate(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + start = (0,) + self.offsets + ends = self.offsets + (sample.shape[self.axis],) + return tuple(da.take(sample, slice(start, end), self.axis) for start, end in zip(start, ends, strict=True)) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py index 72e519a8..13cf670b 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/numpy/join.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from itertools import accumulate from typing import Optional, Any import numpy as np @@ -23,8 +23,6 @@ class Stack(Joiner): """ Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -40,14 +38,16 @@ def join(self, sample: tuple[Any, ...]) -> np.ndarray: return np.stack(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + """Unstacks a stacked sample""" + # np.stack(..., axis=None) is equivalent to np.stack(..., axis=0) + axis = self.axis if self.axis is not None else 0 + # move the stacked axis to the zeroth axis and convert to tuple + return tuple(np.moveaxis(sample, axis, 0)) class VStack(Joiner): """ Vertically Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -56,22 +56,23 @@ class VStack(Joiner): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None # stores the vertical offset where each joined array is def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + # stores + self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1])) return np.vstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.vsplit(sample, self.offsets)) class HStack(Joiner): """ Horizontally Stack a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -80,22 +81,22 @@ class HStack(Joiner): def __init__(self): super().__init__() self.record_initialisation() + self.offsets = None def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1])) return np.hstack( sample, ) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.hsplit(sample, self.offsets)) class Concatenate(Joiner): """ Concatenate a tuple of np.ndarray's - - Currently cannot undo this operation """ _override_interface = ["Delayed", "Serial"] @@ -105,10 +106,12 @@ def __init__(self, axis: Optional[int] = None): super().__init__() self.record_initialisation() self.axis = axis + self.offsets = None def join(self, sample: tuple[Any, ...]) -> np.ndarray: """Join sample""" + self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1])) return np.concatenate(sample, self.axis) # type: ignore def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + return tuple(np.split(sample, self.offsets, axis=self.axis)) diff --git a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py index c9c17623..7870ece7 100644 --- a/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py +++ b/packages/pipeline/src/pyearthtools/pipeline/operations/xarray/join.py @@ -27,7 +27,7 @@ class Merge(Joiner): """ Merge a tuple of xarray object's. - Currently cannot undo this operation + Currently can only undo this operation with xr.Dataset and xr.DataArray inputs. """ _override_interface = "Serial" @@ -36,13 +36,26 @@ def __init__(self, merge_kwargs: Optional[dict[str, Any]] = None): super().__init__() self.record_initialisation() self._merge_kwargs = merge_kwargs + self._input_structure: list[tuple[Union[str, list[str]], dict]] = [] def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Dataset: """Join sample""" + self._input_structure = [ + (item.name, item.attrs) if isinstance(item, xr.DataArray) else (list(item.data_vars), item.attrs) + for item in sample + ] return xr.merge(sample, **(self._merge_kwargs or {})) - def unjoin(self, sample: Any) -> tuple: - return super().unjoin(sample) + def unjoin(self, sample: xr.Dataset) -> tuple: + result = [] + for keys, attrs in self._input_structure: + if isinstance(keys, str): + da = sample[keys] + da.attrs = attrs + result.append(da) + else: + result.append(xr.Dataset({k: sample[k] for k in keys}, attrs=attrs)) + return tuple(result) class LatLonInterpolate(Joiner): @@ -54,6 +67,8 @@ class LatLonInterpolate(Joiner): It assumed the dimensions 'latitude', 'longitude', 'time', and 'level' will be present. 'lat' or 'lon' may also be used for convenience. + + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -68,9 +83,14 @@ def __init__( ): super().__init__() - self.raise_if_dimensions_wrong(reference_dataset) - self.record_initialisation() + + if reference_dataset is None and reference_index is None: + raise ValueError("No reference dataset or reference index set") + elif reference_dataset is not None and reference_index is not None: + raise ValueError("Only one of reference_dataset or reference_index should be set") + elif reference_dataset: + self.raise_if_dimensions_wrong(reference_dataset) self.reference_dataset = reference_dataset self.reference_index = reference_index self.interpolation_method = interpolation_method @@ -154,6 +174,8 @@ class GeospatialTimeSeriesMerge(Joiner): This joiner is more strict about the merging and interpolating, and also raises more informative error messages when it runs into trouble. + + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -220,7 +242,7 @@ class InterpLike(Joiner): """ Merge a tuple of xarray object's. - Currently cannot undo this operation + Currently cannot undo this operation. Raises NotImplementedError if undo is attempted. """ _override_interface = "Serial" @@ -262,7 +284,7 @@ class Concatenate(Joiner): """ Concatenate a tuple of xarray object's - Currently cannot undo this operation + Currently cannot undo this operation. Unjoining a sample returns the same sample. """ _override_interface = "Serial" diff --git a/packages/pipeline/tests/operations/dask/test_dask_join.py b/packages/pipeline/tests/operations/dask/test_dask_join.py new file mode 100644 index 00000000..79bcd7e3 --- /dev/null +++ b/packages/pipeline/tests/operations/dask/test_dask_join.py @@ -0,0 +1,68 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from pyearthtools.pipeline.operations.dask.join import Stack, VStack, HStack, Concatenate + +import dask.array as da +import numpy as np +import pytest + + +def _arrays(*shapes): + """Create dask arrays with given shapes whose elements are sequential integers.""" + offset = 0 + result = [] + for shape in shapes: + size = int(np.prod(shape)) + result.append(da.array(range(offset, offset + size)).reshape(shape)) + offset += size # offset ensures next array has different contents + return tuple(result) + + +# this parameterizations passes in the joiner class to test, with an appropriate axis as needed. +# It compares the joined result to an equivalent dask function, partially initialised with axis as needed. +# The shape of the input array passed to the test is adjusted based on the joiner. +@pytest.mark.parametrize( + ("joiner", "equiv_op", "input_arrays"), + [ + pytest.param(Stack(axis=0), partial(da.stack, axis=0), _arrays((2, 3), (2, 3)), id="Stack-axis0"), + pytest.param(Stack(axis=1), partial(da.stack, axis=1), _arrays((2, 3), (2, 3)), id="Stack-axis1"), + pytest.param(Stack(axis=2), partial(da.stack, axis=2), _arrays((2, 3), (2, 3)), id="Stack-axis2"), + pytest.param(VStack(), da.vstack, _arrays((1, 3, 2), (2, 3, 2)), id="VStack"), + pytest.param(HStack(), da.hstack, _arrays((3, 1, 2), (3, 2, 2)), id="HStack"), + pytest.param( + Concatenate(axis=0), partial(da.concatenate, axis=0), _arrays((1, 3, 2), (2, 3, 2)), id="Concatenate-axis0" + ), + pytest.param( + Concatenate(axis=1), partial(da.concatenate, axis=1), _arrays((3, 1, 2), (3, 2, 2)), id="Concatenate-axis1" + ), + pytest.param( + Concatenate(axis=2), partial(da.concatenate, axis=2), _arrays((3, 2, 1), (3, 2, 2)), id="Concatenate-axis2" + ), + ], +) +def test_join(joiner, equiv_op, input_arrays): + """Tests that joiners reproduce their dask equivalents and are reversible.""" + name = type(joiner).__name__ + + result = joiner.join(input_arrays) + expected = equiv_op(input_arrays) + assert np.array_equal(result.compute(), expected.compute()), f"{name}.join() did not reproduce expected behaviour." + + unjoined = joiner.unjoin(result) + assert isinstance(unjoined, tuple), f"{name}.unjoin() did not return a tuple." + for arr_undo, arr in zip(unjoined, input_arrays, strict=True): + assert np.array_equal(arr_undo.compute(), arr.compute()), f"{name}.unjoin() did not return original arrays." diff --git a/packages/pipeline/tests/operations/numpy/test_numpy_join.py b/packages/pipeline/tests/operations/numpy/test_numpy_join.py new file mode 100644 index 00000000..9ebc684c --- /dev/null +++ b/packages/pipeline/tests/operations/numpy/test_numpy_join.py @@ -0,0 +1,67 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +from pyearthtools.pipeline.operations.numpy.join import Stack, VStack, HStack, Concatenate + +import numpy as np +import pytest + + +def _arrays(*shapes): + """Create numpy arrays with given shapes whose elements are sequential integers.""" + offset = 0 + result = [] + for shape in shapes: + size = int(np.prod(shape)) + result.append(np.arange(offset, offset + size).reshape(shape)) + offset += size + return tuple(result) + + +# this parameterizations passes in the joiner class to test, with an appropriate axis as needed. +# It compares the joined result to an equivalent numpy function, partially initialised with axis as needed. +# The shape of the input array passed to the test is adjusted based on the joiner. +@pytest.mark.parametrize( + ("joiner", "equiv_op", "input_arrays"), + [ + pytest.param(Stack(axis=0), partial(np.stack, axis=0), _arrays((2, 3), (2, 3)), id="Stack-axis0"), + pytest.param(Stack(axis=1), partial(np.stack, axis=1), _arrays((2, 3), (2, 3)), id="Stack-axis1"), + pytest.param(Stack(axis=2), partial(np.stack, axis=2), _arrays((2, 3), (2, 3)), id="Stack-axis2"), + pytest.param(VStack(), np.vstack, _arrays((1, 3, 2), (2, 3, 2)), id="VStack"), + pytest.param(HStack(), np.hstack, _arrays((3, 1, 2), (3, 2, 2)), id="HStack"), + pytest.param( + Concatenate(axis=0), partial(np.concatenate, axis=0), _arrays((1, 3, 2), (2, 3, 2)), id="Concatenate-axis0" + ), + pytest.param( + Concatenate(axis=1), partial(np.concatenate, axis=1), _arrays((3, 1, 2), (3, 2, 2)), id="Concatenate-axis1" + ), + pytest.param( + Concatenate(axis=2), partial(np.concatenate, axis=2), _arrays((3, 2, 1), (3, 2, 2)), id="Concatenate-axis2" + ), + ], +) +def test_join(joiner, equiv_op, input_arrays): + """Tests that joiners reproduce their numpy equivalents and are reversible.""" + name = type(joiner).__name__ + + result = joiner.join(input_arrays) + expected = equiv_op(input_arrays) + assert np.array_equal(result, expected), f"{name}.join() did not reproduce expected behaviour." + + unjoined = joiner.unjoin(result) + assert isinstance(unjoined, tuple), f"{name}.unjoin() did not return a tuple." + for arr_undo, arr in zip(unjoined, input_arrays, strict=True): + assert np.array_equal(arr_undo, arr), f"{name}.unjoin() did not return original arrays." diff --git a/packages/pipeline/tests/operations/xarray/test_xarray_join.py b/packages/pipeline/tests/operations/xarray/test_xarray_join.py new file mode 100644 index 00000000..9d145828 --- /dev/null +++ b/packages/pipeline/tests/operations/xarray/test_xarray_join.py @@ -0,0 +1,316 @@ +# Copyright Commonwealth of Australia, Bureau of Meteorology 2025. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pyearthtools.pipeline.operations.xarray.join import ( + Merge, + LatLonInterpolate, + GeospatialTimeSeriesMerge, + InterpLike, + Concatenate, +) + +import numpy as np +import xarray as xr + +import pytest + + +def test_merge(): + coords = {"x": [1, 2, 3], "y": [4, 5, 6]} + + da = xr.DataArray( + np.arange(9).reshape(3, 3), + coords=coords, + dims=["x", "y"], + name="alpha", + attrs={"source": "model", "units": "K"}, + ) + + ds = xr.Dataset( + { + "beta": xr.DataArray(np.arange(9, 18).reshape(3, 3), coords=coords, dims=["x", "y"]), + "gamma": xr.DataArray(np.arange(18, 27).reshape(3, 3), coords=coords, dims=["x", "y"]), + }, + attrs={"source": "model", "resolution": "1deg"}, # "source" overlaps with da + ) + + sample = (da, ds) + + joiner = Merge() + + result = joiner.join(sample) + + assert result["alpha"].equals(da), "Merge.join didn't merge objects correctly." + assert result["beta"].equals(ds["beta"]), "Merge.join didn't merge objects correctly." + assert result["gamma"].equals(ds["gamma"]), "Merge.join didn't merge objects correctly." + assert result.attrs == da.attrs, "Merge.join result didn't preserve first object's attributes" + assert result.attrs != ds.attrs, "Merge.join didn't discard second object's attributes" + + unjoined = joiner.unjoin(result) + + assert isinstance(unjoined, tuple), "Merge.unjoin didn't result in a tuple." + for d_undo, d_orig in zip(unjoined, sample, strict=True): + assert isinstance(d_undo, type(d_orig)) + assert d_undo.equals(d_orig), "Merge.unjoin didn't restore objects." + assert d_undo.attrs == d_orig.attrs, "Merge.unjoin didn't preserve attributes." + + # test passing kwargs to xr.merge + # should combine attributes + joiner = Merge(merge_kwargs={"combine_attrs": "no_conflicts"}) + + result = joiner.join(sample) + + assert result["alpha"].equals(da), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result["beta"].equals( + ds["beta"] + ), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result["gamma"].equals( + ds["gamma"] + ), 'passing combine_attrs="no_conflict" to Merge didn\'t merge object correctly.' + assert result.attrs == ( + da.attrs | ds.attrs + ), 'passing combine_attrs="no_conflict" to Merge didn\'t unionise attributes.' + + unjoined = joiner.unjoin(result) + + assert isinstance( + unjoined, tuple + ), 'passing combine_attrs="no_conflict" to Merge didn\'t result in a tuple when unjoining.' + for d_undo, d_orig in zip(unjoined, sample, strict=True): + assert isinstance( + d_undo, type(d_orig) + ), "passing combine_attrs=\"no_conflict\" to Merge didn't preserve object's type when unjoining." + assert d_undo.equals( + d_orig + ), 'passing combine_attrs="no_conflict" to Merge didn\'t restore object when unjoining.' + assert ( + d_undo.attrs == d_orig.attrs + ), 'passing combine_attrs="no_conflict" to Merge didn\'t preserve attributes when unjoining.' + + +def _make_ds(var_name, data, lat, lon, time=None, lat_name="latitude", lon_name="longitude"): + """Create a Dataset with latitude, longitude, and time coords.""" + time = time or [0] + return xr.Dataset( + {var_name: xr.DataArray(data, dims=["time", lat_name, lon_name])}, + coords={"time": time, lat_name: lat, lon_name: lon}, + ) + + +@pytest.fixture +def ds_ref(): + return _make_ds(var_name="var_ref", data=np.arange(9).reshape(1, 3, 3), lat=[0.0, 1.0, 2.0], lon=[0.0, 1.0, 2.0]) + + +@pytest.mark.parametrize( + ("lat_name", "lon_name", "joiner_factory"), + [ + pytest.param( + "latitude", "longitude", lambda ds_ref: LatLonInterpolate(reference_index=0), id="reference_index" + ), + pytest.param("lat", "lon", lambda ds_ref: LatLonInterpolate(reference_dataset=ds_ref), id="reference_dataset"), + ], +) +def test_latlon_interpolate_join(lat_name, lon_name, joiner_factory): + """Tests that LatLonInterpolate merges and interpolates datasets to the reference grid.""" + ds_ref = _make_ds( + var_name="var1", + data=np.arange(9).reshape(1, 3, 3), + lat=[0.0, 1.0, 2.0], + lon=[0.0, 1.0, 2.0], + lat_name=lat_name, + lon_name=lon_name, + ) + ds_coarse = _make_ds( + var_name="var2", + data=np.arange(9, 13).reshape(1, 2, 2), + lat=[-0.25, 2.25], + lon=[-0.25, 2.25], + lat_name=lat_name, + lon_name=lon_name, + ) + + result = joiner_factory(ds_ref).join((ds_ref, ds_coarse)) + + assert "var1" in result.data_vars + assert "var2" in result.data_vars + # astype is needed because interp changes datatype + assert ds_ref["var1"].equals(result["var1"].astype(int)) + assert ds_ref.coords.equals(result["var2"].coords) + + expected_interp = np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]) + assert np.array_equal(result["var2"].squeeze("time").values, expected_interp) + + +def test_latlon_interpolate_errors(ds_ref): + """Tests that LatLonInterpolate raises errors for invalid configurations.""" + ds_coarse = _make_ds(var_name="var2", data=np.arange(9, 13).reshape(1, 2, 2), lat=[-0.25, 2.25], lon=[-0.25, 2.25]) + + with pytest.raises(ValueError): + LatLonInterpolate() + + with pytest.raises(ValueError): + LatLonInterpolate(reference_dataset=ds_ref, reference_index=0) + + with pytest.raises(ValueError): + LatLonInterpolate(reference_dataset=ds_ref.rename({"latitude": "abc", "longitude": "123"})) + + joiner = LatLonInterpolate(reference_index=0) + joiner.reference_index = None + with pytest.raises(ValueError): + joiner.join((ds_ref, ds_coarse)) + + # unjoin not implemented + with pytest.raises(NotImplementedError): + joiner.unjoin(ds_ref) + + +def test_geospatial_timeseries_merge_join(ds_ref): + """Tests that GeospatialTimeSeriesMerge interpolates and merges datasets.""" + da_coarse = xr.DataArray( + np.arange(9, 13).reshape(1, 2, 2), + dims=["time", "latitude", "longitude"], + coords={"time": [0], "latitude": [-0.25, 2.25], "longitude": [-0.25, 2.25]}, + name="var2", + ) + + joiner = GeospatialTimeSeriesMerge(reference_index=0) + result = joiner.join((ds_ref, da_coarse)) + + assert "var_ref" in result.data_vars + assert "var2" in result.data_vars + assert ( + result["var2"].shape == ds_ref["var_ref"].shape + ), "GeospatialTimeSeriesMerge did not interpolate to reference grid shape." + assert tuple(result.latitude.values) == tuple(ds_ref.latitude.values) + assert tuple(result.longitude.values) == tuple(ds_ref.longitude.values) + + +def test_geospatial_timeseries_merge_errors(ds_ref): + """Tests that GeospatialTimeSeriesMerge raises errors for invalid inputs.""" + ds_no_time = _make_ds( + var_name="var2", data=np.arange(9, 18).reshape(1, 3, 3), lat=ds_ref.latitude.values, lon=ds_ref.longitude.values + ).drop_dims("time") + + # fail when trying to join without setting reference + with pytest.raises(ValueError): + GeospatialTimeSeriesMerge().join((ds_ref, ds_ref)) + + joiner = GeospatialTimeSeriesMerge(reference_dataset=ds_ref) + # fail when trying to join datasets and one doesn't have the time dim + with pytest.raises(ValueError): + joiner.join((ds_no_time, ds_ref)) + with pytest.raises(ValueError): + joiner.join((ds_ref, ds_no_time)) + + # fail when trying to unjoin + with pytest.raises(NotImplementedError): + GeospatialTimeSeriesMerge().unjoin(None) + + +def test_interplike(ds_ref): + + da_coarse = xr.DataArray( + np.arange(9, 13).reshape(2, 2), + dims=["latitude", "longitude"], + coords={"latitude": [-0.25, 2.25], "longitude": [-0.25, 2.25]}, + name="var1", + ) + da_fine = xr.DataArray( + np.arange(13, 29).reshape(4, 4), + dims=["latitude", "longitude"], + coords={"latitude": [0.0, 0.67, 1.33, 2.0], "longitude": [0.0, 0.67, 1.33, 2.0]}, + name="var2", + ) + + # test default interpolation method (nearest) + joiner = InterpLike(reference_dataset=ds_ref) + result = joiner.join([da_coarse, da_fine]) + expected_nearest = { + "var1": np.array([[9.0, 9.0, 10.0], [9.0, 9.0, 10.0], [11.0, 11.0, 12.0]]), + "var2": np.array([[13.0, 14.0, 16.0], [17.0, 18.0, 20.0], [25.0, 26.0, 28.0]]), + } + for ds in ("var1", "var2"): + assert ds in result.data_vars, f"{ds} missing from joined dataset" + assert (1,) + result[ds].shape == ds_ref[ + "var_ref" + ].shape, f"InterpLike didn't interpolate {ds} onto ds_ref's coords" + assert np.array_equal(expected_nearest[ds], result[ds].values), f"Interplike didn't interpolate {ds}'s values" + + # test linear interpolation method + joiner = InterpLike(reference_dataset=ds_ref, method="linear") + result = joiner.join([da_coarse, da_fine]) + expected_linear = { + "var1": np.array([[9.3, 9.7, 10.1], [10.1, 10.5, 10.9], [10.9, 11.3, 11.7]]), + "var2": np.array([[13.0, 14.5, 16.0], [19.0, 20.5, 22.0], [25.0, 26.5, 28.0]]), + } + for ds in ("var1", "var2"): + assert np.allclose(expected_linear[ds], result[ds].values) + + # test reference index + joiner = InterpLike(reference_index=0) + result = joiner.join([ds_ref, da_coarse, da_fine]) + assert "var_ref" in result.data_vars, "InterpLike didn't preserve reference dataset" + assert ds_ref["var_ref"].equals(result["var_ref"].astype(int)), "InterpLike didn't reproduce reference" + for ds in ("var1", "var2"): + assert ds in result.data_vars, f"{ds} missing from joined dataset" + assert (1,) + result[ds].shape == ds_ref[ + "var_ref" + ].shape, f"InterpLike didn't interpolate {ds} onto ds_ref's coords" + assert np.array_equal(expected_nearest[ds], result[ds].values), f"Interplike didn't interpolate {ds}'s values" + + +def test_interplike_errors(ds_ref): + joiner = InterpLike() + with pytest.raises(ValueError): + joiner.join([ds_ref]) + + with pytest.raises(NotImplementedError): + joiner.unjoin(ds_ref) + + +def test_concatenate(): + # test with dataarrays + da1 = xr.DataArray(np.arange(6).reshape((2, 3)), coords={"x": range(2), "y": range(3)}) + da2 = xr.DataArray(np.arange(6, 18).reshape((4, 3)), coords={"x": range(4), "y": range(3)}) + joiner = Concatenate(concat_dim="x") + result = joiner.join([da1, da2]) + assert np.array_equal(result.values, np.arange(18).reshape((6, 3))) + + # test with datasets + ds1 = xr.Dataset({"var1": da1}) + ds2 = xr.Dataset({"var2": da2}) + result = joiner.join([ds1, ds2]) + expected = np.vstack((da1.values, np.full((4, 3), np.nan))) + assert np.array_equal(expected, result["var1"].values, equal_nan=True) + expected = np.vstack((np.full((2, 3), np.nan), da2.values)) + assert np.array_equal(expected, result["var2"].values, equal_nan=True) + + # test concat kwargs (dim kwarg should be ignored) + joiner = Concatenate(concat_dim="x", concat_kwargs={"fill_value": 0, "dim": "y"}) + result = joiner.join([ds1, ds2]) + expected = np.vstack((da1.values, np.zeros((4, 3)))) + assert np.array_equal( + expected, + result["var1"].values, + ) + expected = np.vstack((np.zeros((2, 3)), da2.values)) + assert np.array_equal( + expected, + result["var2"].values, + ) + + # unjoin not implemented: returns the input + joiner = Concatenate(concat_dim="x") + assert ds1.equals(joiner.unjoin(ds1))