Skip to content

Commit 9426d7e

Browse files
fix cupy (#38)
* fix cupy * style(pre-commit.ci): auto fixes [...] * fix cupy random * fix jax and cupy * style(pre-commit.ci): auto fixes [...] * fix lint --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 07554d6 commit 9426d7e

File tree

6 files changed

+41
-10
lines changed

6 files changed

+41
-10
lines changed

src/microsim/psf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,9 @@ def rz_to_xyz(
157157
xp = NumpyAPI.create(xp)
158158

159159
# Create XY grid of radius values.
160-
rmap = radius_map(xyshape, off) * sf
160+
rmap = radius_map(xyshape, off, xp=xp) * sf
161161
nz = rz.shape[0]
162+
162163
out = xp.asarray(
163164
[
164165
xp.map_coordinates(
@@ -201,7 +202,7 @@ def vectorial_psf(
201202
).astype(xp.float_dtype)
202203

203204
offsets = xp.asarray(pos[:2]) / (dxy * 1e-6)
204-
_psf = rz_to_xyz(rz, (ny, nx), sf, off=offsets) # type: ignore [arg-type]
205+
_psf = rz_to_xyz(rz, (ny, nx), sf, off=offsets, xp=xp) # type: ignore [arg-type]
205206
if normalize:
206207
_psf /= xp.max(_psf)
207208
return _psf

src/microsim/schema/backend.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
import jax
1616

17+
from microsim._data_array import ArrayProtocol
18+
1719
_Shape: TypeAlias = tuple[int, ...]
1820

1921
# Anything that can be coerced to a shape tuple
@@ -47,18 +49,19 @@ def create(cls, backend: BackendName | NumpyAPI | None) -> NumpyAPI:
4749

4850
return NumpyAPI()
4951

52+
_random_seed: int | None = None
53+
_float_dtype: np.dtype | None = None
54+
5055
def __init__(self) -> None:
5156
from scipy import signal, special, stats
5257
from scipy.ndimage import map_coordinates
5358

54-
self._random_seed: int | None = None
5559
self.xp = np
5660
self.signal = signal
5761
self.stats = stats
5862
self.j0 = special.j0
5963
self.j1 = special.j1
6064
self.map_coordinates = map_coordinates
61-
self._float_dtype: np.dtype | None = None
6265

6366
@property
6467
def float_dtype(self) -> np.dtype | None:
@@ -103,6 +106,11 @@ def poisson_rvs(
103106
) -> npt.NDArray:
104107
return self.stats.poisson.rvs(lam, size=shape) # type: ignore
105108

109+
def norm_rvs(
110+
self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None
111+
) -> ArrayProtocol:
112+
return self.stats.norm.rvs(loc, scale) # type: ignore
113+
106114
def fftconvolve(
107115
self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full"
108116
) -> ArrT:
@@ -140,7 +148,6 @@ def __init__(self) -> None:
140148

141149
from ._jax_bessel import j0, j1
142150

143-
self._random_seed: int | None = None
144151
self.xp = jax.numpy
145152
self.signal = signal
146153
self.stats = stats
@@ -173,6 +180,15 @@ def poisson_rvs( # type: ignore
173180

174181
return poisson(self._key, lam, shape=shape)
175182

183+
def norm_rvs(
184+
self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None
185+
) -> ArrayProtocol:
186+
from jax.random import normal
187+
188+
std_samples = normal(self._key, shape=loc.shape)
189+
# scale and shift
190+
return std_samples * scale + loc # type: ignore
191+
176192
def fftconvolve(
177193
self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full"
178194
) -> ArrT:
@@ -211,6 +227,16 @@ def __init__(self) -> None:
211227
self.j1 = special.j1
212228
self.map_coordinates = map_coordinates
213229

230+
def poisson_rvs(
231+
self, lam: npt.ArrayLike, shape: Sequence[int] | None = None
232+
) -> npt.NDArray:
233+
return self.xp.random.poisson(lam, shape) # type: ignore
234+
235+
def norm_rvs(
236+
self, loc: ArrayProtocol, scale: npt.ArrayLike | None = None
237+
) -> ArrayProtocol:
238+
return self.xp.random.normal(loc, scale) # type: ignore
239+
214240
def fftconvolve(
215241
self, a: ArrT, b: ArrT, mode: Literal["full", "valid", "same"] = "full"
216242
) -> ArrT:

src/microsim/schema/detectors/_camera.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,11 @@ def adc_gain(self) -> float:
8181
def max_intensity(self) -> int:
8282
return int(2**self.bit_depth - 1)
8383

84-
def quantize_electrons(self, total_electrons: npt.NDArray) -> npt.NDArray:
85-
voltage = stats.norm.rvs(total_electrons, self.read_noise) * self.gain
86-
return np.round((voltage / self.adc_gain) + self.offset) # type: ignore
84+
def quantize_electrons(
85+
self, total_electrons: npt.NDArray, xp: NumpyAPI
86+
) -> npt.NDArray:
87+
voltage = xp.norm_rvs(total_electrons, self.read_noise) * self.gain
88+
return xp.round((voltage / self.adc_gain) + self.offset) # type: ignore
8789

8890

8991
class CameraCCD(Camera): ...

src/microsim/schema/detectors/_simulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def simulate_camera(
7474
total_electrons = camera.apply_em_gain(total_electrons)
7575

7676
# model read noise
77-
gray_values = camera.quantize_electrons(total_electrons)
77+
gray_values = camera.quantize_electrons(total_electrons, xp)
7878

7979
# sCMOS binning
8080
if binning > 1 and isinstance(camera, CameraCMOS):

src/microsim/util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,8 @@ def ortho_plot(
220220
import matplotlib.pyplot as plt
221221
from matplotlib.colors import PowerNorm
222222

223+
if hasattr(img, "get"):
224+
img = img.get()
223225
img = np.asarray(img)
224226
"""Plot XY and XZ slices of a 3D array."""
225227
_, ax = plt.subplots(ncols=2, figsize=(10, 5))

tests/test_simulation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def test_schema(
6767
assert type(out1.data).__module__.split(".")[0].startswith(np_backend)
6868

6969
out2 = sim1.run()
70-
if seed is None:
70+
if seed is None and np_backend != "jax":
7171
assert not np.allclose(out1, out2)
7272
else:
7373
np.testing.assert_allclose(out1, out2)

0 commit comments

Comments
 (0)