Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 75cc817

Browse files
authored
Merge pull request #1076 from czgdp1807/numpy_mod
Implementing ``numpy.mod`` for array inputs
2 parents 062bab9 + c698e53 commit 75cc817

File tree

56 files changed

+246
-125
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+246
-125
lines changed

integration_tests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ RUN(NAME elemental_06 LABELS cpython llvm)
177177
RUN(NAME elemental_07 LABELS cpython llvm)
178178
RUN(NAME elemental_08 LABELS cpython llvm)
179179
RUN(NAME elemental_09 LABELS cpython llvm)
180+
RUN(NAME elemental_10 LABELS cpython llvm)
180181
RUN(NAME elemental_11 LABELS cpython llvm)
181182
RUN(NAME test_random LABELS cpython llvm)
182183
RUN(NAME test_os LABELS cpython llvm)

integration_tests/elemental_10.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from ltypes import i32, i64
2+
from numpy import mod, int64, empty
3+
4+
def test_numpy_mod():
5+
q1: i64[32, 16, 7] = empty((32, 16, 7), dtype=int64)
6+
d1: i64[32, 16, 7] = empty((32, 16, 7), dtype=int64)
7+
r1: i64[32, 16, 7] = empty((32, 16, 7), dtype=int64)
8+
r1neg: i64[32, 16, 7] = empty((32, 16, 7), dtype=int64)
9+
q2: i64[100] = empty(100, dtype=int64)
10+
d2: i64[100] = empty(100, dtype=int64)
11+
r2: i64[100] = empty(100, dtype=int64)
12+
i: i32; j: i32; k: i32
13+
rem: i64; q: i64; d: i64
14+
15+
for i in range(32):
16+
for j in range(16):
17+
for k in range(7):
18+
d1[i, j, k] = k + 1
19+
q1[i, j, k] = (i + j) * (k + 1) + k
20+
21+
r1 = mod(q1, d1)
22+
r1neg = mod(-q1, d1)
23+
24+
for i in range(32):
25+
for j in range(16):
26+
for k in range(7):
27+
assert r1[i, j, k] == k
28+
if k == 0:
29+
rem = 0
30+
else:
31+
rem = d1[i, j, k] - k
32+
assert r1neg[i, j, k] == rem
33+
34+
for i in range(32):
35+
for j in range(16):
36+
for k in range(7):
37+
d1[i, j, k] = k + 2
38+
q1[i, j, k] = i + j
39+
40+
r1 = mod(d1 * q1 + r1 + 1, d1)
41+
42+
for i in range(32):
43+
for j in range(16):
44+
for k in range(7):
45+
assert r1[i, j, k] == k + 1
46+
47+
r1 = mod(2 * q1 + 1, int(2))
48+
49+
for i in range(32):
50+
for j in range(16):
51+
for k in range(7):
52+
assert r1[i, j, k] == 1
53+
54+
for i in range(100):
55+
d2[i] = i + 1
56+
57+
r2 = mod(int(100), d2)
58+
59+
for i in range(100):
60+
assert r2[i] == 100 % (i + 1)
61+
62+
for i in range(100):
63+
d2[i] = 50 - i
64+
q2[i] = 39 - i
65+
66+
r2 = mod(q2, d2)
67+
68+
for i in range(100):
69+
d = 50 - i
70+
q = 39 - i
71+
rem = r2[i]
72+
if d == 0:
73+
assert rem == 0
74+
else:
75+
assert int((q - rem)/d) - (q - rem)/d == 0
76+
77+
78+
test_numpy_mod()

src/libasr/pass/array_op.cpp

