diff --git a/src/pyrecest/_backend/_common.py b/src/pyrecest/_backend/_common.py index 64d89ea68..5a074f65c 100644 --- a/src/pyrecest/_backend/_common.py +++ b/src/pyrecest/_backend/_common.py @@ -62,6 +62,8 @@ def dot(a, b): if torch_pair is not None: a, b = torch_pair torch = _torch_module_for_values(a, b) + if a.ndim == 0 or b.ndim == 0: + return torch.multiply(a, b) if b.ndim == 1: return torch.einsum("...i,i->...", a, b) if a.ndim == 1: @@ -70,6 +72,8 @@ def dot(a, b): a = _np.asarray(a) b = _np.asarray(b) + if a.ndim == 0 or b.ndim == 0: + return _np.multiply(a, b) if b.ndim == 1: return _np.einsum("...i,i->...", a, b) if a.ndim == 1: diff --git a/src/pyrecest/_backend/_shared_numpy/__init__.py b/src/pyrecest/_backend/_shared_numpy/__init__.py index 33d0b82bb..36de9a2c6 100644 --- a/src/pyrecest/_backend/_shared_numpy/__init__.py +++ b/src/pyrecest/_backend/_shared_numpy/__init__.py @@ -381,6 +381,9 @@ def matmul(*args, **kwargs): def outer(a, b): + a = a if is_array(a) else array(a) + b = b if is_array(b) else array(b) + if a.ndim > 1 and b.ndim > 1: return _np.einsum("...i,...j->...ij", a, b) @@ -392,6 +395,9 @@ def outer(a, b): def matvec(A, b): + A = A if is_array(A) else array(A) + b = b if is_array(b) else array(b) + if b.ndim == 1: return _np.matmul(A, b) if A.ndim == 2: @@ -400,6 +406,11 @@ def matvec(A, b): def dot(a, b): + a = a if is_array(a) else array(a) + b = b if is_array(b) else array(b) + if a.ndim == 0 or b.ndim == 0: + return _np.multiply(a, b) + if b.ndim == 1: return _np.einsum("...i,i->...", a, b) diff --git a/src/pyrecest/_backend/jax/__init__.py b/src/pyrecest/_backend/jax/__init__.py index 386ed783d..7d1e11188 100644 --- a/src/pyrecest/_backend/jax/__init__.py +++ b/src/pyrecest/_backend/jax/__init__.py @@ -489,6 +489,8 @@ def dot(a, b): a = _jnp.asarray(a) b = _jnp.asarray(b) + if a.ndim == 0 or b.ndim == 0: + return _jnp.multiply(a, b) if b.ndim == 1: return _jnp.einsum("...i,i->...", a, b) if a.ndim == 1: diff --git a/src/pyrecest/_backend/pytorch/__init__.py b/src/pyrecest/_backend/pytorch/__init__.py index 677fca68f..a5017ca39 100644 --- a/src/pyrecest/_backend/pytorch/__init__.py +++ b/src/pyrecest/_backend/pytorch/__init__.py @@ -1366,7 +1366,13 @@ def is_array(x): def outer(a, b): + a = array(a) + b = array(b) + a, b = convert_to_wider_dtype([a, b]) + # TODO: improve for torch > 1.9 (dims=0 fails in 1.9) + if a.ndim == 0 or b.ndim == 0: + return _torch.multiply(a, b) return _torch.einsum("...i,...j->...ij", a, b) @@ -1392,6 +1398,9 @@ def dot(a, b): b = array(b) a, b = convert_to_wider_dtype([a, b]) + if a.ndim == 0 or b.ndim == 0: + return _torch.multiply(a, b) + if a.ndim == 1 and b.ndim == 1: return _torch.dot(a, b) diff --git a/tests/test_backend_contracts.py b/tests/test_backend_contracts.py index 57b8349f3..5583726ab 100644 --- a/tests/test_backend_contracts.py +++ b/tests/test_backend_contracts.py @@ -439,6 +439,15 @@ def test_batched_matvec_pairs_leading_dimensions(): assert _to_python(result) == [[1.0, 2.0], [6.0, 12.0]] +def test_matvec_accepts_array_like_inputs(): + result = backend.matvec( + [[1.0, 2.0], [3.0, 4.0]], + [5.0, 6.0], + ) + + assert _to_python(result) == [17.0, 39.0] + + def test_batched_dot_uses_last_axis_inner_product(): first = array([[1.0, 2.0], [3.0, 4.0]]) second = array([[5.0, 6.0], [7.0, 8.0]]) @@ -449,6 +458,17 @@ def test_batched_dot_uses_last_axis_inner_product(): assert _to_python(result) == [17.0, 53.0] +def test_dot_accepts_scalar_operands(): + assert _to_python(backend.dot(2.0, 3.0)) == 6.0 + assert _to_python(backend.dot(2.0, array([1.0, 2.0]))) == [2.0, 4.0] + assert _to_python(backend.dot(array([1.0, 2.0]), 3.0)) == [3.0, 6.0] + assert _to_python(backend.dot(array(2.0), array(3.0))) == 6.0 + + +def test_dot_accepts_array_like_inputs(): + assert _to_python(backend.dot([1.0, 2.0], [3.0, 4.0])) == 11.0 + + def test_batched_dot_accepts_high_rank_right_operand(): first = array([1.0, 2.0]) second = array( @@ -477,6 +497,17 @@ def test_batched_outer_pairs_leading_dimensions(): ] +def test_outer_accepts_array_like_inputs(): + result = backend.outer([1.0, 2.0], [3.0, 4.0]) + + assert _to_python(result) == [[3.0, 4.0], [6.0, 8.0]] + + +def test_outer_accepts_scalar_operands(): + assert _to_python(backend.outer(2.0, array([1.0, 2.0]))) == [2.0, 4.0] + assert _to_python(backend.outer(array([1.0, 2.0]), 3.0)) == [3.0, 6.0] + + def test_outer_accepts_high_rank_right_operand(): first = array([1.0, 2.0]) second = array(