Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions packages/pipeline/src/pyearthtools/pipeline/operations/dask/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# type: ignore[reportPrivateImportUsage]

from itertools import accumulate
from typing import Optional, Any

import dask.array as da
Expand All @@ -26,14 +27,12 @@
class Stack(Joiner, DaskOperation):
"""
Stack a tuple of da.Array's

Currently cannot undo this operation
"""

_override_interface = ["Serial"]
_numpy_counterpart = "join.Stack"

def __init__(self, axis: Optional[int] = None):
def __init__(self, axis: Optional[int] = 0):
super().__init__()
self.record_initialisation()
self.axis = axis
Expand All @@ -43,14 +42,14 @@ def join(self, sample: tuple[Any, ...]) -> da.Array:
return da.stack(sample, self.axis) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
"""Unstacks a stacked sample"""
# move the stacked axis to the zeroth axis and convert to tuple
return tuple(da.moveaxis(sample, self.axis, 0))


class VStack(Joiner, DaskOperation):
"""
Vertically Stack a tuple of da.Array's

Currently cannot undo this operation
"""

_override_interface = ["Serial"]
Expand All @@ -59,22 +58,24 @@ class VStack(Joiner, DaskOperation):
def __init__(self):
super().__init__()
self.record_initialisation()
self.offsets = None

def join(self, sample: tuple[Any, ...]) -> da.Array:
"""Join sample"""
self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1]))
return da.vstack(
sample,
) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
start = (0,) + self.offsets
ends = self.offsets + (sample.shape[0],)
return tuple(sample[start:end] for start, end in zip(start, ends, strict=True))


class HStack(Joiner, DaskOperation):
"""
Horizontally Stack a tuple of da.Array's

Currently cannot undo this operation
"""

_override_interface = ["Serial"]
Expand All @@ -83,22 +84,24 @@ class HStack(Joiner, DaskOperation):
def __init__(self):
super().__init__()
self.record_initialisation()
self.offsets = None

def join(self, sample: tuple[Any, ...]) -> da.Array:
"""Join sample"""
self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1]))
return da.hstack(
sample,
) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
start = (0,) + self.offsets
ends = self.offsets + (sample.shape[1],)
return tuple(sample[:, start:end, ...] for start, end in zip(start, ends, strict=True))


class Concatenate(Joiner, DaskOperation):
"""
Concatenate a tuple of da.Array's

Currently cannot undo this operation
"""

_override_interface = ["Serial"]
Expand All @@ -108,10 +111,14 @@ def __init__(self, axis: Optional[int] = None):
super().__init__()
self.record_initialisation()
self.axis = axis
self.offsets = None

def join(self, sample: tuple[Any, ...]) -> da.Array:
"""Join sample"""
self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1]))
return da.concatenate(sample, self.axis) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
start = (0,) + self.offsets
ends = self.offsets + (sample.shape[self.axis],)
return tuple(da.take(sample, slice(start, end), self.axis) for start, end in zip(start, ends, strict=True))
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from itertools import accumulate
from typing import Optional, Any

import numpy as np
Expand All @@ -23,8 +23,6 @@
class Stack(Joiner):
"""
Stack a tuple of np.ndarray's

Currently cannot undo this operation
"""

_override_interface = ["Delayed", "Serial"]
Expand All @@ -40,14 +38,16 @@ def join(self, sample: tuple[Any, ...]) -> np.ndarray:
return np.stack(sample, self.axis) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
"""Unstacks a stacked sample"""
# np.stack(..., axis=None) is equivalent to np.stack(..., axis=0)
axis = self.axis if self.axis is not None else 0
# move the stacked axis to the zeroth axis and convert to tuple
return tuple(np.moveaxis(sample, axis, 0))


class VStack(Joiner):
"""
Vertically Stack a tuple of np.ndarray's

Currently cannot undo this operation
"""

