6
6
# pylint: disable = no-name-in-module
7
7
from __future__ import annotations
8
8
9
+ import array_api_extra as xpx
9
10
import numpy as np
11
+
12
+ # from scipy.special import ellipe, ellipk
13
+ from array_api_compat import array_namespace
10
14
from scipy .constants import mu_0 as MU0
11
- from scipy .special import ellipe , ellipk
12
15
13
16
from magpylib ._src .fields .special_cel import cel
17
+ from magpylib ._src .fields .special_elliptic import ellipe , ellipk
14
18
from magpylib ._src .input_checks import check_field_input
15
19
from magpylib ._src .utility import cart_to_cyl_coordinates , cyl_field_to_cart
16
20
@@ -58,22 +62,26 @@ def magnet_cylinder_axial_Bfield(z0: np.ndarray, r: np.ndarray, z: np.ndarray) -
58
62
-----
59
63
Implementation based on Derby, American Journal of Physics 78.3 (2010): 229-235.
60
64
"""
61
- n = len (z0 )
65
+ xp = array_namespace (z0 , r , z )
66
+ n = z0 .shape [0 ]
67
+ z0 = xp .astype (z0 , xp .float64 )
68
+ r = xp .astype (r , xp .float64 )
69
+ z = xp .astype (z , xp .float64 )
62
70
63
71
# some important quantities
64
72
zph , zmh = z + z0 , z - z0
65
73
dpr , dmr = 1 + r , 1 - r
66
74
67
- sq0 = np .sqrt (zmh ** 2 + dpr ** 2 )
68
- sq1 = np .sqrt (zph ** 2 + dpr ** 2 )
75
+ sq0 = xp .sqrt (zmh ** 2 + dpr ** 2 )
76
+ sq1 = xp .sqrt (zph ** 2 + dpr ** 2 )
69
77
70
- k1 = np .sqrt ((zph ** 2 + dmr ** 2 ) / (zph ** 2 + dpr ** 2 ))
71
- k0 = np .sqrt ((zmh ** 2 + dmr ** 2 ) / (zmh ** 2 + dpr ** 2 ))
78
+ k1 = xp .sqrt ((zph ** 2 + dmr ** 2 ) / (zph ** 2 + dpr ** 2 ))
79
+ k0 = xp .sqrt ((zmh ** 2 + dmr ** 2 ) / (zmh ** 2 + dpr ** 2 ))
72
80
gamma = dmr / dpr
73
- one = np .ones (n )
81
+ one = xp .ones (n , dtype = xp . float64 )
74
82
75
83
# radial field (unit polarization)
76
- Br = (cel (k1 , one , one , - one ) / sq1 - cel (k0 , one , one , - one ) / sq0 ) / np .pi
84
+ Br = (cel (k1 , one , one , - one ) / sq1 - cel (k0 , one , one , - one ) / sq0 ) / xp .pi
77
85
78
86
# axial field (unit polarization)
79
87
Bz = (
@@ -83,10 +91,10 @@ def magnet_cylinder_axial_Bfield(z0: np.ndarray, r: np.ndarray, z: np.ndarray) -
83
91
zph * cel (k1 , gamma ** 2 , one , gamma ) / sq1
84
92
- zmh * cel (k0 , gamma ** 2 , one , gamma ) / sq0
85
93
)
86
- / np .pi
94
+ / xp .pi
87
95
)
88
96
89
- return np . vstack ((Br , np .zeros (n ), Bz ))
97
+ return xp . stack ((Br , xp .zeros (n ), Bz ))
90
98
91
99
92
100
# CORE
@@ -145,11 +153,13 @@ def magnet_cylinder_diametral_Hfield(
145
153
(unpublished).
146
154
"""
147
155
# pylint: disable=too-many-statements
156
+ xp = array_namespace (z0 , r , z , phi )
148
157
149
- n = len ( z0 )
158
+ n = z0 . shape [ 0 ]
150
159
151
160
# allocate to treat small r special cases
152
- Hr , Hphi , Hz = np .empty ((3 , n ))
161
+ H = xp .empty ((3 , n ))
162
+ Hr , Hphi , Hz = H [0 , ...], H [1 , ...], H [2 , ...]
153
163
154
164
# compute repeated quantities for all cases
155
165
zp = z + z0
@@ -162,7 +172,7 @@ def magnet_cylinder_diametral_Hfield(
162
172
# case small_r: numerical instability of general solution
163
173
mask_small_r = r < 0.05
164
174
mask_general = ~ mask_small_r
165
- if np .any (mask_small_r ):
175
+ if xp .any (mask_small_r ):
166
176
phiX = phi [mask_small_r ]
167
177
zpX , zmX = zp [mask_small_r ], zm [mask_small_r ]
168
178
zp2X , zm2X = zp2 [mask_small_r ], zm2 [mask_small_r ]
@@ -171,8 +181,8 @@ def magnet_cylinder_diametral_Hfield(
171
181
# taylor series for small r
172
182
zpp = zp2X + 1
173
183
zmm = zm2X + 1
174
- sqrt_p = np .sqrt (zpp )
175
- sqrt_m = np .sqrt (zmm )
184
+ sqrt_p = xp .sqrt (zpp )
185
+ sqrt_m = xp .sqrt (zmm )
176
186
177
187
frac1 = zpX / sqrt_p
178
188
frac2 = zmX / sqrt_m
@@ -189,12 +199,12 @@ def magnet_cylinder_diametral_Hfield(
189
199
* r4X
190
200
)
191
201
192
- Hr [mask_small_r ] = - np .cos (phiX ) / 4 * (term1 + 9 * term2 + 25 * term3 )
202
+ Hr [mask_small_r ] = - xp .cos (phiX ) / 4 * (term1 + 9 * term2 + 25 * term3 )
193
203
194
- Hphi [mask_small_r ] = np .sin (phiX ) / 4 * (term1 + 3 * term2 + 5 * term3 )
204
+ Hphi [mask_small_r ] = xp .sin (phiX ) / 4 * (term1 + 3 * term2 + 5 * term3 )
195
205
196
206
Hz [mask_small_r ] = (
197
- - np .cos (phiX )
207
+ - xp .cos (phiX )
198
208
/ 4
199
209
* (
200
210
rX * (1 / zpp / sqrt_p - 1 / zmm / sqrt_m )
@@ -215,21 +225,21 @@ def magnet_cylinder_diametral_Hfield(
215
225
# if there are small_r, select the general/case variables
216
226
# when there are no small_r cases it is not necessary to slice with [True, True, Tue,...]
217
227
phi = phi [mask_general ]
218
- n = len ( phi )
228
+ n = phi . shape [ 0 ]
219
229
zp , zm = zp [mask_general ], zm [mask_general ]
220
230
zp2 , zm2 = zp2 [mask_general ], zm2 [mask_general ]
221
231
r , r2 = r [mask_general ], r2 [mask_general ]
222
232
223
- if np .any (mask_general ):
233
+ if xp .any (mask_general ):
224
234
rp = r + 1
225
235
rm = r - 1
226
236
rp2 = rp ** 2
227
237
rm2 = rm ** 2
228
238
229
239
ap2 = zp2 + rm ** 2
230
240
am2 = zm2 + rm ** 2
231
- ap = np .sqrt (ap2 )
232
- am = np .sqrt (am2 )
241
+ ap = xp .sqrt (ap2 )
242
+ am = xp .sqrt (am2 )
233
243
234
244
argp = - 4 * r / ap2
235
245
argm = - 4 * r / am2
@@ -238,24 +248,24 @@ def magnet_cylinder_diametral_Hfield(
238
248
# result is numerically stable in the vicinity of of r=r0
239
249
# so only the special case must be caught (not the surroundings)
240
250
mask_special = rm == 0
241
- argc = np .ones (n ) * 1e16 # should be np.Inf but leads to 1/0 problems in cel
251
+ argc = xp .ones (n ) * 1e16 # should be np.Inf but leads to 1/0 problems in cel
242
252
argc [~ mask_special ] = - 4 * r [~ mask_special ] / rm2 [~ mask_special ]
243
253
# special case 1/rm
244
- one_over_rm = np .zeros (n )
254
+ one_over_rm = xp .zeros (n )
245
255
one_over_rm [~ mask_special ] = 1 / rm [~ mask_special ]
246
256
247
257
elle_p = ellipe (argp )
248
258
elle_m = ellipe (argm )
249
259
ellk_p = ellipk (argp )
250
260
ellk_m = ellipk (argm )
251
- onez = np .ones (n )
252
- ellpi_p = cel (np .sqrt (1 - argp ), 1 - argc , onez , onez ) # elliptic_Pi
253
- ellpi_m = cel (np .sqrt (1 - argm ), 1 - argc , onez , onez ) # elliptic_Pi
261
+ onez = xp .ones (n )
262
+ ellpi_p = cel (xp .sqrt (1 - argp ), 1 - argc , onez , onez ) # elliptic_Pi
263
+ ellpi_m = cel (xp .sqrt (1 - argm ), 1 - argc , onez , onez ) # elliptic_Pi
254
264
255
265
# compute fields
256
266
Hr [mask_general ] = (
257
- - np .cos (phi )
258
- / (4 * np .pi * r2 )
267
+ - xp .cos (phi )
268
+ / (4 * xp .pi * r2 )
259
269
* (
260
270
- zm * am * elle_m
261
271
+ zp * ap * elle_p
@@ -266,8 +276,8 @@ def magnet_cylinder_diametral_Hfield(
266
276
)
267
277
268
278
Hphi [mask_general ] = (
269
- np .sin (phi )
270
- / (4 * np .pi * r2 )
279
+ xp .sin (phi )
280
+ / (4 * xp .pi * r2 )
271
281
* (
272
282
+ zm * am * elle_m
273
283
- zp * ap * elle_p
@@ -279,8 +289,8 @@ def magnet_cylinder_diametral_Hfield(
279
289
)
280
290
281
291
Hz [mask_general ] = (
282
- - np .cos (phi )
283
- / (2 * np .pi * r )
292
+ - xp .cos (phi )
293
+ / (2 * xp .pi * r )
284
294
* (
285
295
+ am * elle_m
286
296
- ap * elle_p
@@ -289,7 +299,7 @@ def magnet_cylinder_diametral_Hfield(
289
299
)
290
300
)
291
301
292
- return np . vstack ((Hr , Hphi , Hz ))
302
+ return xp . stack ((Hr , Hphi , Hz ))
293
303
294
304
295
305
def BHJM_magnet_cylinder (
@@ -304,21 +314,26 @@ def BHJM_magnet_cylinder(
304
314
"""
305
315
306
316
check_field_input (field )
317
+ xp = array_namespace (observers , dimension , polarization )
318
+ observers = xp .astype (observers , xp .float64 )
319
+ dimension = xp .astype (dimension , xp .float64 )
320
+ polarization = xp .astype (polarization , xp .float64 )
307
321
308
322
# transform to Cy CS --------------------------------------------
309
323
r , phi , z = cart_to_cyl_coordinates (observers )
310
- r0 , z0 = dimension .T / 2
324
+ dims = dimension .T / 2
325
+ r0 , z0 = dims [0 , ...], dims [1 , ...]
311
326
312
327
# scale invariance (make dimensionless)
313
328
r = r / r0
314
329
z = z / r0
315
330
z0 = z0 / r0
316
331
317
332
# allocate for output
318
- BHJM = polarization .astype (float )
333
+ BHJM = xp .astype (polarization , ( xp . float64 ) )
319
334
320
335
# inside/outside
321
- mask_between_bases = np .abs (z ) <= z0 # in-between top and bottom plane
336
+ mask_between_bases = xp .abs (z ) <= z0 # in-between top and bottom plane
322
337
mask_inside_hull = r <= 1 # inside Cylinder hull plane
323
338
mask_inside = mask_between_bases & mask_inside_hull
324
339
@@ -331,17 +346,23 @@ def BHJM_magnet_cylinder(
331
346
return BHJM / MU0
332
347
333
348
# SPECIAL CASE 1: on Cylinder edge
334
- mask_on_hull = np .isclose (r , 1 , rtol = 1e-15 , atol = 0 ) # on Cylinder hull plane
335
- mask_on_bases = np .isclose (abs (z ), z0 , rtol = 1e-15 , atol = 0 ) # on top or bottom plane
349
+ mask_on_hull = xpx .isclose (r , 1 , rtol = 1e-15 , atol = 0 ) # on Cylinder hull plane
350
+ mask_on_bases = xpx .isclose (
351
+ abs (z ), z0 , rtol = 1e-15 , atol = 0
352
+ ) # on top or bottom plane
336
353
mask_not_on_edge = ~ (mask_on_hull & mask_on_bases )
337
354
338
355
# axial/transv polarization cases
339
- pol_x , pol_y , pol_z = polarization .T
356
+ pol_x , pol_y , pol_z = (
357
+ polarization [..., 0 ],
358
+ polarization [..., 1 ],
359
+ polarization [..., 2 ],
360
+ )
340
361
mask_pol_tv = (pol_x != 0 ) | (pol_y != 0 )
341
362
mask_pol_ax = pol_z != 0
342
363
343
364
# SPECIAL CASE 2: pol = 0
344
- mask_pol_not_null = ~ (( pol_x == 0 ) * (pol_y == 0 ) * (pol_z == 0 ) )
365
+ mask_pol_not_null = ( pol_x != 0 ) | (pol_y != 0 ) | (pol_z != 0 )
345
366
346
367
# general case
347
368
mask_gen = mask_pol_not_null & mask_not_on_edge
@@ -354,9 +375,9 @@ def BHJM_magnet_cylinder(
354
375
BHJM *= 0
355
376
356
377
# transversal polarization contributions -----------------------
357
- if any (mask_pol_tv ):
358
- pol_xy = np .sqrt (pol_x ** 2 + pol_y ** 2 )[mask_pol_tv ]
359
- tetta = np . arctan2 (pol_y [mask_pol_tv ], pol_x [mask_pol_tv ])
378
+ if xp . any (mask_pol_tv ):
379
+ pol_xy = xp .sqrt (pol_x ** 2 + pol_y ** 2 )[mask_pol_tv ]
380
+ tetta = xp . atan2 (pol_y [mask_pol_tv ], pol_x [mask_pol_tv ])
360
381
361
382
BHJM [mask_pol_tv ] = (
362
383
magnet_cylinder_diametral_Hfield (
@@ -369,30 +390,42 @@ def BHJM_magnet_cylinder(
369
390
).T
370
391
371
392
# axial polarization contributions ----------------------------
372
- if any (mask_pol_ax ):
373
- BHJM [mask_pol_ax ] += (
374
- magnet_cylinder_axial_Bfield (
375
- z0 = z0 [mask_pol_ax ],
376
- r = r [mask_pol_ax ],
377
- z = z [mask_pol_ax ],
378
- )
379
- * pol_z [mask_pol_ax ]
380
- ).T
393
+ if xp .any (mask_pol_ax ):
394
+ BHJM [mask_pol_ax ] = (
395
+ BHJM [mask_pol_ax ]
396
+ + (
397
+ magnet_cylinder_axial_Bfield (
398
+ z0 = z0 [mask_pol_ax ],
399
+ r = r [mask_pol_ax ],
400
+ z = z [mask_pol_ax ],
401
+ )
402
+ * pol_z [mask_pol_ax ]
403
+ ).T
404
+ )
381
405
382
406
BHJM [:, 0 ], BHJM [:, 1 ] = cyl_field_to_cart (phi , BHJM [:, 0 ], BHJM [:, 1 ])
383
407
384
408
# add/subtract Mag when inside for B/H
385
409
if field == "B" :
386
- mask_tv_inside = mask_pol_tv * mask_inside
387
- if any (mask_tv_inside ): # tv computes H-field
388
- BHJM [mask_tv_inside , 0 ] += pol_x [mask_tv_inside ]
389
- BHJM [mask_tv_inside , 1 ] += pol_y [mask_tv_inside ]
410
+ mask_tv_inside = mask_pol_tv & mask_inside
411
+ mask_tv_inside = xp .broadcast_to (mask_tv_inside [:, xp .newaxis ], BHJM .shape )
412
+ mask_tv_inside = xpx .at (mask_tv_inside )[:, 2 ].set (False , copy = True )
413
+
414
+ if xp .any (mask_tv_inside ): # tv computes H-field
415
+ BHJM = xpx .at (BHJM )[mask_tv_inside ].set (
416
+ BHJM [mask_tv_inside ] + polarization [mask_tv_inside ]
417
+ )
418
+ # BHJM[:, 1] += pol_y * mask_tv_inside
390
419
return BHJM
391
420
392
421
if field == "H" :
393
- mask_ax_inside = mask_pol_ax * mask_inside
394
- if any (mask_ax_inside ): # ax computes B-field
395
- BHJM [mask_ax_inside , 2 ] -= pol_z [mask_ax_inside ]
422
+ mask_ax_inside = mask_pol_ax & mask_inside
423
+ mask_ax_inside = xp .broadcast_to (mask_ax_inside [:, xp .newaxis ], BHJM .shape )
424
+ mask_ax_inside = xpx .at (mask_ax_inside )[:, :2 ].set (False , copy = True )
425
+ if xp .any (mask_ax_inside ): # ax computes B-field
426
+ BHJM = xpx .at (BHJM )[mask_ax_inside ].set (
427
+ BHJM [mask_ax_inside ] - polarization [mask_ax_inside ]
428
+ )
396
429
return BHJM / MU0
397
430
398
431
msg = f"`output_field_type` must be one of ('B', 'H', 'M', 'J'), got { field !r} "
0 commit comments