From 234e6ec77c29c677289f37fcccc54ce6ec4fb726 Mon Sep 17 00:00:00 2001 From: Richard Murray Date: Thu, 27 Jun 2024 15:34:49 -0700 Subject: [PATCH 1/3] move code around to new locations --- control/ctrlplot.py | 232 ++++++++++++++++++++++++++++++++- control/freqplot.py | 215 +----------------------------- control/nichols.py | 5 +- control/phaseplot.py | 2 +- control/pzmap.py | 3 +- control/tests/ctrlplot_test.py | 42 ++++++ control/tests/timeplot_test.py | 35 ----- control/timeplot.py | 17 +-- 8 files changed, 281 insertions(+), 270 deletions(-) create mode 100644 control/tests/ctrlplot_test.py diff --git a/control/ctrlplot.py b/control/ctrlplot.py index c8c30880d..e53d4917e 100644 --- a/control/ctrlplot.py +++ b/control/ctrlplot.py @@ -5,6 +5,7 @@ from os.path import commonprefix +import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np @@ -12,6 +13,28 @@ __all__ = ['suptitle', 'get_plot_axes'] +# +# Style parameters +# + +_ctrlplot_rcParams = mpl.rcParams.copy() +_ctrlplot_rcParams.update({ + 'axes.labelsize': 'small', + 'axes.titlesize': 'small', + 'figure.titlesize': 'medium', + 'legend.fontsize': 'x-small', + 'xtick.labelsize': 'small', + 'ytick.labelsize': 'small', +}) + + +# +# User functions +# +# The functions below can be used by users to modify ctrl plots or get +# information about them. +# + def suptitle( title, fig=None, frame='axes', **kwargs): @@ -35,7 +58,7 @@ def suptitle( Additional keywords (passed to matplotlib). """ - rcParams = config._get_param('freqplot', 'rcParams', kwargs, pop=True) + rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True) if fig is None: fig = plt.gcf() @@ -61,10 +84,10 @@ def suptitle( def get_plot_axes(line_array): """Get a list of axes from an array of lines. - This function can be used to return the set of axes corresponding to - the line array that is returned by `time_response_plot`. This is useful for - generating an axes array that can be passed to subsequent plotting - calls. + This function can be used to return the set of axes corresponding + to the line array that is returned by `time_response_plot`. This + is useful for generating an axes array that can be passed to + subsequent plotting calls. Parameters ---------- @@ -89,6 +112,125 @@ def get_plot_axes(line_array): # # Utility functions # +# These functions are used by plotting routines to provide a consistent way +# of processing and displaing information. +# + + +def _process_ax_keyword( + axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False): + """Utility function to process ax keyword to plotting commands. + + This function processes the `ax` keyword to plotting commands. If no + ax keyword is passed, the current figure is checked to see if it has + the correct shape. If the shape matches the desired shape, then the + current figure and axes are returned. Otherwise a new figure is + created with axes of the desired shape. + + Legacy behavior: some of the older plotting commands use a axes label + to identify the proper axes for plotting. This behavior is supported + through the use of the label keyword, but will only work if shape == + (1, 1) and squeeze == True. + + """ + if axs is None: + fig = plt.gcf() # get current figure (or create new one) + axs = fig.get_axes() + + # Check to see if axes are the right shape; if not, create new figure + # Note: can't actually check the shape, just the total number of axes + if len(axs) != np.prod(shape): + with plt.rc_context(rcParams): + if len(axs) != 0: + # Create a new figure + fig, axs = plt.subplots(*shape, squeeze=False) + else: + # Create new axes on (empty) figure + axs = fig.subplots(*shape, squeeze=False) + fig.set_layout_engine('tight') + fig.align_labels() + else: + # Use the existing axes, properly reshaped + axs = np.asarray(axs).reshape(*shape) + + if clear_text: + # Clear out any old text from the current figure + for text in fig.texts: + text.set_visible(False) # turn off the text + del text # get rid of it completely + else: + try: + axs = np.asarray(axs).reshape(shape) + except ValueError: + raise ValueError( + "specified axes are not the right shape; " + f"got {axs.shape} but expecting {shape}") + fig = axs[0, 0].figure + + # Process the squeeze keyword + if squeeze and shape == (1, 1): + axs = axs[0, 0] # Just return the single axes object + elif squeeze: + axs = axs.squeeze() + + return fig, axs + + +# Turn label keyword into array indexed by trace, output, input +# TODO: move to ctrlutil.py and update parameter names to reflect general use +def _process_line_labels(label, ntraces, ninputs=0, noutputs=0): + if label is None: + return None + + if isinstance(label, str): + label = [label] * ntraces # single label for all traces + + # Convert to an ndarray, if not done aleady + try: + line_labels = np.asarray(label) + except ValueError: + raise ValueError("label must be a string or array_like") + + # Turn the data into a 3D array of appropriate shape + # TODO: allow more sophisticated broadcasting (and error checking) + try: + if ninputs > 0 and noutputs > 0: + if line_labels.ndim == 1 and line_labels.size == ntraces: + line_labels = line_labels.reshape(ntraces, 1, 1) + line_labels = np.broadcast_to( + line_labels, (ntraces, ninputs, noutputs)) + else: + line_labels = line_labels.reshape(ntraces, ninputs, noutputs) + except ValueError: + if line_labels.shape[0] != ntraces: + raise ValueError("number of labels must match number of traces") + else: + raise ValueError("labels must be given for each input/output pair") + + return line_labels + + +# Get labels for all lines in an axes +def _get_line_labels(ax, use_color=True): + labels, lines = [], [] + last_color, counter = None, 0 # label unknown systems + for i, line in enumerate(ax.get_lines()): + label = line.get_label() + if use_color and label.startswith("Unknown"): + label = f"Unknown-{counter}" + if last_color is None: + last_color = line.get_color() + elif last_color != line.get_color(): + counter += 1 + last_color = line.get_color() + elif label[0] == '_': + continue + + if label not in labels: + lines.append(line) + labels.append(label) + + return lines, labels # Utility function to make legend labels @@ -160,3 +302,83 @@ def _find_axes_center(fig, axs): ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])] return (np.sum(xlim)/2, np.sum(ylim)/2) + + +# Internal function to add arrows to a curve +def _add_arrows_to_line2D( + axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8], + arrowstyle='-|>', arrowsize=1, dir=1): + """ + Add arrows to a matplotlib.lines.Line2D at selected locations. + + Parameters: + ----------- + axes: Axes object as returned by axes command (or gca) + line: Line2D object as returned by plot command + arrow_locs: list of locations where to insert arrows, % of total length + arrowstyle: style of the arrow + arrowsize: size of the arrow + + Returns: + -------- + arrows: list of arrows + + Based on https://stackoverflow.com/questions/26911898/ + + """ + # Get the coordinates of the line, in plot coordinates + if not isinstance(line, mpl.lines.Line2D): + raise ValueError("expected a matplotlib.lines.Line2D object") + x, y = line.get_xdata(), line.get_ydata() + + # Determine the arrow properties + arrow_kw = {"arrowstyle": arrowstyle} + + color = line.get_color() + use_multicolor_lines = isinstance(color, np.ndarray) + if use_multicolor_lines: + raise NotImplementedError("multicolor lines not supported") + else: + arrow_kw['color'] = color + + linewidth = line.get_linewidth() + if isinstance(linewidth, np.ndarray): + raise NotImplementedError("multiwidth lines not supported") + else: + arrow_kw['linewidth'] = linewidth + + # Figure out the size of the axes (length of diagonal) + xlim, ylim = axes.get_xlim(), axes.get_ylim() + ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]]) + diag = np.linalg.norm(ul - lr) + + # Compute the arc length along the curve + s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)) + + # Truncate the number of arrows if the curve is short + # TODO: figure out a smarter way to do this + frac = min(s[-1] / diag, 1) + if len(arrow_locs) and frac < 0.05: + arrow_locs = [] # too short; no arrows at all + elif len(arrow_locs) and frac < 0.2: + arrow_locs = [0.5] # single arrow in the middle + + # Plot the arrows (and return list if patches) + arrows = [] + for loc in arrow_locs: + n = np.searchsorted(s, s[-1] * loc) + + if dir == 1 and n == 0: + # Move the arrow forward by one if it is at start of a segment + n = 1 + + # Place the head of the arrow at the desired location + arrow_head = [x[n], y[n]] + arrow_tail = [x[n - dir], y[n - dir]] + + p = mpl.patches.FancyArrowPatch( + arrow_tail, arrow_head, transform=axes.transData, lw=0, + **arrow_kw) + axes.add_patch(p) + arrows.append(p) + return arrows diff --git a/control/freqplot.py b/control/freqplot.py index 5ff690450..277de8a54 100644 --- a/control/freqplot.py +++ b/control/freqplot.py @@ -19,8 +19,9 @@ from . import config from .bdalg import feedback -from .ctrlplot import suptitle, _find_axes_center, _make_legend_labels, \ - _update_suptitle +from .ctrlplot import _add_arrows_to_line2D, _ctrlplot_rcParams, \ + _find_axes_center, _get_line_labels, _make_legend_labels, \ + _process_ax_keyword, _process_line_labels, _update_suptitle, suptitle from .ctrlutil import unwrap from .exception import ControlMIMONotImplemented from .frdata import FrequencyResponseData @@ -34,21 +35,9 @@ 'singular_values_plot', 'gangof4_plot', 'gangof4_response', 'bode', 'nyquist', 'gangof4'] -# Default font dictionary -# TODO: move common plotting params to 'ctrlplot' -_freqplot_rcParams = mpl.rcParams.copy() -_freqplot_rcParams.update({ - 'axes.labelsize': 'small', - 'axes.titlesize': 'small', - 'figure.titlesize': 'medium', - 'legend.fontsize': 'x-small', - 'xtick.labelsize': 'small', - 'ytick.labelsize': 'small', -}) - # Default values for module parameter variables _freqplot_defaults = { - 'freqplot.rcParams': _freqplot_rcParams, + 'freqplot.rcParams': _ctrlplot_rcParams, 'freqplot.feature_periphery_decades': 1, 'freqplot.number_of_samples': 1000, 'freqplot.dB': False, # Plot gain in dB @@ -1937,86 +1926,6 @@ def _parse_linestyle(style_name, allow_false=False): return out -# Internal function to add arrows to a curve -def _add_arrows_to_line2D( - axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8], - arrowstyle='-|>', arrowsize=1, dir=1): - """ - Add arrows to a matplotlib.lines.Line2D at selected locations. - - Parameters: - ----------- - axes: Axes object as returned by axes command (or gca) - line: Line2D object as returned by plot command - arrow_locs: list of locations where to insert arrows, % of total length - arrowstyle: style of the arrow - arrowsize: size of the arrow - - Returns: - -------- - arrows: list of arrows - - Based on https://stackoverflow.com/questions/26911898/ - - """ - # Get the coordinates of the line, in plot coordinates - if not isinstance(line, mpl.lines.Line2D): - raise ValueError("expected a matplotlib.lines.Line2D object") - x, y = line.get_xdata(), line.get_ydata() - - # Determine the arrow properties - arrow_kw = {"arrowstyle": arrowstyle} - - color = line.get_color() - use_multicolor_lines = isinstance(color, np.ndarray) - if use_multicolor_lines: - raise NotImplementedError("multicolor lines not supported") - else: - arrow_kw['color'] = color - - linewidth = line.get_linewidth() - if isinstance(linewidth, np.ndarray): - raise NotImplementedError("multiwidth lines not supported") - else: - arrow_kw['linewidth'] = linewidth - - # Figure out the size of the axes (length of diagonal) - xlim, ylim = axes.get_xlim(), axes.get_ylim() - ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]]) - diag = np.linalg.norm(ul - lr) - - # Compute the arc length along the curve - s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)) - - # Truncate the number of arrows if the curve is short - # TODO: figure out a smarter way to do this - frac = min(s[-1] / diag, 1) - if len(arrow_locs) and frac < 0.05: - arrow_locs = [] # too short; no arrows at all - elif len(arrow_locs) and frac < 0.2: - arrow_locs = [0.5] # single arrow in the middle - - # Plot the arrows (and return list if patches) - arrows = [] - for loc in arrow_locs: - n = np.searchsorted(s, s[-1] * loc) - - if dir == 1 and n == 0: - # Move the arrow forward by one if it is at start of a segment - n = 1 - - # Place the head of the arrow at the desired location - arrow_head = [x[n], y[n]] - arrow_tail = [x[n - dir], y[n - dir]] - - p = mpl.patches.FancyArrowPatch( - arrow_tail, arrow_head, transform=axes.transData, lw=0, - **arrow_kw) - axes.add_patch(p) - arrows.append(p) - return arrows - - # # Function to compute Nyquist curve offsets # @@ -2672,122 +2581,6 @@ def _default_frequency_range(syslist, Hz=None, number_of_samples=None, return omega -# Get labels for all lines in an axes -def _get_line_labels(ax, use_color=True): - labels, lines = [], [] - last_color, counter = None, 0 # label unknown systems - for i, line in enumerate(ax.get_lines()): - label = line.get_label() - if use_color and label.startswith("Unknown"): - label = f"Unknown-{counter}" - if last_color is None: - last_color = line.get_color() - elif last_color != line.get_color(): - counter += 1 - last_color = line.get_color() - elif label[0] == '_': - continue - - if label not in labels: - lines.append(line) - labels.append(label) - - return lines, labels - - -# Turn label keyword into array indexed by trace, output, input -# TODO: move to ctrlutil.py and update parameter names to reflect general use -def _process_line_labels(label, ntraces, ninputs=0, noutputs=0): - if label is None: - return None - - if isinstance(label, str): - label = [label] * ntraces # single label for all traces - - # Convert to an ndarray, if not done aleady - try: - line_labels = np.asarray(label) - except: - raise ValueError("label must be a string or array_like") - - # Turn the data into a 3D array of appropriate shape - # TODO: allow more sophisticated broadcasting (and error checking) - try: - if ninputs > 0 and noutputs > 0: - if line_labels.ndim == 1 and line_labels.size == ntraces: - line_labels = line_labels.reshape(ntraces, 1, 1) - line_labels = np.broadcast_to( - line_labels, (ntraces, ninputs, noutputs)) - else: - line_labels = line_labels.reshape(ntraces, ninputs, noutputs) - except: - if line_labels.shape[0] != ntraces: - raise ValueError("number of labels must match number of traces") - else: - raise ValueError("labels must be given for each input/output pair") - - return line_labels - - -def _process_ax_keyword( - axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False): - """Utility function to process ax keyword to plotting commands. - - This function processes the `ax` keyword to plotting commands. If no - ax keyword is passed, the current figure is checked to see if it has - the correct shape. If the shape matches the desired shape, then the - current figure and axes are returned. Otherwise a new figure is - created with axes of the desired shape. - - Legacy behavior: some of the older plotting commands use a axes label - to identify the proper axes for plotting. This behavior is supported - through the use of the label keyword, but will only work if shape == - (1, 1) and squeeze == True. - - """ - if axs is None: - fig = plt.gcf() # get current figure (or create new one) - axs = fig.get_axes() - - # Check to see if axes are the right shape; if not, create new figure - # Note: can't actually check the shape, just the total number of axes - if len(axs) != np.prod(shape): - with plt.rc_context(rcParams): - if len(axs) != 0: - # Create a new figure - fig, axs = plt.subplots(*shape, squeeze=False) - else: - # Create new axes on (empty) figure - axs = fig.subplots(*shape, squeeze=False) - fig.set_layout_engine('tight') - fig.align_labels() - else: - # Use the existing axes, properly reshaped - axs = np.asarray(axs).reshape(*shape) - - if clear_text: - # Clear out any old text from the current figure - for text in fig.texts: - text.set_visible(False) # turn off the text - del text # get rid of it completely - else: - try: - axs = np.asarray(axs).reshape(shape) - except ValueError: - raise ValueError( - "specified axes are not the right shape; " - f"got {axs.shape} but expecting {shape}") - fig = axs[0, 0].figure - - # Process the squeeze keyword - if squeeze and shape == (1, 1): - axs = axs[0, 0] # Just return the single axes object - elif squeeze: - axs = axs.squeeze() - - return fig, axs - - # # Utility functions to create nice looking labels (KLD 5/23/11) # diff --git a/control/nichols.py b/control/nichols.py index 5eafa594f..78b03b315 100644 --- a/control/nichols.py +++ b/control/nichols.py @@ -18,10 +18,9 @@ import numpy as np from . import config -from .ctrlplot import suptitle +from .ctrlplot import _get_line_labels, _process_ax_keyword, suptitle from .ctrlutil import unwrap -from .freqplot import _default_frequency_range, _freqplot_defaults, \ - _get_line_labels, _process_ax_keyword +from .freqplot import _default_frequency_range, _freqplot_defaults from .lti import frequency_response from .statesp import StateSpace from .xferfcn import TransferFunction diff --git a/control/phaseplot.py b/control/phaseplot.py index a885f2d5c..c7ccd1d1e 100644 --- a/control/phaseplot.py +++ b/control/phaseplot.py @@ -36,8 +36,8 @@ from scipy.integrate import odeint from . import config +from .ctrlplot import _add_arrows_to_line2D from .exception import ControlNotImplemented -from .freqplot import _add_arrows_to_line2D from .nlsys import NonlinearIOSystem, find_eqpt, input_output_response __all__ = ['phase_plane_plot', 'phase_plot', 'box_grid'] diff --git a/control/pzmap.py b/control/pzmap.py index dd3f9e42b..c7082db1d 100644 --- a/control/pzmap.py +++ b/control/pzmap.py @@ -18,7 +18,8 @@ from numpy import cos, exp, imag, linspace, real, sin, sqrt from . import config -from .freqplot import _freqplot_defaults, _get_line_labels +from .ctrlplot import _get_line_labels +from .freqplot import _freqplot_defaults from .grid import nogrid, sgrid, zgrid from .iosys import isctime, isdtime from .lti import LTI diff --git a/control/tests/ctrlplot_test.py b/control/tests/ctrlplot_test.py new file mode 100644 index 000000000..05970bdd1 --- /dev/null +++ b/control/tests/ctrlplot_test.py @@ -0,0 +1,42 @@ +# ctrlplot_test.py - test out control plotting utilities +# RMM, 27 Jun 2024 + +import pytest +import control as ct +import matplotlib.pyplot as plt + +@pytest.mark.usefixtures('mplcleanup') +def test_rcParams(): + sys = ct.rss(2, 2, 2) + + # Create new set of rcParams + my_rcParams = {} + for key in [ + 'axes.labelsize', 'axes.titlesize', 'figure.titlesize', + 'legend.fontsize', 'xtick.labelsize', 'ytick.labelsize']: + match plt.rcParams[key]: + case 8 | 9 | 10: + my_rcParams[key] = plt.rcParams[key] + 1 + case 'medium': + my_rcParams[key] = 11.5 + case 'large': + my_rcParams[key] = 9.5 + case _: + raise ValueError(f"unknown rcParam type for {key}") + + # Generate a figure with the new rcParams + out = ct.step_response(sys).plot(rcParams=my_rcParams) + ax = out[0, 0][0].axes + fig = ax.figure + + # Check to make sure new settings were used + assert ax.xaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] + assert ax.yaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] + assert ax.title.get_fontsize() == my_rcParams['axes.titlesize'] + assert ax.get_xticklabels()[0].get_fontsize() == \ + my_rcParams['xtick.labelsize'] + assert ax.get_yticklabels()[0].get_fontsize() == \ + my_rcParams['ytick.labelsize'] + assert fig._suptitle.get_fontsize() == my_rcParams['figure.titlesize'] + + diff --git a/control/tests/timeplot_test.py b/control/tests/timeplot_test.py index 0fcc159be..6c124c48f 100644 --- a/control/tests/timeplot_test.py +++ b/control/tests/timeplot_test.py @@ -397,41 +397,6 @@ def test_linestyles(): assert lines[7].get_color() == 'green' and lines[7].get_linestyle() == '--' -@pytest.mark.usefixtures('mplcleanup') -def test_rcParams(): - sys = ct.rss(2, 2, 2) - - # Create new set of rcParams - my_rcParams = {} - for key in [ - 'axes.labelsize', 'axes.titlesize', 'figure.titlesize', - 'legend.fontsize', 'xtick.labelsize', 'ytick.labelsize']: - match plt.rcParams[key]: - case 8 | 9 | 10: - my_rcParams[key] = plt.rcParams[key] + 1 - case 'medium': - my_rcParams[key] = 11.5 - case 'large': - my_rcParams[key] = 9.5 - case _: - raise ValueError(f"unknown rcParam type for {key}") - - # Generate a figure with the new rcParams - out = ct.step_response(sys).plot(rcParams=my_rcParams) - ax = out[0, 0][0].axes - fig = ax.figure - - # Check to make sure new settings were used - assert ax.xaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] - assert ax.yaxis.get_label().get_fontsize() == my_rcParams['axes.labelsize'] - assert ax.title.get_fontsize() == my_rcParams['axes.titlesize'] - assert ax.get_xticklabels()[0].get_fontsize() == \ - my_rcParams['xtick.labelsize'] - assert ax.get_yticklabels()[0].get_fontsize() == \ - my_rcParams['ytick.labelsize'] - assert fig._suptitle.get_fontsize() == my_rcParams['figure.titlesize'] - - @pytest.mark.parametrize("resp_fcn", [ ct.step_response, ct.initial_response, ct.impulse_response, ct.forced_response, ct.input_output_response]) diff --git a/control/timeplot.py b/control/timeplot.py index 2eb7aec9b..01b5c7945 100644 --- a/control/timeplot.py +++ b/control/timeplot.py @@ -15,24 +15,13 @@ import numpy as np from . import config -from .ctrlplot import _make_legend_labels, _update_suptitle +from .ctrlplot import _ctrlplot_rcParams, _make_legend_labels, _update_suptitle __all__ = ['time_response_plot', 'combine_time_responses'] -# Default font dictionary -_timeplot_rcParams = mpl.rcParams.copy() -_timeplot_rcParams.update({ - 'axes.labelsize': 'small', - 'axes.titlesize': 'small', - 'figure.titlesize': 'medium', - 'legend.fontsize': 'x-small', - 'xtick.labelsize': 'small', - 'ytick.labelsize': 'small', -}) - # Default values for module parameter variables _timeplot_defaults = { - 'timeplot.rcParams': _timeplot_rcParams, + 'timeplot.rcParams': _ctrlplot_rcParams, 'timeplot.trace_props': [ {'linestyle': s} for s in ['-', '--', ':', '-.']], 'timeplot.output_props': [ @@ -162,7 +151,7 @@ def time_response_plot( config.defaults[''timeplot.rcParams']. """ - from .freqplot import _process_ax_keyword, _process_line_labels + from .ctrlplot import _process_ax_keyword, _process_line_labels from .iosys import InputOutputSystem from .timeresp import TimeResponseData From c67da3aa13a6f4804264e005733878348b7cfbf3 Mon Sep 17 00:00:00 2001 From: Richard Murray Date: Sun, 21 Jul 2024 08:34:28 -0700 Subject: [PATCH 2/3] move code around in grid.py --- control/grid.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/control/grid.py b/control/grid.py index ef9995947..dfe8f9a3e 100644 --- a/control/grid.py +++ b/control/grid.py @@ -141,18 +141,6 @@ def sgrid(scaling=None): return ax, fig -# Utility function used by all grid code -def _final_setup(ax, scaling=None): - ax.set_xlabel('Real') - ax.set_ylabel('Imaginary') - ax.axhline(y=0, color='black', lw=0.25) - ax.axvline(x=0, color='black', lw=0.25) - - # Set up the scaling for the axes - scaling = 'equal' if scaling is None else scaling - plt.axis(scaling) - - # If not grid is given, at least separate stable/unstable regions def nogrid(dt=None, ax=None, scaling=None): fig = plt.gcf() @@ -226,3 +214,15 @@ def zgrid(zetas=None, wns=None, ax=None, scaling=None): _final_setup(ax, scaling=scaling) return ax, fig + + +# Utility function used by all grid code +def _final_setup(ax, scaling=None): + ax.set_xlabel('Real') + ax.set_ylabel('Imaginary') + ax.axhline(y=0, color='black', lw=0.25) + ax.axvline(x=0, color='black', lw=0.25) + + # Set up the scaling for the axes + scaling = 'equal' if scaling is None else scaling + plt.axis(scaling) From 009b8212a7d44f901cfda060168c9b3ecc72145d Mon Sep 17 00:00:00 2001 From: Richard Murray Date: Sun, 21 Jul 2024 09:43:50 -0700 Subject: [PATCH 3/3] fix small typo caught by @slivingston --- control/ctrlplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/control/ctrlplot.py b/control/ctrlplot.py index e53d4917e..6d31664a0 100644 --- a/control/ctrlplot.py +++ b/control/ctrlplot.py @@ -113,7 +113,7 @@ def get_plot_axes(line_array): # Utility functions # # These functions are used by plotting routines to provide a consistent way -# of processing and displaing information. +# of processing and displaying information. #