Lines changed: 81 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -870,66 +870,99 @@ class ArrayOpVisitor : public PassUtils::PassVisitor<ArrayOpVisitor>
870870
sub, nullptr,
871871
s_args.p, s_args.size(), nullptr));
872872
pass_result.push_back(al, subrout_call);
873-
} else if( is_elemental(x.m_name) && x.n_args == 1 &&
874-
ASRUtils::is_array(ASRUtils::expr_type(x.m_args[0].m_value)) ) {
873+
} else if( is_elemental(x.m_name) ) {
874+
std::vector<bool> array_mask(x.n_args, false);
875+
bool at_least_one_array = false;
876+
for( size_t iarg = 0; iarg < x.n_args; iarg++ ) {
877+
array_mask[iarg] = ASRUtils::is_array(
878+
ASRUtils::expr_type(x.m_args[iarg].m_value));
879+
at_least_one_array = at_least_one_array || array_mask[iarg];
880+
}
881+
if (!at_least_one_array) {
882+
return ;
883+
}
875884
std::string res_prefix = "_elemental_func_call_res";
876885
ASR::expr_t* result_var_copy = result_var;
877-
result_var = nullptr;
878-
this->visit_expr(*(x.m_args[0].m_value));
879-
ASR::expr_t* operand = tmp_val;
880-
int rank_operand = PassUtils::get_rank(operand);
881-
if( rank_operand == 0 ) {
886+
bool is_all_rank_0 = true;
887+
std::vector<ASR::expr_t*> operands;
888+
ASR::expr_t* operand = nullptr;
889+
int common_rank = 0;
890+
bool are_all_rank_same = true;
891+
for( size_t iarg = 0; iarg < x.n_args; iarg++ ) {
892+
result_var = nullptr;
893+
this->visit_expr(*(x.m_args[iarg].m_value));
894+
operand = tmp_val;
895+
operands.push_back(operand);
896+
int rank_operand = PassUtils::get_rank(operand);
897+
if( common_rank == 0 ) {
898+
common_rank = rank_operand;
899+
}
900+
if( common_rank != rank_operand &&
901+
rank_operand > 0 ) {
902+
are_all_rank_same = false;
903+
}
904+
array_mask[iarg] = (rank_operand > 0);
905+
is_all_rank_0 = is_all_rank_0 && (rank_operand <= 0);
906+
}
907+
if( is_all_rank_0 ) {
882908
tmp_val = const_cast<ASR::expr_t*>(&(x.base));
883909
return ;
884910
}
885-
if( rank_operand > 0 ) {
886-
result_var = result_var_copy;
887-
if( result_var == nullptr ) {
888-
result_var = create_var(result_var_num, res_prefix,
889-
x.base.base.loc, operand);
890-
result_var_num += 1;
891-
}
892-
tmp_val = result_var;
893-
894-
int n_dims = rank_operand;
895-
Vec<ASR::expr_t*> idx_vars;
896-
PassUtils::create_idx_vars(idx_vars, n_dims, x.base.base.loc, al, current_scope);
897-
ASR::stmt_t* doloop = nullptr;
898-
for( int i = n_dims - 1; i >= 0; i-- ) {
899-
// TODO: Add an If debug node to check if the lower and upper bounds of both the arrays are same.
900-
ASR::do_loop_head_t head;
901-
head.m_v = idx_vars[i];
902-
head.m_start = PassUtils::get_bound(result_var, i + 1, "lbound", al);
903-
head.m_end = PassUtils::get_bound(result_var, i + 1, "ubound", al);
904-
head.m_increment = nullptr;
905-
head.loc = head.m_v->base.loc;
906-
Vec<ASR::stmt_t*> doloop_body;
907-
doloop_body.reserve(al, 1);
908-
if( doloop == nullptr ) {
909-
ASR::expr_t* ref = PassUtils::create_array_ref(operand, idx_vars, al);
910-
ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al);
911-
ASR::expr_t* op_el_wise = nullptr;
911+
if( !are_all_rank_same ) {
912+
throw LCompilersException("Broadcasting support not yet available "
913+
"for different shape arrays.");
914+
}
915+
result_var = result_var_copy;
916+
if( result_var == nullptr ) {
917+
result_var = create_var(result_var_num, res_prefix,
918+
x.base.base.loc, operand);
919+
result_var_num += 1;
920+
}
921+
tmp_val = result_var;
922+
923+
int n_dims = common_rank;
924+
Vec<ASR::expr_t*> idx_vars;
925+
PassUtils::create_idx_vars(idx_vars, n_dims, x.base.base.loc, al, current_scope);
926+
ASR::stmt_t* doloop = nullptr;
927+
for( int i = n_dims - 1; i >= 0; i-- ) {
928+
// TODO: Add an If debug node to check if the lower and upper bounds of both the arrays are same.
929+
ASR::do_loop_head_t head;
930+
head.m_v = idx_vars[i];
931+
head.m_start = PassUtils::get_bound(result_var, i + 1, "lbound", al);
932+
head.m_end = PassUtils::get_bound(result_var, i + 1, "ubound", al);
933+
head.m_increment = nullptr;
934+
head.loc = head.m_v->base.loc;
935+
Vec<ASR::stmt_t*> doloop_body;
936+
doloop_body.reserve(al, 1);
937+
if( doloop == nullptr ) {
938+
Vec<ASR::call_arg_t> ref_args;
939+
ref_args.reserve(al, x.n_args);
940+
for( size_t iarg = 0; iarg < x.n_args; iarg++ ) {
941+
ASR::expr_t* ref = operands[iarg];
942+
if( array_mask[iarg] ) {
943+
ref = PassUtils::create_array_ref(operands[iarg], idx_vars, al);
944+
}
912945
ASR::call_arg_t ref_arg;
913946
ref_arg.loc = ref->base.loc;
914947
ref_arg.m_value = ref;
915-
Vec<ASR::call_arg_t> ref_args;
916-
ref_args.reserve(al, 1);
917948
ref_args.push_back(al, ref_arg);
918-
Vec<ASR::dimension_t> empty_dim;
919-
empty_dim.reserve(al, 1);
920-
ASR::ttype_t* dim_less_type = ASRUtils::duplicate_type(al, x.m_type, &empty_dim);
921-
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_FunctionCall_t(al, x.base.base.loc,
922-
x.m_name, x.m_original_name, ref_args.p, ref_args.size(), dim_less_type,
923-
nullptr, x.m_dt));
924-
ASR::stmt_t* assign = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, res, op_el_wise, nullptr));
925-
doloop_body.push_back(al, assign);
926-
} else {
927-
doloop_body.push_back(al, doloop);
928949
}
929-
doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size()));
950+
Vec<ASR::dimension_t> empty_dim;
951+
empty_dim.reserve(al, 1);
952+
ASR::ttype_t* dim_less_type = ASRUtils::duplicate_type(al, x.m_type, &empty_dim);
953+
ASR::expr_t* op_el_wise = nullptr;
954+
op_el_wise = LFortran::ASRUtils::EXPR(ASR::make_FunctionCall_t(al, x.base.base.loc,
955+
x.m_name, x.m_original_name, ref_args.p, ref_args.size(), dim_less_type,
956+
nullptr, x.m_dt));
957+
ASR::expr_t* res = PassUtils::create_array_ref(result_var, idx_vars, al);
958+
ASR::stmt_t* assign = LFortran::ASRUtils::STMT(ASR::make_Assignment_t(al, x.base.base.loc, res, op_el_wise, nullptr));
959+
doloop_body.push_back(al, assign);
960+
} else {
961+
doloop_body.push_back(al, doloop);
930962
}
931-
pass_result.push_back(al, doloop);
963+
doloop = LFortran::ASRUtils::STMT(ASR::make_DoLoop_t(al, x.base.base.loc, head, doloop_body.p, doloop_body.size()));
932964
}
965+
pass_result.push_back(al, doloop);
933966
}
934967
result_var = nullptr;
935968
}

