|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +import rasterio |
| 6 | +from pystac import Item |
| 7 | +from rasterio.transform import from_gcps |
| 8 | + |
| 9 | +from earthdaily._eds_logging import LoggerConfig |
| 10 | +from earthdaily.datacube._builder import _replace_item_hrefs |
| 11 | +from earthdaily.datacube.constants import DEFAULT_HREF_PATH |
| 12 | + |
| 13 | +logger = LoggerConfig(logger_name=__name__).get_logger() |
| 14 | + |
| 15 | + |
| 16 | +class RasterMetadataEnricher: |
| 17 | + """Enriches STAC items with projection and raster band metadata by reading asset files via rasterio.""" |
| 18 | + |
| 19 | + @staticmethod |
| 20 | + def enrich_item( |
| 21 | + item: Item, |
| 22 | + *, |
| 23 | + force: bool = False, |
| 24 | + ) -> Item: |
| 25 | + """ |
| 26 | + Enrich a STAC item with projection and raster band metadata. |
| 27 | +
|
| 28 | + Opens each GeoTIFF asset with rasterio to extract CRS, transform, shape, |
| 29 | + and per-band information, then writes it back into the STAC item structure. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + item : Item |
| 34 | + STAC item to enrich. The input item is not mutated. |
| 35 | + force : bool, default False |
| 36 | + If False, skip assets that already carry projection/raster metadata. |
| 37 | + If True, overwrite existing values. |
| 38 | +
|
| 39 | + Returns |
| 40 | + ------- |
| 41 | + Item |
| 42 | + A **new** item with enriched metadata. |
| 43 | + """ |
| 44 | + item_dict = item.to_dict(transform_hrefs=False) |
| 45 | + resolved = _replace_item_hrefs([Item.from_dict(item_dict)], DEFAULT_HREF_PATH) |
| 46 | + item_dict = resolved[0].to_dict(transform_hrefs=False) |
| 47 | + |
| 48 | + RasterMetadataEnricher._enrich_item(item_dict, force) |
| 49 | + |
| 50 | + return Item.from_dict(item_dict) |
| 51 | + |
| 52 | + @staticmethod |
| 53 | + def _enrich_item(item: dict[str, Any], force: bool) -> None: |
| 54 | + all_gsd_values: list[float] = [] |
| 55 | + all_epsg_values: list[int] = [] |
| 56 | + all_epsg_code_values: list[str] = [] |
| 57 | + |
| 58 | + current_assets: dict[str, dict] = item.get("assets", {}) |
| 59 | + for asset_key, asset in current_assets.items(): |
| 60 | + if not asset.get("type", "").startswith("image/tiff") or "href" not in asset: |
| 61 | + continue |
| 62 | + |
| 63 | + has_projection = bool(asset.get("proj:code") or asset.get("proj:epsg") or asset.get("proj:wkt2")) |
| 64 | + has_raster = bool(asset.get("raster:bands")) |
| 65 | + |
| 66 | + if force or not has_projection or not has_raster: |
| 67 | + result = RasterMetadataEnricher._process_asset(asset) |
| 68 | + if result is None: |
| 69 | + continue |
| 70 | + |
| 71 | + proj_metadata, raster_metadata = result |
| 72 | + if proj_metadata is None: |
| 73 | + logger.warning("CRS retrieval failed for item %s, asset %s.", item.get("id"), asset_key) |
| 74 | + continue |
| 75 | + |
| 76 | + raster_data = None if not force and has_raster else raster_metadata |
| 77 | + RasterMetadataEnricher._update_asset_metadata(asset, proj_metadata, raster_data, force) |
| 78 | + else: |
| 79 | + RasterMetadataEnricher._copy_properties_to_asset(asset, item.get("properties", {})) |
| 80 | + |
| 81 | + gsd = asset.get("gsd") |
| 82 | + epsg = asset.get("proj:epsg") |
| 83 | + epsg_code = asset.get("proj:code") |
| 84 | + if gsd: |
| 85 | + all_gsd_values.append(gsd) |
| 86 | + if epsg: |
| 87 | + all_epsg_values.append(epsg) |
| 88 | + if epsg_code: |
| 89 | + all_epsg_code_values.append(epsg_code) |
| 90 | + |
| 91 | + RasterMetadataEnricher._update_item_properties( |
| 92 | + item, all_gsd_values, all_epsg_values, all_epsg_code_values, force |
| 93 | + ) |
| 94 | + |
| 95 | + @staticmethod |
| 96 | + def _extract_metadata(dataset: Any) -> tuple[dict | None, dict | None]: |
| 97 | + crs = dataset.crs |
| 98 | + transform = dataset.transform |
| 99 | + |
| 100 | + if crs is None: |
| 101 | + gcps, gcp_crs = dataset.gcps |
| 102 | + if gcps and gcp_crs: |
| 103 | + transform = from_gcps(gcps) |
| 104 | + crs = gcp_crs |
| 105 | + else: |
| 106 | + return None, None |
| 107 | + |
| 108 | + epsg = crs.to_epsg() |
| 109 | + if epsg: |
| 110 | + crs_info = {"proj:code": f"EPSG:{epsg}", "proj:epsg": epsg} |
| 111 | + else: |
| 112 | + crs_info = {"proj:wkt2": crs.to_wkt()} |
| 113 | + |
| 114 | + proj_metadata = { |
| 115 | + "proj:transform": [ |
| 116 | + float(transform.a), |
| 117 | + float(transform.b), |
| 118 | + float(transform.c), |
| 119 | + float(transform.d), |
| 120 | + float(transform.e), |
| 121 | + float(transform.f), |
| 122 | + ], |
| 123 | + "proj:shape": [dataset.height, dataset.width], |
| 124 | + "gsd": abs(transform.a), |
| 125 | + **crs_info, |
| 126 | + } |
| 127 | + |
| 128 | + bands = [] |
| 129 | + for i in range(dataset.count): |
| 130 | + band_data: dict[str, Any] = { |
| 131 | + "data_type": str(dataset.dtypes[i]), |
| 132 | + "spatial_resolution": abs(transform.a), |
| 133 | + } |
| 134 | + |
| 135 | + if dataset.nodata is not None: |
| 136 | + if isinstance(dataset.nodata, (int, float)) or dataset.nodata in ("nan", "inf", "-inf"): |
| 137 | + band_data["nodata"] = dataset.nodata |
| 138 | + |
| 139 | + if hasattr(dataset, "units") and dataset.units and i < len(dataset.units) and dataset.units[i]: |
| 140 | + band_data["unit"] = dataset.units[i] |
| 141 | + |
| 142 | + if ( |
| 143 | + hasattr(dataset, "scales") |
| 144 | + and dataset.scales |
| 145 | + and i < len(dataset.scales) |
| 146 | + and dataset.scales[i] is not None |
| 147 | + ): |
| 148 | + band_data["scale"] = dataset.scales[i] |
| 149 | + |
| 150 | + if ( |
| 151 | + hasattr(dataset, "offsets") |
| 152 | + and dataset.offsets |
| 153 | + and i < len(dataset.offsets) |
| 154 | + and dataset.offsets[i] is not None |
| 155 | + ): |
| 156 | + band_data["offset"] = dataset.offsets[i] |
| 157 | + |
| 158 | + bands.append(band_data) |
| 159 | + |
| 160 | + raster_metadata = {"raster:bands": bands} |
| 161 | + return proj_metadata, raster_metadata |
| 162 | + |
| 163 | + @staticmethod |
| 164 | + def _process_asset(asset: dict[str, Any]) -> tuple[dict | None, dict | None] | None: |
| 165 | + try: |
| 166 | + with rasterio.open(asset["href"]) as dataset: |
| 167 | + return RasterMetadataEnricher._extract_metadata(dataset) |
| 168 | + except Exception as e: |
| 169 | + logger.warning("Failed to open asset %s: %s", asset.get("href", "<no href>"), e) |
| 170 | + return None |
| 171 | + |
| 172 | + @staticmethod |
| 173 | + def _update_asset_metadata( |
| 174 | + asset: dict[str, Any], |
| 175 | + proj_metadata: dict | None, |
| 176 | + raster_metadata: dict | None, |
| 177 | + force: bool, |
| 178 | + ) -> None: |
| 179 | + if proj_metadata: |
| 180 | + for key, value in proj_metadata.items(): |
| 181 | + if key not in asset or force: |
| 182 | + asset[key] = value |
| 183 | + |
| 184 | + if raster_metadata and ("raster:bands" not in asset or force): |
| 185 | + asset.update(raster_metadata) |
| 186 | + |
| 187 | + @staticmethod |
| 188 | + def _copy_properties_to_asset(asset: dict[str, Any], item_properties: dict[str, Any]) -> None: |
| 189 | + proj_attrs = ["proj:code", "proj:epsg", "proj:wkt2", "proj:transform", "proj:shape", "gsd"] |
| 190 | + for attr in proj_attrs: |
| 191 | + if attr in item_properties and attr not in asset: |
| 192 | + asset[attr] = item_properties[attr] |
| 193 | + |
| 194 | + if "raster:bands" in item_properties and "raster:bands" not in asset: |
| 195 | + asset["raster:bands"] = item_properties["raster:bands"] |
| 196 | + |
| 197 | + @staticmethod |
| 198 | + def _update_item_properties( |
| 199 | + item: dict[str, Any], |
| 200 | + all_gsd_values: list[float], |
| 201 | + all_epsg_values: list[int], |
| 202 | + all_epsg_code_values: list[str], |
| 203 | + force: bool, |
| 204 | + ) -> None: |
| 205 | + properties = item.get("properties", {}) |
| 206 | + |
| 207 | + unique_gsd = list(set(filter(None, all_gsd_values))) |
| 208 | + unique_epsg = list(set(filter(None, all_epsg_values))) |
| 209 | + unique_code = list(set(filter(None, all_epsg_code_values))) |
| 210 | + |
| 211 | + if len(unique_gsd) == 1: |
| 212 | + if "gsd" not in properties or properties["gsd"] is None or force: |
| 213 | + properties["gsd"] = unique_gsd[0] |
| 214 | + |
| 215 | + if len(unique_code) == 1: |
| 216 | + if "proj:code" not in properties or properties["proj:code"] is None or force: |
| 217 | + properties["proj:code"] = unique_code[0] |
| 218 | + |
| 219 | + if len(unique_epsg) == 1: |
| 220 | + if "proj:epsg" not in properties or properties["proj:epsg"] is None or force: |
| 221 | + properties["proj:epsg"] = unique_epsg[0] |
0 commit comments