Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add support for for-else #1711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ stmt
| ImplicitDeallocate(expr* vars)
| DoConcurrentLoop(do_loop_head head, stmt* body)
| DoLoop(identifier? name, do_loop_head head, stmt* body)
| ForElse(do_loop_head head, stmt* body, stmt* orelse)
| ErrorStop(expr? code)
| Exit(identifier? stmt_name)
| ForAllSingle(do_loop_head head, stmt assign_stmt)
Expand Down
1 change: 1 addition & 0 deletions src/libasr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ set(SRC

pass/nested_vars.cpp
pass/param_to_const.cpp
pass/for_else.cpp
pass/do_loops.cpp
pass/for_all.cpp
pass/global_stmts.cpp
Expand Down
130 changes: 130 additions & 0 deletions src/libasr/pass/for_else.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#include "for_else.h"
#include "libasr/asr_scopes.h"

#include <libasr/asr.h>
#include <libasr/containers.h>
#include <libasr/exception.h>
#include <libasr/asr_utils.h>
#include <libasr/asr_verify.h>
#include <libasr/pass/for_all.h>
#include <libasr/pass/stmt_walk_visitor.h>

#include <stack>

namespace LCompilers {

using ASR::is_a;
using ASR::down_cast;
using ASR::stmtType;

class ExitVisitor : public ASR::StatementWalkVisitor<ExitVisitor> {
public:
std::stack<ASR::stmt_t*> doLoopStack;
std::map<ASR::stmt_t*, ASR::symbol_t*> &doLoopFlagMap;

ExitVisitor(Allocator &al, std::map<ASR::stmt_t*, ASR::symbol_t*> &doLoopFlagMap)
: StatementWalkVisitor(al), doLoopFlagMap(doLoopFlagMap) {
}

void visit_DoLoop(const ASR::DoLoop_t &x) {
ASR::stmt_t *doLoopStmt = (ASR::stmt_t*)(&x);

doLoopStack.push(doLoopStmt);

ASR::DoLoop_t& xx = const_cast<ASR::DoLoop_t&>(x);
this->transform_stmts(xx.m_body, xx.n_body);

doLoopStack.pop();
}

void visit_Exit(const ASR::Exit_t &x) {
if (doLoopStack.empty() ||
// the current loop is not originally a ForElse loop
doLoopFlagMap.find(doLoopStack.top()) == doLoopFlagMap.end())
return;

Vec<ASR::stmt_t*> result;
result.reserve(al, 1);

Location loc = x.base.base.loc;
ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0));
ASR::symbol_t* flag_symbol = doLoopFlagMap[doLoopStack.top()];
ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol));
ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type));
ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr));
result.push_back(al, assign_stmt);
result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc, nullptr)));

pass_result = result;
}
};

class ForElseVisitor : public ASR::StatementWalkVisitor<ForElseVisitor>
{
public:
ForElseVisitor(Allocator &al) : StatementWalkVisitor(al) {
counter = 0;
}

std::map<ASR::stmt_t*, ASR::symbol_t*> doLoopFlagMap;

int counter;

void visit_ForElse(const ASR::ForElse_t &x) {
Location loc = x.base.base.loc;

Vec<ASR::stmt_t*> result;
result.reserve(al, 1);

// create a boolean flag and set it to true
ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0));
ASR::expr_t* true_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true, bool_type));

auto target_scope = current_scope; // al.make_new<SymbolTable>(current_scope);

Str s;
s.from_str_view(std::string("_no_break_") + std::to_string(counter));
counter++;

ASR::symbol_t* flag_symbol = LCompilers::ASR::down_cast<ASR::symbol_t>(
ASR::make_Variable_t(
al, loc, target_scope,
s.c_str(al), nullptr, 0, ASR::intentType::Local, nullptr, nullptr,
ASR::storage_typeType::Default, bool_type,
ASR::abiType::Source, ASR::Public,
ASR::presenceType::Required, false));
ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol));
target_scope->add_symbol(s.c_str(al), flag_symbol);

