Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions src/pycmor/std_lib/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from pathlib import Path

import numpy as np
import pandas as pd
import xarray as xr
from xarray.core.utils import is_scalar
Expand Down Expand Up @@ -379,6 +380,41 @@ def _save_dataset_with_native_timespan(
)


_GEO_COORD_NAMES = frozenset(
{
"lat", "lon", "latitude", "longitude",
"lat_bnds", "lon_bnds", "lat_bounds", "lon_bounds",
"latitude_bnds", "longitude_bnds", "latitude_bounds", "longitude_bounds",
}
)


def _cast_geo_coords_to_float64(da):
"""Cast geographic coordinate variables to float64 (CMOR3 / CF requirement).

Accepts both xr.DataArray and xr.Dataset.
"""
coord_updates = {
name: da[name].astype(np.float64)
for name in da.coords
if name in _GEO_COORD_NAMES and da[name].dtype != np.float64
}
if coord_updates:
for name in coord_updates:
logger.debug(
f"Casting {name!r} from {da[name].dtype} to float64 for CMIP compliance"
)
da = da.assign_coords(coord_updates)
if isinstance(da, xr.Dataset):
for name in list(da.data_vars):
if name in _GEO_COORD_NAMES and da[name].dtype != np.float64:
logger.debug(
f"Casting {name!r} from {da[name].dtype} to float64 for CMIP compliance"
)
da[name] = da[name].astype(np.float64)
return da


def save_dataset(da: xr.DataArray, rule):
"""
Save dataset to one or more files.
Expand Down Expand Up @@ -436,6 +472,10 @@ def save_dataset(da: xr.DataArray, rule):
# Set default calendar if none is specified
if time_encoding.get("calendar") is None:
time_encoding["calendar"] = "standard"

# CMOR3 / CF requirement: geographic coordinates must be float64
da = _cast_geo_coords_to_float64(da)

if not has_time_axis(da):
filepath = create_filepath(da, rule)
return da.to_netcdf(
Expand All @@ -455,7 +495,6 @@ def save_dataset(da: xr.DataArray, rule):
)
if isinstance(da, xr.DataArray):
da = da.to_dataset()

# Set time variable attributes
if rule._pycmor_cfg("xarray_time_set_standard_name"):
da[time_label].attrs["standard_name"] = "time"
Expand Down Expand Up @@ -510,8 +549,6 @@ def save_dataset(da: xr.DataArray, rule):
da[time_label].attrs["calendar"] = time_encoding["calendar"]

# Ensure the encoding is set on the time variable itself
if isinstance(da, xr.DataArray):
da = da.to_dataset()
da[time_label].encoding.update(time_encoding)

if not has_time_axis(da):
Expand Down
76 changes: 76 additions & 0 deletions tests/unit/test_savedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,79 @@ def test_save_dataset_with_custom_time_settings(tmp_path):
assert (
time_var.encoding["calendar"] == custom_calendar
), f"XArray encoding calendar does not match. Expected {custom_calendar}, got {time_var.encoding['calendar']}"


@pytest.mark.parametrize("coord_name", ["lat", "lon", "latitude", "longitude"])
@pytest.mark.parametrize("input_dtype", [np.float32, np.float16])
def test_save_dataset_casts_geo_coords_to_float64(tmp_path, coord_name, input_dtype):
"""Geographic coordinate variables must be written as float64 (CMOR3 / CF requirement)."""
dates = xr.cftime_range(start="2001", periods=2, freq="MS", calendar="noleap")
coords = {
"time": dates,
coord_name: np.array([-45.0, 0.0, 45.0], dtype=input_dtype),
}
da = xr.DataArray(
np.zeros((2, 3), dtype=np.float32),
coords=coords,
dims=["time", coord_name],
name="tos",
)
rule = Mock()
rule._pycmor_cfg = PycmorConfigManager.from_pycmor_cfg({})
rule.data_request_variable.frequency = "mon"
rule.data_request_variable.table_header.approx_interval = 30
rule.cmor_variable = "tos"
rule.variant_label = "r1i1p1f1"
rule.source_id = "AWI-ESM-3"
rule.experiment_id = "historical"
rule.file_timespan = "2YS"
rule.output_directory = str(tmp_path)

save_dataset(da, rule)

saved = list(tmp_path.glob("*.nc"))
assert len(saved) == 1
with xr.open_dataset(saved[0]) as ds:
assert ds[coord_name].dtype == np.float64, (
f"{coord_name} should be float64, got {ds[coord_name].dtype}"
)


@pytest.mark.parametrize("bounds_name", ["lat_bnds", "lon_bnds", "lat_bounds", "lon_bounds"])
def test_save_dataset_casts_geo_bounds_to_float64(tmp_path, bounds_name):
"""Geographic bounds variables must also be written as float64."""
dates = xr.cftime_range(start="2001", periods=2, freq="MS", calendar="noleap")
coord_name = "lat" if "lat" in bounds_name else "lon"
coord_vals = np.array([-45.0, 0.0, 45.0], dtype=np.float32)
bounds_vals = np.array(
[[-67.5, -22.5], [-22.5, 22.5], [22.5, 67.5]], dtype=np.float32
)
ds = xr.Dataset(
{
"tos": xr.DataArray(
np.zeros((2, 3), dtype=np.float32),
coords={"time": dates, coord_name: coord_vals},
dims=["time", coord_name],
),
bounds_name: xr.DataArray(bounds_vals, dims=[coord_name, "bnds"]),
}
)
rule = Mock()
rule._pycmor_cfg = PycmorConfigManager.from_pycmor_cfg({})
rule.data_request_variable.frequency = "mon"
rule.data_request_variable.table_header.approx_interval = 30
rule.cmor_variable = "tos"
rule.variant_label = "r1i1p1f1"
rule.source_id = "AWI-ESM-3"
rule.experiment_id = "historical"
rule.file_timespan = "2YS"
rule.output_directory = str(tmp_path)

save_dataset(ds, rule)

saved = list(tmp_path.glob("*.nc"))
assert len(saved) == 1
with xr.open_dataset(saved[0]) as result:
assert result[bounds_name].dtype == np.float64, (
f"{bounds_name} should be float64, got {result[bounds_name].dtype}"
)
Loading