Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ class InvalidConfigException(Exception):
pass


def validate_config(config_inputs):
def validate_config(config_inputs: Dict[str, "InputDataset"]) -> None:
"""
Validate that, in the config:
- either `variables` or `derived_variables` are present in the config
Expand Down
7 changes: 5 additions & 2 deletions mllam_data_prep/ops/chunking.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Dict

import numpy as np
import xarray as xr
from loguru import logger

# Max chunk size warning
CHUNK_MAX_SIZE_WARNING = 1 * 1024**3 # 1GB


def check_chunk_size(ds, chunks):
def check_chunk_size(ds: xr.Dataset, chunks: Dict[str, int]) -> None:
"""
Check the chunk size and warn if it exceeds CHUNK_MAX_SIZE_WARNING.

Expand Down Expand Up @@ -44,7 +47,7 @@ def check_chunk_size(ds, chunks):
)


def chunk_dataset(ds, chunks):
def chunk_dataset(ds: xr.Dataset, chunks: Dict[str, int]) -> xr.Dataset:
"""
Check the chunk size and chunk the dataset.

Expand Down
2 changes: 1 addition & 1 deletion mllam_data_prep/ops/loading.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import xarray as xr


def load_input_dataset(fp):
def load_input_dataset(fp: str) -> xr.Dataset:
"""
Load the dataset

Expand Down
13 changes: 11 additions & 2 deletions mllam_data_prep/ops/mapping.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from typing import Any, Dict, List

import xarray as xr

from ..config import DimMapping
from .stacking import stack_variables_as_coord_values, stack_variables_by_coord_values


def _check_for_malformed_list_arg(s):
def _check_for_malformed_list_arg(s: Any) -> None:
if isinstance(s, str) and "," in s:
raise Exception(
"Rather than writing `{s}` to define a list you would `[{s}]` in the config file."
)


def map_dims_and_variables(ds, dim_mapping, expected_input_var_dims):
def map_dims_and_variables(
ds: xr.Dataset,
dim_mapping: Dict[str, DimMapping],
expected_input_var_dims: List[str],
) -> xr.DataArray:
"""
Map the input dimensions to the architecture dimensions
using the `dim_mapping` dictionary. Each key in the `dim_mapping`
Expand Down
22 changes: 17 additions & 5 deletions mllam_data_prep/ops/selection.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import datetime
from typing import List, Optional, Union

import pandas as pd
import xarray as xr

from ..config import Range


def _normalize_slice_startstop(s):
def _normalize_slice_startstop(
s: Union[pd.Timestamp, str, int, float]
) -> Union[pd.Timestamp, str, int, float]:
if isinstance(s, pd.Timestamp):
return s
elif isinstance(s, str):
Expand All @@ -17,7 +21,9 @@ def _normalize_slice_startstop(s):
return s


def _normalize_slice_step(s):
def _normalize_slice_step(
s: Optional[Union[pd.Timedelta, str, int, float]]
) -> Optional[Union[pd.Timedelta, str, int, float]]:
if isinstance(s, pd.Timedelta):
return s
elif isinstance(s, str):
Expand All @@ -29,7 +35,9 @@ def _normalize_slice_step(s):
return s


def select_by_kwargs(ds, **coord_ranges):
def select_by_kwargs(
ds: xr.Dataset, **coord_ranges: Union[Range, List[Any]]
) -> xr.Dataset:
"""
Do `xr.Dataset.sel` on `ds` using the `coord_ranges` to select the coordinates, for each
entry in the dictionary, the key is the coordinate name and the value is the selection
Expand Down Expand Up @@ -92,7 +100,7 @@ def select_by_kwargs(ds, **coord_ranges):
return ds


def check_point_in_dataset(coord, point, ds):
def check_point_in_dataset(coord: str, point: Any, ds: xr.Dataset) -> None:
"""
check that the requested point is in the data.
"""
Expand All @@ -102,7 +110,11 @@ def check_point_in_dataset(coord, point, ds):
)


def check_step(sel_step, coord, ds):
def check_step(
sel_step: Union[pd.Timedelta, datetime.timedelta],
coord: str,
ds: xr.Dataset,
) -> None:
"""
check that the step requested is exactly what the data has
"""
Expand Down
11 changes: 9 additions & 2 deletions mllam_data_prep/ops/stacking.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import xarray as xr


def stack_variables_as_coord_values(ds, name_format, combined_dim_name):
def stack_variables_as_coord_values(
ds: xr.Dataset, name_format: str, combined_dim_name: str
) -> xr.DataArray:
"""
combine all variables in an xr.Dataset into a single xr.DataArray
by stacking the variables along a new coordinate with the name given
Expand Down Expand Up @@ -50,7 +52,12 @@ def stack_variables_as_coord_values(ds, name_format, combined_dim_name):
return da_combined


def stack_variables_by_coord_values(ds, coord, name_format, combined_dim_name):
def stack_variables_by_coord_values(
ds: xr.Dataset,
coord: str,
name_format: str,
combined_dim_name: str,
) -> xr.DataArray:
"""
combine all variables in an xr.Dataset on all coordinate values of `coord`
into a single xr.DataArray
Expand Down
26 changes: 19 additions & 7 deletions mllam_data_prep/ops/subsetting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
def extract_variable(ds, var_name, coords_to_sample=dict()):
from typing import Dict, Optional

import xarray as xr

from ..config import ValueSelection


def extract_variable(
ds: xr.Dataset,
var_name: str,
coords_to_sample: Optional[Dict[str, ValueSelection]] = None,
) -> xr.DataArray:
"""
Extract specified variable from the provided input dataset. If
coordinates for subsetting are defined, then subset the variable along
Expand All @@ -8,19 +19,20 @@ def extract_variable(ds, var_name, coords_to_sample=dict()):
----------
ds : xr.Dataset
Input dataset
var_name : Union[Dict, List]
Either a list or dictionary with variables to extract.
If a dictionary the keys are the variable name and the values are
entries for each coordinate and coordinate values to extract
coords_to_sample: Dict
var_name : str
Name of the variable to extract from the dataset
coords_to_sample: Dict[str, ValueSelection], optional
Optional argument for subsetting/sampling along the specified
coordinates
coordinates. Keys are coordinate names, values are ValueSelection
objects defining the values to select and optionally the units.

Returns
----------
da: xr.DataArray
Extracted variable (subsetted along the specified coordinates)
"""
if coords_to_sample is None:
coords_to_sample = {}

try:
da = ds[var_name]
Expand Down