Skip to content

Commit 49908d7

Browse files
authored
Change how JAX Keras backend requirement is reported (#2420)
Signed-off-by: Andrzej Kotłowski <andrzej.kotlowski@intel.com>
1 parent 273c9ec commit 49908d7

7 files changed

Lines changed: 24 additions & 16 deletions

File tree

docs/source/JAX.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ JAX
1010

1111
## Introduction
1212

13-
`neural_compressor_jax` provides an API for applying quantization on Keras models such as ViT and Gemma3.
13+
`neural_compressor.jax` provides an API for applying quantization to Keras models such as ViT and Gemma3.
14+
Since only JAX is supported as the Keras backend, the environment variable `KERAS_BACKEND` should be set to `jax`.
1415
The following 8-bit floating-point formats are supported: `fp8_e4m3` and `fp8_e5m2`.
1516

1617
Quantized models can be saved and loaded using standard Keras APIs

neural_compressor/jax/quantization/layers_dynamic.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@
3737
verify_api,
3838
)
3939

40-
if keras.config.backend() != "jax":
41-
raise ValueError(
42-
f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n"
43-
'Consider setting KERAS_BACKEND env var to "jax".'
44-
)
45-
4640
dynamic_quant_mapping = {}
4741

4842

neural_compressor/jax/quantization/layers_static.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,6 @@
3737
verify_api,
3838
)
3939

40-
if keras.config.backend() != "jax":
41-
raise ValueError(
42-
f"{__name__} only supports JAX backend, but the current backend is {keras.config.backend()}.\n"
43-
'Consider setting KERAS_BACKEND env var to "jax".'
44-
)
45-
4640
static_quant_mapping = {}
4741

4842

neural_compressor/jax/quantization/quantize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from neural_compressor.common import logger
2222
from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry
2323
from neural_compressor.common.utils import Mode, log_process
24-
from neural_compressor.jax.utils import algos_mapping
24+
from neural_compressor.jax.utils import algos_mapping, check_backend
2525

2626

2727
def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name):
@@ -58,6 +58,7 @@ def quantize_model(
5858
keras.Model: The quantized model.
5959
"""
6060
# fmt: on
61+
check_backend()
6162
if not inplace:
6263
raise NotImplementedError("Out of place quantization is not supported yet. "
6364
"Please set parameter inplace=True for quantize_model() to modify the model in-place")

neural_compressor/jax/quantization/saving.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from neural_compressor.common import logger
2929
from neural_compressor.common.base_config import config_registry
3030
from neural_compressor.jax.quantization.config import FRAMEWORK_NAME, BaseConfig, DynamicQuantConfig, StaticQuantConfig
31-
from neural_compressor.jax.utils.utility import dtype_mapping, iterate_over_layers
31+
from neural_compressor.jax.utils.utility import check_backend, dtype_mapping, iterate_over_layers
3232

3333

3434
def quant_config_to_json_object(quant_config: BaseConfig) -> dict:
@@ -446,6 +446,7 @@ def prepare_deserialized_quantized_model(
446446
Returns:
447447
Union[KerasQuantizedModelWrapperMixin, KerasQuantizedModelBackboneWrapper]: The transformed quantized model/backbone wrapper.
448448
"""
449+
check_backend()
449450
model_info = quant_config.get_model_info(model)
450451
configs_mapping = quant_config.to_config_mapping(model_info=model_info)
451452

neural_compressor/jax/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from neural_compressor.jax.utils.utility import algos_mapping, register_algo
17+
from neural_compressor.jax.utils.utility import algos_mapping, register_algo, check_backend

neural_compressor/jax/utils/utility.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@
3030
from neural_compressor.common import logger
3131

3232

33+
def check_backend(raise_error=True):
34+
"""Check if the current Keras backend is JAX and log a warning or error if not."""
35+
36+
if keras.config.backend() != "jax":
37+
message = (
38+
f"neural_compressor.jax only supports JAX backend, but the current Keras backend is {keras.config.backend()}. "
39+
'Consider setting KERAS_BACKEND env var to "jax".'
40+
)
41+
if raise_error:
42+
raise ValueError(message)
43+
else:
44+
logger.warning(message)
45+
46+
47+
check_backend(raise_error=False)
48+
49+
3350
def add_fp8_support(function):
3451
"""Extend a dtype size function to support FP8 dtypes.
3552

0 commit comments

Comments
 (0)