11from __future__ import annotations
22
3+ from functools import cache
34from typing import TYPE_CHECKING
45
56import numpy as np
@@ -38,7 +39,7 @@ def simpson(
3839 ni2sin2theta = objective .ni ** 2 * sintheta ** 2
3940 nsroot = xp .sqrt (objective .ns ** 2 - ni2sin2theta )
4041 ngroot = xp .sqrt (objective .ng ** 2 - ni2sin2theta )
41- _z = zv [:, xp .newaxis , xp .newaxis ] if zv . ndim else zv
42+ _z = xp . asarray ( zv if np . isscalar ( zv ) else zv [:, xp .newaxis , xp .newaxis ])
4243 L0 = (
4344 objective .ni * (ci - _z ) * costheta
4445 + zp * nsroot
@@ -85,8 +86,9 @@ def _cast_objective(objective: ObjectiveKwargs | ObjectiveLens | None) -> Object
8586 raise TypeError (f"Expected ObjectiveLens, got { type (objective )} " )
8687
8788
89+ @cache
8890def vectorial_rz (
89- zv : npt . NDArray ,
91+ zv : Sequence [ float ] ,
9092 nx : int = 51 ,
9193 pos : tuple [float , float , float ] = (0 , 0 , 0 ),
9294 dxy : float = 0.04 ,
@@ -123,7 +125,9 @@ def vectorial_rz(
123125
124126 step = p .half_angle / nSamples
125127 theta = xp .arange (1 , nSamples + 1 ) * step
126- simpson_integral = simpson (p , theta , constJ , zv , ci , zpos , wave_num , xp = xp )
128+ simpson_integral = simpson (
129+ p , theta , constJ , xp .asarray (zv ), ci , zpos , wave_num , xp = xp
130+ )
127131 return 8.0 * np .pi / 3.0 * simpson_integral * (step / ud ) ** 2
128132
129133
@@ -177,8 +181,9 @@ def rz_to_xyz(
177181# return o.reshape((nx, ny, nz)).T
178182
179183
184+ @cache
180185def vectorial_psf (
181- zv : npt . NDArray ,
186+ zv : Sequence [ float ] ,
182187 nx : int = 31 ,
183188 ny : int | None = None ,
184189 pos : tuple [float , float , float ] = (0 , 0 , 0 ),
@@ -190,7 +195,7 @@ def vectorial_psf(
190195 xp : NumpyAPI | None = None ,
191196) -> npt .NDArray :
192197 xp = NumpyAPI .create (xp )
193- zv = xp .asarray (zv * 1e-6 ) # convert to meters
198+ zv = tuple ( np .asarray (zv ) * 1e-6 ) # convert to meters
194199 ny = ny or nx
195200 rz = vectorial_rz (
196201 zv , np .maximum (ny , nx ), pos , dxy , wvl , objective = objective , sf = sf , xp = xp
@@ -203,9 +208,9 @@ def vectorial_psf(
203208 return _psf
204209
205210
206- def _centered_zv (nz : int , dz : float , pz : float = 0 ) -> npt . NDArray :
211+ def _centered_zv (nz : int , dz : float , pz : float = 0 ) -> tuple [ float , ...] :
207212 lim = (nz - 1 ) * dz / 2
208- return np .linspace (- lim + pz , lim + pz , nz )
213+ return tuple ( np .linspace (- lim + pz , lim + pz , nz ) )
209214
210215
211216def vectorial_psf_centered (
0 commit comments