Skip to content

Commit 7f33a2a

Browse files
juntyrtreigerm
andauthored
Update numcodecs-wasm, test NaNs in ZFP (#57)
* Update numcodecs-wasm, test NaNs in ZFP * fix typing * Fix minor bugs --------- Co-authored-by: Tim Reichelt <[email protected]>
1 parent fd241ad commit 7f33a2a

File tree

6 files changed

+40
-36
lines changed

6 files changed

+40
-36
lines changed

pyproject.toml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,21 @@ dependencies = [
1313
"matplotlib~=3.8",
1414
"netcdf4~=1.7.2",
1515
"numcodecs>=0.13.0,<0.17",
16-
"numcodecs-combinators[xarray]~=0.2.4",
17-
"numcodecs-observers~=0.1.1",
18-
"numcodecs-wasm~=0.1.7",
19-
"numcodecs-wasm-bit-round~=0.3.0",
20-
"numcodecs-wasm-fixed-offset-scale~=0.3.0",
21-
"numcodecs-wasm-jpeg2000~=0.2.0",
22-
"numcodecs-wasm-pco~=0.2.1",
23-
"numcodecs-wasm-round~=0.4.0",
24-
"numcodecs-wasm-sperr~=0.1.0",
25-
"numcodecs-wasm-stochastic-rounding~=0.1.1",
26-
"numcodecs-wasm-sz3~=0.6.0",
27-
"numcodecs-wasm-tthresh~=0.2.0",
28-
"numcodecs-wasm-zfp~=0.5.3",
29-
"numcodecs-wasm-zfp-classic~=0.3.3",
30-
"numcodecs-wasm-zstd~=0.3.0",
16+
"numcodecs-combinators[xarray]~=0.2.10",
17+
"numcodecs-observers~=0.1.2",
18+
"numcodecs-wasm~=0.2.1",
19+
"numcodecs-wasm-bit-round~=0.4.0",
20+
"numcodecs-wasm-fixed-offset-scale~=0.4.0",
21+
"numcodecs-wasm-jpeg2000~=0.3.0",
22+
"numcodecs-wasm-pco~=0.3.0",
23+
"numcodecs-wasm-round~=0.5.0",
24+
"numcodecs-wasm-sperr~=0.2.0",
25+
"numcodecs-wasm-stochastic-rounding~=0.2.0",
26+
"numcodecs-wasm-sz3~=0.7.0",
27+
"numcodecs-wasm-tthresh~=0.3.0",
28+
"numcodecs-wasm-zfp~=0.6.0",
29+
"numcodecs-wasm-zfp-classic~=0.4.0",
30+
"numcodecs-wasm-zstd~=0.4.0",
3131
"pandas~=2.2",
3232
"scipy~=1.14",
3333
"seaborn~=0.13.2",

src/climatebenchpress/compressor/compressors/abc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def build(
196196

197197
# Class interface
198198
@classproperty
199-
def registry(cls) -> Mapping:
199+
def registry(cls) -> Mapping[str, type["Compressor"]]:
200200
return MappingProxyType(Compressor._registry)
201201

202202
# Implementation details
@@ -247,11 +247,13 @@ def _get_variant_bounds(
247247
converted_bounds: dict[VariableName, dict[VariantName, ErrorBound]] = dict()
248248
variant_names = {cls.name}
249249
for var, error_bound in error_bounds.items():
250+
cls_has_abs_error_impl: bool = cls.has_abs_error_impl # type: ignore
250251
abs_bound_codec = (
251-
error_bound.abs_error is not None and cls.has_abs_error_impl
252+
error_bound.abs_error is not None and cls_has_abs_error_impl
252253
)
254+
cls_has_rel_error_impl: bool = cls.has_rel_error_impl # type: ignore
253255
rel_bound_codec = (
254-
error_bound.rel_error is not None and cls.has_rel_error_impl
256+
error_bound.rel_error is not None and cls_has_rel_error_impl
255257
)
256258
if abs_bound_codec or rel_bound_codec:
257259
# If codec is compatible with the error bound no transformation

src/climatebenchpress/compressor/compressors/zfp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ class Zfp(Compressor):
2222
@staticmethod
2323
def abs_bound_codec(error_bound, **kwargs):
2424
return numcodecs_wasm_zfp_classic.ZfpClassic(
25-
mode="fixed-accuracy", tolerance=error_bound
25+
mode="fixed-accuracy", tolerance=error_bound, non_finite="allow-unsafe"
2626
)

src/climatebenchpress/compressor/compressors/zfp_round.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@ class ZfpRound(Compressor):
2525

2626
@staticmethod
2727
def abs_bound_codec(error_bound, **kwargs):
28-
return numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound)
28+
return numcodecs_wasm_zfp.Zfp(
29+
mode="fixed-accuracy", tolerance=error_bound, non_finite="allow-unsafe"
30+
)

src/climatebenchpress/compressor/scripts/compress.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
import argparse
44
import json
55
import traceback
6-
from collections.abc import Container
6+
from collections.abc import Container, Mapping
77
from pathlib import Path
88
from typing import Callable
99

1010
import numcodecs_observers
11+
import numpy as np
1112
import xarray as xr
1213
from numcodecs.abc import Codec
1314
from numcodecs_combinators.stack import CodecStack
@@ -68,24 +69,24 @@ def compress(
6869
continue
6970

7071
ds = xr.open_dataset(dataset, chunks=dict(), engine="zarr")
71-
ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs = (
72-
dict(),
73-
dict(),
74-
dict(),
75-
dict(),
76-
dict(),
77-
)
72+
ds_dtypes: dict[str, np.dtype] = dict()
73+
ds_abs_mins: dict[str, float] = dict()
74+
ds_abs_maxs: dict[str, float] = dict()
75+
ds_mins: dict[str, float] = dict()
76+
ds_maxs: dict[str, float] = dict()
7877
for v in ds:
78+
vs: str = str(v)
7979
abs_vals = xr.ufuncs.abs(ds[v])
80+
ds_dtypes[vs] = ds[v].dtype
8081
# Take minimum of non-zero absolute values to avoid division by zero.
81-
ds_abs_mins[v] = abs_vals.where(abs_vals > 0).min().values.item()
82-
ds_abs_maxs[v] = abs_vals.max().values.item()
83-
ds_mins[v] = ds[v].min().values.item()
84-
ds_maxs[v] = ds[v].max().values.item()
85-
ds_dtypes[v] = ds[v].dtype
82+
ds_abs_mins[vs] = abs_vals.where(abs_vals > 0).min().values.item()
83+
ds_abs_maxs[vs] = abs_vals.max().values.item()
84+
ds_mins[vs] = ds[v].min().values.item()
85+
ds_maxs[vs] = ds[v].max().values.item()
8686

8787
error_bounds = get_error_bounds(datasets_error_bounds, dataset.parent.name)
88-
for compressor in Compressor.registry.values():
88+
registry: Mapping[str, type[Compressor]] = Compressor.registry # type: ignore
89+
for compressor in registry.values():
8990
if compressor.name in exclude_compressor:
9091
continue
9192
if include_compressor and compressor.name not in include_compressor:
@@ -216,7 +217,7 @@ def get_error_bounds(
216217
args = parser.parse_args()
217218

218219
compress(
219-
basepath=Path(),
220+
basepath=args.basepath,
220221
exclude_dataset=args.exclude_dataset,
221222
include_dataset=args.include_dataset,
222223
exclude_compressor=args.exclude_compressor,

src/climatebenchpress/compressor/scripts/concatenate_metrics.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,4 +189,3 @@ def get_error_bound_name(
189189
args = parser.parse_args()
190190

191191
concatenate_metrics(basepath=args.basepath)
192-
concatenate_metrics(basepath=args.basepath)

0 commit comments

Comments
 (0)