|
14 | 14 |
|
15 | 15 | import jax |
16 | 16 |
|
| 17 | + from microsim._data_array import ArrayProtocol |
| 18 | + |
17 | 19 | _Shape: TypeAlias = tuple[int, ...] |
18 | 20 |
|
19 | 21 | # Anything that can be coerced to a shape tuple |
@@ -47,18 +49,19 @@ def create(cls, backend: BackendName | NumpyAPI | None) -> NumpyAPI: |
47 | 49 |
|
48 | 50 | return NumpyAPI() |
49 | 51 |
|
| 52 | + _random_seed: int | None = None |
| 53 | + _float_dtype: np.dtype | None = None |
| 54 | + |
50 | 55 | def __init__(self) -> None: |
51 | 56 | from scipy import signal, special, stats |
52 | 57 | from scipy.ndimage import map_coordinates |
53 | 58 |
|
54 | | - self._random_seed: int | None = None |
55 | 59 | self.xp = np |
56 | 60 | self.signal = signal |
57 | 61 | self.stats = stats |
58 | 62 | self.j0 = special.j0 |
59 | 63 | self.j1 = special.j1 |
60 | 64 | self.map_coordinates = map_coordinates |
61 | | - self._float_dtype: np.dtype | None = None |
62 | 65 |
|
63 | 66 | @property |
64 | 67 | def float_dtype(self) -> np.dtype | None: |
@@ -103,6 +106,11 @@ def poisson_rvs( |
103 | 106 | ) -> npt.NDArray: |
104 | 107 | return self.stats.poisson.rvs(lam, size=shape) # type: ignore |
105 | 108 |
|
| 109 | + def norm_rvs( |
| 110 | + self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None |
| 111 | + ) -> ArrayProtocol: |
| 112 | + return self.stats.norm.rvs(loc, scale) # type: ignore |
| 113 | + |
106 | 114 | def fftconvolve( |
107 | 115 | self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full" |
108 | 116 | ) -> ArrT: |
@@ -140,7 +148,6 @@ def __init__(self) -> None: |
140 | 148 |
|
141 | 149 | from ._jax_bessel import j0, j1 |
142 | 150 |
|
143 | | - self._random_seed: int | None = None |
144 | 151 | self.xp = jax.numpy |
145 | 152 | self.signal = signal |
146 | 153 | self.stats = stats |
@@ -173,6 +180,15 @@ def poisson_rvs( # type: ignore |
173 | 180 |
|
174 | 181 | return poisson(self._key, lam, shape=shape) |
175 | 182 |
|
| 183 | + def norm_rvs( |
| 184 | + self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None |
| 185 | + ) -> ArrayProtocol: |
| 186 | + from jax.random import normal |
| 187 | + |
| 188 | + std_samples = normal(self._key, shape=loc.shape) |
| 189 | + # scale and shift |
| 190 | + return std_samples * scale + loc # type: ignore |
| 191 | + |
176 | 192 | def fftconvolve( |
177 | 193 | self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full" |
178 | 194 | ) -> ArrT: |
@@ -211,6 +227,16 @@ def __init__(self) -> None: |
211 | 227 | self.j1 = special.j1 |
212 | 228 | self.map_coordinates = map_coordinates |
213 | 229 |
|
| 230 | + def poisson_rvs( |
| 231 | + self, lam: npt.ArrayLike, shape: Sequence[int] | None = None |
| 232 | + ) -> npt.NDArray: |
| 233 | + return self.xp.random.poisson(lam, shape) # type: ignore |
| 234 | + |
| 235 | + def norm_rvs( |
| 236 | + self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None |
| 237 | + ) -> ArrayProtocol: |
| 238 | + return self.xp.random.normal(loc, scale) # type: ignore |
| 239 | + |
214 | 240 | def fftconvolve( |
215 | 241 | self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full" |
216 | 242 | ) -> ArrT: |
|
0 commit comments