_override_interface = ["Delayed", "Serial"]
Expand All @@ -56,22 +56,23 @@ class VStack(Joiner):
def __init__(self):
super().__init__()
self.record_initialisation()
self.offsets = None # stores the vertical offset where each joined array is

def join(self, sample: tuple[Any, ...]) -> np.ndarray:
"""Join sample"""
# stores
self.offsets = tuple(accumulate(arr.shape[0] for arr in sample[:-1]))
return np.vstack(
sample,
) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
return tuple(np.vsplit(sample, self.offsets))


class HStack(Joiner):
"""
Horizontally Stack a tuple of np.ndarray's

Currently cannot undo this operation
"""

_override_interface = ["Delayed", "Serial"]
Expand All @@ -80,22 +81,22 @@ class HStack(Joiner):
def __init__(self):
super().__init__()
self.record_initialisation()
self.offsets = None

def join(self, sample: tuple[Any, ...]) -> np.ndarray:
"""Join sample"""
self.offsets = tuple(accumulate(arr.shape[1] for arr in sample[:-1]))
return np.hstack(
sample,
) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
return tuple(np.hsplit(sample, self.offsets))


class Concatenate(Joiner):
"""
Concatenate a tuple of np.ndarray's

Currently cannot undo this operation
"""

_override_interface = ["Delayed", "Serial"]
Expand All @@ -105,10 +106,12 @@ def __init__(self, axis: Optional[int] = None):
super().__init__()
self.record_initialisation()
self.axis = axis
self.offsets = None

def join(self, sample: tuple[Any, ...]) -> np.ndarray:
"""Join sample"""
self.offsets = tuple(accumulate(arr.shape[self.axis] for arr in sample[:-1]))
return np.concatenate(sample, self.axis) # type: ignore

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
return tuple(np.split(sample, self.offsets, axis=self.axis))
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Merge(Joiner):
"""
Merge a tuple of xarray object's.

Currently cannot undo this operation
Currently can only undo this operation with xr.Dataset and xr.DataArray inputs.
"""

_override_interface = "Serial"
Expand All @@ -36,13 +36,26 @@ def __init__(self, merge_kwargs: Optional[dict[str, Any]] = None):
super().__init__()
self.record_initialisation()
self._merge_kwargs = merge_kwargs
self._input_structure: list[tuple[Union[str, list[str]], dict]] = []

def join(self, sample: tuple[Union[xr.Dataset, xr.DataArray], ...]) -> xr.Dataset:
"""Join sample"""
self._input_structure = [
(item.name, item.attrs) if isinstance(item, xr.DataArray) else (list(item.data_vars), item.attrs)
for item in sample
]
return xr.merge(sample, **(self._merge_kwargs or {}))

def unjoin(self, sample: Any) -> tuple:
return super().unjoin(sample)
def unjoin(self, sample: xr.Dataset) -> tuple:
result = []
for keys, attrs in self._input_structure:
if isinstance(keys, str):
da = sample[keys]
da.attrs = attrs
result.append(da)
else:
result.append(xr.Dataset({k: sample[k] for k in keys}, attrs=attrs))
return tuple(result)


class LatLonInterpolate(Joiner):
Expand All @@ -54,6 +67,8 @@ class LatLonInterpolate(Joiner):

It assumed the dimensions 'latitude', 'longitude', 'time', and 'level' will
be present. 'lat' or 'lon' may also be used for convenience.

Currently cannot undo this operation. Raises NotImplementedError if undo is attempted.
"""

_override_interface = "Serial"
Expand All @@ -68,9 +83,14 @@ def __init__(
):
super().__init__()

self.raise_if_dimensions_wrong(reference_dataset)

self.record_initialisation()

if reference_dataset is None and reference_index is None:
raise ValueError("No reference dataset or reference index set")
elif reference_dataset is not None and reference_index is not None:
raise ValueError("Only one of reference_dataset or reference_index should be set")
elif reference_dataset:
self.raise_if_dimensions_wrong(reference_dataset)
self.reference_dataset = reference_dataset
self.reference_index = reference_index
self.interpolation_method = interpolation_method
Expand Down Expand Up @@ -154,6 +174,8 @@ class GeospatialTimeSeriesMerge(Joiner):

This joiner is more strict about the merging and interpolating, and also
raises more informative error messages when it runs into trouble.

Currently cannot undo this operation. Raises NotImplementedError if undo is attempted.
"""

