Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5b2ceb5

Browse files
committed
Adds array api support for cylinder
1 parent 84e18d0 commit 5b2ceb5

File tree

4 files changed

+176
-111
lines changed

4 files changed

+176
-111
lines changed

src/magpylib/_src/fields/field_BH_cylinder.py

Lines changed: 93 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
# pylint: disable = no-name-in-module
77
from __future__ import annotations
88

9+
import array_api_extra as xpx
910
import numpy as np
11+
12+
# from scipy.special import ellipe, ellipk
13+
from array_api_compat import array_namespace
1014
from scipy.constants import mu_0 as MU0
11-
from scipy.special import ellipe, ellipk
1215

1316
from magpylib._src.fields.special_cel import cel
17+
from magpylib._src.fields.special_elliptic import ellipe, ellipk
1418
from magpylib._src.input_checks import check_field_input
1519
from magpylib._src.utility import cart_to_cyl_coordinates, cyl_field_to_cart
1620

@@ -58,22 +62,26 @@ def magnet_cylinder_axial_Bfield(z0: np.ndarray, r: np.ndarray, z: np.ndarray) -
5862
-----
5963
Implementation based on Derby, American Journal of Physics 78.3 (2010): 229-235.
6064
"""
61-
n = len(z0)
65+
xp = array_namespace(z0, r, z)
66+
n = z0.shape[0]
67+
z0 = xp.astype(z0, xp.float64)
68+
r = xp.astype(r, xp.float64)
69+
z = xp.astype(z, xp.float64)
6270

6371
# some important quantities
6472
zph, zmh = z + z0, z - z0
6573
dpr, dmr = 1 + r, 1 - r
6674

67-
sq0 = np.sqrt(zmh**2 + dpr**2)
68-
sq1 = np.sqrt(zph**2 + dpr**2)
75+
sq0 = xp.sqrt(zmh**2 + dpr**2)
76+
sq1 = xp.sqrt(zph**2 + dpr**2)
6977

70-
k1 = np.sqrt((zph**2 + dmr**2) / (zph**2 + dpr**2))
71-
k0 = np.sqrt((zmh**2 + dmr**2) / (zmh**2 + dpr**2))
78+
k1 = xp.sqrt((zph**2 + dmr**2) / (zph**2 + dpr**2))
79+
k0 = xp.sqrt((zmh**2 + dmr**2) / (zmh**2 + dpr**2))
7280
gamma = dmr / dpr
73-
one = np.ones(n)
81+
one = xp.ones(n, dtype=xp.float64)
7482

7583
# radial field (unit polarization)
76-
Br = (cel(k1, one, one, -one) / sq1 - cel(k0, one, one, -one) / sq0) / np.pi
84+
Br = (cel(k1, one, one, -one) / sq1 - cel(k0, one, one, -one) / sq0) / xp.pi
7785

7886
# axial field (unit polarization)
7987
Bz = (
@@ -83,10 +91,10 @@ def magnet_cylinder_axial_Bfield(z0: np.ndarray, r: np.ndarray, z: np.ndarray) -
8391
zph * cel(k1, gamma**2, one, gamma) / sq1
8492
- zmh * cel(k0, gamma**2, one, gamma) / sq0
8593
)
86-
/ np.pi
94+
/ xp.pi
8795
)
8896

89-
return np.vstack((Br, np.zeros(n), Bz))
97+
return xp.stack((Br, xp.zeros(n), Bz))
9098

9199

92100
# CORE
@@ -145,11 +153,13 @@ def magnet_cylinder_diametral_Hfield(
145153
(unpublished).
146154
"""
147155
# pylint: disable=too-many-statements
156+
xp = array_namespace(z0, r, z, phi)
148157

149-
n = len(z0)
158+
n = z0.shape[0]
150159

151160
# allocate to treat small r special cases
152-
Hr, Hphi, Hz = np.empty((3, n))
161+
H = xp.empty((3, n))
162+
Hr, Hphi, Hz = H[0, ...], H[1, ...], H[2, ...]
153163