ASR::stmt_t* assign_stmt = ASRUtils::STMT(
ASR::make_Assignment_t(al, loc, flag_expr, true_expr, nullptr));
result.push_back(al, assign_stmt);

// convert head and body to DoLoop
ASR::stmt_t *doLoopStmt = ASRUtils::STMT(
ASR::make_DoLoop_t(al, loc, nullptr, x.m_head, x.m_body, x.n_body)
);
result.push_back(al, doLoopStmt);

// this DoLoop corresponds to the current flag
doLoopFlagMap[doLoopStmt] = flag_symbol;
// std::cerr << doLoopStmt << " -> " << flag_expr << std::endl;

// add an If block that executes the orelse statements when the flag is true
result.push_back(
al, ASRUtils::STMT(
ASR::make_If_t(al, loc, flag_expr, x.m_orelse, x.n_orelse, nullptr, 0)));

pass_result = result;
}
};

void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& /*pass_options*/) {
ForElseVisitor v(al);
v.visit_TranslationUnit(unit);
ExitVisitor v2(al, v.doLoopFlagMap);
v2.visit_TranslationUnit(unit);
}

} // namespace LCompilers
15 changes: 15 additions & 0 deletions src/libasr/pass/for_else.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef FORELSE_H
#define FORELSE_H

#include <libasr/asr.h>
#include <libasr/utils.h>

namespace LCompilers {

void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit,
const LCompilers::PassOptions& pass_options);

} // namespace LCompilers


#endif
3 changes: 3 additions & 0 deletions src/libasr/pass/pass_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <lpython/utils.h>
#endif

