From 10241a4657354b14e7d526a5e3c9dd67596ac4ab Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 20 Mar 2023 13:49:50 -0500 Subject: [PATCH 1/2] Fix converting atan2 to sympy --- symengine/lib/symengine_wrapper.pyx | 8 ++++++++ symengine/tests/test_sympy_conv.py | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/symengine/lib/symengine_wrapper.pyx b/symengine/lib/symengine_wrapper.pyx index 7cbf148c..52c65d6c 100644 --- a/symengine/lib/symengine_wrapper.pyx +++ b/symengine/lib/symengine_wrapper.pyx @@ -2635,6 +2635,14 @@ class atan2(Function): cdef Basic Y = sympify(y) return c2py(symengine.atan2(X.thisptr, Y.thisptr)) + def _sympy_(self): + import sympy + return sympy.atan2(*self.args_as_sympy()) + + def _sage_(self): + import sage.all as sage + return sage.atan2(*self.args_as_sage()) + # For backwards compatibility Sin = sin diff --git a/symengine/tests/test_sympy_conv.py b/symengine/tests/test_sympy_conv.py index 7692ad71..e9fbe0bd 100644 --- a/symengine/tests/test_sympy_conv.py +++ b/symengine/tests/test_sympy_conv.py @@ -171,6 +171,7 @@ def test_conv7(): assert acot(x/3) == acot(sympy.Symbol("x") / 3) assert acsc(x/3) == acsc(sympy.Symbol("x") / 3) assert asec(x/3) == asec(sympy.Symbol("x") / 3) + assert atan2(x/3, y) == atan2(sympy.Symbol("x") / 3, sympy.Symbol("y")) assert sin(x/3)._sympy_() == sympy.sin(sympy.Symbol("x") / 3) assert sin(x/3)._sympy_() != sympy.cos(sympy.Symbol("x") / 3) @@ -185,6 +186,22 @@ def test_conv7(): assert acot(x/3)._sympy_() == sympy.acot(sympy.Symbol("x") / 3) assert acsc(x/3)._sympy_() == sympy.acsc(sympy.Symbol("x") / 3) assert asec(x/3)._sympy_() == sympy.asec(sympy.Symbol("x") / 3) + assert atan2(x/3, y)._sympy_() == sympy.atan2(sympy.Symbol("x") / 3, sympy.Symbol("y")) + + assert sympy.sympify(sin(x/3)) == sympy.sin(sympy.Symbol("x") / 3) + assert sympy.sympify(sin(x/3)) != sympy.cos(sympy.Symbol("x") / 3) + assert sympy.sympify(cos(x/3)) == sympy.cos(sympy.Symbol("x") / 3) + assert sympy.sympify(tan(x/3)) == sympy.tan(sympy.Symbol("x") / 3) + assert sympy.sympify(cot(x/3)) == sympy.cot(sympy.Symbol("x") / 3) + assert sympy.sympify(csc(x/3)) == sympy.csc(sympy.Symbol("x") / 3) + assert sympy.sympify(sec(x/3)) == sympy.sec(sympy.Symbol("x") / 3) + assert sympy.sympify(asin(x/3)) == sympy.asin(sympy.Symbol("x") / 3) + assert sympy.sympify(acos(x/3)) == sympy.acos(sympy.Symbol("x") / 3) + assert sympy.sympify(atan(x/3)) == sympy.atan(sympy.Symbol("x") / 3) + assert sympy.sympify(acot(x/3)) == sympy.acot(sympy.Symbol("x") / 3) + assert sympy.sympify(acsc(x/3)) == sympy.acsc(sympy.Symbol("x") / 3) + assert sympy.sympify(asec(x/3)) == sympy.asec(sympy.Symbol("x") / 3) + assert sympy.sympify(atan2(x/3, y)) == sympy.atan2(sympy.Symbol("x") / 3, sympy.Symbol("y")) @unittest.skipIf(not have_sympy, "SymPy not installed") @@ -204,6 +221,7 @@ def test_conv7b(): assert sympify(sympy.acot(x/3)) == acot(Symbol("x") / 3) assert sympify(sympy.acsc(x/3)) == acsc(Symbol("x") / 3) assert sympify(sympy.asec(x/3)) == asec(Symbol("x") / 3) + assert sympify(sympy.atan2(x/3, y)) == atan2(Symbol("x") / 3, Symbol("y")) @unittest.skipIf(not have_sympy, "SymPy not installed") From eb78901bfbb6b37c15e03d73fb2420a47aa44909 Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 20 Mar 2023 14:21:00 -0500 Subject: [PATCH 2/2] import atan2 --- symengine/tests/test_sympy_conv.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/symengine/tests/test_sympy_conv.py b/symengine/tests/test_sympy_conv.py index e9fbe0bd..5d173dc4 100644 --- a/symengine/tests/test_sympy_conv.py +++ b/symengine/tests/test_sympy_conv.py @@ -2,7 +2,7 @@ function_symbol, I, E, pi, oo, zoo, nan, true, false, exp, gamma, have_mpfr, have_mpc, DenseMatrix, sin, cos, tan, cot, csc, sec, asin, acos, atan, acot, acsc, asec, sinh, cosh, tanh, coth, - asinh, acosh, atanh, acoth, Add, Mul, Pow, diff, GoldenRatio, + asinh, acosh, atanh, acoth, atan2, Add, Mul, Pow, diff, GoldenRatio, Catalan, EulerGamma, UnevaluatedExpr, RealDouble) from symengine.lib.symengine_wrapper import (Subs, Derivative, RealMPFR, ComplexMPC, PyNumber, Function, LambertW, zeta, dirichlet_eta,