154164
# compute repeated quantities for all cases
155165
zp = z + z0
@@ -162,7 +172,7 @@ def magnet_cylinder_diametral_Hfield(
162172
# case small_r: numerical instability of general solution
163173
mask_small_r = r < 0.05
164174
mask_general = ~mask_small_r
165-
if np.any(mask_small_r):
175+
if xp.any(mask_small_r):
166176
phiX = phi[mask_small_r]
167177
zpX, zmX = zp[mask_small_r], zm[mask_small_r]
168178
zp2X, zm2X = zp2[mask_small_r], zm2[mask_small_r]
@@ -171,8 +181,8 @@ def magnet_cylinder_diametral_Hfield(
171181
# taylor series for small r
172182
zpp = zp2X + 1
173183
zmm = zm2X + 1
174-
sqrt_p = np.sqrt(zpp)
175-
sqrt_m = np.sqrt(zmm)
184+
sqrt_p = xp.sqrt(zpp)
185+
sqrt_m = xp.sqrt(zmm)
176186

177187
frac1 = zpX / sqrt_p
178188
frac2 = zmX / sqrt_m
@@ -189,12 +199,12 @@ def magnet_cylinder_diametral_Hfield(
189199
* r4X
190200
)
191201

192-
Hr[mask_small_r] = -np.cos(phiX) / 4 * (term1 + 9 * term2 + 25 * term3)
202+
Hr[mask_small_r] = -xp.cos(phiX) / 4 * (term1 + 9 * term2 + 25 * term3)
193203

194-
Hphi[mask_small_r] = np.sin(phiX) / 4 * (term1 + 3 * term2 + 5 * term3)
204+
Hphi[mask_small_r] = xp.sin(phiX) / 4 * (term1 + 3 * term2 + 5 * term3)
195205

196206
Hz[mask_small_r] = (
197-
-np.cos(phiX)
207+
-xp.cos(phiX)
198208
/ 4
199209
* (
200210
rX * (1 / zpp / sqrt_p - 1 / zmm / sqrt_m)
@@ -215,21 +225,21 @@ def magnet_cylinder_diametral_Hfield(
215225
# if there are small_r, select the general/case variables
216226
# when there are no small_r cases it is not necessary to slice with [True, True, Tue,...]
217227
phi = phi[mask_general]
218-
n = len(phi)
228+
n = phi.shape[0]
219229
zp, zm = zp[mask_general], zm[mask_general]
220230
zp2, zm2 = zp2[mask_general], zm2[mask_general]
221231
r, r2 = r[mask_general], r2[mask_general]
222232

223-
if np.any(mask_general):
233+
if xp.any(mask_general):
224234
rp = r + 1
225235
rm = r - 1
226236
rp2 = rp**2
227237
rm2 = rm**2
228238

229239
ap2 = zp2 + rm**2
230240
am2 = zm2 + rm**2
231-
ap = np.sqrt(ap2)
232-
am = np.sqrt(am2)
241+
ap = xp.sqrt(ap2)
242+
am = xp.sqrt(am2)
233243

234244
argp = -4 * r / ap2
235245
argm = -4 * r / am2
@@ -238,24 +248,24 @@ def magnet_cylinder_diametral_Hfield(
238248
# result is numerically stable in the vicinity of of r=r0
239249
# so only the special case must be caught (not the surroundings)
240250
mask_special = rm == 0
241-
argc = np.ones(n) * 1e16 # should be np.Inf but leads to 1/0 problems in cel
251+
argc = xp.ones(n) * 1e16 # should be np.Inf but leads to 1/0 problems in cel
242252
argc[~mask_special] = -4 * r[~mask_special] / rm2[~mask_special]
243253
# special case 1/rm
244-
one_over_rm = np.zeros(n)
254+
one_over_rm = xp.zeros(n)
245255
one_over_rm[~mask_special] = 1 / rm[~mask_special]
246256

247257
elle_p = ellipe(argp)
248258
elle_m = ellipe(argm)
249259
ellk_p = ellipk(argp)
250260
ellk_m = ellipk(argm)
251-
onez = np.ones(n)
252-
ellpi_p = cel(np.sqrt(1 - argp), 1 - argc, onez, onez) # elliptic_Pi
253-
ellpi_m = cel(np.sqrt(1 - argm), 1 - argc, onez, onez) # elliptic_Pi
261+
onez = xp.ones(n)
262+
ellpi_p = cel(xp.sqrt(1 - argp), 1 - argc, onez, onez) # elliptic_Pi
263+
ellpi_m = cel(xp.sqrt(1 - argm), 1 - argc, onez, onez) # elliptic_Pi
254264

255265
# compute fields
256266
Hr[mask_general] = (
257-
-np.cos(phi)
258-
/ (4 * np.pi * r2)
267+
-xp.cos(phi)
268+
/ (4 * xp.pi * r2)
259269
* (
260270
-zm * am * elle_m
261271
+ zp * ap * elle_p
@@ -266,8 +276,8 @@ def magnet_cylinder_diametral_Hfield(
266276
)
267277

268278
Hphi[mask_general] = (
269-
np.sin(phi)
270-
/ (4 * np.pi * r2)
279+
xp.sin(phi)
280+
/ (4 * xp.pi * r2)
271281
* (
272282
+zm * am * elle_m
273283
- zp * ap * elle_p
@@ -279,8 +289,8 @@ def magnet_cylinder_diametral_Hfield(
279289
)
280290

281291
Hz[mask_general] = (
282-
-np.cos(phi)
283-
/ (2 * np.pi * r)
292+
-xp.cos(phi)
293+
/ (2 * xp.pi * r)
284294
* (
285295
+am * elle_m
286296
- ap * elle_p
@@ -289,7 +299,7 @@ def magnet_cylinder_diametral_Hfield(
289299
)
290300
)
291301

292-
return np.vstack((Hr, Hphi, Hz))
302+
return xp.stack((Hr, Hphi, Hz))
293303

294304

295305
def BHJM_magnet_cylinder(
@@ -304,21 +314,26 @@ def BHJM_magnet_cylinder(
304314
"""
305315

306316
check_field_input(field)
317+
xp = array_namespace(observers, dimension, polarization)
318+
observers = xp.astype(observers, xp.float64)
319+
dimension = xp.astype(dimension, xp.float64)
320+
polarization = xp.astype(polarization, xp.float64)
307321

308322
# transform to Cy CS --------------------------------------------
309323
r, phi, z = cart_to_cyl_coordinates(observers)
310-
r0, z0 = dimension.T / 2
324+
dims = dimension.T / 2
325+
r0, z0 = dims[0, ...], dims[1, ...]
311326

312327
# scale invariance (make dimensionless)
313328
r = r / r0
314329
z = z / r0
315330
z0 = z0 / r0
316331

317332
# allocate for output
318-
BHJM = polarization.astype(float)
333+
BHJM = xp.astype(polarization, (xp.float64))
319334

320335
# inside/outside
321-
mask_between_bases = np.abs(z) <= z0 # in-between top and bottom plane
336+
mask_between_bases = xp.abs(z) <= z0 # in-between top and bottom plane
322337
mask_inside_hull = r <= 1 # inside Cylinder hull plane
323338
mask_inside = mask_between_bases & mask_inside_hull
324339

@@ -331,17 +346,23 @@ def BHJM_magnet_cylinder(
331346
return BHJM / MU0
332347

333348
# SPECIAL CASE 1: on Cylinder edge
334-
mask_on_hull = np.isclose(r, 1, rtol=1e-15, atol=0) # on Cylinder hull plane
335-
mask_on_bases = np.isclose(abs(z), z0, rtol=1e-15, atol=0) # on top or bottom plane
349+
mask_on_hull = xpx.isclose(r, 1, rtol=1e-15, atol=0) # on Cylinder hull plane
350+
mask_on_bases = xpx.isclose(
351+
abs(z), z0, rtol=1e-15, atol=0
352+
) # on top or bottom plane
336353
mask_not_on_edge = ~(mask_on_hull & mask_on_bases)
337354

338355
# axial/transv polarization cases
339-
pol_x, pol_y, pol_z = polarization.T
356+
pol_x, pol_y, pol_z = (
357+
polarization[..., 0],
358+
polarization[..., 1],
359+
polarization[..., 2],
360+
)
340361
mask_pol_tv = (pol_x != 0) | (pol_y != 0)
341362
mask_pol_ax = pol_z != 0
342363

343364
# SPECIAL CASE 2: pol = 0
344-
mask_pol_not_null = ~((pol_x == 0) * (pol_y == 0) * (pol_z == 0))
365+
mask_pol_not_null = (pol_x != 0) | (pol_y != 0) | (pol_z != 0)
345366

346367
# general case
347368
mask_gen = mask_pol_not_null & mask_not_on_edge
@@ -354,9 +375,9 @@ def BHJM_magnet_cylinder(
354375
BHJM *= 0
355376

356377
# transversal polarization contributions -----------------------
357-
if any(mask_pol_tv):
358-
pol_xy = np.sqrt(pol_x**2 + pol_y**2)[mask_pol_tv]
359-
tetta = np.arctan2(pol_y[mask_pol_tv], pol_x[mask_pol_tv])
378+
if xp.any(mask_pol_tv):
379+
pol_xy = xp.sqrt(pol_x**2 + pol_y**2)[mask_pol_tv]
380+
tetta = xp.atan2(pol_y[mask_pol_tv], pol_x[mask_pol_tv])
360381

361382
BHJM[mask_pol_tv] = (
362383
magnet_cylinder_diametral_Hfield(
@@ -369,30 +390,42 @@ def BHJM_magnet_cylinder(
369390
).T
370391

371392
# axial polarization contributions ----------------------------
372-
if any(mask_pol_ax):
373-
BHJM[mask_pol_ax] += (
374-
magnet_cylinder_axial_Bfield(
375-
z0=z0[mask_pol_ax],
376-
r=r[mask_pol_ax],
377-
z=z[mask_pol_ax],
378-
)
379-
* pol_z[mask_pol_ax]
380-
).T
393+
if xp.any(mask_pol_ax):
394+
BHJM[mask_pol_ax] = (
395+
BHJM[mask_pol_ax]
396+
+ (
397+
magnet_cylinder_axial_Bfield(
398+
z0=z0[mask_pol_ax],
399+
r=r[mask_pol_ax],
400+
z=z[mask_pol_ax],
401+
)
402+
* pol_z[mask_pol_ax]
403+
).T
404+
)
381405

382406
BHJM[:, 0], BHJM[:, 1] = cyl_field_to_cart(phi, BHJM[:, 0], BHJM[:, 1])
383407

384408
# add/subtract Mag when inside for B/H
385409
if field == "B":
386-
mask_tv_inside = mask_pol_tv * mask_inside
387-
if any(mask_tv_inside): # tv computes H-field
388-
BHJM[mask_tv_inside, 0] += pol_x[mask_tv_inside]
389-
BHJM[mask_tv_inside, 1] += pol_y[mask_tv_inside]
410+
mask_tv_inside = mask_pol_tv & mask_inside
411+
mask_tv_inside = xp.broadcast_to(mask_tv_inside[:, xp.newaxis], BHJM.shape)
412+
mask_tv_inside = xpx.at(mask_tv_inside)[:, 2].set(False, copy=True)
413+
414+
if xp.any(mask_tv_inside): # tv computes H-field
415+
BHJM = xpx.at(BHJM)[mask_tv_inside].set(
416+
BHJM[mask_tv_inside] + polarization[mask_tv_inside]
417+
)
418+
# BHJM[:, 1] += pol_y * mask_tv_inside
390419
return BHJM
391420

392421
if field == "H":
393-
mask_ax_inside = mask_pol_ax * mask_inside
394-
if any(mask_ax_inside): # ax computes B-field
395-
BHJM[mask_ax_inside, 2] -= pol_z[mask_ax_inside]
422+
mask_ax_inside = mask_pol_ax & mask_inside
423+
mask_ax_inside = xp.broadcast_to(mask_ax_inside[:, xp.newaxis], BHJM.shape)
424+
mask_ax_inside = xpx.at(mask_ax_inside)[:, :2].set(False, copy=True)
425+
if xp.any(mask_ax_inside): # ax computes B-field
426+
BHJM = xpx.at(BHJM)[mask_ax_inside].set(
427+
BHJM[mask_ax_inside] - polarization[mask_ax_inside]
428+
)
396429
return BHJM / MU0
397430

398431
msg = f"`output_field_type` must be one of ('B', 'H', 'M', 'J'), got {field!r}"

0 commit comments

Comments
 (0)