diff --git a/src/pycmor/std_lib/files.py b/src/pycmor/std_lib/files.py index fcc1b0be..2d2c1fd3 100644 --- a/src/pycmor/std_lib/files.py +++ b/src/pycmor/std_lib/files.py @@ -40,6 +40,7 @@ from pathlib import Path +import numpy as np import pandas as pd import xarray as xr from xarray.core.utils import is_scalar @@ -379,6 +380,41 @@ def _save_dataset_with_native_timespan( ) +_GEO_COORD_NAMES = frozenset( + { + "lat", "lon", "latitude", "longitude", + "lat_bnds", "lon_bnds", "lat_bounds", "lon_bounds", + "latitude_bnds", "longitude_bnds", "latitude_bounds", "longitude_bounds", + } +) + + +def _cast_geo_coords_to_float64(da): + """Cast geographic coordinate variables to float64 (CMOR3 / CF requirement). + + Accepts both xr.DataArray and xr.Dataset. + """ + coord_updates = { + name: da[name].astype(np.float64) + for name in da.coords + if name in _GEO_COORD_NAMES and da[name].dtype != np.float64 + } + if coord_updates: + for name in coord_updates: + logger.debug( + f"Casting {name!r} from {da[name].dtype} to float64 for CMIP compliance" + ) + da = da.assign_coords(coord_updates) + if isinstance(da, xr.Dataset): + for name in list(da.data_vars): + if name in _GEO_COORD_NAMES and da[name].dtype != np.float64: + logger.debug( + f"Casting {name!r} from {da[name].dtype} to float64 for CMIP compliance" + ) + da[name] = da[name].astype(np.float64) + return da + + def save_dataset(da: xr.DataArray, rule): """ Save dataset to one or more files. @@ -436,6 +472,10 @@ def save_dataset(da: xr.DataArray, rule): # Set default calendar if none is specified if time_encoding.get("calendar") is None: time_encoding["calendar"] = "standard" + + # CMOR3 / CF requirement: geographic coordinates must be float64 + da = _cast_geo_coords_to_float64(da) + if not has_time_axis(da): filepath = create_filepath(da, rule) return da.to_netcdf( @@ -455,7 +495,6 @@ def save_dataset(da: xr.DataArray, rule): ) if isinstance(da, xr.DataArray): da = da.to_dataset() - # Set time variable attributes if rule._pycmor_cfg("xarray_time_set_standard_name"): da[time_label].attrs["standard_name"] = "time" @@ -510,8 +549,6 @@ def save_dataset(da: xr.DataArray, rule): da[time_label].attrs["calendar"] = time_encoding["calendar"] # Ensure the encoding is set on the time variable itself - if isinstance(da, xr.DataArray): - da = da.to_dataset() da[time_label].encoding.update(time_encoding) if not has_time_axis(da): diff --git a/tests/unit/test_savedataset.py b/tests/unit/test_savedataset.py index a9582899..12ab2ce5 100644 --- a/tests/unit/test_savedataset.py +++ b/tests/unit/test_savedataset.py @@ -302,3 +302,79 @@ def test_save_dataset_with_custom_time_settings(tmp_path): assert ( time_var.encoding["calendar"] == custom_calendar ), f"XArray encoding calendar does not match. Expected {custom_calendar}, got {time_var.encoding['calendar']}" + + +@pytest.mark.parametrize("coord_name", ["lat", "lon", "latitude", "longitude"]) +@pytest.mark.parametrize("input_dtype", [np.float32, np.float16]) +def test_save_dataset_casts_geo_coords_to_float64(tmp_path, coord_name, input_dtype): + """Geographic coordinate variables must be written as float64 (CMOR3 / CF requirement).""" + dates = xr.cftime_range(start="2001", periods=2, freq="MS", calendar="noleap") + coords = { + "time": dates, + coord_name: np.array([-45.0, 0.0, 45.0], dtype=input_dtype), + } + da = xr.DataArray( + np.zeros((2, 3), dtype=np.float32), + coords=coords, + dims=["time", coord_name], + name="tos", + ) + rule = Mock() + rule._pycmor_cfg = PycmorConfigManager.from_pycmor_cfg({}) + rule.data_request_variable.frequency = "mon" + rule.data_request_variable.table_header.approx_interval = 30 + rule.cmor_variable = "tos" + rule.variant_label = "r1i1p1f1" + rule.source_id = "AWI-ESM-3" + rule.experiment_id = "historical" + rule.file_timespan = "2YS" + rule.output_directory = str(tmp_path) + + save_dataset(da, rule) + + saved = list(tmp_path.glob("*.nc")) + assert len(saved) == 1 + with xr.open_dataset(saved[0]) as ds: + assert ds[coord_name].dtype == np.float64, ( + f"{coord_name} should be float64, got {ds[coord_name].dtype}" + ) + + +@pytest.mark.parametrize("bounds_name", ["lat_bnds", "lon_bnds", "lat_bounds", "lon_bounds"]) +def test_save_dataset_casts_geo_bounds_to_float64(tmp_path, bounds_name): + """Geographic bounds variables must also be written as float64.""" + dates = xr.cftime_range(start="2001", periods=2, freq="MS", calendar="noleap") + coord_name = "lat" if "lat" in bounds_name else "lon" + coord_vals = np.array([-45.0, 0.0, 45.0], dtype=np.float32) + bounds_vals = np.array( + [[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]], dtype=np.float32 + ) + ds = xr.Dataset( + { + "tos": xr.DataArray( + np.zeros((2, 3), dtype=np.float32), + coords={"time": dates, coord_name: coord_vals}, + dims=["time", coord_name], + ), + bounds_name: xr.DataArray(bounds_vals, dims=[coord_name, "bnds"]), + } + ) + rule = Mock() + rule._pycmor_cfg = PycmorConfigManager.from_pycmor_cfg({}) + rule.data_request_variable.frequency = "mon" + rule.data_request_variable.table_header.approx_interval = 30 + rule.cmor_variable = "tos" + rule.variant_label = "r1i1p1f1" + rule.source_id = "AWI-ESM-3" + rule.experiment_id = "historical" + rule.file_timespan = "2YS" + rule.output_directory = str(tmp_path) + + save_dataset(ds, rule) + + saved = list(tmp_path.glob("*.nc")) + assert len(saved) == 1 + with xr.open_dataset(saved[0]) as result: + assert result[bounds_name].dtype == np.float64, ( + f"{bounds_name} should be float64, got {result[bounds_name].dtype}" + )