diff --git a/CHANGELOG.md b/CHANGELOG.md index 2703d0cdd..65a38d1c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,8 @@ # Changelog +## [Unreleased] - YYYY-MM-DD +- Added `units_length` input to the `show` function to allow displaying axes with different length units. This parameter can be set individually for each subplot. ([#786](https://github.com/magpylib/magpylib/pull/786)) + ## [5.0.4] - 2024-06-18 - Add support for Numpy 2.0 ([#795](https://github.com/magpylib/magpylib/pull/789)) - Fix markers legend not being suppressible ([#795](https://github.com/magpylib/magpylib/pull/789)) diff --git a/README.md b/README.md index 485c0f9fb..e53873454 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ Conda Cloud - MyBinder link + MyBinder link black @@ -136,7 +136,7 @@ A valid software citation could be author = {{Michael-Ortner et al.}}, title = {magpylib}, url = {https://magpylib.readthedocs.io/en/latest/}, - version = {5.0.4}, + version = {5.1.0dev}, date = {2023-06-25}, } ``` diff --git a/docs/_pages/user_guide/docs/docs_graphics.md b/docs/_pages/user_guide/docs/docs_graphics.md index 4ab213db8..6cce8f201 100644 --- a/docs/_pages/user_guide/docs/docs_graphics.md +++ b/docs/_pages/user_guide/docs/docs_graphics.md @@ -4,9 +4,9 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.0 + jupytext_version: 1.16.1 kernelspec: - display_name: Python 3 + display_name: Python 3 (ipykernel) language: python name: python3 orphan: true @@ -422,3 +422,23 @@ with magpy.show_context(loop, sens, animation=True) as sc: sc.show(output=["Hx", "Hy", "Hz"], row=2) sc.show(output="Hxyz", col=2, row=2) ``` + +### Canvas length units + +When displaying very small Magpylib objects, the axes scaling in meters might be inadequate and you may want to use other units that fit the system dimensions more nicely. The example below shows how to display an object (in this case the same) with different length units and zoom levels. + +```{tip} +Setting `units_length="auto"` will infer the most suitable units based on the maximum range of the system. +``` + +```{code-cell} ipython3 +import magpylib as magpy + +c1 = magpy.magnet.Cuboid(dimension=(0.001, 0.001, 0.001), polarization=(1, 2, 3)) + +with magpy.show_context(c1, backend="matplotlib") as s: + s.show(row=1, col=1, units_length="auto", zoom=0) + s.show(row=1, col=2, units_length="mm", zoom=1) + s.show(row=2, col=1, units_length="µm", zoom=2) + s.show(row=2, col=2, units_length="m", zoom=3) +``` diff --git a/docs/_pages/user_guide/examples/examples_vis_pv_streamlines.md b/docs/_pages/user_guide/examples/examples_vis_pv_streamlines.md index ac1e53b3d..8fc2fd30a 100644 --- a/docs/_pages/user_guide/examples/examples_vis_pv_streamlines.md +++ b/docs/_pages/user_guide/examples/examples_vis_pv_streamlines.md @@ -48,8 +48,8 @@ strl = grid.streamlines_from_source( # Create a Pyvista plotting scene pl = pv.Plotter() -# Add magnet to scene -magpy.show(magnet, canvas=pl, backend="pyvista") +# Add magnet to scene - streamlines units are assumed to be meters +magpy.show(magnet, canvas=pl, units_length="m", backend="pyvista") # Prepare legend parameters legend_args = { diff --git a/magpylib/__init__.py b/magpylib/__init__.py index f4d092c57..e04e76c49 100644 --- a/magpylib/__init__.py +++ b/magpylib/__init__.py @@ -28,7 +28,7 @@ """ # module level dunders -__version__ = "5.0.4" +__version__ = "5.1.0dev" __author__ = "Michael Ortner & Alexandre Boisselet" __credits__ = "The Magpylib community" __all__ = [ diff --git a/magpylib/_src/display/backend_matplotlib.py b/magpylib/_src/display/backend_matplotlib.py index 3aecfbdcb..206de393a 100644 --- a/magpylib/_src/display/backend_matplotlib.py +++ b/magpylib/_src/display/backend_matplotlib.py @@ -249,6 +249,7 @@ def display_matplotlib( """Display objects and paths graphically using the matplotlib library.""" frames = data["frames"] ranges = data["ranges"] + labels = data["labels"] fig_kwargs = {} if not fig_kwargs else fig_kwargs fig_kwargs = {"dpi": 80, **fig_kwargs} @@ -352,10 +353,11 @@ def draw_frame(frame_ind): for row_col_num, ax in axes.items(): count = count_with_labels.get(row_col_num, 0) if ax.name == "3d": - ax.set( - **{f"{k}label": f"{k} (m)" for k in "xyz"}, - **{f"{k}lim": r for k, r in zip("xyz", ranges)}, - ) + if row_col_num in ranges: + ax.set( + **{f"{k}label": labels[row_col_num][k] for k in "xyz"}, + **{f"{k}lim": r for k, r in zip("xyz", ranges[row_col_num])}, + ) ax.set_box_aspect(aspect=(1, 1, 1)) if 0 < count <= legend_maxitems: lg_kw = {"bbox_to_anchor": (1.04, 1), "loc": "upper left"} diff --git a/magpylib/_src/display/backend_plotly.py b/magpylib/_src/display/backend_plotly.py index 56efb9a97..263c0ce82 100644 --- a/magpylib/_src/display/backend_plotly.py +++ b/magpylib/_src/display/backend_plotly.py @@ -75,15 +75,17 @@ def match_args(ttype: str): return set(named_args) -def apply_fig_ranges(fig, ranges, apply2d=True): +def apply_fig_ranges(fig, ranges_rc, labels_rc, apply2d=True): """This is a helper function which applies the ranges properties of the provided `fig` object - according to a provided ranges. All three space direction will be equal and match the - maximum of the ranges needed to display all objects, including their paths. + according to a provided ranges for each subplot. All three space direction will be equal and + match the maximum of the ranges needed to display all objects, including their paths. Parameters ---------- - ranges: array of dimension=(3,2) + ranges_rc: dict of arrays of dimension=(3,2) min and max graph range + labels_rc: dict of dicts + contains a dict with 'x', 'y', 'z' keys and respective labels as strings for each subplot apply2d: bool, default = True applies fixed range also on 2d traces @@ -92,15 +94,27 @@ def apply_fig_ranges(fig, ranges, apply2d=True): ------- None: NoneType """ - fig.update_scenes( - **{ - f"{k}axis": {"range": ranges[i], "autorange": False, "title": f"{k} (m)"} - for i, k in enumerate("xyz") - }, - aspectratio={k: 1 for k in "xyz"}, - aspectmode="manual", - camera_eye={"x": 1, "y": -1.5, "z": 1.4}, - ) + for rc, ranges in ranges_rc.items(): + row, col = rc + labels = labels_rc.get(rc, {k: "" for k in "xyz"}) + kwargs = { + **{ + f"{k}axis": { + "range": ranges[i], + "autorange": False, + "title": labels[k], + } + for i, k in enumerate("xyz") + }, + "aspectratio": {k: 1 for k in "xyz"}, + "aspectmode": "manual", + "camera_eye": {"x": 1, "y": -1.5, "z": 1.4}, + } + + # pylint: disable=protected-access + if fig._grid_ref is not None: + kwargs.update({"row": row, "col": col}) + fig.update_scenes(**kwargs) if apply2d: apply_2d_ranges(fig) @@ -274,7 +288,6 @@ def process_extra_trace(model): def display_plotly( data, - zoom=1, canvas=None, renderer=None, return_fig=False, @@ -345,11 +358,13 @@ def display_plotly( rows=rows_list, cols=cols_list, ) - ranges = data["ranges"] + ranges_rc = data["ranges"] if extra_data: - ranges = get_scene_ranges(*frames[0]["data"], zoom=zoom) + ranges_rc = get_scene_ranges(*frames[0]["data"]) if update_layout: - apply_fig_ranges(fig, ranges, apply2d=isanimation) + apply_fig_ranges( + fig, ranges_rc, labels_rc=data["labels"], apply2d=isanimation + ) fig.update_layout( legend_itemsizing="constant", # legend_groupclick="toggleitem", diff --git a/magpylib/_src/display/display.py b/magpylib/_src/display/display.py index 946144312..e3f7445d3 100644 --- a/magpylib/_src/display/display.py +++ b/magpylib/_src/display/display.py @@ -11,14 +11,15 @@ from magpylib._src.defaults.defaults_utility import get_defaults_dict from magpylib._src.display.traces_generic import MagpyMarkers from magpylib._src.display.traces_generic import get_frames +from magpylib._src.display.traces_utility import DEFAULT_ROW_COL_PARAMS +from magpylib._src.display.traces_utility import linearize_dict from magpylib._src.display.traces_utility import process_show_input_objs from magpylib._src.input_checks import check_format_input_backend from magpylib._src.input_checks import check_format_input_vector from magpylib._src.input_checks import check_input_animation -from magpylib._src.input_checks import check_input_zoom from magpylib._src.utility import check_path_format -disp_args = get_defaults_dict("display").keys() +disp_args = set(get_defaults_dict("display")) class RegisteredBackend: @@ -30,14 +31,14 @@ def __init__( self, *, name, - show_func_getter, + show_func, supports_animation, supports_subplots, supports_colorgradient, supports_animation_output, ): self.name = name - self.show_func_getter = show_func_getter + self.show_func = show_func self.supports = { "animation": supports_animation, "subplots": supports_subplots, @@ -54,7 +55,6 @@ def show( cls, *objs, backend, - zoom=0, title=None, max_rows=None, max_cols=None, @@ -83,12 +83,18 @@ def show( f"\nFalling back to: {params}" ) kwargs.update(params) - frame_kwargs = { + display_kwargs = { k: v for k, v in kwargs.items() - if any(k.startswith(arg) for arg in disp_args) + if any(k.startswith(arg) for arg in disp_args - {"style"}) + } + style_kwargs = {k: v for k, v in kwargs.items() if k.startswith("style")} + style_kwargs = linearize_dict(style_kwargs, separator="_") + kwargs = { + k: v + for k, v in kwargs.items() + if (k not in display_kwargs and k not in style_kwargs) } - kwargs = {k: v for k, v in kwargs.items() if k not in frame_kwargs} backend_kwargs = { k[len(backend) + 1 :]: v for k, v in kwargs.items() @@ -117,13 +123,12 @@ def show( objs, supports_colorgradient=self.supports["colorgradient"], backend=backend, - zoom=zoom, title=title, - **frame_kwargs, + style_kwargs=style_kwargs, + **display_kwargs, ) - return self.show_func_getter()( + return self.show_func( data, - zoom=zoom, max_rows=max_rows, max_cols=max_cols, subplot_specs=subplot_specs, @@ -136,12 +141,9 @@ def show( def get_show_func(backend): """Return the backend show function""" # defer import to show call. Importerror should only fail if unavalaible backend is called - return lambda: getattr( + return lambda *args, backend=backend, **kwargs: getattr( import_module(f"magpylib._src.display.backend_{backend}"), f"display_{backend}" - ) - - -ROW_COL_SPECIFIC_NAMES = ("row", "col", "output", "sumup", "pixel_agg", "in_out") + )(*args, **kwargs) def infer_backend(canvas): @@ -182,7 +184,6 @@ def _show( *objects, backend=None, animation=False, - zoom=0, markers=None, **kwargs, ): @@ -193,9 +194,10 @@ def _show( # process input objs objects, obj_list_flat, max_rows, max_cols, subplot_specs = process_show_input_objs( - objects, **{k: v for k, v in kwargs.items() if k in ROW_COL_SPECIFIC_NAMES} + objects, + **{k: v for k, v in kwargs.items() if k in DEFAULT_ROW_COL_PARAMS}, ) - kwargs = {k: v for k, v in kwargs.items() if k not in ROW_COL_SPECIFIC_NAMES} + kwargs = {k: v for k, v in kwargs.items() if k not in DEFAULT_ROW_COL_PARAMS} kwargs["max_rows"], kwargs["max_cols"] = max_rows, max_cols kwargs["subplot_specs"] = subplot_specs @@ -204,7 +206,6 @@ def _show( # input checks backend = check_format_input_backend(backend) - check_input_zoom(zoom) check_input_animation(animation) check_format_input_vector( markers, @@ -216,15 +217,7 @@ def _show( ) if markers: - objects = [ - *objects, - { - "objects": [MagpyMarkers(*markers)], - "row": 1, - "col": 1, - "output": "model3d", - }, - ] + objects.append({"objects": [MagpyMarkers(*markers)], **DEFAULT_ROW_COL_PARAMS}) if backend == "auto": backend = infer_backend(kwargs.get("canvas", None)) @@ -232,7 +225,6 @@ def _show( return RegisteredBackend.show( backend=backend, *objects, - zoom=zoom, animation=animation, **kwargs, ) @@ -407,9 +399,9 @@ def show( } ) if ctx.isrunning: - rco = {k: v for k, v in kwargs.items() if k in ROW_COL_SPECIFIC_NAMES} + rco = {k: v for k, v in kwargs.items() if k in DEFAULT_ROW_COL_PARAMS} ctx.kwargs.update( - {k: v for k, v in kwargs.items() if k not in ROW_COL_SPECIFIC_NAMES} + {k: v for k, v in kwargs.items() if k not in DEFAULT_ROW_COL_PARAMS} ) ctx_objects = tuple({**o, **rco} for o in ctx.objects_from_ctx) objects, *_ = process_show_input_objs(ctx_objects + objects, **rco) @@ -452,11 +444,11 @@ def show_context( ) try: ctx.isrunning = True - rco = {k: v for k, v in kwargs.items() if k in ROW_COL_SPECIFIC_NAMES} + rco = {k: v for k, v in kwargs.items() if k in DEFAULT_ROW_COL_PARAMS} objects, *_ = process_show_input_objs(objects, **rco) ctx.objects_from_ctx += tuple(objects) ctx.kwargs.update( - {k: v for k, v in kwargs.items() if k not in ROW_COL_SPECIFIC_NAMES} + {k: v for k, v in kwargs.items() if k not in DEFAULT_ROW_COL_PARAMS} ) yield ctx ctx.show_return_value = _show(*ctx.objects, **ctx.kwargs) @@ -491,7 +483,7 @@ def reset(self, reset_show_return_value=True): RegisteredBackend( name="matplotlib", - show_func_getter=get_show_func("matplotlib"), + show_func=get_show_func("matplotlib"), supports_animation=True, supports_subplots=True, supports_colorgradient=False, @@ -501,7 +493,7 @@ def reset(self, reset_show_return_value=True): RegisteredBackend( name="plotly", - show_func_getter=get_show_func("plotly"), + show_func=get_show_func("plotly"), supports_animation=True, supports_subplots=True, supports_colorgradient=True, @@ -510,7 +502,7 @@ def reset(self, reset_show_return_value=True): RegisteredBackend( name="pyvista", - show_func_getter=get_show_func("pyvista"), + show_func=get_show_func("pyvista"), supports_animation=True, supports_subplots=True, supports_colorgradient=True, diff --git a/magpylib/_src/display/traces_generic.py b/magpylib/_src/display/traces_generic.py index b1d372297..37b1f1b36 100644 --- a/magpylib/_src/display/traces_generic.py +++ b/magpylib/_src/display/traces_generic.py @@ -20,17 +20,21 @@ from magpylib._src.defaults.defaults_utility import ALLOWED_SYMBOLS from magpylib._src.defaults.defaults_utility import linearize_dict from magpylib._src.display.traces_utility import draw_arrowed_line -from magpylib._src.display.traces_utility import get_flatten_objects_properties from magpylib._src.display.traces_utility import get_legend_label +from magpylib._src.display.traces_utility import get_objects_props_by_row_col from magpylib._src.display.traces_utility import get_rot_pos_from_path from magpylib._src.display.traces_utility import get_scene_ranges from magpylib._src.display.traces_utility import getColorscale from magpylib._src.display.traces_utility import getIntensity from magpylib._src.display.traces_utility import group_traces from magpylib._src.display.traces_utility import place_and_orient_model3d +from magpylib._src.display.traces_utility import rescale_traces from magpylib._src.display.traces_utility import slice_mesh_from_colorscale from magpylib._src.style import DefaultMarkers from magpylib._src.utility import format_obj_input +from magpylib._src.utility import get_unit_factor +from magpylib._src.utility import style_temp_edit +from magpylib._src.utility import unit_prefix class MagpyMarkers: @@ -202,13 +206,12 @@ def get_trace2D_dict( field_str, coords_str, obj_lst_str, - frame_focus_inds, + focus_inds, frames_indices, mode, label_suff, - color, - symbol, - linestyle, + units_polarization, + units_magnetization, **kwargs, ): """return a 2d trace based on field and parameters""" @@ -218,9 +221,14 @@ def get_trace2D_dict( y = y[0] else: y = np.linalg.norm(y, axis=0) - marker_size = np.array([2] * len(frames_indices)) - marker_size[frame_focus_inds] = 15 + marker_size = np.array([3] * len(frames_indices)) + marker_size[focus_inds] = 15 title = f"{field_str}{''.join(coords_str)}" + unit = ( + units_polarization + if field_str in "BJ" + else units_magnetization if field_str in "HM" else "" + ) trace = { "mode": "lines+markers", "legendgrouptitle_text": f"{title}" @@ -228,35 +236,35 @@ def get_trace2D_dict( "text": mode, "hovertemplate": ( "Path index: %{x} " - f"{title}: " + "%{y:.3s}T
" + f"{title}: %{{y:.3s}}{unit}
" f"{'sources'}:
{obj_lst_str['sources']}
" f"{'sensors'}:
{obj_lst_str['sensors']}" # "", ), "x": frames_indices, "y": y[frames_indices], - "line_dash": linestyle, - "line_color": color, "marker_size": marker_size, - "marker_color": color, - "marker_symbol": symbol, "showlegend": True, "legendgroup": f"{title}{label_suff}", - **kwargs, } + trace.update(kwargs) return trace -def get_generic_traces_2D( - *, - objects, +def get_traces_2D( + *objects, output=("Bx", "By", "Bz"), row=None, col=None, sumup=True, pixel_agg=None, in_out="auto", - style_path_frames=None, + styles=None, + units_polarization="T", + units_magnetization="A/m", + # pylint: disable=unused-argument + units_length="m", + zoom=0, ): """draws and animates sensor values over a path in a subplot""" # pylint: disable=import-outside-toplevel @@ -278,6 +286,7 @@ def get_generic_traces_2D( if not isinstance(output, (list, tuple)): output = [output] output_params = {} + field_str_list = [] for out, linestyle in zip(output, cycle(ALLOWED_LINESTYLES[:6])): field_str, *coords_str = out if not coords_str: @@ -285,31 +294,44 @@ def get_generic_traces_2D( if field_str not in "BHMJ" and set(coords_str).difference(set("xyz")): raise ValueError( "The `output` parameter must start with 'B', 'H', 'M', 'J' " - "and be followed by a combination of 'x', 'y', 'z' (e.g. 'Bxy' or ('Bxy', 'Hz') )" + "and be followed by a combination of 'x', 'y', 'z' (e.g. 'Bxy' or ('Bxy', 'Bz') )" f"\nreceived {out!r} instead" ) + field_str_list.append(field_str) output_params[out] = { "field_str": field_str, "coords_str": coords_str, - "linestyle": linestyle, + "line_dash": linestyle, } - BH_array = getBH_level2( - sources, - sensors, - sumup=sumup, - squeeze=False, - field=field_str, - pixel_agg=pixel_agg, - output="ndarray", - in_out=in_out, - ) - BH_array = BH_array.swapaxes(1, 2) # swap axes to have sensors first, path second - - frames_indices = np.arange(0, BH_array.shape[2]) - frame_focus_inds = [-1] if style_path_frames is None else style_path_frames - if isinstance(frame_focus_inds, numbers.Number): - # pylint: disable=invalid-unary-operand-type - frame_focus_inds = frames_indices[::-style_path_frames] + field_str_list = list(dict.fromkeys(field_str_list)) + BH_array = {} + for field_str in field_str_list: + BH_array[field_str] = getBH_level2( + sources, + sensors, + sumup=sumup, + squeeze=False, + field=field_str, + pixel_agg=pixel_agg, + output="ndarray", + in_out=in_out, + ) + # swap axes to have sensors first, path second + BH_array[field_str] = BH_array[field_str].swapaxes(1, 2) + frames_indices = np.arange(0, BH_array[field_str_list[0]].shape[2]) + + def get_focus_inds(*objs): + focus_inds = [] + for obj in objs: + style = styles.get(obj, obj.style) + frames = style.path.frames + inds = [] if frames is None else frames + if isinstance(inds, numbers.Number): + # pylint: disable=invalid-unary-operand-type + inds = frames_indices[::-frames] + focus_inds.extend(inds) + focus_inds = list(dict.fromkeys(focus_inds)) + return focus_inds if focus_inds else [-1] def get_obj_list_str(objs): if len(objs) < 8: @@ -320,7 +342,8 @@ def get_obj_list_str(objs): return obj_lst_str def get_label_and_color(obj): - style = obj.style + style = styles.get(obj, None) + style = obj.style if style is None else style label = get_legend_label(obj, style=style) color = getattr(style, "color", None) return label, color @@ -338,15 +361,18 @@ def get_label_and_color(obj): label_src, color_src = get_label_and_color(src) symbols = cycle(ALLOWED_SYMBOLS[:6]) for sens_ind, sens in enumerate(sensors): + focus_inds = get_focus_inds(src, sens) label_sens, color_sens = get_label_and_color(sens) label_suff = label_sens - if mode == "sensors": - label, color = label_src, color_src - else: + label = label_src + line_color = color_src + marker_color = color_sens if len(sensors) > 1 else None + if sumup: + line_color = color_sens + label = label_sens label_suff = ( f"{label_src}" if len(sources) == 1 else f"{len(sources)} sources" ) - label, color = label_sens, color_sens num_of_pix = ( len(sens.pixel.reshape(-1, 3)) if (not isinstance(sens, magpy.Collection)) @@ -357,29 +383,33 @@ def get_label_and_color(obj): pix_suff = "" num_of_pix_to_show = 1 if pixel_agg else num_of_pix for pix_ind in range(num_of_pix_to_show): - symbol = next(symbols) - BH = BH_array[src_ind, sens_ind, :, pix_ind] + marker_symbol = next(symbols) if num_of_pix > 1: if pixel_agg: - pix_suff = f" ({num_of_pix} pixels {pixel_agg})" + pix_suff = f" - {num_of_pix} pixels {pixel_agg}" else: - pix_suff = f" (pixel {pix_ind})" + pix_suff = f" - pixel {pix_ind}" for param in output_params.values(): + BH = BH_array[param["field_str"]][src_ind, sens_ind, :, pix_ind] traces.append( get_trace2D_dict( BH, **param, obj_lst_str=obj_lst_str, - frame_focus_inds=frame_focus_inds, + focus_inds=focus_inds, frames_indices=frames_indices, mode=mode, label_suff=label_suff, name=f"{label}{pix_suff}", - color=color, - symbol=symbol, + line_color=line_color, + marker_color=marker_color, + marker_line_color=marker_color, + marker_symbol=marker_symbol, type="scatter", row=row, col=col, + units_polarization=units_polarization, + units_magnetization=units_magnetization, ) ) return traces @@ -394,9 +424,10 @@ def process_extra_trace(model): "constructor": extr.constructor, "kwargs": model_kwargs, "args": model_args, + "coordsargs": extr.coordsargs, "kwargs_extra": model["kwargs_extra"], } - kwargs, args = place_and_orient_model3d( + kwargs, args, coordsargs = place_and_orient_model3d( model_kwargs=model_kwargs, model_args=model_args, orientation=model["orientation"], @@ -404,13 +435,15 @@ def process_extra_trace(model): coordsargs=extr.coordsargs, scale=extr.scale, return_model_args=True, + return_coordsargs=True, ) + trace3d["coordsargs"] = coordsargs trace3d["kwargs"].update(kwargs) trace3d["args"] = args return trace3d -def get_generic_traces( +def get_generic_traces3D( input_obj, autosize=None, legendgroup=None, @@ -441,7 +474,6 @@ def get_generic_traces( # pylint: disable=too-many-nested-blocks # pylint: disable=protected-access # pylint: disable=import-outside-toplevel - style = input_obj.style is_mag_arrows = False is_mag = hasattr(input_obj, "magnetization") and hasattr(style, "magnetization") @@ -497,24 +529,24 @@ def get_generic_traces( extr.update(extr.updatefunc()) # update before checking backend if extr.backend == "generic": extr.update(extr.updatefunc()) - tr_generic = {"opacity": style.opacity} + tr_non_generic = {"opacity": style.opacity} ttype = extr.constructor.lower() obj_extr_trace = extr.kwargs() if callable(extr.kwargs) else extr.kwargs obj_extr_trace = {"type": ttype, **obj_extr_trace} if ttype == "scatter3d": for k in ("marker", "line"): - tr_generic[f"{k}_color"] = tr_generic.get( + tr_non_generic[f"{k}_color"] = tr_non_generic.get( f"{k}_color", style.color ) elif ttype == "mesh3d": - tr_generic["showscale"] = tr_generic.get("showscale", False) - tr_generic["color"] = tr_generic.get("color", style.color) + tr_non_generic["showscale"] = tr_non_generic.get("showscale", False) + tr_non_generic["color"] = tr_non_generic.get("color", style.color) else: # pragma: no cover raise ValueError( f"{ttype} is not supported, only 'scatter3d' and 'mesh3d' are" ) - tr_generic.update(linearize_dict(obj_extr_trace, separator="_")) - traces_generic.append(tr_generic) + tr_non_generic.update(linearize_dict(obj_extr_trace, separator="_")) + traces_generic.append(tr_non_generic) if is_mag_arrows: mag = input_obj.magnetization @@ -544,6 +576,7 @@ def get_generic_traces( path_traces_generic = group_traces(*path_traces_generic) for tr in path_traces_generic: + tr.update(place_and_orient_model3d(tr)) tr.update(row=row, col=col) if tr.get("opacity", None) is None: tr["opacity"] = style.opacity @@ -573,7 +606,7 @@ def get_generic_traces( extr.update(extr.updatefunc()) # update before checking backend if extr.backend == extra_backend: for orient, pos in zip(orientations, positions): - tr_generic = { + tr_non_generic = { "model3d": extr, "position": pos, "orientation": orient, @@ -591,8 +624,8 @@ def get_generic_traces( "col": col, }, } - tr_generic = process_extra_trace(tr_generic) - path_traces_extra_non_generic_backend.append(tr_generic) + tr_non_generic = process_extra_trace(tr_non_generic) + path_traces_extra_non_generic_backend.append(tr_non_generic) out.update({extra_backend: path_traces_extra_non_generic_backend}) return out @@ -717,7 +750,34 @@ def extract_animation_properties( return path_indices, exp, frame_duration -def draw_frame(objs, colorsequence=None, zoom=0.0, autosize=None, **kwargs) -> Tuple: +def get_traces_3D(flat_objs_props, extra_backend=False, autosize=None, **kwargs): + """Return traces, traces to resize and extra_backend_traces""" + extra_backend_traces = [] + traces_dict = {} + for obj, params in flat_objs_props.items(): + params = {**params, **kwargs} + if autosize is None and getattr(obj, "_autosize", False): + # temporary coordinates to be able to calculate ranges + # pylint: disable=protected-access + x, y, z = obj._position.T + rc_dict = {k: v for k, v in params.items() if k in ("row", "col")} + traces_dict[obj] = [{"x": x, "y": y, "z": z, "_autosize": True, **rc_dict}] + else: + traces_dict[obj] = [] + with style_temp_edit(obj, style_temp=params.pop("style", None), copy=True): + out_traces = get_generic_traces3D( + obj, + extra_backend=extra_backend, + autosize=autosize, + **params, + ) + if extra_backend: + extra_backend_traces.extend(out_traces.get(extra_backend, [])) + traces_dict[obj].extend(out_traces["generic"]) + return traces_dict, extra_backend_traces + + +def draw_frame(objs, *, colorsequence, rc_params, style_kwargs, **kwargs) -> Tuple: """ Creates traces from input `objs` and provided parameters, updates the size of objects like Sensors and Dipoles in `kwargs` depending on the canvas size. @@ -732,89 +792,73 @@ def draw_frame(objs, colorsequence=None, zoom=0.0, autosize=None, **kwargs) -> T colorsequence = default_settings.display.colorsequence # dipoles and sensors use autosize, the trace building has to be put at the back of the queue. # autosize is calculated from the other traces overall scene range - - style_path_frames = kwargs.get( - "style_path_frames", [-1] - ) # get before next func strips style - flat_objs_props, kwargs = get_flatten_objects_properties( - *objs, colorsequence=colorsequence, **kwargs - ) - traces_dict, traces_to_resize_dict, extra_backend_traces = get_row_col_traces( - flat_objs_props, **kwargs - ) - traces = [t for tr in traces_dict.values() for t in tr] - ranges = get_scene_ranges(*traces, *extra_backend_traces, zoom=zoom) - if autosize is None or autosize == "return": - # pylint: disable=no-member - autosize = np.mean(np.diff(ranges)) / default_settings.display.autosizefactor - - traces_dict_2, _, extra_backend_traces2 = get_row_col_traces( - traces_to_resize_dict, autosize=autosize, **kwargs + objs_rc = get_objects_props_by_row_col( + *objs, + colorsequence=colorsequence, + style_kwargs=style_kwargs, ) - traces_dict.update(traces_dict_2) - extra_backend_traces.extend(extra_backend_traces2) + traces_dict = {} + extra_backend_traces = [] + rc_params = {} if rc_params is None else rc_params + for rc, props in objs_rc.items(): + if props["rc_params"]["output"] == "model3d": + rc_params[rc] = rc_params.get(rc, {}) + rc_params[rc]["units_length"] = props["rc_params"]["units_length"] + rc_keys = ("row", "col") + rc_kwargs = {k: v for k, v in props["rc_params"].items() if k in rc_keys} + traces_d1, traces_ex1 = get_traces_3D( + props["objects"], **rc_kwargs, **kwargs + ) + rc_params[rc]["autosize"] = rc_params.get(rc, {}).get("autosize", None) + if rc_params[rc]["autosize"] is None: + zoom = rc_params[rc]["zoom"] = props["rc_params"]["zoom"] + traces = [t for tr in traces_d1.values() for t in tr] + ranges_rc = get_scene_ranges(*traces, *traces_ex1, zoom=zoom) + # pylint: disable=no-member + factor = default_settings.display.autosizefactor + rc_params[rc]["autosize"] = np.mean(np.diff(ranges_rc[rc])) / factor + to_resize_keys = { + k for k, v in traces_d1.items() if v and "_autosize" in v[0] + } + flat_objs_props = { + k: v for k, v in props["objects"].items() if k in to_resize_keys + } + traces_d2, traces_ex2 = get_traces_3D( + flat_objs_props, + autosize=rc_params[rc]["autosize"], + **rc_kwargs, + **kwargs, + ) + traces_dict.update( + {(k, *rc): v for k, v in {**traces_d1, **traces_d2}.items()} + ) + extra_backend_traces.extend([*traces_ex1, *traces_ex2]) traces = group_traces(*[t for tr in traces_dict.values() for t in tr]) - obj_list_2d = [o for o in objs if o["output"] != "model3d"] - for objs_2d in obj_list_2d: - traces2d = get_generic_traces_2D( - **objs_2d, - style_path_frames=style_path_frames, - ) - traces.extend(traces2d) - return traces, autosize, ranges, extra_backend_traces - -def get_row_col_traces(flat_objs_props, extra_backend=False, autosize=None, **kwargs): - """Return traces, traces to resize and extra_backend_traces""" - # pylint: disable=protected-access - extra_backend_traces = [] - traces_dict = {} - traces_to_resize_dict = {} - for obj, params in flat_objs_props.items(): - params.update(kwargs) - if autosize is None and getattr(obj, "_autosize", False): - traces_to_resize_dict[obj] = {**params} - # temporary coordinates to be able to calculate ranges - x, y, z = obj._position.T - traces_dict[obj] = [{"x": x, "y": y, "z": z}] - else: - traces_dict[obj] = [] - rco_obj = params.pop("row_cols") - orig_style = getattr(obj, "_style", None) - try: - style_temp = params.pop("style", None) - for rco in rco_obj: - # temporary replace style attribute - obj._style = style_temp - if len(rco_obj) >= 2 and style_temp: - # deepcopy style only if obj is in multiple subplots. - obj._style = style_temp.copy() - params["row"], params["col"], output_typ = rco - if output_typ == "model3d": - out_traces = get_generic_traces( - obj, - extra_backend=extra_backend, - autosize=autosize, - **params, - ) - if extra_backend: - extra_backend_traces.extend( - out_traces.get(extra_backend, []) - ) - traces_dict[obj].extend(out_traces["generic"]) - finally: - obj._style = orig_style - return traces_dict, traces_to_resize_dict, extra_backend_traces + styles = { + obj: params.get("style", None) + for o_rc in objs_rc.values() + for obj, params in o_rc["objects"].items() + } + for props in objs_rc.values(): + if props["rc_params"]["output"] != "model3d": + traces2d = get_traces_2D( + *props["objects"], + **props["rc_params"], + styles=styles, + ) + traces.extend(traces2d) + return traces, extra_backend_traces, rc_params def get_frames( objs, colorsequence=None, - zoom=1, title=None, animation=False, supports_colorgradient=True, backend="generic", + style_kwargs=None, **kwargs, ): """This is a helper function which generates frames with generic traces to be provided to @@ -844,25 +888,26 @@ def get_frames( ) # create frame for each path index or downsampled path index frames = [] - autosize = "return" + title_str = title + rc_params = {} for i, ind in enumerate(path_indices): extra_backend_traces = [] if animation: - kwargs["style_path_frames"] = [ind] + style_kwargs["style_path_frames"] = [ind] title = "Animation 3D - " if title is None else title title_str = f"""{title}path index: {ind+1:0{exp}d}""" - traces, autosize_init, ranges, extra_backend_traces = draw_frame( + traces, extra_backend_traces, rc_params_temp = draw_frame( objs, - colorsequence, - zoom, - autosize=autosize, + colorsequence=colorsequence, + rc_params=rc_params, supports_colorgradient=supports_colorgradient, extra_backend=backend, + style_kwargs=style_kwargs, **kwargs, ) if i == 0: # get the dipoles and sensors autosize from first frame - autosize = autosize_init + rc_params = rc_params_temp frames.append( { "data": traces, @@ -871,13 +916,30 @@ def get_frames( "extra_backend_traces": extra_backend_traces, } ) - clean_legendgroups(frames) traces = [t for frame in frames for t in frame["data"]] - ranges = get_scene_ranges(*traces, *extra_backend_traces, zoom=zoom) + zoom = {rc: v["zoom"] for rc, v in rc_params.items()} + ranges_rc = get_scene_ranges(*traces, *extra_backend_traces, zoom=zoom) + labels_rc = {(1, 1): {k: "" for k in "xyz"}} + scale_factors_rc = {} + for rc, params in rc_params.items(): + units_length = params["units_length"] + if units_length == "auto": + rmax = np.amax(np.abs(ranges_rc[rc])) + units_length = f"{unit_prefix(rmax, as_tuple=True)[2]}m" + unit_str = "" if not (units_length) else f" ({units_length})" + labels_rc[rc] = {k: f"{k}{unit_str}" for k in "xyz"} + scale_factors_rc[rc] = get_unit_factor(units_length, target_unit="m") + ranges_rc[rc] *= scale_factors_rc[rc] + + for frame in frames: + for key in ("data", "extra_backend_traces"): + frame[key] = rescale_traces(frame[key], factors=scale_factors_rc) + out = { "frames": frames, - "ranges": ranges, + "ranges": ranges_rc, + "labels": labels_rc, "input_kwargs": {**kwargs, **animation_kwargs}, } if animation: diff --git a/magpylib/_src/display/traces_utility.py b/magpylib/_src/display/traces_utility.py index d0851986e..63013fc5e 100644 --- a/magpylib/_src/display/traces_utility.py +++ b/magpylib/_src/display/traces_utility.py @@ -3,6 +3,7 @@ # pylint: disable=too-many-branches from collections import defaultdict from functools import lru_cache +from itertools import chain from itertools import cycle from typing import Tuple @@ -11,8 +12,21 @@ from magpylib._src.defaults.defaults_classes import default_settings from magpylib._src.defaults.defaults_utility import linearize_dict +from magpylib._src.input_checks import check_input_zoom from magpylib._src.style import get_style from magpylib._src.utility import format_obj_input +from magpylib._src.utility import merge_dicts_with_conflict_check + +DEFAULT_ROW_COL_PARAMS = { + "row": 1, + "col": 1, + "output": "model3d", + "sumup": True, + "pixel_agg": "mean", + "in_out": "auto", + "zoom": 0, + "units_length": "auto", +} def get_legend_label(obj, style=None, suffix=True): @@ -34,49 +48,55 @@ def get_legend_label(obj, style=None, suffix=True): def place_and_orient_model3d( model_kwargs, + *, model_args=None, orientation=None, position=None, coordsargs=None, scale=1, return_model_args=False, + return_coordsargs=False, + length_factor=1, **kwargs, ): """places and orients mesh3d dict""" - if orientation is None and position is None: - return {**model_kwargs, **kwargs} - position = (0.0, 0.0, 0.0) if position is None else position - position = np.array(position, dtype=float) - new_model_dict = {} - if model_args is None: - model_args = () - new_model_args = list(model_args) - vertices, coordsargs, useargs = get_vertices_from_model( - model_kwargs, model_args, coordsargs - ) - - # sometimes traces come as (n,m,3) shape - vert_shape = vertices.shape - vertices = np.reshape(vertices, (3, -1)) - - vertices = vertices.T - - if orientation is not None: - vertices = orientation.apply(vertices) - new_vertices = (vertices * scale + position).T - new_vertices = np.reshape(new_vertices, vert_shape) - for i, k in enumerate("xyz"): - key = coordsargs[k] - if useargs: - ind = int(key[5]) - new_model_args[ind] = new_vertices[i] - else: - new_model_dict[key] = new_vertices[i] - new_model_kwargs = {**model_kwargs, **new_model_dict, **kwargs} + if orientation is None and position is None and length_factor == 1: + new_model_kwargs = {**model_kwargs, **kwargs} + new_model_args = model_args + else: + position = (0.0, 0.0, 0.0) if position is None else position + position = np.array(position, dtype=float) + new_model_dict = {} + if model_args is None: + model_args = () + new_model_args = list(model_args) + vertices, coordsargs, useargs = get_vertices_from_model( + model_kwargs, model_args, coordsargs + ) + # sometimes traces come as (n,m,3) shape + vert_shape = vertices.shape + vertices = np.reshape(vertices.astype(float), (3, -1)) + + vertices = vertices.T + + if orientation is not None: + vertices = orientation.apply(vertices) + new_vertices = (vertices * scale + position).T * length_factor + new_vertices = np.reshape(new_vertices, vert_shape) + for i, k in enumerate("xyz"): + key = coordsargs[k] + if useargs: + ind = int(key[5]) + new_model_args[ind] = new_vertices[i] + else: + new_model_dict[key] = new_vertices[i] + new_model_kwargs = {**model_kwargs, **new_model_dict, **kwargs} out = (new_model_kwargs,) if return_model_args: out += (new_model_args,) + if return_coordsargs: + out += (coordsargs,) return out[0] if len(out) == 1 else out @@ -244,42 +264,52 @@ def get_rot_pos_from_path(obj, show_path=None): return rots, poss, inds -def get_flatten_objects_properties(*objs, colorsequence, **kwargs): +def get_objects_props_by_row_col(*objs, colorsequence, style_kwargs): """Return flat dict with objs as keys object properties as values. Properties include: row_cols, style, legendgroup, legendtext""" - flat_objs = {} + flat_objs_rc = {} + rc_params_by_obj = {} + obj_list_semi_flat = [o for obj in objs for o in obj["objects"]] for obj in objs: - flat_sub_objs = get_flatten_objects_properties_recursive( - *obj["objects"], colorsequence=colorsequence, **kwargs - ) - for subobj, props in flat_sub_objs.items(): - if subobj in flat_objs: - props["row_cols"] = flat_objs[subobj]["row_cols"] - elif "row_cols" not in props: - props["row_cols"] = [] - props["row_cols"].extend([(obj["row"], obj["col"], obj["output"])]) - flat_objs.update(flat_sub_objs) - kwargs = {k: v for k, v in kwargs.items() if not k.startswith("style")} - return flat_objs, kwargs + rc_params = {k: v for k, v in obj.items() if k != "objects"} + for subobj in obj["objects"]: + children = getattr(subobj, "children_all", []) + for child in chain([subobj], children): + if child not in rc_params_by_obj: + rc_params_by_obj[child] = [] + rc_params_by_obj[child].append(rc_params) + flat_sub_objs = get_flatten_objects_properties_recursive( + *obj_list_semi_flat, + style_kwargs=style_kwargs, + colorsequence=colorsequence, + ) + for obj, rc_params_list in rc_params_by_obj.items(): + for rc_params in rc_params_list: + rc = rc_params["row"], rc_params["col"] + if rc not in flat_objs_rc: + flat_objs_rc[rc] = {"objects": {}, "rc_params": rc_params} + flat_objs_rc[rc]["objects"][obj] = flat_sub_objs[obj] + return flat_objs_rc def get_flatten_objects_properties_recursive( *obj_list_semi_flat, + style_kwargs=None, colorsequence=None, color_cycle=None, parent_legendgroup=None, parent_color=None, parent_label=None, parent_showlegend=None, - **kwargs, ): """returns a flat dict -> (obj: display_props, ...) from nested collections""" if color_cycle is None: color_cycle = cycle(colorsequence) flat_objs = {} - for subobj in obj_list_semi_flat: + for subobj in dict.fromkeys(obj_list_semi_flat): isCollection = getattr(subobj, "children", None) is not None - style = get_style(subobj, default_settings, **kwargs) + style_kwargs = {} if style_kwargs is None else style_kwargs + style = get_style(subobj, default_settings, **style_kwargs) if style.label is None: style.label = str(type(subobj).__name__) legendgroup = f"{subobj}" if parent_legendgroup is None else parent_legendgroup @@ -299,18 +329,17 @@ def get_flatten_objects_properties_recursive( "showlegend": parent_showlegend, } if isCollection: - flat_objs.update( - get_flatten_objects_properties_recursive( - *subobj.children, - colorsequence=colorsequence, - color_cycle=color_cycle, - parent_legendgroup=legendgroup, - parent_color=style.color, - parent_label=label, - parent_showlegend=style.legend.show, - **kwargs, - ) + new_ojbs = get_flatten_objects_properties_recursive( + *subobj.children, + colorsequence=colorsequence, + color_cycle=color_cycle, + parent_legendgroup=legendgroup, + parent_color=style.color, + parent_label=label, + parent_showlegend=style.legend.show, + style_kwargs=style_kwargs, ) + flat_objs = {**new_ojbs, **flat_objs} return flat_objs @@ -467,38 +496,47 @@ def getColorscale( return colorscale -def get_scene_ranges(*traces, zoom=1) -> np.ndarray: +def get_scene_ranges(*traces, zoom=0) -> np.ndarray: """ Returns 3x2 array of the min and max ranges in x,y,z directions of input traces. Traces can be any plotly trace object or a dict, with x,y,z numbered parameters. """ - trace3d_found = False - if traces: - ranges = {k: [] for k in "xyz"} - for tr in traces: - coords = "xyz" - if "constructor" in tr: - verts, *_ = get_vertices_from_model( - model_args=tr.get("args", None), - model_kwargs=tr.get("kwargs", None), - coordsargs=tr.get("coordsargs", None), - ) - tr = dict(zip("xyz", verts)) - if "z" in tr: # only extend range for 3d traces - trace3d_found = True - pts = np.array([tr[k] for k in coords], dtype="float64").T - try: # for mesh3d, use only vertices part of faces for range calculation - inds = np.array([tr[k] for k in "ijk"], dtype="int64").T - pts = pts[inds] - except KeyError: - # for 2d meshes, nothing special needed - pass - pts = pts.reshape(-1, 3) - if pts.size != 0: - min_max = np.nanmin(pts, axis=0), np.nanmax(pts, axis=0) - for v, min_, max_ in zip(ranges.values(), *min_max): - v.extend([min_, max_]) - if trace3d_found: + ranges_rc = {} + tr_dim_count = {} + for tr in traces: + coords = "xyz" + rc = tr.get("row", 1), tr.get("col", 1) + if "constructor" in tr: + verts, *_ = get_vertices_from_model( + model_args=tr.get("args", None), + model_kwargs=tr.get("kwargs", None), + coordsargs=tr.get("coordsargs", None), + ) + kwex = tr["kwargs_extra"] + tr = dict(zip("xyz", verts)) + rc = kwex["row"], kwex["col"] + if rc not in ranges_rc: + ranges_rc[rc] = {k: [] for k in "xyz"} + tr_dim_count[rc] = {"2D": 0, "3D": 0} + if "z" not in tr: # only extend range for 3d traces + tr_dim_count[rc]["2D"] += 1 + else: + tr_dim_count[rc]["3D"] += 1 + pts = np.array([tr[k] for k in coords], dtype="float64").T + try: # for mesh3d, use only vertices part of faces for range calculation + inds = np.array([tr[k] for k in "ijk"], dtype="int64").T + pts = pts[inds] + except KeyError: + # for 2d meshes, nothing special needed + pass + pts = pts.reshape(-1, 3) + if pts.size != 0: + min_max = np.nanmin(pts, axis=0), np.nanmax(pts, axis=0) + for v, min_, max_ in zip(ranges_rc[rc].values(), *min_max): + v.extend([min_, max_]) + for rc, ranges in ranges_rc.items(): + if tr_dim_count[rc]["3D"]: + zo = zoom[rc] if isinstance(zoom, dict) else zoom # SET 3D PLOT BOUNDARIES # collect min/max from all elements r = np.array([[np.nanmin(v), np.nanmax(v)] for v in ranges.values()]) @@ -506,10 +544,34 @@ def get_scene_ranges(*traces, zoom=1) -> np.ndarray: m = size.max() / 2 m = 1 if m == 0 else m center = r.mean(axis=1) - ranges = np.array([center - m * (1 + zoom), center + m * (1 + zoom)]).T - if not traces or not trace3d_found: - ranges = np.array([[-1.0, 1.0]] * 3) - return ranges + ranges = np.array([center - m * (1 + zo), center + m * (1 + zo)]).T + else: + ranges = np.array([[-1.0, 1.0]] * 3) + ranges_rc[rc] = ranges + if not ranges_rc: + ranges_rc[(1, 1)] = np.array([[-1.0, 1.0]] * 3) + return ranges_rc + + +def rescale_traces(traces, factors): + """Rescale traces based on scale factors by (row,col) index""" + for ind, tr in enumerate(traces): + if "constructor" in tr: + kwex = tr["kwargs_extra"] + rc = kwex["row"], kwex["col"] + kwargs, args = place_and_orient_model3d( + model_kwargs=tr.get("kwargs", None), + model_args=tr.get("args", None), + coordsargs=tr.get("coordsargs", None), + length_factor=factors[rc], + return_model_args=True, + ) + tr["kwargs"].update(kwargs) + tr["args"] = args + if "z" in tr: # rescale only 3d traces + rc = tr.get("row", 1), tr.get("col", 1) + traces[ind] = place_and_orient_model3d(tr, length_factor=factors[rc]) + return traces def group_traces(*traces): @@ -578,57 +640,52 @@ def subdivide_mesh_by_facecolor(trace): def process_show_input_objs(objs, **kwargs): """Extract max_rows and max_cols from obj list of dicts""" - defaults = { - "row": 1, - "col": 1, - "output": "model3d", - "sumup": True, - "pixel_agg": "mean", - "in_out": "auto", - } - max_rows = max_cols = 1 - flat_objs = [] - new_objs = {} - subplot_specs = {} + defaults = DEFAULT_ROW_COL_PARAMS.copy() + identifiers = ("row", "col") + unique_fields = tuple(k for k in defaults if k not in identifiers) + sources_and_sensors_only = [] + new_objs = [] for obj in objs: + # add missing kwargs if isinstance(obj, dict): obj = {**defaults, **obj, **kwargs} else: obj = {**defaults, "objects": obj, **kwargs} + # extend objects list obj["objects"] = format_obj_input( obj["objects"], allow="sources+sensors+collections" ) - flat_objs.extend(format_obj_input(obj["objects"], allow="sources+sensors")) - if obj["row"] is not None: - max_rows = max(max_rows, obj["row"]) - if obj["col"] is not None: - max_cols = max(max_cols, obj["col"]) - out = obj["output"] - key = (obj["row"], obj["col"], out if isinstance(out, str) else tuple(out)) - if key not in new_objs: - new_objs[key] = obj - else: - new_objs[key]["objects"] = list( - dict.fromkeys(new_objs[key]["objects"] + obj["objects"]) - ) - current_subplot_specs = subplot_specs.get(key[:2], obj["output"]) - if current_subplot_specs != obj["output"]: - raise ValueError( - f"Row/Col {key[:2]}, received conflicting output types " - f"{current_subplot_specs!r} vs {obj['output']!r}" - ) - subplot_specs[key[:2]] = obj["output"] + sources_and_sensors_only.extend( + format_obj_input(obj["objects"], allow="sources+sensors") + ) + new_objs.append(obj) + row_col_dict = merge_dicts_with_conflict_check( + new_objs, + target="objects", + identifiers=identifiers, + unique_fields=unique_fields, + ) + + # create subplot specs grid + row_cols = [*row_col_dict] + max_rows, max_cols = np.max(row_cols, axis=0).astype(int) if row_cols else (1, 1) + # convert to int to avoid np.int32 type conflicting with plolty subplot specs + max_rows, max_cols = int(max_rows), int(max_cols) specs = np.array([[{"type": "scene"}] * max_cols] * max_rows) - for inds, out in subplot_specs.items(): - if out != "model3d": - specs[inds[0] - 1, inds[1] - 1] = {"type": "xy"} + for rc, obj in row_col_dict.items(): + if obj["output"] != "model3d": + specs[rc[0] - 1, rc[1] - 1] = {"type": "xy"} if max_rows == 1 and max_cols == 1: max_rows = max_cols = None + + for obj in row_col_dict.values(): + check_input_zoom(obj.get("zoom", None)) + return ( - list(new_objs.values()), - list(dict.fromkeys(flat_objs)), + list(row_col_dict.values()), + list(dict.fromkeys(sources_and_sensors_only)), max_rows, max_cols, specs, diff --git a/magpylib/_src/input_checks.py b/magpylib/_src/input_checks.py index 434c6a8b4..54b2a029d 100644 --- a/magpylib/_src/input_checks.py +++ b/magpylib/_src/input_checks.py @@ -69,14 +69,9 @@ def check_array_shape(inp: np.ndarray, dims: tuple, shape_m1: int, length=None, def check_input_zoom(inp): """check show zoom input""" - if not isinstance(inp, numbers.Number): - raise MagpylibBadUserInput( - "Input parameter `zoom` must be a number `zoom>=0`.\n" - f"Instead received {inp}." - ) - if inp < 0: + if not (isinstance(inp, numbers.Number) and inp >= 0): raise MagpylibBadUserInput( - "Input parameter `zoom` must be a number `zoom>=0`.\n" + "Input parameter `zoom` must be a positive number or zero.\n" f"Instead received {inp}." ) diff --git a/magpylib/_src/utility.py b/magpylib/_src/utility.py index d62d377c1..89a24fcff 100644 --- a/magpylib/_src/utility.py +++ b/magpylib/_src/utility.py @@ -3,6 +3,7 @@ # pylint: disable=import-outside-toplevel # pylint: disable=cyclic-import # import numbers +from contextlib import contextmanager from functools import lru_cache from inspect import signature from math import log10 @@ -237,8 +238,36 @@ def filter_objects(obj_list, allow="sources+sensors", warn=True): 24: "Y", # yotta } +_UNIT_PREFIX_REVERSED = {v: k for k, v in _UNIT_PREFIX.items()} -def unit_prefix(number, unit="", precision=3, char_between="") -> str: + +@lru_cache(maxsize=None) +def get_unit_factor(unit_input, *, target_unit, deci_centi=True): + """return unit factor based on input and target unit""" + if unit_input is None or unit_input == target_unit: + return 1 + pref, suff, factor_power = "", "", None + prefs = _UNIT_PREFIX_REVERSED + if deci_centi: + prefs = {**_UNIT_PREFIX_REVERSED, "d": -1, "c": -2} + unit_input_str = str(unit_input) + if unit_input_str: + if len(unit_input_str) >= 2: + pref, *suff = unit_input_str + suff = "".join(suff) + if suff == target_unit: + factor_power = prefs.get(pref, None) + + if factor_power is None or len(unit_input_str) > 2: + valid_inputs = [f"{k}{target_unit}" for k in prefs] + raise ValueError( + f"Invalid unit input ({unit_input!r}), must be one of {valid_inputs}" + ) + factor = 1 / (10**factor_power) + return factor + + +def unit_prefix(number, unit="", precision=3, char_between="", as_tuple=False) -> str: """ displays a number with given unit and precision and uses unit prefixes for the exponents from yotta (y) to Yocto (Y). If the exponent is smaller or bigger, falls back to scientific notation. @@ -253,10 +282,13 @@ def unit_prefix(number, unit="", precision=3, char_between="") -> str: char_between : str, optional character to insert between number of prefix. Can be " " or any string, if a space is wanted before the unit symbol , by default "" + as_tuple: bool, optional + if True returns (new_number_str, char_between, prefix, unit) tuple + else returns the joined string Returns ------- - str - returns formatted number as string + str or tuple + returns formatted number as string or tuple """ digits = int(log10(abs(number))) // 3 * 3 if number != 0 else 0 prefix = _UNIT_PREFIX.get(digits, "") @@ -264,7 +296,10 @@ def unit_prefix(number, unit="", precision=3, char_between="") -> str: if prefix == "": digits = 0 new_number_str = f"{number / 10 ** digits:.{precision}g}" - return f"{new_number_str}{char_between}{prefix}{unit}" + res = (new_number_str, char_between, prefix, unit) + if as_tuple: + return res + return "".join(f"{v}" for v in res) def add_iteration_suffix(name): @@ -398,3 +433,80 @@ def has_parameter(func: Callable, param_name: str) -> bool: """Check if input function has a specific parameter""" sig = signature(func) return param_name in sig.parameters + + +def merge_dicts_with_conflict_check(objs, *, target, identifiers, unique_fields): + """ + Merge dictionaries ensuring unique identifier fields don't lead to conflict. + + Parameters + ---------- + objs : list of dicts + List of dictionaries to be merged based on identifier fields. + target : str + The key in the dictionaries whose values are lists to be merged. + identifiers : list of str + Keys used to identify a unique dictionary. + unique_fields : list of str + Additional keys that must not conflict across merged dictionaries. + + Returns + ------- + dict of dicts + Merged dictionaries with combined `target` lists, ensuring no conflicts + in `unique_fields`. + + Raises + ------ + ValueError + If a conflict is detected in `unique_fields` for any `identifiers`. + + Notes + ----- + `objs` should be a list of dictionaries. Identifiers determine uniqueness, + and merging is done by extending the lists in the `target` key. If any of + the `unique_fields` conflict with previously tracked identifiers, a + `ValueError` is raised detailing the conflict. + + """ + merged_dict = {} + tracker = {} + for obj in objs: + key_dict = {k: obj[k] for k in identifiers} + key = tuple(key_dict.values()) + tracker_previous = tracker.get(key, None) + tracker_actual = tuple(obj[field] for field in unique_fields) + if key in tracker and tracker_previous != tracker_actual: + diff = [ + f"{f!r} first got {a!r} then {t!r}" + for f, a, t in zip(unique_fields, tracker_actual, tracker_previous) + if a != t + ] + raise ValueError( + f"Conflicting parameters detected for {key_dict}: {', '.join(diff)}." + ) + tracker[key] = tracker_actual + + if key not in merged_dict: + merged_dict[key] = obj + else: + merged_dict[key][target] = list( + dict.fromkeys([*merged_dict[key][target], *obj[target]]) + ) + return merged_dict + + +@contextmanager +def style_temp_edit(obj, style_temp, copy=True): + """Temporary replace style to allow edits before returning to original state""" + # pylint: disable=protected-access + orig_style = getattr(obj, "_style", None) + try: + # temporary replace style attribute + obj._style = style_temp + if style_temp and copy: + # deepcopy style only if obj is in multiple subplots. + obj._style = style_temp.copy() + yield + finally: + obj._style = orig_style diff --git a/tests/test_display_matplotlib.py b/tests/test_display_matplotlib.py index 8a4fc464a..890aabfbd 100644 --- a/tests/test_display_matplotlib.py +++ b/tests/test_display_matplotlib.py @@ -422,8 +422,7 @@ def test_matplotlib_model3d_extra_updatefunc(): def test_empty_display(): """should not fail if nothing to display""" - ax = plt.subplot(projection="3d") - magpy.show(canvas=ax, backend="matplotlib", return_fig=True) + magpy.show(backend="matplotlib", return_fig=True) def test_graphics_model_mpl(): @@ -523,7 +522,10 @@ def test_bad_show_inputs(): ) with pytest.raises( ValueError, - match=r"Row/Col .* received conflicting output types.*", + match=( + r"Conflicting parameters detected for {'row': 1, 'col': 1}:" + r" 'output' first got 'model3d' then 'Bx'." + ), ): with magpy.show_context(animation=False, sumup=True, pixel_agg="mean") as s: s.show(cyl1, sensor, col=1, output="Bx") @@ -615,3 +617,13 @@ def test_show_legend(): s2.style.legend = "full legend replace" s3.style.description = "description replace only" magpy.show(s1, s2, s3, return_fig=True) + + +@pytest.mark.parametrize("units_length", ["mT", "inch", "dam", "e"]) +def test_bad_units_length(units_length): + """test units lenghts""" + + c = magpy.magnet.Cuboid(polarization=(0, 0, 1), dimension=(1, 1, 1)) + + with pytest.raises(ValueError, match=r"Invalid unit input.*"): + c.show(units_length=units_length, return_fig=True, backend="matplotlib") diff --git a/tests/test_display_plotly.py b/tests/test_display_plotly.py index 93dc908d2..60a4d682a 100644 --- a/tests/test_display_plotly.py +++ b/tests/test_display_plotly.py @@ -4,6 +4,7 @@ import magpylib as magpy from magpylib._src.exceptions import MagpylibBadUserInput +from magpylib._src.utility import get_unit_factor # pylint: disable=assignment-from-no-return # pylint: disable=no-member @@ -393,3 +394,94 @@ def test_legends(): assert [t.name for t in fig.data] == ["Marker"] assert [t.showlegend for t in fig.data] == [False] + + +def test_color_precedence(): + """Test if color precedence is respected when calling in nested collections""" + c1 = magpy.magnet.Cuboid(polarization=(0, 0, 1), dimension=(1, 1, 1)) + c2 = c1.copy(position=(1, 0, 0)) + c3 = c1.copy(position=(2, 0, 0)) + coll = magpy.Collection(c1, magpy.Collection(c2, c3)) + kw = { + "backend": "plotly", + "style_magnetization_show": False, + "colorsequence": ["red", "blue", "green"], + "return_fig": True, + } + fig = magpy.show(coll, **kw) + assert [tr["color"] for tr in fig.data] == ["red"] + + fig = magpy.show(*coll, **kw) + assert [tr["color"] for tr in fig.data] == ["red", "blue"] + + fig = magpy.show(*coll.sources_all, **kw) + assert [tr["color"] for tr in fig.data] == ["red", "blue", "green"] + + fig = magpy.show({"objects": c1, "col": 1}, {"objects": c1, "col": 2}, **kw) + # sane obj in different subplot should have same color + assert [tr["color"] for tr in fig.data] == ["red", "red"] + + +def test_colors_output2d(): + """Tests if lines have objects corresponding colors in ouptut=Bx, By...""" + l1 = magpy.current.Circle( + current=1, + diameter=1, + style_label="L1", + style_arrow_show=False, + ) + l2 = l1.copy(diameter=2) + s1 = magpy.Sensor( + pixel=[[0, 0, 0], [0, 1, 0]], + position=np.linspace((-1, 0, 1), (1, 0, 1), 10), + style_label="S", + style_model3d_showdefault=False, + ) + s2 = s1.copy().move((0, 0, 1)) + objs = {"objects": [l1, l2, s1, s2]} + kw = { + "return_fig": True, + "colorsequence": ["red", "blue", "green", "cyan"], + } + kw2d = {"output": "Bx", "col": 2} + + def get_scatters2d(fig): + return [t.line.color for t in fig.data if t.type == "scatter"] + + fig = magpy.show(objs, {**objs, **kw2d, "sumup": True}, **kw) + assert get_scatters2d(fig) == ["green", "cyan"] + + fig = magpy.show(objs, {**objs, **kw2d, "sumup": True, "pixel_agg": None}, **kw) + assert get_scatters2d(fig) == [*["green"] * 2, *["cyan"] * 2] + + fig = magpy.show(objs, {**objs, **kw2d, "sumup": False}, **kw) + assert get_scatters2d(fig) == [*["red"] * 2, *["blue"] * 2] + + fig = magpy.show(objs, {**objs, **kw2d, "sumup": False, "pixel_agg": None}, **kw) + assert get_scatters2d(fig) == [*["red"] * 4, *["blue"] * 4] + + +def test_units_length(): + """test units lenghts""" + + dims = (1, 2, 3) + c1 = magpy.magnet.Cuboid(dimension=dims, polarization=(1, 2, 3)) + inputs = [ + {"objects": c1, "row": 1, "col": 1, "units_length": "m", "zoom": 3}, + {"objects": c1, "row": 1, "col": 2, "units_length": "dm", "zoom": 2}, + {"objects": c1, "row": 2, "col": 1, "units_length": "cm", "zoom": 1}, + {"objects": c1, "row": 2, "col": 2, "units_length": "mm", "zoom": 0}, + ] + fig = magpy.show( + *inputs, + backend="plotly", + return_fig=True, + ) + for ind, inp in enumerate(inputs): + scene = getattr(fig.layout, f"scene{'' if ind==0 else ind+1}") + for k in "xyz": + ax = getattr(scene, f"{k}axis") + assert ax.title.text == f"{k} ({inp['units_length']})" + factor = get_unit_factor(inp["units_length"], target_unit="m") + r = (inp["zoom"] + 1) / 2 * factor * max(dims) + assert ax.range == (-r, r) diff --git a/tests/test_input_checks.py b/tests/test_input_checks.py index 68024e801..b92c048b9 100644 --- a/tests/test_input_checks.py +++ b/tests/test_input_checks.py @@ -553,7 +553,7 @@ def test_input_show_zoom_bad(zoom): """bad show zoom inputs""" x = magpy.Sensor() with pytest.raises(MagpylibBadUserInput): - magpy.show(x, zoom=zoom) + magpy.show(x, zoom=zoom, return_fig=True, backend="plotly") @pytest.mark.parametrize(