src/libasr/pass/pass_utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,10 @@ namespace LFortran {
142142
ASR::AssociateBlock_t *s = ASR::down_cast<ASR::AssociateBlock_t>(item.second);
143143
self().visit_AssociateBlock(*s);
144144
}
145+
if (ASR::is_a<ASR::Block_t>(*item.second)) {
146+
ASR::Block_t *s = ASR::down_cast<ASR::Block_t>(item.second);
147+
self().visit_Block(*s);
148+
}
145149
}
146150
}
147151

@@ -151,6 +155,13 @@ namespace LFortran {
151155
ASR::Function_t &xx = const_cast<ASR::Function_t&>(x);
152156
current_scope = xx.m_symtab;
153157
transform_stmts(xx.m_body, xx.n_body);
158+
159+
for (auto &item : x.m_symtab->get_scope()) {
160+
if (ASR::is_a<ASR::Block_t>(*item.second)) {
161+
ASR::Block_t *s = ASR::down_cast<ASR::Block_t>(item.second);
162+
self().visit_Block(*s);
163+
}
164+
}
154165
}
155166

156167
void visit_AssociateBlock(const ASR::AssociateBlock_t& x) {
@@ -163,6 +174,10 @@ namespace LFortran {
163174
ASR::Block_t &xx = const_cast<ASR::Block_t&>(x);
164175
current_scope = xx.m_symtab;
165176
transform_stmts(xx.m_body, xx.n_body);
177+
178+
for (auto &item : x.m_symtab->get_scope()) {
179+
self().visit_symbol(*item.second);
180+
}
166181
}
167182

168183
};

src/lpython/semantics/python_ast_to_asr.cpp

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -862,21 +862,18 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
862862
}
863863

