diff --git a/README.md b/README.md index efa5641..64db313 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,7 @@ The available `install_options` are: For developers we suggest installing with the editable flag `-e` and the verbose flag `-v`: ```shell -pip install -e -v .[dev] +pip install -v -e .[dev] ``` use of the editable installation method will allow updates to the Python source code to be reflected in the installation without re-installing. Subsequent installs of `pyFMS` will recompile all extension modules, due to the methods of compilation used by `scikit-build-core`. diff --git a/pyfms/tools/README.md b/pyfms/tools/README.md new file mode 100644 index 0000000..9d1ab5a --- /dev/null +++ b/pyfms/tools/README.md @@ -0,0 +1,150 @@ +# generate-history + +Generates structurally faithful FMS raw history NetCDF files using the pyfms +diag manager. Useful for creating test fixtures for climate data processing +pipelines without needing a full model run. + +`generate-history` is installed as a command when pyfms is pip-installed. + +--- + +## Usage + +Write a `diag_table.yaml` in the standard FMS format, then run one of: + +```tcsh +# Regular rectangular grid +generate-history diag_table.yaml \ + --nx 96 --ny 96 [--nz 65] \ + --calendar noleap \ + --nsteps 720 \ + --output-dir ./output + +# Cubed-sphere atmosphere (C96, 6 tiles) +mpirun -n 6 generate-history diag_table.yaml \ + --grid-type cubed-sphere --ntile 96 [--nz 65] \ + --calendar noleap \ + --nsteps 720 \ + --output-dir ./output + +# Tripolar ocean +generate-history diag_table.yaml \ + --grid-type tripolar --nx 1440 --ny 1080 [--nz 75] \ + --calendar noleap \ + --nsteps 720 \ + --q-vars uo,vo \ + --output-dir ./output +``` + +### All options + +| Option | Required | Description | +|---|---|---| +| `diag_table.yaml` | yes | Path to your diag_table.yaml | +| `--grid-type` | no | `regular` (default), `cubed-sphere`, or `tripolar` | +| `--nx`, `--ny` | for regular/tripolar | Horizontal grid dimensions | +| `--ntile` | for cubed-sphere | Tile size (e.g. 96 for C96, 48 for C48) | +| `--nz N` | no | Vertical levels; omit or 0 for all-2D output | +| `--calendar` | no | `noleap` (default), `julian`, `gregorian`, `thirty_day` | +| `--nsteps` | yes | Number of 1-hour steps to simulate | +| `--output-dir` | no | Where to write output (default: `./output`) | +| `--seed` | no | Random seed for reproducible data (default: 0) | +| `--q-vars` | no | Comma-separated var_names on the q-grid (tripolar only) | + +The internal model timestep is fixed at 1 hour. Output frequency is driven by +the `freq` field in the diag_table.yaml. Data values are random but reproducible. + +--- + +## Grid types + +### `regular` (default) +Standard rectangular grid. Produces a single output file per diag_files entry. + +### `cubed-sphere` +Six-tile gnomonic equal-angle cubed-sphere (GFDL FMS convention). Requires +`mpirun -n 6`. Produces: +- `{file_name}.tile1.nc` … `{file_name}.tile6.nc` — history files +- `C{ntile}_mosaic.nc` — FMS mosaic descriptor (for fregrid) +- `C{ntile}_grid.tile{1..6}.nc` — per-tile supergrid files (lat/lon/area/dx/dy) +- `C{ntile}_scrip.nc` — combined SCRIP file (for ESMF/esmpy) + +### `tripolar` +Single-tile tripolar ocean grid with h-point (tracer) and q-point (velocity) +axes. Produces: +- `{file_name}.nc` — history file with `xh`, `yh`, `xq`, `yq` dimensions +- `ocean_mosaic.nc` — FMS mosaic descriptor (for fregrid) +- `ocean_hgrid.nc` — supergrid file (lat/lon/area/dx/dy) +- `ocean_scrip.nc` — SCRIP file (for ESMF/esmpy) + +Variables named in `--q-vars` are placed on the q-grid; all others default +to the h-grid. + +--- + +## diag_table.yaml format + +```yaml +title: my_test +base_date: 2000 1 1 0 0 0 + +diag_files: +- file_name: atmos_month + freq: 1 months + time_units: hours + unlimdim: time + varlist: + - module: atm_mod + var_name: tas + reduction: average + kind: r4 + output_name: tas + - module: atm_mod + var_name: ua + reduction: average + kind: r4 + output_name: ua +``` + +`base_date` sets the simulation start time (year month day hour minute second). +`--nsteps 720` with a 1-hour internal step covers 30 days from that start time. + +--- + +## Checking output + +```python +import xarray as xr + +# Regular or tripolar +ds = xr.open_dataset("output/atmos_month.nc") +print(ds) + +# Cubed-sphere (one tile) +ds = xr.open_dataset("output/atmos_month.tile1.nc") +print(ds) + +# Grid spec +ds = xr.open_dataset("output/C96_mosaic.nc") +``` + +--- + +## Troubleshooting + +**No output files produced** +→ The simulation must run long enough to cross at least one output boundary. +Check that `--nsteps` × 1 hour exceeds the `freq` of every file in the diag_table. + +**`import pyfms` fails** +→ The pyfms venv is not active, or the required environment modules (gcc, mpich, +netcdf-c, netcdf-fortran) were not loaded. Load modules and activate the venv +before running. + +**cubed-sphere: `RuntimeError: requires exactly 6 MPI ranks`** +→ Run with `mpirun -n 6 generate-history ...` + +**cubed-sphere: no `.tileN.nc` files appear** +→ The cubic mosaic domain integration with diag_manager may need verification. +Check that `define_cubic_mosaic` is set as the current domain before calling +`diag_manager.init`. diff --git a/pyfms/tools/__init__.py b/pyfms/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyfms/tools/generate_history.py b/pyfms/tools/generate_history.py new file mode 100644 index 0000000..dfab96f --- /dev/null +++ b/pyfms/tools/generate_history.py @@ -0,0 +1,593 @@ +#!/usr/bin/env python3 +""" +Generate structurally faithful FMS raw history files using pyfms diag_manager. + +Reads a diag_table.yaml to determine output files and variables, then runs +the FMS diag manager to produce real NetCDF output with reproducible random data. + +The internal model timestep is fixed at 1 hour (3600 s). Output scheduling is +driven by the 'freq' field in the diag_table.yaml, as in a real FMS run. + +Usage examples: + # Regular rectangular grid + generate-history diag_table.yaml --nx 96 --ny 96 [--nz 33] \\ + --calendar noleap --nsteps 720 + + # Cubed-sphere atmosphere (requires mpirun -n 6) + mpirun -n 6 generate-history diag_table.yaml \\ + --grid-type cubed-sphere --ntile 96 --nz 65 \\ + --calendar noleap --nsteps 720 + + # Tripolar ocean + generate-history diag_table.yaml \\ + --grid-type tripolar --nx 1440 --ny 1080 --nz 75 \\ + --calendar noleap --nsteps 720 \\ + --q-vars uo,vo + +--nsteps is the number of 1-hour steps to simulate (e.g. 720 = 30 days). +If --nz is absent or 0, all variables are 2-D (y, x). +If --nz > 0, all variables are 3-D (z, y, x). +""" + +from __future__ import annotations + +import argparse +import os +import shutil +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +import numpy as np +import yaml + +import pyfms +from pyfms.tools.grid_spec import ( + stamp_cubed_sphere_history, + stamp_tripolar_history, + write_cubed_sphere_gridspec, + write_tripolar_gridspec, +) + + +INTERNAL_TIMESTEP_SECONDS = 3600 + +CALENDAR_MAP = { + "NOLEAP": lambda: pyfms.fms.NOLEAP, + "JULIAN": lambda: pyfms.fms.JULIAN, + "GREGORIAN": lambda: pyfms.fms.GREGORIAN, + "THIRTY_DAY": lambda: pyfms.fms.THIRTY_DAY_MONTHS, +} + +KIND_MAP = { + "r4": "float32", + "r8": "float64", +} + + +def parse_base_date(base_date_str: str | int) -> datetime: + """Parse an FMS base_date string into a datetime object. + + Args: + base_date_str: Six space-separated integers representing + year, month, day, hour, minute, second (as produced by + yaml.safe_load of an FMS diag_table.yaml base_date field). + + Returns: + Corresponding datetime object. + + Raises: + ValueError: If the string does not contain exactly six fields. + """ + parts = str(base_date_str).strip().split() + if len(parts) != 6: + raise ValueError( + f"base_date must have 6 space-separated fields, got: {base_date_str!r}" + ) + y, mo, d, h, mi, s = (int(p) for p in parts) + return datetime(y, mo, d, h, mi, s) + + +def collect_vars(diag_table: dict) -> list[tuple[str, str, str]]: + """Collect unique (module, var_name, dtype) tuples from a diag_table dict. + + Deduplicates across all diag_files entries. Variables appearing in + multiple files are registered once and sent to all matching outputs + by the diag manager. + + Args: + diag_table: Parsed diag_table.yaml content as returned by + yaml.safe_load. + + Returns: + List of (module_name, var_name, numpy_dtype_str) tuples, where + numpy_dtype_str is one of 'float32' or 'float64'. + """ + seen: dict[tuple[str, str], str] = {} + for f in diag_table.get("diag_files", []): + for v in f.get("varlist", []): + key = (v["module"], v["var_name"]) + if key not in seen: + kind = v.get("kind", "r4") + seen[key] = KIND_MAP.get(kind, "float32") + return [(module, var_name, dtype) for (module, var_name), dtype in seen.items()] + + +def collect_stems(diag_table: dict) -> list[str]: + """Return the file_name stem from every diag_files entry. + + Args: + diag_table: Parsed diag_table.yaml content as returned by + yaml.safe_load. + + Returns: + List of file_name strings (one per diag_files entry). + """ + return [f["file_name"] for f in diag_table.get("diag_files", [])] + + +def _setup_regular_domain(nx: int, ny: int) -> tuple[int, dict[str, int]]: + """Set up a rectangular 2-D MPI domain. + + Args: + nx: Global grid size in x. + ny: Global grid size in y. + + Returns: + Tuple of (domain_id, compute) where compute maps 'isc', 'iec', + 'jsc', 'jec' to their 1-indexed values for the local PE. + """ + npes = pyfms.mpp.npes() + domain = pyfms.mpp_domains.define_domains( + global_indices=[0, nx - 1, 0, ny - 1], + layout=[1, npes], + ) + pyfms.mpp_domains.define_io_domain(domain_id=domain.domain_id, io_layout=[1, 1]) + return domain.domain_id, { + "isc": domain.isc, + "iec": domain.iec, + "jsc": domain.jsc, + "jec": domain.jec, + } + + +def _setup_cubed_sphere_domain(ntile: int) -> tuple[int, dict[str, int]]: + """Set up a cubic mosaic domain. + + Requires exactly 6 MPI ranks (one per tile). + + Args: + ntile: Tile grid size (e.g. 96 for C96). + + Returns: + Tuple of (domain_id, compute) where compute maps 'isc', 'iec', + 'jsc', 'jec' to their 1-indexed values for the local PE. + + Raises: + RuntimeError: If the number of MPI ranks is not 6. + """ + npes = pyfms.mpp.npes() + if npes != 6: + raise RuntimeError( + f"cubed-sphere grid type requires exactly 6 MPI ranks, got {npes}. " + "Run with: mpirun -n 6 generate-history ..." + ) + domain_id = pyfms.mpp_domains.define_cubic_mosaic( + ni=[ntile] * 6, + nj=[ntile] * 6, + global_indices=[1, ntile, 1, ntile], + layout=[1, 1], + ntiles=6, + ) + pyfms.mpp_domains.define_io_domain(domain_id=domain_id, io_layout=[1, 1]) + compute = pyfms.mpp_domains.get_compute_domain(domain_id=domain_id) + return domain_id, { + "isc": compute["isc"], + "iec": compute["iec"], + "jsc": compute["jsc"], + "jec": compute["jec"], + } + + +def _register_regular_axes( + nx: int, ny: int, nz: int, domain_id: int, set_name: str = "atm" +) -> dict[str, int]: + """Register x, y and optionally z diag axes for a regular grid. + + Args: + nx: Grid size in x. + ny: Grid size in y. + nz: Number of vertical levels; 0 means no z axis. + domain_id: MPI domain id from define_domains / define_cubic_mosaic. + set_name: Axis set name passed to axis_init (e.g. 'atm'). + + Returns: + Dict with keys 'x', 'y', and 'z' (None when nz == 0) mapping to + the integer axis ids returned by axis_init. + """ + id_x = pyfms.diag_manager.axis_init( + name="x", + axis_data=np.arange(nx, dtype=np.float64), + units="point_E", + cart_name="x", + domain_id=domain_id, + long_name="point_E", + set_name=set_name, + ) + id_y = pyfms.diag_manager.axis_init( + name="y", + axis_data=np.arange(ny, dtype=np.float64), + units="point_N", + cart_name="y", + domain_id=domain_id, + long_name="point_N", + set_name=set_name, + ) + id_z = None + if nz > 0: + id_z = pyfms.diag_manager.axis_init( + name="z", + axis_data=np.arange(nz, dtype=np.float64), + units="point_Z", + cart_name="z", + long_name="point_Z", + set_name=set_name, + not_xy=True, + ) + return {"x": id_x, "y": id_y, "z": id_z} + + +def _register_cubed_sphere_axes(ntile: int, nz: int, domain_id: int) -> dict[str, int]: + """Register x, y and optionally z diag axes for a cubed-sphere grid. + + Passes tile_count=6 so FMS writes separate tile output files. + + Args: + ntile: Tile grid size. + nz: Number of vertical levels; 0 means no z axis. + domain_id: Cubic mosaic domain id from define_cubic_mosaic. + + Returns: + Dict with keys 'x', 'y', and 'z' (None when nz == 0). + """ + id_x = pyfms.diag_manager.axis_init( + name="x", + axis_data=np.arange(ntile, dtype=np.float64), + units="point_E", + cart_name="x", + domain_id=domain_id, + long_name="point_E", + set_name="atm", + tile_count=6, + ) + id_y = pyfms.diag_manager.axis_init( + name="y", + axis_data=np.arange(ntile, dtype=np.float64), + units="point_N", + cart_name="y", + domain_id=domain_id, + long_name="point_N", + set_name="atm", + tile_count=6, + ) + id_z = None + if nz > 0: + id_z = pyfms.diag_manager.axis_init( + name="z", + axis_data=np.arange(nz, dtype=np.float64), + units="point_Z", + cart_name="z", + long_name="point_Z", + set_name="atm", + not_xy=True, + ) + return {"x": id_x, "y": id_y, "z": id_z} + + +def _register_tripolar_axes( + nx: int, ny: int, nz: int, domain_id: int +) -> dict[str, int | None]: + """Register xh/yh (h-grid) and xq/yq (q-grid) axes for a tripolar ocean grid. + + The h-grid (tracer/T-grid) axes are centered at integer positions. + The q-grid (velocity/U-grid) axes are offset by +0.5 (staggered). + + Args: + nx: Grid size in x. + ny: Grid size in y. + nz: Number of vertical levels; 0 means no z axis. + domain_id: MPI domain id from define_domains. + + Returns: + Dict with keys 'xh', 'yh', 'xq', 'yq', and 'z' (None when nz == 0). + """ + id_xh = pyfms.diag_manager.axis_init( + name="xh", + axis_data=np.arange(1.0, nx + 1.0, dtype=np.float64), + units="degree_east", + cart_name="x", + domain_id=domain_id, + long_name="h-point longitude index", + set_name="ocean", + ) + id_yh = pyfms.diag_manager.axis_init( + name="yh", + axis_data=np.arange(1.0, ny + 1.0, dtype=np.float64), + units="degree_north", + cart_name="y", + domain_id=domain_id, + long_name="h-point latitude index", + set_name="ocean", + ) + id_xq = pyfms.diag_manager.axis_init( + name="xq", + axis_data=np.arange(0.5, nx + 0.5, dtype=np.float64), + units="degree_east", + cart_name="x", + domain_id=domain_id, + long_name="q-point longitude index", + set_name="ocean", + ) + id_yq = pyfms.diag_manager.axis_init( + name="yq", + axis_data=np.arange(0.5, ny + 0.5, dtype=np.float64), + units="degree_north", + cart_name="y", + domain_id=domain_id, + long_name="q-point latitude index", + set_name="ocean", + ) + id_z = None + if nz > 0: + id_z = pyfms.diag_manager.axis_init( + name="z_l", + axis_data=np.arange(nz, dtype=np.float64), + units="m", + cart_name="z", + long_name="Layer pseudo-depth", + set_name="ocean", + not_xy=True, + ) + return {"xh": id_xh, "yh": id_yh, "xq": id_xq, "yq": id_yq, "z": id_z} + + +def _register_fields( + vars_list: list[tuple[str, str, str]], + h_axes: list[int], + q_axes: list[int], + q_var_names: set[str], + start_time: datetime, +) -> list[tuple[int, str]]: + """Register all diagnostic fields with the diag manager. + + Args: + vars_list: List of (module_name, var_name, dtype) tuples. + h_axes: Axis id list for h-point (or default) fields. + q_axes: Axis id list for q-point fields. + q_var_names: Set of var_names to place on the q-grid. + start_time: Simulation start time for init_time. + + Returns: + List of (field_id, dtype_str) tuples in the same order as vars_list. + """ + field_ids = [] + for module_name, var_name, dtype in vars_list: + axes = q_axes if var_name in q_var_names else h_axes + fid = pyfms.diag_manager.register_field_array( + module_name=module_name, + field_name=var_name, + dtype=dtype, + axes=list(axes), + long_name=var_name, + units="none", + missing_value=-99.99, + range_data=np.array([-1e6, 1e6], dtype=dtype), + init_time=start_time, + ) + field_ids.append((fid, dtype)) + return field_ids + + +def _run_time_loop( + field_ids: list[tuple[int, str]], + local_shape: tuple[int, ...], + nsteps: int, + start_time: datetime, + timestep: timedelta, + rng: np.random.Generator, +) -> None: + """Drive the FMS diag manager time loop. + + Generates random data for each field at each timestep, sends it to the + diag manager, and advances the simulation clock. + + Args: + field_ids: List of (field_id, dtype_str) pairs from _register_fields. + local_shape: Shape of the local (per-PE) data array. + nsteps: Total number of timesteps to simulate. + start_time: Simulation start time (used to track current time). + timestep: Duration of each simulation step. + rng: NumPy random generator for reproducible synthetic data. + """ + curr_time = start_time + for _ in range(nsteps): + curr_time = curr_time + timestep + for fid, dtype in field_ids: + field = rng.random(local_shape, dtype=np.float64).astype(dtype) + pyfms.diag_manager.send_data( + diag_field_id=fid, + field=field, + time=curr_time, + ) + pyfms.diag_manager.send_complete(timestep) + + +def main() -> None: + """Entry point for the generate-history CLI tool.""" + parser = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("diag_table", help="Path to diag_table.yaml") + parser.add_argument( + "--grid-type", + default="regular", + choices=["regular", "cubed-sphere", "tripolar"], + help="Grid type: regular (default), cubed-sphere, or tripolar", + ) + parser.add_argument("--nx", type=int, help="Grid points in x (regular / tripolar)") + parser.add_argument("--ny", type=int, help="Grid points in y (regular / tripolar)") + parser.add_argument( + "--ntile", + type=int, + help="Tile size for cubed-sphere (e.g. 96 for C96)", + ) + parser.add_argument( + "--nz", + type=int, + default=0, + help="Vertical levels (0 = all vars 2-D)", + ) + parser.add_argument( + "--calendar", + default="NOLEAP", + type=str.upper, + choices=list(CALENDAR_MAP), + ) + parser.add_argument( + "--nsteps", + type=int, + required=True, + help="Number of 1-hour steps to simulate", + ) + parser.add_argument( + "--output-dir", + default="./output", + help="Directory for output files (default: ./output)", + ) + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed for reproducible data (default: 0)", + ) + parser.add_argument( + "--q-vars", + default="", + help="Comma-separated var_names to assign to q-grid (tripolar only)", + ) + args = parser.parse_args() + + grid_type = args.grid_type + nz = args.nz + three_d = nz > 0 + timestep = timedelta(seconds=INTERNAL_TIMESTEP_SECONDS) + rng = np.random.default_rng(args.seed) + q_var_names: set[str] = set(v.strip() for v in args.q_vars.split(",") if v.strip()) + + # Validate dimension args + if grid_type == "cubed-sphere": + if args.ntile is None: + parser.error("--ntile is required for --grid-type cubed-sphere") + ntile = args.ntile + nx = ny = ntile + else: + if args.nx is None or args.ny is None: + parser.error("--nx and --ny are required for regular/tripolar grid types") + nx, ny = args.nx, args.ny + ntile = None + + with open(args.diag_table) as fh: + diag_table = yaml.safe_load(fh) + + start_time = parse_base_date(diag_table["base_date"]) + end_time = start_time + timestep * args.nsteps + + outdir = Path(args.output_dir).resolve() + outdir.mkdir(parents=True, exist_ok=True) + shutil.copy(args.diag_table, outdir / "diag_table.yaml") + (outdir / "input.nml").write_text( + "&diag_manager_nml\n use_modern_diag = .true.\n/\n" + ) + + os.chdir(outdir) + + pyfms.fms.init(calendar_type=CALENDAR_MAP[args.calendar]()) + + # ----------------------------------------------------------------------- + # Domain setup + # ----------------------------------------------------------------------- + if grid_type == "cubed-sphere": + domain_id, compute = _setup_cubed_sphere_domain(ntile) + else: + domain_id, compute = _setup_regular_domain(nx, ny) + + pyfms.diag_manager.init(diag_model_subset=pyfms.diag_manager.DIAG_ALL) + pyfms.mpp_domains.set_current_domain(domain_id=domain_id) + + # ----------------------------------------------------------------------- + # Axis registration + # ----------------------------------------------------------------------- + if grid_type == "cubed-sphere": + axes_dict = _register_cubed_sphere_axes(ntile, nz, domain_id) + xy_axes = [axes_dict["x"], axes_dict["y"]] + elif grid_type == "tripolar": + axes_dict = _register_tripolar_axes(nx, ny, nz, domain_id) + xy_axes = [axes_dict["xh"], axes_dict["yh"]] # default h-axes + xy_q_axes = [axes_dict["xq"], axes_dict["yq"]] + else: + axes_dict = _register_regular_axes(nx, ny, nz, domain_id) + xy_axes = [axes_dict["x"], axes_dict["y"]] + + if three_d and axes_dict["z"] is not None: + h_axes = xy_axes + [axes_dict["z"]] + q_axes = xy_q_axes + [axes_dict["z"]] if grid_type == "tripolar" else h_axes + else: + h_axes = xy_axes + q_axes = xy_q_axes if grid_type == "tripolar" else h_axes + + # ----------------------------------------------------------------------- + # Field registration + # ----------------------------------------------------------------------- + vars_list = collect_vars(diag_table) + if not vars_list: + raise ValueError("No variables found in diag_table.yaml") + + field_ids = _register_fields(vars_list, h_axes, q_axes, q_var_names, start_time) + + pyfms.diag_manager.set_time_end(end_time) + + # ----------------------------------------------------------------------- + # Local array shape (per-PE compute domain) + # ----------------------------------------------------------------------- + isize = compute["iec"] - compute["isc"] + 1 + jsize = compute["jec"] - compute["jsc"] + 1 + local_shape: tuple[int, ...] = (isize, jsize, nz) if three_d else (isize, jsize) + + # ----------------------------------------------------------------------- + # Time loop + # ----------------------------------------------------------------------- + _run_time_loop(field_ids, local_shape, args.nsteps, start_time, timestep, rng) + + pyfms.diag_manager.end(end_time) + pyfms.fms.end() + + # ----------------------------------------------------------------------- + # Post-processing: grid spec files and global attributes (root PE only) + # ----------------------------------------------------------------------- + if pyfms.mpp.pe() == pyfms.mpp.root_pe(): + stems = collect_stems(diag_table) + + if grid_type == "cubed-sphere": + write_cubed_sphere_gridspec(outdir, ntile) + stamp_cubed_sphere_history(outdir, ntile, stems) + + elif grid_type == "tripolar": + write_tripolar_gridspec(outdir, nx, ny) + stamp_tripolar_history(outdir, stems) + + print(f"Output written to: {outdir}") + + +if __name__ == "__main__": + main() diff --git a/pyfms/tools/grid_spec.py b/pyfms/tools/grid_spec.py new file mode 100644 index 0000000..92591f5 --- /dev/null +++ b/pyfms/tools/grid_spec.py @@ -0,0 +1,821 @@ +""" +Generate minimally functional FMS grid spec files for regridding. + +Produces NetCDF files compatible with fregrid (FMS native regridder). +SCRIP files are also written for ESMF/esmpy compatibility. + +Two grid types are supported: + cubed-sphere: C{ntile}_mosaic.nc + C{ntile}_grid.tile{N}.nc + C{ntile}_scrip.nc + tripolar: ocean_mosaic.nc + ocean_hgrid.nc + ocean_scrip.nc +""" + +from __future__ import annotations + +from pathlib import Path + +import netCDF4 as nc +import numpy as np +from numpy.typing import NDArray + + +EARTH_RADIUS = 6.371e6 + + +# --------------------------------------------------------------------------- +# Cubed-sphere geometry +# --------------------------------------------------------------------------- + + +def _face_vectors(tile: int) -> tuple[NDArray, NDArray, NDArray]: + """Return the face-normal (N), right-tangent (R), and up-tangent (U) unit + vectors for one of the 6 FMS gnomonic equal-angle cube faces. + + The tile ordering and orientation follow the GFDL/FMS convention produced + by ``make_hgrid --grid_type gnomonic_ed``: + tile 1 — equatorial, lon_center=350°; x=east, y=north + tile 2 — equatorial, lon_center=80°; x=east, y=north + tile 3 — north polar cap; x=toward 170°, y=toward 260° + tile 4 — equatorial, lon_center=170°; x=south (−z), y=toward 260° + tile 5 — equatorial, lon_center=260°; x=south (−z), y=toward 350° + tile 6 — south polar cap; x=toward 80°, y=toward 350° + + Tiles 4 and 5 have the x-axis pointing southward (−z) so that their + edges align correctly with tiles 1, 2, 3, and 6 at the cube contacts. + + The contact connectivity in the C{N}_mosaic.nc files produced by this + module is only valid for this specific orientation. + + Args: + tile: Integer in [1, 6]. + + Returns: + Tuple of three (3,) float64 unit vectors (N, R, U). + """ + if tile == 1: + lon = np.radians(350.0) + N = np.array([np.cos(lon), np.sin(lon), 0.0]) + R = np.array([-np.sin(lon), np.cos(lon), 0.0]) # east at lon=350° + U = np.array([0.0, 0.0, 1.0]) + elif tile == 2: + lon = np.radians(80.0) + N = np.array([np.cos(lon), np.sin(lon), 0.0]) + R = np.array([-np.sin(lon), np.cos(lon), 0.0]) # east at lon=80° + U = np.array([0.0, 0.0, 1.0]) + elif tile == 3: + N = np.array([0.0, 0.0, 1.0]) + R = np.array([np.cos(np.radians(170.0)), np.sin(np.radians(170.0)), 0.0]) + U = np.array([np.cos(np.radians(260.0)), np.sin(np.radians(260.0)), 0.0]) + elif tile == 4: + lon = np.radians(170.0) + N = np.array([np.cos(lon), np.sin(lon), 0.0]) + R = np.array([0.0, 0.0, -1.0]) # south (−z) + U = np.array([-np.sin(lon), np.cos(lon), 0.0]) # east at lon=170° → toward 260° + elif tile == 5: + lon = np.radians(260.0) + N = np.array([np.cos(lon), np.sin(lon), 0.0]) + R = np.array([0.0, 0.0, -1.0]) # south (−z) + U = np.array([-np.sin(lon), np.cos(lon), 0.0]) # east at lon=260° → toward 350° + elif tile == 6: + N = np.array([0.0, 0.0, -1.0]) + R = np.array([np.cos(np.radians(80.0)), np.sin(np.radians(80.0)), 0.0]) + U = np.array([np.cos(np.radians(350.0)), np.sin(np.radians(350.0)), 0.0]) + else: + raise ValueError(f"tile must be 1–6, got {tile}") + return N, R, U + + +def _gnomonic_tile_latlon(ntile: int, tile: int) -> tuple[NDArray, NDArray]: + """Compute gnomonic equal-angle supergrid lat/lon for one cube face. + + The supergrid has (2*ntile+1, 2*ntile+1) points — cell centers and corners + interleaved. The equal-angle parameterisation maps the angular interval + [-π/4, π/4] uniformly in both directions. + + Args: + ntile: Model grid size per tile (e.g. 96 for C96). + tile: Integer in [1, 6]. + + Returns: + Tuple (lat, lon) each of shape (2*ntile+1, 2*ntile+1) in degrees. + lon is in [0°, 360°). + """ + sn = 2 * ntile + ang = np.linspace(-np.pi / 4, np.pi / 4, sn + 1) + N, R, U = _face_vectors(tile) + A, B = np.meshgrid(ang, ang) # (sn+1, sn+1); A varies along i, B along j + ta = np.tan(A)[..., np.newaxis] + tb = np.tan(B)[..., np.newaxis] + P = N + ta * R + tb * U # (sn+1, sn+1, 3) + r = np.linalg.norm(P, axis=-1, keepdims=True) + P = P / r + lat = np.degrees(np.arcsin(np.clip(P[..., 2], -1.0, 1.0))) + lon = np.degrees(np.arctan2(P[..., 1], P[..., 0])) % 360.0 + return lat, lon + + +# --------------------------------------------------------------------------- +# Shared geometry helpers +# --------------------------------------------------------------------------- + + +def _haversine(lat1: NDArray, lon1: NDArray, lat2: NDArray, lon2: NDArray) -> NDArray: + """Great circle distance in metres between arrays of (lat, lon) pairs. + + Args: + lat1, lon1: Starting coordinates in degrees. + lat2, lon2: Ending coordinates in degrees. + + Returns: + Array of distances in metres, same shape as inputs. + """ + d_lat = np.radians(lat2 - lat1) + d_lon = np.radians(lon2 - lon1) + lat1_r = np.radians(lat1) + lat2_r = np.radians(lat2) + a = ( + np.sin(d_lat / 2) ** 2 + + np.cos(lat1_r) * np.cos(lat2_r) * np.sin(d_lon / 2) ** 2 + ) + return 2.0 * EARTH_RADIUS * np.arcsin(np.sqrt(np.clip(a, 0.0, 1.0))) + + +def _to_cartesian(lat_deg: NDArray, lon_deg: NDArray) -> NDArray: + """Convert (lat, lon) degrees to 3-D Cartesian unit vectors. + + Args: + lat_deg: Latitude array in degrees. + lon_deg: Longitude array in degrees. + + Returns: + Array of shape (*lat_deg.shape, 3). + """ + lat = np.radians(lat_deg) + lon = np.radians(lon_deg) + return np.stack( + [np.cos(lat) * np.cos(lon), np.cos(lat) * np.sin(lon), np.sin(lat)], axis=-1 + ) + + +def _triangle_solid_angle(a: NDArray, b: NDArray, c: NDArray) -> NDArray: + """Solid angle (steradians) of a spherical triangle with vertices a, b, c. + + Uses the formula: Omega = 2 * arctan( |a·(b×c)| / (1 + a·b + b·c + a·c) ). + + Args: + a, b, c: Arrays of shape (..., 3) — Cartesian unit vectors. + + Returns: + Array of solid angles in steradians, shape (...,). + """ + bxc = np.cross(b, c) + num = np.abs(np.einsum("...i,...i->...", a, bxc)) + den = 1.0 + ( + np.einsum("...i,...i->...", a, b) + + np.einsum("...i,...i->...", b, c) + + np.einsum("...i,...i->...", a, c) + ) + return 2.0 * np.arctan2(num, den) + + +def _cell_areas(lat: NDArray, lon: NDArray) -> NDArray: + """Cell areas in m² for a supergrid lat/lon array. + + Splits each quadrilateral cell into two triangles and sums + spherical excess areas (Girard's theorem). + + Args: + lat, lon: Arrays of shape (M+1, N+1) — supergrid corners. + + Returns: + Array of shape (M, N) — one area per cell in m². + """ + v = _to_cartesian(lat, lon) + v1 = v[:-1, :-1] # bottom-left + v2 = v[:-1, 1:] # bottom-right + v3 = v[1:, 1:] # top-right + v4 = v[1:, :-1] # top-left + a1 = _triangle_solid_angle(v1, v2, v3) + a2 = _triangle_solid_angle(v1, v3, v4) + return (a1 + a2) * EARTH_RADIUS ** 2 + + +def _dx_dy(lat: NDArray, lon: NDArray) -> tuple[NDArray, NDArray]: + """Edge lengths in metres on a supergrid. + + Args: + lat, lon: Supergrid arrays of shape (nyp, nxp). + + Returns: + Tuple (dx, dy): + dx — shape (nyp, nx) — zonal edge lengths (i-direction) + dy — shape (ny, nxp) — meridional edge lengths (j-direction) + """ + dx = _haversine(lat[:, :-1], lon[:, :-1], lat[:, 1:], lon[:, 1:]) + dy = _haversine(lat[:-1, :], lon[:-1, :], lat[1:, :], lon[1:, :]) + return dx, dy + + +def _angle_dx(lat: NDArray, lon: NDArray) -> NDArray: + """Approximate grid rotation angle (degrees east) at every supergrid point. + + Defined as the angle between the local x-axis of the grid and geographic + east. Computed from the bearing between adjacent x-direction points. + + Args: + lat, lon: Supergrid arrays of shape (nyp, nxp). + + Returns: + Array of shape (nyp, nxp) in degrees. + """ + # Forward bearing at each interior/edge point (central/forward difference) + lat_r = np.radians(lat) + lon_r = np.radians(lon) + # Use central difference except at right boundary + dlat = np.diff(lat_r, axis=1) # (nyp, nxp-1) + dlon = np.diff(lon_r, axis=1) + lat_mid = (lat_r[:, :-1] + lat_r[:, 1:]) / 2 + bearing = np.degrees( + np.arctan2( + np.sin(dlon) * np.cos(lat_r[:, 1:]), + np.cos(lat_mid) * np.sin(lat_r[:, 1:]) + - np.sin(lat_mid) * np.cos(lat_r[:, 1:]) * np.cos(dlon), + ) + ) # (nyp, nxp-1) + # Pad right edge by repeating the last column + angle = np.concatenate([bearing, bearing[:, -1:]], axis=1) + return angle + + +# --------------------------------------------------------------------------- +# Mosaic file writers +# --------------------------------------------------------------------------- + + +def _write_char_var(ncvar: nc.Variable, value: str) -> None: + """Write a string into a (string,) char variable, null-padded.""" + n = ncvar.shape[0] + b = value.encode("ascii") + arr = np.zeros(n, dtype="S1") + arr[: len(b)] = np.frombuffer(b, dtype="S1") + ncvar[:] = arr + + +def _write_char_array(ncvar: nc.Variable, values: list[str]) -> None: + """Write strings into a (n, string) char variable, null-padded.""" + n = ncvar.shape[1] + for i, v in enumerate(values): + b = v.encode("ascii") + arr = np.zeros(n, dtype="S1") + arr[: len(b)] = np.frombuffer(b, dtype="S1") + ncvar[i, :] = arr + + +def _write_grid_tile_nc( + path: Path, lat: NDArray, lon: NDArray, *, projection: str = "cube_gnomonic" +) -> None: + """Write a single grid tile NetCDF file (FMS hgrid format). + + Args: + path: Output file path. + lat, lon: Supergrid arrays of shape (nyp, nxp) in degrees. + projection: Value for the ``tile:projection`` attribute. + """ + nyp, nxp = lat.shape + ny, nx = nyp - 1, nxp - 1 + + dx, dy = _dx_dy(lat, lon) + area = _cell_areas(lat, lon) + angle = _angle_dx(lat, lon) + + ds = nc.Dataset(path, "w", format="NETCDF3_64BIT_OFFSET") + try: + ds.createDimension("string", 255) + ds.createDimension("nx", nx) + ds.createDimension("ny", ny) + ds.createDimension("nxp", nxp) + ds.createDimension("nyp", nyp) + + vt = ds.createVariable("tile", "c", ("string",)) + vt.standard_name = "grid_tile_spec" + vt.geometry = "spherical" + vt.north_pole = "0.0 90.0" + vt.projection = projection + vt.discretization = "logically_rectangular" + vt.conformal = "FALSE" + _write_char_var(vt, "tile1") + + vx = ds.createVariable("x", "f8", ("nyp", "nxp")) + vx.standard_name = "geographic_longitude" + vx.units = "degree_east" + vx[:] = lon + + vy = ds.createVariable("y", "f8", ("nyp", "nxp")) + vy.standard_name = "geographic_latitude" + vy.units = "degree_north" + vy[:] = lat + + vdx = ds.createVariable("dx", "f8", ("nyp", "nx")) + vdx.standard_name = "grid_edge_x_distance" + vdx.units = "meters" + vdx[:] = dx + + vdy = ds.createVariable("dy", "f8", ("ny", "nxp")) + vdy.standard_name = "grid_edge_y_distance" + vdy.units = "meters" + vdy[:] = dy + + va = ds.createVariable("area", "f8", ("ny", "nx")) + va.standard_name = "grid_cell_area" + va.units = "m2" + va[:] = area + + vang = ds.createVariable("angle_dx", "f8", ("nyp", "nxp")) + vang.standard_name = "grid_vertex_x_angle_WRT_geographic_east" + vang.units = "degrees_east" + vang[:] = angle + + ds.grid_version = "0.2" + finally: + ds.close() + + +def _write_mosaic_nc( + path: Path, + mosaic_name: str, + gridlocation: str, + gridfiles: list[str], + gridtiles: list[str], + contacts: list[str], + contact_index: list[str], +) -> None: + """Write a FMS mosaic descriptor NetCDF file. + + Args: + path: Output file path. + mosaic_name: Value of the ``mosaic`` variable. + gridlocation: Directory containing grid tile files. + gridfiles: List of tile filename strings. + gridtiles: List of tile name strings (e.g. ["tile1", ...]). + contacts: List of contact specifier strings. + contact_index: List of contact index strings (same length as contacts). + """ + ntiles = len(gridfiles) + ncontact = len(contacts) + + ds = nc.Dataset(path, "w", format="NETCDF3_64BIT_OFFSET") + try: + ds.createDimension("ntiles", ntiles) + ds.createDimension("ncontact", ncontact) + ds.createDimension("string", 255) + + vm = ds.createVariable("mosaic", "c", ("string",)) + vm.standard_name = "grid_mosaic_spec" + vm.children = "gridtiles" + vm.contact_regions = "contacts" + vm.grid_descriptor = "" + _write_char_var(vm, mosaic_name) + + vgl = ds.createVariable("gridlocation", "c", ("string",)) + vgl.standard_name = "grid_file_location" + _write_char_var(vgl, gridlocation) + + vgf = ds.createVariable("gridfiles", "c", ("ntiles", "string")) + _write_char_array(vgf, gridfiles) + + vgt = ds.createVariable("gridtiles", "c", ("ntiles", "string")) + _write_char_array(vgt, gridtiles) + + vc = ds.createVariable("contacts", "c", ("ncontact", "string")) + vc.standard_name = "grid_contact_spec" + vc.contact_type = "boundary" + vc.alignment = "true" + vc.contact_index = "contact_index" + vc.orientation = "orient" + _write_char_array(vc, contacts) + + vci = ds.createVariable("contact_index", "c", ("ncontact", "string")) + vci.standard_name = "starting_ending_point_index_of_contact" + _write_char_array(vci, contact_index) + + ds.grid_version = "0.2" + finally: + ds.close() + + +def _write_scrip_nc(path: Path, lat: NDArray, lon: NDArray) -> None: + """Write a SCRIP-format grid file derived from a supergrid. + + Cell centers come from every other supergrid point (even indices), + and cell corners from the 4 surrounding odd-indexed points. + + Args: + path: Output file path. + lat, lon: Supergrid arrays of shape (nyp, nxp) in degrees, + where nyp = 2*ny+1, nxp = 2*nx+1. + """ + nyp, nxp = lat.shape + ny, nx = (nyp - 1) // 2, (nxp - 1) // 2 + grid_size = ny * nx + + # Centers: supergrid even-indexed points (2i, 2j) for i in 0..ny-1, j in 0..nx-1 + lat_ctr = lat[1::2, 1::2].ravel() # (grid_size,) + lon_ctr = lon[1::2, 1::2].ravel() + + # Corners: 4 corners of each cell from supergrid + # For cell (iy, ix): corners at supergrid (2*iy, 2*ix), (2*iy, 2*ix+2), + # (2*iy+2, 2*ix+2), (2*iy+2, 2*ix) + iy = np.arange(ny) + ix = np.arange(nx) + IY, IX = np.meshgrid(iy, ix, indexing="ij") # (ny, nx) + si = 2 * IY # supergrid j index of bottom-left corner + sj = 2 * IX # supergrid i index of bottom-left corner + + lat_corners = np.stack( + [ + lat[si, sj].ravel(), + lat[si, sj + 2].ravel(), + lat[si + 2, sj + 2].ravel(), + lat[si + 2, sj].ravel(), + ], + axis=-1, + ) # (grid_size, 4) + + lon_corners = np.stack( + [ + lon[si, sj].ravel(), + lon[si, sj + 2].ravel(), + lon[si + 2, sj + 2].ravel(), + lon[si + 2, sj].ravel(), + ], + axis=-1, + ) + + ds = nc.Dataset(path, "w", format="NETCDF3_64BIT_OFFSET") + try: + ds.createDimension("grid_size", grid_size) + ds.createDimension("grid_corners", 4) + ds.createDimension("grid_rank", 2) + + vdims = ds.createVariable("grid_dims", "i4", ("grid_rank",)) + vdims[:] = [ny, nx] + + vcl = ds.createVariable("grid_center_lat", "f8", ("grid_size",)) + vcl.units = "degrees" + vcl[:] = lat_ctr + + vclo = ds.createVariable("grid_center_lon", "f8", ("grid_size",)) + vclo.units = "degrees" + vclo[:] = lon_ctr + + vkl = ds.createVariable("grid_corner_lat", "f8", ("grid_size", "grid_corners")) + vkl.units = "degrees" + vkl[:] = lat_corners + + vklo = ds.createVariable("grid_corner_lon", "f8", ("grid_size", "grid_corners")) + vklo.units = "degrees" + vklo[:] = lon_corners + + vmask = ds.createVariable("grid_imask", "i4", ("grid_size",)) + vmask.units = "unitless" + vmask[:] = np.ones(grid_size, dtype=np.int32) + + ds.title = "SCRIP grid file" + ds.conventions = "SCRIP" + finally: + ds.close() + + +# --------------------------------------------------------------------------- +# Public interface: cubed-sphere +# --------------------------------------------------------------------------- + + +def write_cubed_sphere_gridspec(outdir: Path | str, ntile: int) -> None: + """Write a complete cubed-sphere FMS grid spec to *outdir*. + + Produces: + C{ntile}_grid.tile{1..6}.nc — per-tile supergrid files + C{ntile}_mosaic.nc — FMS mosaic descriptor (for fregrid) + C{ntile}_scrip.nc — SCRIP file (for ESMF/esmpy, bonus) + + The geometry uses the gnomonic equal-angle projection with the GFDL/FMS + tile orientation convention, so the contact strings in the mosaic file + are valid for regridding with fregrid. + + Args: + outdir: Directory where files will be written (must exist). + ntile: Model grid size per tile side (e.g. 96 for C96). + """ + outdir = Path(outdir) + n = ntile + sn = 2 * n # supergrid size per tile side + mosaic_name = f"C{n}_mosaic" + + all_lat, all_lon = [], [] + + for tile in range(1, 7): + lat, lon = _gnomonic_tile_latlon(n, tile) + all_lat.append(lat) + all_lon.append(lon) + fname = f"C{n}_grid.tile{tile}.nc" + _write_grid_tile_nc(outdir / fname, lat, lon, projection="cube_gnomonic") + + # Mosaic file — contact index strings use the supergrid size sn + contacts = [ + f"{mosaic_name}:tile1::{mosaic_name}:tile2", + f"{mosaic_name}:tile1::{mosaic_name}:tile3", + f"{mosaic_name}:tile1::{mosaic_name}:tile5", + f"{mosaic_name}:tile1::{mosaic_name}:tile6", + f"{mosaic_name}:tile2::{mosaic_name}:tile3", + f"{mosaic_name}:tile2::{mosaic_name}:tile4", + f"{mosaic_name}:tile2::{mosaic_name}:tile6", + f"{mosaic_name}:tile3::{mosaic_name}:tile4", + f"{mosaic_name}:tile3::{mosaic_name}:tile5", + f"{mosaic_name}:tile4::{mosaic_name}:tile5", + f"{mosaic_name}:tile4::{mosaic_name}:tile6", + f"{mosaic_name}:tile5::{mosaic_name}:tile6", + ] + contact_index = [ + f"{sn}:{sn},1:{sn}::1:1,1:{sn}", + f"1:{sn},{sn}:{sn}::1:1,{sn}:1", + f"1:1,1:{sn}::{sn}:1,{sn}:{sn}", + f"1:{sn},1:1::1:{sn},{sn}:{sn}", + f"1:{sn},{sn}:{sn}::1:{sn},1:1", + f"{sn}:{sn},1:{sn}::{sn}:1,1:1", + f"1:{sn},1:1::{sn}:{sn},{sn}:1", + f"{sn}:{sn},1:{sn}::1:1,1:{sn}", + f"1:{sn},{sn}:{sn}::1:1,{sn}:1", + f"1:{sn},{sn}:{sn}::1:{sn},1:1", + f"{sn}:{sn},1:{sn}::{sn}:1,1:1", + f"{sn}:{sn},1:{sn}::1:1,1:{sn}", + ] + _write_mosaic_nc( + outdir / f"{mosaic_name}.nc", + mosaic_name=mosaic_name, + gridlocation="./", + gridfiles=[f"C{n}_grid.tile{t}.nc" for t in range(1, 7)], + gridtiles=[f"tile{t}" for t in range(1, 7)], + contacts=contacts, + contact_index=contact_index, + ) + + # SCRIP: combine all 6 tiles into one flat SCRIP file + scrip_path = outdir / f"C{n}_scrip.nc" + _write_cubed_sphere_scrip(scrip_path, all_lat, all_lon, ntile) + + +def _write_cubed_sphere_scrip( + path: Path, all_lat: list[NDArray], all_lon: list[NDArray], ntile: int +) -> None: + """Write a combined SCRIP file for all 6 cube tiles. + + Args: + path: Output file path. + all_lat, all_lon: Lists of 6 supergrid arrays, one per tile. + ntile: Model grid size per tile side. + """ + sn = 2 * ntile + cells_per_tile = ntile * ntile + grid_size = 6 * cells_per_tile + + lat_ctr = np.empty(grid_size) + lon_ctr = np.empty(grid_size) + lat_corners = np.empty((grid_size, 4)) + lon_corners = np.empty((grid_size, 4)) + + for t, (lat, lon) in enumerate(zip(all_lat, all_lon)): + start = t * cells_per_tile + end = start + cells_per_tile + + lat_ctr[start:end] = lat[1::2, 1::2].ravel() + lon_ctr[start:end] = lon[1::2, 1::2].ravel() + + iy = np.arange(ntile) + ix = np.arange(ntile) + IY, IX = np.meshgrid(iy, ix, indexing="ij") + si = 2 * IY + sj = 2 * IX + + lat_corners[start:end, 0] = lat[si, sj].ravel() + lat_corners[start:end, 1] = lat[si, sj + 2].ravel() + lat_corners[start:end, 2] = lat[si + 2, sj + 2].ravel() + lat_corners[start:end, 3] = lat[si + 2, sj].ravel() + lon_corners[start:end, 0] = lon[si, sj].ravel() + lon_corners[start:end, 1] = lon[si, sj + 2].ravel() + lon_corners[start:end, 2] = lon[si + 2, sj + 2].ravel() + lon_corners[start:end, 3] = lon[si + 2, sj].ravel() + + ds = nc.Dataset(path, "w", format="NETCDF3_64BIT_OFFSET") + try: + ds.createDimension("grid_size", grid_size) + ds.createDimension("grid_corners", 4) + ds.createDimension("grid_rank", 2) + + vdims = ds.createVariable("grid_dims", "i4", ("grid_rank",)) + vdims[:] = [6 * ntile, ntile] + + vcl = ds.createVariable("grid_center_lat", "f8", ("grid_size",)) + vcl.units = "degrees" + vcl[:] = lat_ctr + + vclo = ds.createVariable("grid_center_lon", "f8", ("grid_size",)) + vclo.units = "degrees" + vclo[:] = lon_ctr + + vkl = ds.createVariable("grid_corner_lat", "f8", ("grid_size", "grid_corners")) + vkl.units = "degrees" + vkl[:] = lat_corners + + vklo = ds.createVariable("grid_corner_lon", "f8", ("grid_size", "grid_corners")) + vklo.units = "degrees" + vklo[:] = lon_corners + + vmask = ds.createVariable("grid_imask", "i4", ("grid_size",)) + vmask.units = "unitless" + vmask[:] = np.ones(grid_size, dtype=np.int32) + + ds.title = f"Cubed-sphere C{ntile} SCRIP grid" + ds.conventions = "SCRIP" + finally: + ds.close() + + +# --------------------------------------------------------------------------- +# Public interface: tripolar ocean +# --------------------------------------------------------------------------- + + +def _tripolar_supergrid( + nx: int, ny: int, lat_south: float = -80.0, lat_bp: float = 65.0 +) -> tuple[NDArray, NDArray]: + """Generate a simplified tripolar ocean supergrid. + + The domain uses a Mercator-like regular spacing in the southern region + (lat_south to lat_bp) and a bipolar-fold approximation north of lat_bp. + The supergrid has (2*ny+1, 2*nx+1) points. + + Args: + nx: Number of model grid cells in x. + ny: Number of model grid cells in y. + lat_south: Southernmost latitude of the domain. + lat_bp: Latitude of the bipolar fold join. + + Returns: + Tuple (lat, lon) each of shape (2*ny+1, 2*nx+1) in degrees. + lon is in [-180°, 180°]. + """ + sny = 2 * ny + snx = 2 * nx + + # x-axis: uniform longitude spanning full 360° + lon_1d = np.linspace(-180.0, 180.0, snx + 1, endpoint=True) + + # y-axis: split between regular and bipolar regions + # Count supergrid rows allocated to each region (proportional to lat range) + lat_range_total = 90.0 - lat_south + lat_range_reg = lat_bp - lat_south + n_reg = max(2, int(round(sny * lat_range_reg / lat_range_total))) + n_bp = sny - n_reg + + # Regular region: Mercator-like, enhanced resolution near equator + lat_reg = np.linspace(lat_south, lat_bp, n_reg + 1) + + # Bipolar region north of lat_bp: use a simple conformal-like mapping + # The bipolar cap linearly interpolates lat from lat_bp to 90 but the + # longitude folds at the midpoint (i=snx/2). + lat_bp_1d = np.linspace(lat_bp, 90.0, n_bp + 1) + + # Concatenate lat_1d (drop the duplicate lat_bp row) + lat_1d = np.concatenate([lat_reg, lat_bp_1d[1:]]) # (sny+1,) + + # Build 2D arrays + LON, LAT = np.meshgrid(lon_1d, lat_1d) # (sny+1, snx+1) + + # Apply bipolar fold north of n_reg rows: fold x symmetrically + for j in range(n_reg, sny + 1): + frac = (j - n_reg) / max(1, n_bp) # 0 at join, 1 at pole + # Fold: mirror the right half of the longitude axis + fold_lon = lon_1d.copy() + fold_lon[snx // 2 + 1 :] = lon_1d[snx // 2 - 1 :: -1][: snx // 2] + LON[j, :] = fold_lon + frac * (0.0 - fold_lon) # blend toward 0 at pole + # Latitude still increases smoothly + LAT[j, :] = lat_1d[j] + + return LAT, LON + + +def write_tripolar_gridspec( + outdir: Path | str, nx: int, ny: int, lat_south: float = -80.0, lat_bp: float = 65.0 +) -> None: + """Write a complete tripolar ocean FMS grid spec to *outdir*. + + Produces: + ocean_hgrid.nc — supergrid file (FMS hgrid format, for fregrid) + ocean_mosaic.nc — FMS mosaic descriptor + ocean_scrip.nc — SCRIP file (for ESMF/esmpy, bonus) + + The tripolar grid uses a Mercator-like regular section south of + *lat_bp* and a simplified bipolar fold north of it. The longitude + is periodic (the eastern edge connects to the western edge). + + Args: + outdir: Directory where files will be written (must exist). + nx: Model grid size in x (number of longitude cells). + ny: Model grid size in y (number of latitude rows). + lat_south: Southern boundary latitude in degrees. + lat_bp: Latitude of the bipolar fold join in degrees. + """ + outdir = Path(outdir) + lat, lon = _tripolar_supergrid(nx, ny, lat_south, lat_bp) + + _write_grid_tile_nc(outdir / "ocean_hgrid.nc", lat, lon, projection="tripolar") + + snx = 2 * nx + sny = 2 * ny + contacts = [ + "ocean_mosaic:tile1::ocean_mosaic:tile1", + "ocean_mosaic:tile1::ocean_mosaic:tile1", + ] + contact_index = [ + f"{snx}:{snx},1:{sny}::1:1,1:{sny}", # periodic x + f"1:{snx // 2},{sny}:{sny}::{snx}:{snx // 2 + 1},{sny}:{sny}", # bipolar fold + ] + _write_mosaic_nc( + outdir / "ocean_mosaic.nc", + mosaic_name="ocean_mosaic", + gridlocation="./", + gridfiles=["ocean_hgrid.nc"], + gridtiles=["tile1"], + contacts=contacts, + contact_index=contact_index, + ) + + _write_scrip_nc(outdir / "ocean_scrip.nc", lat, lon) + + +# --------------------------------------------------------------------------- +# Post-processing: add global attributes to history NetCDF files +# --------------------------------------------------------------------------- + + +def stamp_cubed_sphere_history( + outdir: Path | str, ntile: int, file_stems: list[str] +) -> None: + """Add cubed-sphere grid global attributes to FMS history tile files. + + For each file stem in *file_stems*, opens + ``{outdir}/{stem}.tile{N}.nc`` (N = 1..6) and sets: + ``grid_type = "cubic_mosaic"`` + ``grid_tile = "N"`` + ``associated_files = "area: C{ntile}_mosaic.nc"`` + + Args: + outdir: Directory containing the history files. + ntile: Tile size (used to build the mosaic file name). + file_stems: List of output file name stems (without tile suffix). + """ + outdir = Path(outdir) + for stem in file_stems: + for t in range(1, 7): + path = outdir / f"{stem}.tile{t}.nc" + if not path.exists(): + continue + with nc.Dataset(path, "a") as ds: + ds.grid_type = "cubic_mosaic" + ds.grid_tile = str(t) + ds.associated_files = f"area: C{ntile}_mosaic.nc" + for vname, var in ds.variables.items(): + if vname not in ( + "time", + "time_bnds", + "average_T1", + "average_T2", + "average_DT", + ): + var.interp_method = "conserve_order1" + + +def stamp_tripolar_history(outdir: Path | str, file_stems: list[str]) -> None: + """Add tripolar grid global attributes to FMS ocean history files. + + Args: + outdir: Directory containing the history files. + file_stems: List of output file name stems. + """ + outdir = Path(outdir) + for stem in file_stems: + path = outdir / f"{stem}.nc" + if not path.exists(): + continue + with nc.Dataset(path, "a") as ds: + ds.grid_type = "tripolar" + ds.associated_files = "area: ocean_mosaic.nc" + for vname, var in ds.variables.items(): + if vname not in ( + "time", + "time_bnds", + "average_T1", + "average_T2", + "average_DT", + ): + var.interp_method = "conserve_order1" diff --git a/pyproject.toml b/pyproject.toml index 637e8a0..848e6df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ extras = [ "pyfms[dev]", ] +[project.scripts] +generate-history = "pyfms.tools.generate_history:main" + [tool.aliases] [tool.scikit-build]