Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add slicing access for state-space models with tests #1012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions control/statesp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import math
from copy import deepcopy
from warnings import warn
from collections.abc import Iterable

import numpy as np
import scipy as sp
Expand Down Expand Up @@ -289,9 +290,9 @@ def __init__(self, *args, **kwargs):
raise ValueError("A and B must have the same number of rows.")
if self.nstates != C.shape[1]:
raise ValueError("A and C must have the same number of columns.")
if self.ninputs != B.shape[1]:
if self.ninputs != B.shape[1] or self.ninputs != D.shape[1]:
raise ValueError("B and D must have the same number of columns.")
if self.noutputs != C.shape[0]:
if self.noutputs != C.shape[0] or self.noutputs != D.shape[0]:
raise ValueError("C and D must have the same number of rows.")

#
Expand Down Expand Up @@ -1215,17 +1216,23 @@ def append(self, other):

def __getitem__(self, indices):
"""Array style access"""
if len(indices) != 2:
if not isinstance(indices, Iterable) or len(indices) != 2:
raise IOError('must provide indices of length 2 for state space')
outdx = indices[0] if isinstance(indices[0], list) else [indices[0]]
inpdx = indices[1] if isinstance(indices[1], list) else [indices[1]]
outdx, inpdx = indices

# Convert int to slice to ensure that numpy doesn't drop the dimension
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)

if not isinstance(outdx, slice) or not isinstance(inpdx, slice):
raise TypeError(f"system indices must be integers or slices")

sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
self.name + config.defaults['iosys.indexed_system_name_suffix']
return StateSpace(
self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx],
self.dt, name=sysname,
inputs=[self.input_labels[i] for i in list(inpdx)],
outputs=[self.output_labels[i] for i in list(outdx)])
self.dt, name=sysname,
inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx])

def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None,
name=None, copy_names=True, **kwargs):
Expand Down
41 changes: 33 additions & 8 deletions control/tests/statesp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,28 +463,53 @@ def test_append_tf(self):
np.testing.assert_array_almost_equal(sys3c.A[:3, 3:], np.zeros((3, 2)))
np.testing.assert_array_almost_equal(sys3c.A[3:, :3], np.zeros((2, 3)))

def test_array_access_ss(self):

def test_array_access_ss_failure(self):
sys1 = StateSpace(
[[1., 2.], [3., 4.]],
[[5., 6.], [6., 8.]],
[[9., 10.], [11., 12.]],
[[13., 14.], [15., 16.]], 1,
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
with pytest.raises(IOError):
sys1[0]

@pytest.mark.parametrize("outdx, inpdx",
[(0, 1),
(slice(0, 1, 1), 1),
(0, slice(1, 2, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(None, None, -1), 1),
(0, slice(None, None, -1)),
(slice(None, 2, None), 1),
(slice(None, None, 1), slice(None, None, 2)),
(0, slice(1, 2, 1)),
(slice(0, 1, 1), slice(1, 2, 1))])
def test_array_access_ss(self, outdx, inpdx):
sys1 = StateSpace(
[[1., 2.], [3., 4.]],
[[5., 6.], [7., 8.]],
[[9., 10.], [11., 12.]],
[[13., 14.], [15., 16.]], 1,
inputs=['u0', 'u1'], outputs=['y0', 'y1'])

sys1_01 = sys1[0, 1]
sys1_01 = sys1[outdx, inpdx]

# Convert int to slice to ensure that numpy doesn't drop the dimension
if isinstance(outdx, int): outdx = slice(outdx, outdx+1, 1)
if isinstance(inpdx, int): inpdx = slice(inpdx, inpdx+1, 1)

np.testing.assert_array_almost_equal(sys1_01.A,
sys1.A)
np.testing.assert_array_almost_equal(sys1_01.B,
sys1.B[:, 1:2])
sys1.B[:, inpdx])
np.testing.assert_array_almost_equal(sys1_01.C,
sys1.C[0:1, :])
sys1.C[outdx, :])
np.testing.assert_array_almost_equal(sys1_01.D,
sys1.D[0, 1])
sys1.D[outdx, inpdx])

assert sys1.dt == sys1_01.dt
assert sys1_01.input_labels == ['u1']
assert sys1_01.output_labels == ['y0']
assert sys1_01.input_labels == sys1.input_labels[inpdx]
assert sys1_01.output_labels == sys1.output_labels[outdx]
assert sys1_01.name == sys1.name + "$indexed"

def test_dc_gain_cont(self):
Expand Down
7 changes: 7 additions & 0 deletions control/xferfcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@

"""

from collections.abc import Iterable

# External function declarations
import numpy as np
from numpy import angle, array, empty, finfo, ndarray, ones, \
Expand Down Expand Up @@ -758,7 +760,12 @@ def __pow__(self, other):
return (TransferFunction([1], [1]) / self) * (self**(other + 1))

def __getitem__(self, key):
if not isinstance(key, Iterable) or len(key) != 2:
raise IOError('must provide indices of length 2 for transfer functions')

key1, key2 = key
if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)):
raise TypeError(f"system indices must be integers or slices")

# pre-process
if isinstance(key1, int):
Expand Down
Loading