-
Notifications
You must be signed in to change notification settings - Fork 30
tests: add unit tests for functions without direct test coverage #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| """ | ||
| Unit tests for ops.chunking module. | ||
| """ | ||
| import numpy as np | ||
| import pytest | ||
| import xarray as xr | ||
| from loguru import logger | ||
|
|
||
| from mllam_data_prep.ops.chunking import check_chunk_size, chunk_dataset | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def small_dataset(): | ||
| """Create a small test dataset.""" | ||
| return xr.Dataset( | ||
| { | ||
| "var1": (["x", "y"], np.random.random((10, 10))), | ||
| "var2": (["x", "y"], np.random.random((10, 10))), | ||
| }, | ||
| coords={"x": range(10), "y": range(10)}, | ||
| ) | ||
|
|
||
|
|
||
| def test_check_chunk_size_small_chunks(small_dataset): | ||
| """Test check_chunk_size with small chunks (should not warn).""" | ||
| chunks = {"x": 5, "y": 5} | ||
| # Should not raise or warn | ||
| check_chunk_size(small_dataset, chunks) | ||
|
|
||
|
|
||
| def test_check_chunk_size_large_chunks(small_dataset): | ||
| """Test check_chunk_size with large chunks (should warn).""" | ||
| # Use chunk sizes that exceed 1GB threshold | ||
| # For float64 (8 bytes), need chunks product > 1GB / 8 = 134217728 | ||
| # Using chunks of 12000 x 12000 = 144000000 elements > 134217728 | ||
| chunks = {"x": 12000, "y": 12000} | ||
|
|
||
| # Capture loguru logs using a handler | ||
| from io import StringIO | ||
|
|
||
| log_capture = StringIO() | ||
| handler_id = logger.add(log_capture, format="{message}") | ||
|
|
||
| try: | ||
| check_chunk_size(small_dataset, chunks) | ||
| log_output = log_capture.getvalue() | ||
| assert "exceeds" in log_output.lower() | ||
| finally: | ||
| logger.remove(handler_id) | ||
|
|
||
|
|
||
| def test_check_chunk_size_missing_dimension(small_dataset): | ||
| """Test check_chunk_size when dimension doesn't exist in variable.""" | ||
| chunks = {"x": 5, "z": 10} # z doesn't exist | ||
| # Should not raise, just skip the missing dimension | ||
| check_chunk_size(small_dataset, chunks) | ||
|
|
||
|
|
||
| def test_chunk_dataset_success(small_dataset): | ||
| """Test chunk_dataset successfully chunks a dataset.""" | ||
| chunks = {"x": 5, "y": 5} | ||
| chunked = chunk_dataset(small_dataset, chunks) | ||
| assert isinstance(chunked, xr.Dataset) | ||
| # Check that chunking was applied | ||
| assert chunked["var1"].chunks is not None | ||
|
|
||
|
|
||
| def test_chunk_dataset_invalid_chunks(small_dataset): | ||
| """Test chunk_dataset with invalid chunk specification.""" | ||
| chunks = {"x": -1} # Invalid chunk size | ||
| with pytest.raises(Exception, match="Error chunking dataset"): | ||
| chunk_dataset(small_dataset, chunks) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,47 @@ | ||
| """ | ||
| Unit tests for ops.loading module. | ||
| """ | ||
| import pytest | ||
| import xarray as xr | ||
|
|
||
| from mllam_data_prep.ops.loading import load_input_dataset | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def sample_dataset(): | ||
| """Create a simple test dataset.""" | ||
| return xr.Dataset( | ||
| {"var": (["x"], [1, 2, 3])}, | ||
| coords={"x": [0, 1, 2]}, | ||
| ) | ||
|
|
||
|
|
||
| def test_load_input_dataset_zarr(sample_dataset, tmp_path): | ||
| """Test load_input_dataset with zarr format.""" | ||
| zarr_path = tmp_path / "test.zarr" | ||
| sample_dataset.to_zarr(zarr_path, mode="w") | ||
|
|
||
| loaded = load_input_dataset(str(zarr_path)) | ||
| assert isinstance(loaded, xr.Dataset) | ||
| assert "var" in loaded.data_vars | ||
| assert list(loaded.x.values) == [0, 1, 2] | ||
|
|
||
|
|
||
| def test_load_input_dataset_netcdf(sample_dataset, tmp_path): | ||
| """Test load_input_dataset with netCDF format.""" | ||
| # Skip if NetCDF engine is not available | ||
| pytest.importorskip("netCDF4") | ||
|
|
||
| nc_path = tmp_path / "test.nc" | ||
| sample_dataset.to_netcdf(nc_path, engine="netcdf4") | ||
|
|
||
| loaded = load_input_dataset(str(nc_path)) | ||
| assert isinstance(loaded, xr.Dataset) | ||
| assert "var" in loaded.data_vars | ||
| assert list(loaded.x.values) == [0, 1, 2] | ||
|
|
||
|
|
||
| def test_load_input_dataset_nonexistent(): | ||
| """Test load_input_dataset with non-existent file.""" | ||
| with pytest.raises((OSError, FileNotFoundError)): | ||
| load_input_dataset("/nonexistent/path/to/file.zarr") | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| """ | ||
| Unit tests for helper functions in ops.selection module. | ||
| """ | ||
| import pandas as pd | ||
| import pytest | ||
| import xarray as xr | ||
|
|
||
| from mllam_data_prep.ops.selection import check_point_in_dataset, check_step | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def simple_time_dataset(): | ||
| """Create a simple dataset with time coordinate.""" | ||
| time_values = pd.date_range("2020-01-01", periods=5, freq="3H") | ||
| return xr.Dataset( | ||
| {"var": (["time"], range(5))}, | ||
| coords={"time": time_values}, | ||
| ) | ||
|
|
||
|
|
||
| def test_check_point_in_dataset_point_exists(simple_time_dataset): | ||
| """Test check_point_in_dataset when point exists in coordinate.""" | ||
| point = simple_time_dataset.time.values[2] | ||
| # Should not raise | ||
| check_point_in_dataset("time", point, simple_time_dataset) | ||
|
|
||
|
|
||
| def test_check_point_in_dataset_point_not_exists(simple_time_dataset): | ||
| """Test check_point_in_dataset when point does not exist in coordinate.""" | ||
| point = pd.Timestamp("2020-01-02T12:00") | ||
| with pytest.raises(ValueError, match="Provided value for coordinate time"): | ||
| check_point_in_dataset("time", point, simple_time_dataset) | ||
|
|
||
|
|
||
| def test_check_point_in_dataset_none_point(simple_time_dataset): | ||
| """Test check_point_in_dataset when point is None (should not raise).""" | ||
| # Should not raise when point is None | ||
| check_point_in_dataset("time", None, simple_time_dataset) | ||
|
|
||
|
|
||
| def test_check_step_constant_step_matches(simple_time_dataset): | ||
| """Test check_step when step is constant and matches requested step.""" | ||
| requested_step = pd.Timedelta(hours=3) | ||
| # Should not raise | ||
| check_step(requested_step, "time", simple_time_dataset) | ||
|
|
||
|
|
||
| def test_check_step_constant_step_mismatch(simple_time_dataset): | ||
| """Test check_step when step is constant but doesn't match requested step.""" | ||
| requested_step = pd.Timedelta(hours=6) | ||
| with pytest.raises(ValueError, match="Step size for coordinate time"): | ||
| check_step(requested_step, "time", simple_time_dataset) | ||
|
|
||
|
|
||
| def test_check_step_non_constant_step(): | ||
| """Test check_step when step size is not constant.""" | ||
| # Create dataset with non-constant time steps | ||
| time_values = pd.to_datetime( | ||
| ["2020-01-01T00:00", "2020-01-01T03:00", "2020-01-01T10:00", "2020-01-01T13:00"] | ||
| ) | ||
| ds = xr.Dataset( | ||
| {"var": (["time"], range(4))}, | ||
| coords={"time": time_values}, | ||
| ) | ||
| requested_step = pd.Timedelta(hours=3) | ||
| with pytest.raises(ValueError, match="Step size for coordinate time is not constant"): | ||
| check_step(requested_step, "time", ds) | ||
|
|
||
|
|
||
| def test_check_step_single_point_coordinate(): | ||
| """Test check_step with single point coordinate (should raise descriptive ValueError).""" | ||
| # Create dataset with single time point | ||
| time_values = pd.date_range("2020-01-01", periods=1, freq="3H") | ||
| ds = xr.Dataset( | ||
| {"var": (["time"], [1])}, | ||
| coords={"time": time_values}, | ||
| ) | ||
| requested_step = pd.Timedelta(hours=3) | ||
| with pytest.raises(ValueError, match="Cannot compute step size.*fewer than 2 points"): | ||
| check_step(requested_step, "time", ds) | ||
|
Comment on lines
+70
to
+80
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sample_dataset.to_netcdf(...)requires an optional NetCDF engine (typicallyscipy,netCDF4, orh5netcdf). The project dependencies don’t appear to include any of these, so this test may fail in CI depending on the environment. Consider usingpytest.importorskip(...)for the chosen engine and specifying it explicitly (e.g.,engine="scipy"), or adding an explicit test dependency to ensure NetCDF support is available.