864864
SymbolTable *symtab = current_scope;
865-
while (symtab->parent != nullptr && symtab->get_scope().find(local_sym) == symtab->get_scope().end()) {
866-
symtab = symtab->parent;
867-
}
868-
if (symtab->get_scope().find(local_sym) == symtab->get_scope().end()) {
865+
if (symtab->resolve_symbol(local_sym) == nullptr) {
869866
LFORTRAN_ASSERT(ASR::is_a<ASR::ExternalSymbol_t>(*stemp));
870867
std::string mod_name = ASR::down_cast<ASR::ExternalSymbol_t>(stemp)->m_module_name;
871-
ASR::symbol_t *mt = symtab->get_symbol(mod_name);
868+
ASR::symbol_t *mt = symtab->resolve_symbol(mod_name);
872869
ASR::Module_t *m = ASR::down_cast<ASR::Module_t>(mt);
873870
stemp = import_from_module(al, m, symtab, mod_name,
874871
remote_sym, local_sym, loc);
875872
LFORTRAN_ASSERT(ASR::is_a<ASR::ExternalSymbol_t>(*stemp));
876873
symtab->add_symbol(local_sym, stemp);
877874
s = ASRUtils::symbol_get_past_external(stemp);
878875
} else {
879-
stemp = symtab->get_symbol(local_sym);
876+
stemp = symtab->resolve_symbol(local_sym);
880877
}
881878
}
882879
if (ASR::is_a<ASR::Function_t>(*s)) {
@@ -3419,9 +3416,6 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
34193416
current_scope = al.make_new<SymbolTable>(parent_scope);
34203417
transform_stmts(body, x.n_body, x.m_body);
34213418
int32_t total_syms = current_scope->get_scope().size();
3422-
for( auto& item: current_scope->get_scope() ) {
3423-
total_syms -= ASR::is_a<ASR::ExternalSymbol_t>(*item.second);
3424-
}
34253419
if( total_syms > 0 ) {
34263420
std::string name = parent_scope->get_unique_name("block");
34273421
ASR::asr_t* block = ASR::make_Block_t(al, x.base.base.loc,
@@ -3433,24 +3427,8 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
34333427
ASR::down_cast<ASR::symbol_t>(block)));
34343428
body.reserve(al, 1);
34353429
body.push_back(al, decls);
3436-
} else {
3437-
// Revert global counter as no variables
3438-
// were declared inside the loop so
3439-
// current_scope is not needed.
3440-
for( auto& item: current_scope->get_scope() ) {
3441-
if( !ASR::is_a<ASR::ExternalSymbol_t>(*item.second) ) {
3442-
continue ;
3443-
}
3444-
3445-
ASR::ExternalSymbol_t* ext_sym = ASR::down_cast<ASR::ExternalSymbol_t>(item.second);
3446-
ASR::symbol_t* new_ext_sym = ASR::down_cast<ASR::symbol_t>(
3447-
ASR::make_ExternalSymbol_t(al, ext_sym->base.base.loc, parent_scope, ext_sym->m_name,
3448-
ext_sym->m_external, ext_sym->m_module_name, ext_sym->m_scope_names,
3449-
ext_sym->n_scope_names, ext_sym->m_original_name, ext_sym->m_access));
3450-
parent_scope->add_symbol(item.first, new_ext_sym);
3451-
}
3452-
current_scope = parent_scope;
34533430
}
3431+
current_scope = parent_scope;
34543432

34553433
if (loop_start) {
34563434
head.m_start = loop_start;

src/runtime/lpython_intrinsic_numpy.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ltypes import f64, f32, ccall, vectorize, overload
1+
from ltypes import i32, i64, f64, f32, ccall, vectorize, overload
22

33
pi_64: f64 = 3.141592653589793238462643383279502884197
44
pi_32: f32 = 3.141592653589793238462643383279502884197
@@ -354,3 +354,19 @@ def _lfortran_satanh(x: f32) -> f32:
354354
@vectorize
355355
def arctanh(x: f32) -> f32:
356356
return _lfortran_satanh(x)
357+
358+
########## mod ##########
359+
360+
@overload
361+
@vectorize
362+
def mod(x1: i64, x2: i64) -> i64:
363+
if x2 == 0:
364+
return int(0)
365+
return x1 % x2
366+
367+
@overload
368+
@vectorize
369+
def mod(x1: i32, x2: i32) -> i32:
370+
if x2 == 0:
371+
return 0
372+
return x1 % x2

tests/reference/asr-array_01_decl-39cf894.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_01_decl-39cf894.stdout",
9-
"stdout_hash": "626eb3f1b687f1ba8499cb42d852d4e2fc316a11b2f48a058380c23a",
9+
"stdout_hash": "2380e43101fe35534821c0b5707103f481413bb34c98d6c9ea7cf3dc",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-array_01_decl-39cf894.stdout

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tests/reference/asr-array_02_decl-e8f6874.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-array_02_decl-e8f6874.stdout",
9-
"stdout_hash": "c25f1a5906a7ff591373fdd8a67d5e879f359390030a27d9bfc3ee9a",
9+
"stdout_hash": "dfcfd217678648da9009018c17a45aab3334b7d0625a80318066c2d7",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

tests/reference/asr-array_02_decl-e8f6874.stdout

Lines changed: 1 addition & 1 deletion
Large diffs are not rendered by default.

tests/reference/asr-complex1-f26c460.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"outfile": null,
77
"outfile_hash": null,
88
"stdout": "asr-complex1-f26c460.stdout",
9-
"stdout_hash": "c741bb3cc0e4ba3e9076c989424b8eddb5ebdb9458b9fb894dc84044",
9+
"stdout_hash": "c6f072f66291b41dec2703ad724fcacb63378c280ee00018739ec8b7",
1010
"stderr": null,
1111
"stderr_hash": null,
1212
"returncode": 0

0 commit comments

Comments
 (0)