Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
66a591f
In test_hypercube_graph, add random weights to a fixture and make the…
vabor112 Sep 11, 2024
e1d96dc
Move tests around to align with the current package structure better
vabor112 Sep 11, 2024
cc85678
Some fixes and improvements inspired by revising tests. Detailed list…
vabor112 Oct 18, 2024
623fb63
Fix take_along_axis for Python 3.9 and below
vabor112 Oct 18, 2024
9beb71f
Merge branch 'main' into improve_tests
vabor112 Oct 18, 2024
e3ff505
Unify most tests for the Circle, HypercubeGraph, Hypersphere, Special…
vabor112 Oct 21, 2024
738d1f6
Remove deprecated tests
vabor112 Oct 21, 2024
575b11b
Add Graph and Mesh to tests/spaces/test_eigenfunctions_basics.py
vabor112 Oct 21, 2024
e952af8
Fix ProductDiscreteSpectrumSpace's incompatibility with the TensorFlo…
vabor112 Oct 24, 2024
9c5b65e
Graph now remembers if the Laplacian is normalized and complaints if …
vabor112 Oct 24, 2024
77596fa
Revise tests/spaces/test_graph.py and tests/spaces/test_eigenfunction…
vabor112 Oct 24, 2024
1d36892
Stop tests/kernels/test_normalization.py from requesting more than th…
vabor112 Oct 24, 2024
50caf1b
Rename test_eigenfunctions_basics.py to test_eigenfunctions.py
vabor112 Oct 24, 2024
3b6799e
Python 3.8 compatibility
vabor112 Oct 24, 2024
9281d5e
Fix Graph checks for the scipy_sparse backend
vabor112 Oct 25, 2024
d392bc0
Move MaternKarhunenLoeveKernel import inside DeterministicFeatureMapC…
vabor112 Oct 27, 2024
a817021
Revise test_lie_groups.py. test_characters_orthogonal currently fails…
vabor112 Oct 27, 2024
4616703
Fix a bug in SpecialOrthogonal introduced by cc85678. Fixes issue #152
vabor112 Oct 28, 2024
94b29b1
Finish revising tests/spaces/test_lie_groups.py
vabor112 Oct 28, 2024
c95842e
Revise tests/kernels/test_matern_karhunenloeve_kernel.py. Make check_…
vabor112 Nov 11, 2024
f86e34b
Minor fixes in tests/kernels/test_matern_karhunenloeve_kernel.py
vabor112 Nov 11, 2024
9b2c318
Fix __str__ in geometric_kernels/spaces/spd.py
vabor112 Nov 11, 2024
45157f1
New tests for MaternFeatureMapKernel: tests/kernels/test_feature_map_…
vabor112 Nov 11, 2024
61c81c4
Remove tests/kernels/test_normalization.py: normalization tests are n…
vabor112 Nov 11, 2024
0b0a406
Make check_function_with_backend a bit more informative. Add spaces()…
vabor112 Nov 12, 2024
2df9f8e
Make random in SpecialOrthogonal and SpecialUnitary compatible with j…
vabor112 Nov 12, 2024
b73c7ed
Revise tests/feature_maps/test_feature_maps.py and add tests/sampling…
vabor112 Nov 12, 2024
57d389b
Remove tests/test_dtypes.py: same things are now more thoroughly chec…
vabor112 Nov 12, 2024
7876a0c
Move scripts/compute_characters.py from geometric_kernels/utils to sc…
vabor112 Nov 17, 2024
bf294c2
Continuing to revise tests. Many changes, see the list below.
vabor112 Nov 17, 2024
a46ff93
Change parameters in tests/spaces/test_hyperbolic.py and tests/spaces…
vabor112 Nov 17, 2024
c050d53
Fix typing issues in `ProductGeometricKernel` and `ProductDiscreteSpe…
vabor112 Nov 18, 2024
d32bb35
Add pytest-cov to test_requirements.txt and flags --cov --cov-report=…
vabor112 Nov 18, 2024
edb6169
Incorporate the feedback from stoprightthere's review
vabor112 Nov 21, 2024
1903ac5
Merge branch 'main' into improve_tests
vabor112 Nov 21, 2024
35ca2d9
Incorporate the feedback from stoprightthere's review [continued]
vabor112 Nov 25, 2024
3115c5e
Incorporate the feedback from stoprightthere's review [continued 2]
vabor112 Nov 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ lint:


test: ## Run the tests, start with the failing ones and break on first fail.
pytest -v -x --ff -rN -Wignore -s --tb=short --durations=0 tests
pytest -v -x --ff -rN -Wignore -s --tb=short --durations=0 --cov --cov-report=xml tests
pytest --nbmake --nbmake-kernel=python3 --durations=0 --nbmake-timeout=1000 --ignore=notebooks/frontends/GPJax.ipynb notebooks/
33 changes: 25 additions & 8 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,24 @@ @inproceedings{jaquier2021
year={2021}
}

@article{azangulov2022,
title = {Stationary Kernels and Gaussian Processes on Lie Groups and their Homogeneous Spaces I: the compact case},
author = {Azangulov, Iskander and Smolensky, Andrei and Terenin, Alexander and Borovitskiy, Viacheslav},
journal = {arXiv preprint arXiv:2208.14960},
year = {2022}
@article{azangulov2024a,
title={Stationary Kernels and Gaussian Processes on Lie Groups and their Homogeneous Spaces I: the compact case},
author={Azangulov, Iskander and Smolensky, Andrei and Terenin, Alexander and Borovitskiy, Viacheslav},
journal={Journal of Machine Learning Research},
year={2024},
volume={25},
number={280},
pages={1--52},
}

@article{azangulov2023,
@article{azangulov2024b,
title={Stationary Kernels and Gaussian Processes on Lie Groups and their Homogeneous Spaces II: non-compact symmetric spaces},
author={Azangulov, Iskander and Smolensky, Andrei and Terenin, Alexander and Borovitskiy, Viacheslav},
journal={arXiv preprint arXiv:2301.13088},
year={2023}
journal={Journal of Machine Learning Research},
year={2024},
volume={25},
number={281},
pages={1--51},
}

@inproceedings{yang2024,
Expand Down Expand Up @@ -135,4 +141,15 @@ @inproceedings{borovitskiy2023
author={Borovitskiy, Viacheslav and Karimi, Mohammad Reza and Somnath, Vignesh Ram and Krause, Andreas},
booktitle={International Conference on Artificial Intelligence and Statistics},
year={2023},
}

@article{sawyer1992,
author = {Sawyer, Patrice},
journal = {Canadian Journal of Mathematics},
number = {3},
pages = {624--651},
publisher = {Cambridge University Press},
title = {The heat equation on the spaces of positive definite matrices},
volume = {44},
year = {1992},
}
2 changes: 1 addition & 1 deletion docs/theory/addition_theorem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ For example, for $M = \mathbb{S}_2$ and $L = 20$ the corresponding $J$ is $400$.
In the simplest special case of $\mathbb{S}_d$, the circle $\mathbb{S}_1$, the eigenfunctions are given by $\sin(l \theta), \cos(l \theta)$, where $l$ indexes levels.
The outer product $\cos(l \theta) \cos(l \theta') + \sin(l \theta) \sin(l \theta')$ in this case can be simplified to $\cos(l (\theta-\theta')) = \cos(l d_{\mathbb{S}_1}(\theta, \theta'))$ thanks to an elementary trigonometric identity.

Such addition theorems appear beyond hyperspheres, for example for Lie groups and other compact homogeneous spaces :cite:p:`azangulov2022`.
Such addition theorems appear beyond hyperspheres, for example for Lie groups and other compact homogeneous spaces :cite:p:`azangulov2024a`.
In the library, such spaces use the class :class:`~.EigenfunctionsWithAdditionTheorem` to represent the spectrum of $\Delta_{\mathcal{M}}$.
For them, the *number of levels* parameter of the :class:`~.kernels.MaternKarhunenLoeveKernel` maps to $L$ in the above formula.

Expand Down
4 changes: 2 additions & 2 deletions docs/theory/symmetric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
.. warning::
You can get by fine without reading this page for almost all use cases, just use the standard :class:`~.kernels.MaternGeometricKernel`, following the example notebooks :doc:`on hyperbolic spaces </examples/Hyperbolic>` and :doc:`on the space of symmetric positive definite matrices (SPD) </examples/SPD>`.

This is optional material meant to explain the basic theory and based mainly on :cite:t:`azangulov2023`.
This is optional material meant to explain the basic theory and based mainly on :cite:t:`azangulov2024b`.

=======
Theory
Expand All @@ -20,7 +20,7 @@ In the Euclidean case, closed form expressions for kernels are available and ran
No closed form expressions for kernels are usually available on other non-compact symmetric spaces.
Because of that, random Fourier features are the basic means of computing the kernels in this case.

A complete mathematical treatise can be found in :cite:t:`azangulov2023`.
A complete mathematical treatise can be found in :cite:t:`azangulov2024b`.
Here we briefly present the main ideas.
Recall that the usual Euclidean random Fourier features boil down to

Expand Down
3 changes: 2 additions & 1 deletion geometric_kernels/feature_maps/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from beartype.typing import Dict, Optional, Tuple

from geometric_kernels.feature_maps.base import FeatureMap
from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel
from geometric_kernels.spaces import DiscreteSpectrumSpace


Expand All @@ -25,6 +24,8 @@ class DeterministicFeatureMapCompact(FeatureMap):
"""

def __init__(self, space: DiscreteSpectrumSpace, num_levels: int):
from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel

self.space = space
self.num_levels = num_levels
self.kernel = MaternKarhunenLoeveKernel(space, num_levels)
Expand Down
4 changes: 2 additions & 2 deletions geometric_kernels/feature_maps/probability_densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def _randcat_fix(

def _alphas(n: int) -> B.Numeric:
r"""
Compute alphas for Prop. 16 & 17 of cite:t:`azangulov2023`
Compute alphas for Prop. 16 & 17 of cite:t:`azangulov2024b`
for the hyperbolic space of dimension `n`.

:param n:
Expand Down Expand Up @@ -299,7 +299,7 @@ def _sample_mixture_matern(
) -> Tuple[B.RandomState, B.Numeric]:
r"""
Sample from the mixture distribution from Prop. 17 of
cite:t:`azangulov2023` for specific alphas `alpha`, length
cite:t:`azangulov2024b` for specific alphas `alpha`, length
scale ($\kappa$) `lengthscale`, smoothness `nu` and dimension `dim`, using
`key` random state.

Expand Down
3 changes: 2 additions & 1 deletion geometric_kernels/feature_maps/random_phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from geometric_kernels.feature_maps.base import FeatureMap
from geometric_kernels.feature_maps.probability_densities import base_density_sample
from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel
from geometric_kernels.lab_extras import complex_like, from_numpy, is_complex
from geometric_kernels.spaces import DiscreteSpectrumSpace, NoncompactSymmetricSpace

Expand All @@ -43,6 +42,8 @@ def __init__(
num_levels: int,
num_random_phases: int = 3000,
):
from geometric_kernels.kernels.karhunen_loeve import MaternKarhunenLoeveKernel

self.space = space
self.num_levels = num_levels
self.num_random_phases = num_random_phases
Expand Down
17 changes: 12 additions & 5 deletions geometric_kernels/kernels/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
assert isinstance(kernel.space, Space)
self.spaces.append(kernel.space)
self.element_shapes = [space.element_shape for space in self.spaces]
self.element_dtypes = [space.element_dtype for space in self.spaces]

if dimension_indices is None:
dimensions = [math.prod(shape) for shape in self.element_shapes]
Expand Down Expand Up @@ -114,9 +115,13 @@ def K(self, params: Dict[str, B.Numeric], X, X2=None, **kwargs) -> B.Numeric:
if X2 is None:
X2 = X

Xs = project_product(X, self.dimension_indices, self.element_shapes)
X2s = project_product(X2, self.dimension_indices, self.element_shapes)
params_list = params_to_params_list(params)
Xs = project_product(
X, self.dimension_indices, self.element_shapes, self.element_dtypes
)
X2s = project_product(
X2, self.dimension_indices, self.element_shapes, self.element_dtypes
)
params_list = params_to_params_list(len(self.kernels), params)

return B.prod(
B.stack(
Expand All @@ -130,8 +135,10 @@ def K(self, params: Dict[str, B.Numeric], X, X2=None, **kwargs) -> B.Numeric:
)

def K_diag(self, params, X):
Xs = project_product(X, self.dimension_indices, self.element_shapes)
params_list = params_to_params_list(params)
Xs = project_product(
X, self.dimension_indices, self.element_shapes, self.element_dtypes
)
params_list = params_to_params_list(len(self.kernels), params)

return B.prod(
B.stack(
Expand Down
34 changes: 34 additions & 0 deletions geometric_kernels/lab_extras/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,37 @@ def dtype_bool(reference: B.RandomState):
:param reference:
A random state to infer the backend from.
"""


@dispatch
@abstract()
def bool_like(reference: B.Numeric):
"""
Return the type of the reference if it is of boolean type.
Otherwise return `bool` dtype of a backend based on the reference.

:param reference:
Array of any backend.
"""


def smart_cast(
dtype: Union[B.Bool, B.Int, B.Float, B.Complex, B.Numeric], x: B.Numeric
):
"""
Return `x` cast to the `dtype` abstract data type.

:param dtype:
An abstract DType of lab, one of `B.Bool`, `B.Int`, `B.Float`,
`B.Complex`, `B.Numeric`.
:param x:
Array of any backend.
"""
if dtype == B.Bool:
return B.cast(bool_like(x), x)
elif dtype == B.Int:
return B.cast(int_like(x), x)
elif dtype == B.Float:
return B.cast(float_like(x), x)
elif dtype == B.Complex:
return B.cast(complex_like(x), x)
32 changes: 24 additions & 8 deletions geometric_kernels/lab_extras/jax/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import lab as B
from beartype.typing import List
from lab import dispatch
from plum import Union
from plum import Union, convert

_Numeric = Union[B.Number, B.JAXNumeric]

Expand Down Expand Up @@ -89,7 +89,9 @@ def float_like(reference: B.JAXNumeric):
"""
reference_dtype = reference.dtype
if jnp.issubdtype(reference_dtype, jnp.floating):
return reference_dtype
return convert(
reference_dtype, B.JAXDType
) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
else:
return jnp.float64

Expand All @@ -106,7 +108,9 @@ def dtype_integer(reference: B.JAXRandomState): # type: ignore
def int_like(reference: B.JAXNumeric):
reference_dtype = reference.dtype
if jnp.issubdtype(reference_dtype, jnp.integer):
return reference_dtype
return convert(
reference_dtype, B.JAXDType
) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
else:
return jnp.int32

Expand Down Expand Up @@ -155,10 +159,7 @@ def complex_like(reference: B.JAXNumeric):
"""
Return `complex` dtype of a backend based on the reference.
"""
if B.dtype(reference) == jnp.float32:
return jnp.complex64
else:
return jnp.complex128
return B.promote_dtypes(jnp.complex64, reference.dtype)


@dispatch
Expand Down Expand Up @@ -244,4 +245,19 @@ def dtype_bool(reference: B.JAXRandomState): # type: ignore
"""
Return `bool` dtype of a backend based on the reference.
"""
return bool
return jnp.bool_


@dispatch
def bool_like(reference: B.JAXRandomState):
"""
Return the type of the reference if it is of boolean type.
Otherwise return `bool` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if jnp.issubdtype(reference_dtype, jnp.bool_):
return convert(
reference_dtype, B.JAXDType
) # JAX .dtype returns a NumPy data type. This converts it to a JAX one.
else:
return jnp.bool_
20 changes: 15 additions & 5 deletions geometric_kernels/lab_extras/numpy/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,7 @@ def complex_like(reference: B.NPNumeric):
"""
Return `complex` dtype of a backend based on the reference.
"""
if reference.dtype == np.float32:
return np.complex64
else:
return np.complex128
return B.promote_dtypes(np.complex64, reference.dtype)


@dispatch
Expand Down Expand Up @@ -239,4 +236,17 @@ def dtype_bool(reference: B.NPRandomState): # type: ignore
"""
Return `bool` dtype of a backend based on the reference.
"""
return bool
return np.bool_


@dispatch
def bool_like(reference: B.NPNumeric):
"""
Return the type of the reference if it is of boolean type.
Otherwise return `bool` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if np.issubdtype(reference_dtype, np.bool_):
return reference_dtype
else:
return np.bool_
26 changes: 17 additions & 9 deletions geometric_kernels/lab_extras/numpy/sparse_extras.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import lab as B
import scipy
import scipy.sparse as sp
Expand All @@ -11,15 +13,21 @@
SparseArray defines a lab data type that covers all possible sparse
scipy arrays, so that multiple dispatch works with such arrays.
"""
SparseArray = Union[
sp.bsr_matrix,
sp.coo_matrix,
sp.csc_matrix,
sp.csr_matrix,
sp.dia_matrix,
sp.dok_matrix,
sp.lil_matrix,
]
if sys.version_info[:2] <= (3, 8):
SparseArray = Union[
sp.bsr_matrix,
sp.coo_matrix,
sp.csc_matrix,
sp.csr_matrix,
sp.dia_matrix,
sp.dok_matrix,
sp.lil_matrix,
]
else:
SparseArray = Union[
sp.sparray,
sp.spmatrix,
]


@dispatch
Expand Down
26 changes: 21 additions & 5 deletions geometric_kernels/lab_extras/tensorflow/extras.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import lab as B
import tensorflow as tf
import tensorflow_probability as tfp
Expand All @@ -13,7 +15,11 @@ def take_along_axis(a: _Numeric, index: _Numeric, axis: int = 0) -> _Numeric: #
"""
Gathers elements of `a` along `axis` at `index` locations.
"""
return tf.gather(a, B.flatten(index), axis=axis)
if sys.version_info[:2] <= (3, 9):
index = tf.cast(index, tf.int32)
return tf.experimental.numpy.take_along_axis(
a, index, axis=axis
) # the absence of explicit cast to int64 causes an error for Python 3.9 and below


@dispatch
Expand Down Expand Up @@ -164,10 +170,7 @@ def complex_like(reference: B.TFNumeric):
"""
Return `complex` dtype of a backend based on the reference.
"""
if B.dtype(reference) == tf.float32:
return tf.complex64
else:
return tf.complex128
return B.promote_dtypes(tf.complex64, reference.dtype)


@dispatch
Expand Down Expand Up @@ -251,3 +254,16 @@ def dtype_bool(reference: B.TFRandomState): # type: ignore
Return `bool` dtype of a backend based on the reference.
"""
return tf.bool


@dispatch
def bool_like(reference: B.NPNumeric):
"""
Return the type of the reference if it is of boolean type.
Otherwise return `bool` dtype of a backend based on the reference.
"""
reference_dtype = reference.dtype
if reference_dtype.is_bool:
return reference_dtype
else:
return tf.bool
Loading