From 52ab45e7a4289fe15fb8ba5e44f6afc3127fde4c Mon Sep 17 00:00:00 2001 From: Rory Yorke Date: Sat, 1 Feb 2025 13:31:48 +0200 Subject: [PATCH 1/3] Add test for scalar `timepts` arg in `solve_flat_ocp` --- control/tests/flatsys_test.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/control/tests/flatsys_test.py b/control/tests/flatsys_test.py index a12bf1480..46e3c4934 100644 --- a/control/tests/flatsys_test.py +++ b/control/tests/flatsys_test.py @@ -452,6 +452,29 @@ def test_flat_solve_ocp(self, basis): np.testing.assert_almost_equal(x_const, x_nlconst) np.testing.assert_almost_equal(u_const, u_nlconst) + def test_solve_flat_ocp_scalar_timepts(self): + # scalar timepts gives expected result + f = fs.LinearFlatSystem(ct.ss(ct.tf([1],[1,1]))) + + def terminal_cost(x, u): + return (x-5).dot(x-5)+u.dot(u) + + traj1 = fs.solve_flat_ocp(f, [0, 1], x0=[23], + terminal_cost=terminal_cost) + + traj2 = fs.solve_flat_ocp(f, 1, x0=[23], + terminal_cost=terminal_cost) + + teval = np.linspace(0, 1, 101) + + r1 = traj1.response(teval) + r2 = traj2.response(teval) + + assert np.max(abs(r1.x-r2.x)) == 0 + assert np.max(abs(r1.u-r2.u)) == 0 + assert np.max(abs(r1.y-r2.y)) == 0 + + def test_bezier_basis(self): bezier = fs.BezierFamily(4) time = np.linspace(0, 1, 100) From b7eeb79c2c1a21a0bb6416bcb62f4f57ab2ec886 Mon Sep 17 00:00:00 2001 From: Rory Yorke Date: Sat, 1 Feb 2025 13:32:26 +0200 Subject: [PATCH 2/3] Handle scalar `timepts` arg in `solve_flat_ocp` Fixes gh-1110. --- control/flatsys/flatsys.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/control/flatsys/flatsys.py b/control/flatsys/flatsys.py index 7d76b9d78..5818d118b 100644 --- a/control/flatsys/flatsys.py +++ b/control/flatsys/flatsys.py @@ -721,8 +721,7 @@ def solve_flat_ocp( # Process final time timepts = np.atleast_1d(timepts) - Tf = timepts[-1] - T0 = timepts[0] if len(timepts) > 1 else T0 + T0 = timepts[0] if len(timepts) > 1 else 0 # Process keyword arguments if trajectory_constraints is None: From a05b9d50f8304e17ed46d299df79d1d3667d0088 Mon Sep 17 00:00:00 2001 From: Rory Yorke Date: Sat, 1 Feb 2025 18:59:00 +0200 Subject: [PATCH 3/3] Use assert_array_equal to compare results in scalar timeresp test --- control/tests/flatsys_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/control/tests/flatsys_test.py b/control/tests/flatsys_test.py index 46e3c4934..5b66edaf5 100644 --- a/control/tests/flatsys_test.py +++ b/control/tests/flatsys_test.py @@ -470,9 +470,9 @@ def terminal_cost(x, u): r1 = traj1.response(teval) r2 = traj2.response(teval) - assert np.max(abs(r1.x-r2.x)) == 0 - assert np.max(abs(r1.u-r2.u)) == 0 - assert np.max(abs(r1.y-r2.y)) == 0 + np.testing.assert_array_equal(r1.x, r2.x) + np.testing.assert_array_equal(r1.y, r2.y) + np.testing.assert_array_equal(r1.u, r2.u) def test_bezier_basis(self):