Skip to content

Commit ac92442

Browse files
authored
feat: cache psf (#26)
1 parent f373bc6 commit ac92442

2 files changed

Lines changed: 28 additions & 7 deletions

File tree

src/microsim/psf.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import cache
34
from typing import TYPE_CHECKING
45

56
import numpy as np
@@ -38,7 +39,7 @@ def simpson(
3839
ni2sin2theta = objective.ni**2 * sintheta**2
3940
nsroot = xp.sqrt(objective.ns**2 - ni2sin2theta)
4041
ngroot = xp.sqrt(objective.ng**2 - ni2sin2theta)
41-
_z = zv[:, xp.newaxis, xp.newaxis] if zv.ndim else zv
42+
_z = xp.asarray(zv if np.isscalar(zv) else zv[:, xp.newaxis, xp.newaxis])
4243
L0 = (
4344
objective.ni * (ci - _z) * costheta
4445
+ zp * nsroot
@@ -85,8 +86,9 @@ def _cast_objective(objective: ObjectiveKwargs | ObjectiveLens | None) -> Object
8586
raise TypeError(f"Expected ObjectiveLens, got {type(objective)}")
8687

8788

89+
@cache
8890
def vectorial_rz(
89-
zv: npt.NDArray,
91+
zv: Sequence[float],
9092
nx: int = 51,
9193
pos: tuple[float, float, float] = (0, 0, 0),
9294
dxy: float = 0.04,
@@ -123,7 +125,9 @@ def vectorial_rz(
123125

124126
step = p.half_angle / nSamples
125127
theta = xp.arange(1, nSamples + 1) * step
126-
simpson_integral = simpson(p, theta, constJ, zv, ci, zpos, wave_num, xp=xp)
128+
simpson_integral = simpson(
129+
p, theta, constJ, xp.asarray(zv), ci, zpos, wave_num, xp=xp
130+
)
127131
return 8.0 * np.pi / 3.0 * simpson_integral * (step / ud) ** 2
128132

129133

@@ -177,8 +181,9 @@ def rz_to_xyz(
177181
# return o.reshape((nx, ny, nz)).T
178182

179183

184+
@cache
180185
def vectorial_psf(
181-
zv: npt.NDArray,
186+
zv: Sequence[float],
182187
nx: int = 31,
183188
ny: int | None = None,
184189
pos: tuple[float, float, float] = (0, 0, 0),
@@ -190,7 +195,7 @@ def vectorial_psf(
190195
xp: NumpyAPI | None = None,
191196
) -> npt.NDArray:
192197
xp = NumpyAPI.create(xp)
193-
zv = xp.asarray(zv * 1e-6) # convert to meters
198+
zv = tuple(np.asarray(zv) * 1e-6) # convert to meters
194199
ny = ny or nx
195200
rz = vectorial_rz(
196201
zv, np.maximum(ny, nx), pos, dxy, wvl, objective=objective, sf=sf, xp=xp
@@ -203,9 +208,9 @@ def vectorial_psf(
203208
return _psf
204209

205210

206-
def _centered_zv(nz: int, dz: float, pz: float = 0) -> npt.NDArray:
211+
def _centered_zv(nz: int, dz: float, pz: float = 0) -> tuple[float, ...]:
207212
lim = (nz - 1) * dz / 2
208-
return np.linspace(-lim + pz, lim + pz, nz)
213+
return tuple(np.linspace(-lim + pz, lim + pz, nz))
209214

210215

211216
def vectorial_psf_centered(

src/microsim/schema/lens.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,22 @@ class ObjectiveLens(SimBaseModel):
3232

3333
magnification: float = Field(1, description="magnification of objective lens.")
3434

35+
def __hash__(self) -> int:
36+
return hash(
37+
(
38+
self.numerical_aperture,
39+
self.coverslip_ri,
40+
self.coverslip_ri_spec,
41+
self.immersion_medium_ri,
42+
self.immersion_medium_ri_spec,
43+
self.specimen_ri,
44+
self.working_distance,
45+
self.coverslip_thickness,
46+
self.coverslip_thickness_spec,
47+
self.magnification,
48+
)
49+
)
50+
3551
@model_validator(mode="before")
3652
def _vroot(cls, values: Any) -> Any:
3753
if isinstance(values, dict):

0 commit comments

Comments
 (0)