_override_interface = "Serial"
Expand Down Expand Up @@ -220,7 +242,7 @@ class InterpLike(Joiner):
"""
Merge a tuple of xarray object's.

Currently cannot undo this operation
Currently cannot undo this operation. Raises NotImplementedError if undo is attempted.
"""

_override_interface = "Serial"
Expand Down Expand Up @@ -262,7 +284,7 @@ class Concatenate(Joiner):
"""
Concatenate a tuple of xarray object's

Currently cannot undo this operation
Currently cannot undo this operation. Unjoining a sample returns the same sample.
"""

_override_interface = "Serial"
Expand Down
68 changes: 68 additions & 0 deletions packages/pipeline/tests/operations/dask/test_dask_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright Commonwealth of Australia, Bureau of Meteorology 2025.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from pyearthtools.pipeline.operations.dask.join import Stack, VStack, HStack, Concatenate

import dask.array as da
import numpy as np
import pytest


def _arrays(*shapes):
"""Create dask arrays with given shapes whose elements are sequential integers."""
offset = 0
result = []
for shape in shapes:
size = int(np.prod(shape))
result.append(da.array(range(offset, offset + size)).reshape(shape))
offset += size # offset ensures next array has different contents
return tuple(result)


# this parameterizations passes in the joiner class to test, with an appropriate axis as needed.
# It compares the joined result to an equivalent dask function, partially initialised with axis as needed.
# The shape of the input array passed to the test is adjusted based on the joiner.
@pytest.mark.parametrize(
("joiner", "equiv_op", "input_arrays"),
[
pytest.param(Stack(axis=0), partial(da.stack, axis=0), _arrays((2, 3), (2, 3)), id="Stack-axis0"),
pytest.param(Stack(axis=1), partial(da.stack, axis=1), _arrays((2, 3), (2, 3)), id="Stack-axis1"),
pytest.param(Stack(axis=2), partial(da.stack, axis=2), _arrays((2, 3), (2, 3)), id="Stack-axis2"),
pytest.param(VStack(), da.vstack, _arrays((1, 3, 2), (2, 3, 2)), id="VStack"),
pytest.param(HStack(), da.hstack, _arrays((3, 1, 2), (3, 2, 2)), id="HStack"),
pytest.param(
Concatenate(axis=0), partial(da.concatenate, axis=0), _arrays((1, 3, 2), (2, 3, 2)), id="Concatenate-axis0"
),
pytest.param(
Concatenate(axis=1), partial(da.concatenate, axis=1), _arrays((3, 1, 2), (3, 2, 2)), id="Concatenate-axis1"
),
pytest.param(
Concatenate(axis=2), partial(da.concatenate, axis=2), _arrays((3, 2, 1), (3, 2, 2)), id="Concatenate-axis2"
),
],
)
def test_join(joiner, equiv_op, input_arrays):
"""Tests that joiners reproduce their dask equivalents and are reversible."""
name = type(joiner).__name__

result = joiner.join(input_arrays)
expected = equiv_op(input_arrays)
assert np.array_equal(result.compute(), expected.compute()), f"{name}.join() did not reproduce expected behaviour."

unjoined = joiner.unjoin(result)
assert isinstance(unjoined, tuple), f"{name}.unjoin() did not return a tuple."
for arr_undo, arr in zip(unjoined, input_arrays, strict=True):
assert np.array_equal(arr_undo.compute(), arr.compute()), f"{name}.unjoin() did not return original arrays."
Loading
Loading