diff --git a/integration_tests/test_gruntz.py b/integration_tests/test_gruntz.py index 6c7ddcb7b8..376f68495a 100644 --- a/integration_tests/test_gruntz.py +++ b/integration_tests/test_gruntz.py @@ -2,16 +2,28 @@ from sympy import Symbol def mmrv(e: S, x: S) -> list[S]: - l: list[S] = [] if not e.has(x): - return l + list0: list[S] = [] + return list0 + elif e == x: + list1: list[S] = [x] + return list1 else: raise -def test_mrv1(): +def test_mrv(): + # Case 1 x: S = Symbol("x") y: S = Symbol("y") - ans: list[S] = mmrv(y, x) - assert len(ans) == 0 + ans1: list[S] = mmrv(y, x) + print(ans1) + assert len(ans1) == 0 -test_mrv1() \ No newline at end of file + # Case 2 + ans2: list[S] = mmrv(x, x) + ele1: S = ans2[0] + print(ele1) + assert ele1 == x + assert len(ans2) == 1 + +test_mrv() \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index afa7082ff9..f85bfde582 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -690,6 +690,15 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(*xx.m_test)) { + ASR::SymbolicCompare_t *s = ASR::down_cast(xx.m_test); + ASR::expr_t* function_call = nullptr; + if (s->m_op == ASR::cmpopType::Eq) { + function_call = basic_compare(xx.base.base.loc, "basic_eq", s->m_left, s->m_right); + } else { + function_call = basic_compare(xx.base.base.loc, "basic_neq", s->m_left, s->m_right); + } + xx.m_test = function_call; } }