#include <libasr/pass/for_else.h>
#include <libasr/pass/do_loops.h>
#include <libasr/pass/for_all.h>
#include <libasr/pass/init_expr.h>
Expand Down Expand Up @@ -64,6 +65,7 @@ namespace LCompilers {
std::vector<std::string> _user_defined_passes;
std::vector<std::string> _skip_passes, _c_skip_passes;
std::map<std::string, pass_function> _passes_db = {
{"for_else", &pass_replace_forelse},
{"do_loops", &pass_replace_do_loops},
{"global_stmts", &pass_wrap_global_stmts_into_function},
{"implied_do_loops", &pass_replace_implied_do_loops},
Expand Down Expand Up @@ -200,6 +202,7 @@ namespace LCompilers {
"print_arr",
"print_list_tuple",
"array_dim_intrinsics_update",
"for_else",
"do_loops",
"forall",
"select_case",
Expand Down
6 changes: 6 additions & 0 deletions src/libasr/pass/pass_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ namespace LCompilers {
this->current_scope = current_scope_copy;
}

void visit_ForElse(const ASR::ForElse_t& x) {
self().visit_do_loop_head(x.m_head);
ASR::ForElse_t& xx = const_cast<ASR::ForElse_t&>(x);
transform_stmts(xx.m_body, xx.n_body);
transform_stmts(xx.m_orelse, xx.n_orelse);
}
};

template <class Struct>
Expand Down
6 changes: 6 additions & 0 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4698,6 +4698,12 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
if (parallel) {
tmp = ASR::make_DoConcurrentLoop_t(al, x.base.base.loc, head,
body.p, body.size());
} else if ( x.n_orelse > 0 ) {
Vec<ASR::stmt_t*> orelse;
orelse.reserve(al, x.n_orelse);
transform_stmts(orelse, x.n_orelse, x.m_orelse);
tmp = ASR::make_ForElse_t(al, x.base.base.loc, head,
body.p, body.size(), orelse.p, orelse.size());
} else {
tmp = ASR::make_DoLoop_t(al, x.base.base.loc, nullptr, head,
body.p, body.size());
Expand Down
10 changes: 10 additions & 0 deletions tests/for_else/break_in_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
def break_in_if():
i: i32
for i in range(4):
print(i)
if i == 2:
break
else:
print(10)

break_in_if()
12 changes: 12 additions & 0 deletions tests/for_else/nested_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def nested_loop():
i: i32
j: i32
for i in range(4):
print("outer: " + str(i))
for j in range(10, 20):
print(" inner: " + str(j))
break
else:
print("no break in outer loop")

nested_loop()
8 changes: 8 additions & 0 deletions tests/for_else/no_break.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
def no_break():
i: i32
for i in range(4):
print(i)
else:
print(10)

no_break()
9 changes: 9 additions & 0 deletions tests/for_else/with_break.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
def with_break():
i: i32
for i in range(4):
print(i)
break
else:
print(10)

with_break()
13 changes: 13 additions & 0 deletions tests/reference/runtime-break_in_if-e70c15c.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "runtime-break_in_if-e70c15c",
"cmd": "lpython {infile}",
"infile": "tests/for_else/break_in_if.py",
"infile_hash": "48b6f27dc6d584ec9765483a80f3f0a705bbbdae1afddf31bab19305",
"outfile": null,
"outfile_hash": null,
"stdout": "runtime-break_in_if-e70c15c.stdout",
"stdout_hash": "32abcb5e52daed49078b31f00c8096aca4e2bcef16f75512188bc328",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
3 changes: 3 additions & 0 deletions tests/reference/runtime-break_in_if-e70c15c.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0
1
2
13 changes: 13 additions & 0 deletions tests/reference/runtime-nested_loop-6a5a431.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "runtime-nested_loop-6a5a431",
"cmd": "lpython {infile}",
"infile": "tests/for_else/nested_loop.py",
"infile_hash": "41866cefd6bbcd693f2f7a3e2fb229fa5500e0fcb41ae864a5902709",
"outfile": null,
"outfile_hash": null,
"stdout": "runtime-nested_loop-6a5a431.stdout",
"stdout_hash": "8bfba961db77bc0c4a2f4cc1332a96d27325a39768c79454eabd871e",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
9 changes: 9 additions & 0 deletions tests/reference/runtime-nested_loop-6a5a431.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
outer: 0
inner: 10
outer: 1
inner: 10
outer: 2
inner: 10
outer: 3
inner: 10
no break in outer loop
13 changes: 13 additions & 0 deletions tests/reference/runtime-no_break-1e0d019.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "runtime-no_break-1e0d019",
"cmd": "lpython {infile}",
"infile": "tests/for_else/no_break.py",
"infile_hash": "c9c058789cbf4ee7e18ee93f4f307556887c633b227c121b6f5297f5",
"outfile": null,
"outfile_hash": null,
"stdout": "runtime-no_break-1e0d019.stdout",
"stdout_hash": "7b61b2e1808c9472fa08a21dd6e1c0e656cbe912c2dfecddc22c3db5",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
5 changes: 5 additions & 0 deletions tests/reference/runtime-no_break-1e0d019.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
0
1
2
3
10
13 changes: 13 additions & 0 deletions tests/reference/runtime-with_break-a7ff7d8.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "runtime-with_break-a7ff7d8",
"cmd": "lpython {infile}",
"infile": "tests/for_else/with_break.py",
"infile_hash": "008d9850b479a3c0d21703c9bd98b147d5b12baab71ec9a72f90e74a",
"outfile": null,
"outfile_hash": null,
"stdout": "runtime-with_break-a7ff7d8.stdout",
"stdout_hash": "51d17b7c777114691588f549eb084256fb6fc05c641289d486bf8367",
"stderr": null,
"stderr_hash": null,
"returncode": 0
}
1 change: 1 addition & 0 deletions tests/reference/runtime-with_break-a7ff7d8.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
0
16 changes: 16 additions & 0 deletions tests/tests.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1058,3 +1058,19 @@ run_with_dbg = true
[[test]]
filename = "runtime_errors/test_raise_01.py"
run_with_dbg = true

[[test]]
filename = "for_else/with_break.py"
run = true

[[test]]
filename = "for_else/no_break.py"
run = true

[[test]]
filename = "for_else/break_in_if.py"
run = true

[[test]]
filename = "for_else/nested_loop.py"
run = true