diff --git a/integration_tests/test_gruntz.py b/integration_tests/test_gruntz.py index 33b940f86f..0a3948f722 100644 --- a/integration_tests/test_gruntz.py +++ b/integration_tests/test_gruntz.py @@ -1,10 +1,10 @@ from lpython import S -from sympy import Symbol, log +from sympy import Symbol, log, E, Pow def mmrv(e: S, x: S) -> list[S]: + empty_list : list[S] = [] if not e.has(x): - list0: list[S] = [] - return list0 + return empty_list elif e == x: list1: list[S] = [x] return list1 @@ -12,6 +12,25 @@ def mmrv(e: S, x: S) -> list[S]: arg0: S = e.args[0] list2: list[S] = mmrv(arg0, x) return list2 + elif e.func == Pow: + if e.args[0] != E: + e1: S = S(1) + newe: S = e + while newe.func == Pow: + b1: S = newe.args[0] + e1 = e1 * newe.args[1] + newe = b1 + if b1 == S(1): + return empty_list + if not e1.has(x): + list3: list[S] = mmrv(b1, x) + return list3 + else: + # TODO as noted in #2526 + pass + else: + # TODO + pass else: raise @@ -35,6 +54,13 @@ def test_mrv(): ele2: S = ans3[0] print(ele2) assert ele2 == x - assert len(ans2) == 1 + assert len(ans3) == 1 + + # Case 4 + ans4: list[S] = mmrv(x**S(2), x) + ele3: S = ans4[0] + print(ele3) + assert ele3 == x + assert len(ans4) == 1 test_mrv() \ No newline at end of file diff --git a/src/libasr/pass/replace_symbolic.cpp b/src/libasr/pass/replace_symbolic.cpp index d17b575b21..83da75603f 100644 --- a/src/libasr/pass/replace_symbolic.cpp +++ b/src/libasr/pass/replace_symbolic.cpp @@ -1007,6 +1007,18 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor(x); + transform_stmts(xx.m_body, xx.n_body); + if (ASR::is_a(*xx.m_test)) { + ASR::IntrinsicScalarFunction_t* intrinsic_func = ASR::down_cast(xx.m_test); + if (ASR::is_a(*intrinsic_func->m_type)) { + ASR::expr_t* function_call = process_attributes(xx.base.base.loc, xx.m_test); + xx.m_test = function_call; + } + } + } + void visit_Return(const ASR::Return_t &x) { // freeing out variables if (!symbolic_vars_to_free.empty()){