Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit c3a3d50

Browse files
committed
Step 3
1 parent e3a4a65 commit c3a3d50

File tree

3 files changed

+323
-22
lines changed

3 files changed

+323
-22
lines changed

integration_tests/gruntz_demo2.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@
118118
"""
119119
from functools import reduce
120120

121-
from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate
121+
from sympy.core import Basic, S, Mul, PoleError, expand_mul, evaluate, Integer
122122
from sympy.core.cache import cacheit
123123
from sympy.core.numbers import I, oo
124124
from sympy.core.symbol import Dummy, Wild, Symbol
@@ -145,14 +145,16 @@ def mrv(e, x):
145145

146146
if e == x:
147147
return {x}
148-
if e.is_Mul or e.is_Add:
148+
elif e.is_Integer:
149+
return {}
150+
elif e.is_Mul or e.is_Add:
149151
a, b = e.as_two_terms()
150152
ans1 = mrv(a, x)
151153
ans2 = mrv(b, x)
152154
return mrv_max(mrv(a, x), mrv(b, x), x)
153-
if e.is_Pow:
155+
elif e.is_Pow:
154156
return mrv(e.base, x)
155-
if e.is_Function:
157+
elif e.is_Function:
156158
return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args))
157159
raise NotImplementedError(f"Can't calculate the MRV of {e}.")
158160

@@ -225,7 +227,7 @@ def rewrite(e, x, w):
225227
c = limitinf(a.exp/g.exp, x)
226228
b = exp(a.exp - c*g.exp)*w**c # exponential must never be expanded here
227229
with evaluate(False):
228-
e = e.subs(a, b)
230+
e = e.xreplace({a: b})
229231

230232
return e
231233

@@ -330,5 +332,16 @@ def gruntz(e, z, z0, dir="+"):
330332

331333
# tests
332334
x = Symbol('x')
333-
ans = gruntz(sin(x)/x, x, 0)
334-
print(ans)
335+
# Print the basic limit:
336+
print(gruntz(sin(x)/x, x, 0))
337+
338+
# Test other cases
339+
assert gruntz(sin(x)/x, x, 0) == 1
340+
assert gruntz(2*sin(x)/x, x, 0) == 2
341+
assert gruntz(sin(2*x)/x, x, 0) == 2
342+
assert gruntz(sin(x)**2/x, x, 0) == 0
343+
assert gruntz(sin(x)/x**2, x, 0) == oo
344+
assert gruntz(sin(x)**2/x**2, x, 0) == 1
345+
assert gruntz(sin(sin(sin(x)))/sin(x), x, 0) == 1
346+
assert gruntz(2*log(x+1)/x, x, 0) == 2
347+
assert gruntz(sin((log(x+1)/x)*x)/x, x, 0) == 1

integration_tests/gruntz_demo3.py

Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
from lpython import S, str
2+
from sympy import Symbol, Pow, sin, oo, pi, E, Mul, Add, oo, log, exp, cos
3+
4+
def mrv(e: S, x: S) -> list[S]:
5+
"""
6+
Calculate the MRV set of the expression.
7+
8+
Examples
9+
========
10+
11+
>>> mrv(log(x - log(x))/log(x), x)
12+
{x}
13+
14+
"""
15+
16+
if e.is_integer:
17+
empty_list: list[S] = []
18+
return empty_list
19+
if e == x:
20+
list1: list[S] = [x]
21+
return list1
22+
if e.func == log:
23+
arg0: S = e.args[0]
24+
list2: list[S] = mrv(arg0, x)
25+
return list2
26+
if e.func == Mul or e.func == Add:
27+
a: S = e.args[0]
28+
b: S = e.args[1]
29+
ans1: list[S] = mrv(a, x)
30+
ans2: list[S] = mrv(b, x)
31+
list3: list[S] = mrv_max(ans1, ans2, x)
32+
return list3
33+
if e.func == Pow:
34+
base: S = e.args[0]
35+
list4: list[S] = mrv(base, x)
36+
return list4
37+
if e.func == sin:
38+
list5: list[S] = [x]
39+
return list5
40+
# elif e.is_Function:
41+
# return reduce(lambda a, b: mrv_max(a, b, x), (mrv(a, x) for a in e.args))
42+
raise NotImplementedError(f"Can't calculate the MRV of {e}.")
43+
44+
def mrv_max(f: list[S], g: list[S], x: S) -> list[S]:
45+
"""Compute the maximum of two MRV sets.
46+
47+
Examples
48+
========
49+
50+
>>> mrv_max({log(x)}, {x**5}, x)
51+
{x**5}
52+
53+
"""
54+
55+
if len(f) == 0:
56+
return g
57+
elif len(g) == 0:
58+
return f
59+
# elif f & g:
60+
# return f | g
61+
else:
62+
f1: S = f[0]
63+
g1: S = g[0]
64+
bool1: bool = f1 == x
65+
bool2: bool = g1 == x
66+
if bool1 and bool2:
67+
l: list[S] = [x]
68+
return l
69+
70+
def rewrite(e: S, x: S, w: S) -> S:
71+
"""
72+
Rewrites the expression in terms of the MRV subexpression.
73+
74+
Parameters
75+
==========
76+
77+
e : Expr
78+
an expression
79+
x : Symbol
80+
variable of the `e`
81+
w : Symbol
82+
The symbol which is going to be used for substitution in place
83+
of the MRV in `x` subexpression.
84+
85+
Returns
86+
=======
87+
88+
The rewritten expression
89+
90+
Examples
91+
========
92+
93+
>>> rewrite(exp(x)*log(x), x, y)
94+
(log(x)/y, -x)
95+
96+
"""
97+
Omega: list[S] = mrv(e, x)
98+
Omega1: S = Omega[0]
99+
100+
if Omega1 == x:
101+
newe: S = e.subs(x, S(1)/w)
102+
return newe
103+
104+
def sign(e: S) -> S:
105+
"""
106+
Returns the complex sign of an expression:
107+
108+
Explanation
109+
===========
110+
111+
If the expression is real the sign will be:
112+
113+
* $1$ if expression is positive
114+
* $0$ if expression is equal to zero
115+
* $-1$ if expression is negative
116+
"""
117+
118+
if e.is_positive:
119+
return S(1)
120+
elif e == S(0):
121+
return S(0)
122+
else:
123+
return S(-1)
124+
125+
def signinf(e: S, x : S) -> S:
126+
"""
127+
Determine sign of the expression at the infinity.
128+
129+
Returns
130+
=======
131+
132+
{1, 0, -1}
133+
One or minus one, if `e > 0` or `e < 0` for `x` sufficiently
134+
large and zero if `e` is *constantly* zero for `x\to\infty`.
135+
136+
"""
137+
138+
if not e.has(x):
139+
return sign(e)
140+
if e == x:
141+
return S(1)
142+
if e.func == Pow:
143+
base: S = e.args[0]
144+
if signinf(base, x) == S(1):
145+
return S(1)
146+
147+
def leadterm(e: S, x: S) -> list[S]:
148+
"""
149+
Returns the leading term a*x**b as a list [a, b].
150+
"""
151+
if e == sin(x)/x:
152+
l1: list[S] = [S(1), S(0)]
153+
return l1
154+
elif e == S(2)*sin(x)/x:
155+
l2: list[S] = [S(2), S(0)]
156+
return l2
157+
elif e == sin(S(2)*x)/x:
158+
l3: list[S] = [S(2), S(0)]
159+
return l3
160+
elif e == sin(x)**S(2)/x:
161+
l4: list[S] = [S(1), S(1)]
162+
return l4
163+
elif e == sin(x)/x**S(2):
164+
l5: list[S] = [S(1), S(-1)]
165+
return l5
166+
elif e == sin(x)**S(2)/x**S(2):
167+
l6: list[S] = [S(1), S(0)]
168+
return l6
169+
elif e == sin(sin(sin(x)))/sin(x):
170+
l7: list[S] = [S(1), S(0)]
171+
return l7
172+
elif e == S(2)*log(x+S(1))/x:
173+
l8: list[S] = [S(2), S(0)]
174+
return l8
175+
elif e == sin((log(x+S(1))/x)*x)/x:
176+
l9: list[S] = [S(1), S(0)]
177+
return l9
178+
raise NotImplementedError(f"Can't calculate the leadterm of {e}.")
179+
180+
def mrv_leadterm(e: S, x: S) -> list[S]:
181+
"""
182+
Compute the leading term of the series.
183+
184+
Returns
185+
=======
186+
187+
tuple
188+
The leading term `c_0 w^{e_0}` of the series of `e` in terms
189+
of the most rapidly varying subexpression `w` in form of
190+
the pair ``(c0, e0)`` of Expr.
191+
192+
Examples
193+
========
194+
195+
>>> leadterm(1/exp(-x + exp(-x)) - exp(x), x)
196+
(-1, 0)
197+
198+
"""
199+
200+
# w = Dummy('w', real=True, positive=True)
201+
# e = rewrite(e, x, w)
202+
# return e.leadterm(w)
203+
w: S = Symbol('w')
204+
newe: S = rewrite(e, x, w)
205+
coeff_exp_list: list[S] = leadterm(newe, w)
206+
207+
return coeff_exp_list
208+
209+
def limitinf(e: S, x: S) -> S:
210+
"""
211+
Compute the limit of the expression at the infinity.
212+
213+
Examples
214+
========
215+
216+
>>> limitinf(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x)
217+
-1
218+
219+
"""
220+
221+
if not e.has(x):
222+
return e
223+
224+
coeff_exp_list: list[S] = mrv_leadterm(e, x)
225+
c0: S = coeff_exp_list[0]
226+
e0: S = coeff_exp_list[1]
227+
sig: S = signinf(e0, x)
228+
if sig == S(1):
229+
return S(0)
230+
if sig == S(-1):
231+
return signinf(c0, x) * oo
232+
if sig == S(0):
233+
return limitinf(c0, x)
234+
raise NotImplementedError(f'Result depends on the sign of {sig}.')
235+
236+
def gruntz(e: S, z: S, z0: S, dir: str ="+") -> S:
237+
"""
238+
Compute the limit of e(z) at the point z0 using the Gruntz algorithm.
239+
240+
Explanation
241+
===========
242+
243+
``z0`` can be any expression, including oo and -oo.
244+
245+
For ``dir="+"`` (default) it calculates the limit from the right
246+
(z->z0+) and for ``dir="-"`` the limit from the left (z->z0-). For infinite z0
247+
(oo or -oo), the dir argument does not matter.
248+
249+
This algorithm is fully described in the module docstring in the gruntz.py
250+
file. It relies heavily on the series expansion. Most frequently, gruntz()
251+
is only used if the faster limit() function (which uses heuristics) fails.
252+
"""
253+
254+
e0: S
255+
if str(dir) == "-":
256+
e0 = e.subs(z, z0 - S(1)/z)
257+
elif str(dir) == "+":
258+
e0 = e.subs(z, z0 + S(1)/z)
259+
else:
260+
raise NotImplementedError("dir must be '+' or '-'")
261+
262+
r: S = limitinf(e0, z)
263+
return r
264+
265+
# test
266+
def test():
267+
x: S = Symbol('x')
268+
print(gruntz(sin(x)/x, x, S(0), "+"))
269+
print(gruntz(S(2)*sin(x)/x, x, S(0), "+"))
270+
print(gruntz(sin(S(2)*x)/x, x, S(0), "+"))
271+
print(gruntz(sin(x)**S(2)/x, x, S(0), "+"))
272+
print(gruntz(sin(x)/x**S(2), x, S(0), "+"))
273+
print(gruntz(sin(x)**S(2)/x**S(2), x, S(0), "+"))
274+
print(gruntz(sin(sin(sin(x)))/sin(x), x, S(0), "+"))
275+
print(gruntz(S(2)*log(x+S(1))/x, x, S(0), "+"))
276+
print(gruntz(sin((log(x+S(1))/x)*x)/x, x, S(0), "+"))
277+
278+
assert gruntz(sin(x)/x, x, S(0)) == S(1)
279+
assert gruntz(S(2)*sin(x)/x, x, S(0)) == S(2)
280+
assert gruntz(sin(S(2)*x)/x, x, S(0)) == S(2)
281+
assert gruntz(sin(x)**S(2)/x, x, S(0)) == S(0)
282+
assert gruntz(sin(x)/x**S(2), x, S(0)) == oo
283+
assert gruntz(sin(x)**S(2)/x**S(2), x, S(0)) == S(1)
284+
assert gruntz(sin(sin(sin(x)))/sin(x), x, S(0)) == S(1)
285+
assert gruntz(S(2)*log(x+S(1))/x, x, S(0)) == S(2)
286+
assert gruntz(sin((log(x+S(1))/x)*x)/x, x, S(0)) == S(1)
287+
288+
test()

src/libasr/pass/replace_symbolic.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,19 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
346346
transform_stmts(xx.m_body, xx.n_body);
347347

348348
// freeing out variables
349-
if (!symbolic_vars_to_free.empty()) {
350-
Vec<ASR::stmt_t*> func_body;
351-
func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);
349+
// if (!symbolic_vars_to_free.empty()) {
350+
// Vec<ASR::stmt_t*> func_body;
351+
// func_body.from_pointer_n_copy(al, xx.m_body, xx.n_body);
352352

353-
for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
354-
func_body.push_back(al, basic_free_stack(x.base.base.loc,
355-
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
356-
}
353+
// for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
354+
// func_body.push_back(al, basic_free_stack(x.base.base.loc,
355+
// ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
356+
// }
357357

358-
xx.n_body = func_body.size();
359-
xx.m_body = func_body.p;
360-
symbolic_vars_to_free.clear();
361-
}
358+
// xx.n_body = func_body.size();
359+
// xx.m_body = func_body.p;
360+
// symbolic_vars_to_free.clear();
361+
// }
362362

363363
SetChar function_dependencies;
364364
function_dependencies.from_pointer_n_copy(al, xx.m_dependencies, xx.n_dependencies);
@@ -1113,10 +1113,10 @@ class ReplaceSymbolicVisitor : public PassUtils::PassVisitor<ReplaceSymbolicVisi
11131113
void visit_Return(const ASR::Return_t &x) {
11141114
// freeing out variables
11151115
if (!symbolic_vars_to_free.empty()){
1116-
for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
1117-
pass_result.push_back(al, basic_free_stack(x.base.base.loc,
1118-
ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
1119-
}
1116+
// for (ASR::symbol_t* symbol : symbolic_vars_to_free) {
1117+
// pass_result.push_back(al, basic_free_stack(x.base.base.loc,
1118+
// ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, symbol))));
1119+
// }
11201120
pass_result.push_back(al, ASRUtils::STMT(ASR::make_Return_t(al, x.base.base.loc)));
11211121
}
11221122
}

0 commit comments

Comments
 (0)