diff --git a/control/statesp.py b/control/statesp.py index 0c2856b15..717fc9a73 100644 --- a/control/statesp.py +++ b/control/statesp.py @@ -50,6 +50,7 @@ import math from copy import deepcopy from warnings import warn +from collections.abc import Iterable import numpy as np import scipy as sp @@ -289,9 +290,9 @@ def __init__(self, *args, **kwargs): raise ValueError("A and B must have the same number of rows.") if self.nstates != C.shape[1]: raise ValueError("A and C must have the same number of columns.") - if self.ninputs != B.shape[1]: + if self.ninputs != B.shape[1] or self.ninputs != D.shape[1]: raise ValueError("B and D must have the same number of columns.") - if self.noutputs != C.shape[0]: + if self.noutputs != C.shape[0] or self.noutputs != D.shape[0]: raise ValueError("C and D must have the same number of rows.") # @@ -1215,17 +1216,23 @@ def append(self, other): def __getitem__(self, indices): """Array style access""" - if len(indices) != 2: + if not isinstance(indices, Iterable) or len(indices) != 2: raise IOError('must provide indices of length 2 for state space') - outdx = indices[0] if isinstance(indices[0], list) else [indices[0]] - inpdx = indices[1] if isinstance(indices[1], list) else [indices[1]] + outdx, inpdx = indices + + # Convert int to slice to ensure that numpy doesn't drop the dimension + if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1) + if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1) + + if not isinstance(outdx, slice) or not isinstance(inpdx, slice): + raise TypeError(f"system indices must be integers or slices") + sysname = config.defaults['iosys.indexed_system_name_prefix'] + \ self.name + config.defaults['iosys.indexed_system_name_suffix'] return StateSpace( self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx], - self.dt, name=sysname, - inputs=[self.input_labels[i] for i in list(inpdx)], - outputs=[self.output_labels[i] for i in list(outdx)]) + self.dt, name=sysname, + inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx]) def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None, name=None, copy_names=True, **kwargs): diff --git a/control/tests/statesp_test.py b/control/tests/statesp_test.py index 59f441456..6ddf9933e 100644 --- a/control/tests/statesp_test.py +++ b/control/tests/statesp_test.py @@ -463,28 +463,53 @@ def test_append_tf(self): np.testing.assert_array_almost_equal(sys3c.A[:3, 3:], np.zeros((3, 2))) np.testing.assert_array_almost_equal(sys3c.A[3:, :3], np.zeros((2, 3))) - def test_array_access_ss(self): - + def test_array_access_ss_failure(self): sys1 = StateSpace( [[1., 2.], [3., 4.]], [[5., 6.], [6., 8.]], [[9., 10.], [11., 12.]], [[13., 14.], [15., 16.]], 1, inputs=['u0', 'u1'], outputs=['y0', 'y1']) + with pytest.raises(IOError): + sys1[0] + + @pytest.mark.parametrize("outdx, inpdx", + [(0, 1), + (slice(0, 1, 1), 1), + (0, slice(1, 2, 1)), + (slice(0, 1, 1), slice(1, 2, 1)), + (slice(None, None, -1), 1), + (0, slice(None, None, -1)), + (slice(None, 2, None), 1), + (slice(None, None, 1), slice(None, None, 2)), + (0, slice(1, 2, 1)), + (slice(0, 1, 1), slice(1, 2, 1))]) + def test_array_access_ss(self, outdx, inpdx): + sys1 = StateSpace( + [[1., 2.], [3., 4.]], + [[5., 6.], [7., 8.]], + [[9., 10.], [11., 12.]], + [[13., 14.], [15., 16.]], 1, + inputs=['u0', 'u1'], outputs=['y0', 'y1']) - sys1_01 = sys1[0, 1] + sys1_01 = sys1[outdx, inpdx] + + # Convert int to slice to ensure that numpy doesn't drop the dimension + if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1) + if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1) + np.testing.assert_array_almost_equal(sys1_01.A, sys1.A) np.testing.assert_array_almost_equal(sys1_01.B, - sys1.B[:, 1:2]) + sys1.B[:, inpdx]) np.testing.assert_array_almost_equal(sys1_01.C, - sys1.C[0:1, :]) + sys1.C[outdx, :]) np.testing.assert_array_almost_equal(sys1_01.D, - sys1.D[0, 1]) + sys1.D[outdx, inpdx]) assert sys1.dt == sys1_01.dt - assert sys1_01.input_labels == ['u1'] - assert sys1_01.output_labels == ['y0'] + assert sys1_01.input_labels == sys1.input_labels[inpdx] + assert sys1_01.output_labels == sys1.output_labels[outdx] assert sys1_01.name == sys1.name + "$indexed" def test_dc_gain_cont(self): diff --git a/control/xferfcn.py b/control/xferfcn.py index 63aeff8f9..ba9af3913 100644 --- a/control/xferfcn.py +++ b/control/xferfcn.py @@ -47,6 +47,8 @@ """ +from collections.abc import Iterable + # External function declarations import numpy as np from numpy import angle, array, empty, finfo, ndarray, ones, \ @@ -758,7 +760,12 @@ def __pow__(self, other): return (TransferFunction([1], [1]) / self) * (self**(other + 1)) def __getitem__(self, key): + if not isinstance(key, Iterable) or len(key) != 2: + raise IOError('must provide indices of length 2 for transfer functions') + key1, key2 = key + if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)): + raise TypeError(f"system indices must be integers or slices") # pre-process if isinstance(key1, int):