diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 34af5be3cd..c0c8cdede0 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -168,6 +168,7 @@ stmt | ImplicitDeallocate(expr* vars) | DoConcurrentLoop(do_loop_head head, stmt* body) | DoLoop(identifier? name, do_loop_head head, stmt* body) + | ForElse(do_loop_head head, stmt* body, stmt* orelse) | ErrorStop(expr? code) | Exit(identifier? stmt_name) | ForAllSingle(do_loop_head head, stmt assign_stmt) diff --git a/src/libasr/CMakeLists.txt b/src/libasr/CMakeLists.txt index e3439f899c..c73733c005 100644 --- a/src/libasr/CMakeLists.txt +++ b/src/libasr/CMakeLists.txt @@ -29,6 +29,7 @@ set(SRC pass/nested_vars.cpp pass/param_to_const.cpp + pass/for_else.cpp pass/do_loops.cpp pass/for_all.cpp pass/global_stmts.cpp diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp new file mode 100644 index 0000000000..04a1ecfc74 --- /dev/null +++ b/src/libasr/pass/for_else.cpp @@ -0,0 +1,130 @@ +#include "for_else.h" +#include "libasr/asr_scopes.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace LCompilers { + +using ASR::is_a; +using ASR::down_cast; +using ASR::stmtType; + +class ExitVisitor : public ASR::StatementWalkVisitor { +public: + std::stack doLoopStack; + std::map &doLoopFlagMap; + + ExitVisitor(Allocator &al, std::map &doLoopFlagMap) + : StatementWalkVisitor(al), doLoopFlagMap(doLoopFlagMap) { + } + + void visit_DoLoop(const ASR::DoLoop_t &x) { + ASR::stmt_t *doLoopStmt = (ASR::stmt_t*)(&x); + + doLoopStack.push(doLoopStmt); + + ASR::DoLoop_t& xx = const_cast(x); + this->transform_stmts(xx.m_body, xx.n_body); + + doLoopStack.pop(); + } + + void visit_Exit(const ASR::Exit_t &x) { + if (doLoopStack.empty() || + // the current loop is not originally a ForElse loop + doLoopFlagMap.find(doLoopStack.top()) == doLoopFlagMap.end()) + return; + + Vec result; + result.reserve(al, 1); + + Location loc = x.base.base.loc; + ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); + ASR::symbol_t* flag_symbol = doLoopFlagMap[doLoopStack.top()]; + ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol)); + ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); + ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); + result.push_back(al, assign_stmt); + result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc, nullptr))); + + pass_result = result; + } +}; + +class ForElseVisitor : public ASR::StatementWalkVisitor +{ +public: + ForElseVisitor(Allocator &al) : StatementWalkVisitor(al) { + counter = 0; + } + + std::map doLoopFlagMap; + + int counter; + + void visit_ForElse(const ASR::ForElse_t &x) { + Location loc = x.base.base.loc; + + Vec result; + result.reserve(al, 1); + + // create a boolean flag and set it to true + ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); + ASR::expr_t* true_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true, bool_type)); + + auto target_scope = current_scope; // al.make_new(current_scope); + + Str s; + s.from_str_view(std::string("_no_break_") + std::to_string(counter)); + counter++; + + ASR::symbol_t* flag_symbol = LCompilers::ASR::down_cast( + ASR::make_Variable_t( + al, loc, target_scope, + s.c_str(al), nullptr, 0, ASR::intentType::Local, nullptr, nullptr, + ASR::storage_typeType::Default, bool_type, + ASR::abiType::Source, ASR::Public, + ASR::presenceType::Required, false)); + ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol)); + target_scope->add_symbol(s.c_str(al), flag_symbol); + + ASR::stmt_t* assign_stmt = ASRUtils::STMT( + ASR::make_Assignment_t(al, loc, flag_expr, true_expr, nullptr)); + result.push_back(al, assign_stmt); + + // convert head and body to DoLoop + ASR::stmt_t *doLoopStmt = ASRUtils::STMT( + ASR::make_DoLoop_t(al, loc, nullptr, x.m_head, x.m_body, x.n_body) + ); + result.push_back(al, doLoopStmt); + + // this DoLoop corresponds to the current flag + doLoopFlagMap[doLoopStmt] = flag_symbol; + // std::cerr << doLoopStmt << " -> " << flag_expr << std::endl; + + // add an If block that executes the orelse statements when the flag is true + result.push_back( + al, ASRUtils::STMT( + ASR::make_If_t(al, loc, flag_expr, x.m_orelse, x.n_orelse, nullptr, 0))); + + pass_result = result; + } +}; + +void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& /*pass_options*/) { + ForElseVisitor v(al); + v.visit_TranslationUnit(unit); + ExitVisitor v2(al, v.doLoopFlagMap); + v2.visit_TranslationUnit(unit); +} + +} // namespace LCompilers diff --git a/src/libasr/pass/for_else.h b/src/libasr/pass/for_else.h new file mode 100644 index 0000000000..189f2882cb --- /dev/null +++ b/src/libasr/pass/for_else.h @@ -0,0 +1,15 @@ +#ifndef FORELSE_H +#define FORELSE_H + +#include +#include + +namespace LCompilers { + + void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& pass_options); + +} // namespace LCompilers + + +#endif diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index cdb8915ba7..1560945da2 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -14,6 +14,7 @@ #include #endif +#include #include #include #include @@ -64,6 +65,7 @@ namespace LCompilers { std::vector _user_defined_passes; std::vector _skip_passes, _c_skip_passes; std::map _passes_db = { + {"for_else", &pass_replace_forelse}, {"do_loops", &pass_replace_do_loops}, {"global_stmts", &pass_wrap_global_stmts_into_function}, {"implied_do_loops", &pass_replace_implied_do_loops}, @@ -200,6 +202,7 @@ namespace LCompilers { "print_arr", "print_list_tuple", "array_dim_intrinsics_update", + "for_else", "do_loops", "forall", "select_case", diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 71842acb76..33b6eb318e 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -220,6 +220,12 @@ namespace LCompilers { this->current_scope = current_scope_copy; } + void visit_ForElse(const ASR::ForElse_t& x) { + self().visit_do_loop_head(x.m_head); + ASR::ForElse_t& xx = const_cast(x); + transform_stmts(xx.m_body, xx.n_body); + transform_stmts(xx.m_orelse, xx.n_orelse); + } }; template diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 6f33d8090a..a2751f37e5 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4698,6 +4698,12 @@ class BodyVisitor : public CommonVisitor { if (parallel) { tmp = ASR::make_DoConcurrentLoop_t(al, x.base.base.loc, head, body.p, body.size()); + } else if ( x.n_orelse > 0 ) { + Vec orelse; + orelse.reserve(al, x.n_orelse); + transform_stmts(orelse, x.n_orelse, x.m_orelse); + tmp = ASR::make_ForElse_t(al, x.base.base.loc, head, + body.p, body.size(), orelse.p, orelse.size()); } else { tmp = ASR::make_DoLoop_t(al, x.base.base.loc, nullptr, head, body.p, body.size()); diff --git a/tests/for_else/break_in_if.py b/tests/for_else/break_in_if.py new file mode 100644 index 0000000000..0a7799d5ae --- /dev/null +++ b/tests/for_else/break_in_if.py @@ -0,0 +1,10 @@ +def break_in_if(): + i: i32 + for i in range(4): + print(i) + if i == 2: + break + else: + print(10) + +break_in_if() diff --git a/tests/for_else/nested_loop.py b/tests/for_else/nested_loop.py new file mode 100644 index 0000000000..ade5968637 --- /dev/null +++ b/tests/for_else/nested_loop.py @@ -0,0 +1,12 @@ +def nested_loop(): + i: i32 + j: i32 + for i in range(4): + print("outer: " + str(i)) + for j in range(10, 20): + print(" inner: " + str(j)) + break + else: + print("no break in outer loop") + +nested_loop() diff --git a/tests/for_else/no_break.py b/tests/for_else/no_break.py new file mode 100644 index 0000000000..66db0701eb --- /dev/null +++ b/tests/for_else/no_break.py @@ -0,0 +1,8 @@ +def no_break(): + i: i32 + for i in range(4): + print(i) + else: + print(10) + +no_break() diff --git a/tests/for_else/with_break.py b/tests/for_else/with_break.py new file mode 100644 index 0000000000..ad0e15ddf2 --- /dev/null +++ b/tests/for_else/with_break.py @@ -0,0 +1,9 @@ +def with_break(): + i: i32 + for i in range(4): + print(i) + break + else: + print(10) + +with_break() diff --git a/tests/reference/runtime-break_in_if-e70c15c.json b/tests/reference/runtime-break_in_if-e70c15c.json new file mode 100644 index 0000000000..8808d44788 --- /dev/null +++ b/tests/reference/runtime-break_in_if-e70c15c.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-break_in_if-e70c15c", + "cmd": "lpython {infile}", + "infile": "tests/for_else/break_in_if.py", + "infile_hash": "48b6f27dc6d584ec9765483a80f3f0a705bbbdae1afddf31bab19305", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-break_in_if-e70c15c.stdout", + "stdout_hash": "32abcb5e52daed49078b31f00c8096aca4e2bcef16f75512188bc328", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-break_in_if-e70c15c.stdout b/tests/reference/runtime-break_in_if-e70c15c.stdout new file mode 100644 index 0000000000..4539bbf2d2 --- /dev/null +++ b/tests/reference/runtime-break_in_if-e70c15c.stdout @@ -0,0 +1,3 @@ +0 +1 +2 diff --git a/tests/reference/runtime-nested_loop-6a5a431.json b/tests/reference/runtime-nested_loop-6a5a431.json new file mode 100644 index 0000000000..5f83825c4d --- /dev/null +++ b/tests/reference/runtime-nested_loop-6a5a431.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-nested_loop-6a5a431", + "cmd": "lpython {infile}", + "infile": "tests/for_else/nested_loop.py", + "infile_hash": "41866cefd6bbcd693f2f7a3e2fb229fa5500e0fcb41ae864a5902709", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-nested_loop-6a5a431.stdout", + "stdout_hash": "8bfba961db77bc0c4a2f4cc1332a96d27325a39768c79454eabd871e", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-nested_loop-6a5a431.stdout b/tests/reference/runtime-nested_loop-6a5a431.stdout new file mode 100644 index 0000000000..b8da464e3a --- /dev/null +++ b/tests/reference/runtime-nested_loop-6a5a431.stdout @@ -0,0 +1,9 @@ +outer: 0 + inner: 10 +outer: 1 + inner: 10 +outer: 2 + inner: 10 +outer: 3 + inner: 10 +no break in outer loop diff --git a/tests/reference/runtime-no_break-1e0d019.json b/tests/reference/runtime-no_break-1e0d019.json new file mode 100644 index 0000000000..bb6392c30b --- /dev/null +++ b/tests/reference/runtime-no_break-1e0d019.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-no_break-1e0d019", + "cmd": "lpython {infile}", + "infile": "tests/for_else/no_break.py", + "infile_hash": "c9c058789cbf4ee7e18ee93f4f307556887c633b227c121b6f5297f5", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-no_break-1e0d019.stdout", + "stdout_hash": "7b61b2e1808c9472fa08a21dd6e1c0e656cbe912c2dfecddc22c3db5", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-no_break-1e0d019.stdout b/tests/reference/runtime-no_break-1e0d019.stdout new file mode 100644 index 0000000000..32cd0b7857 --- /dev/null +++ b/tests/reference/runtime-no_break-1e0d019.stdout @@ -0,0 +1,5 @@ +0 +1 +2 +3 +10 diff --git a/tests/reference/runtime-with_break-a7ff7d8.json b/tests/reference/runtime-with_break-a7ff7d8.json new file mode 100644 index 0000000000..38886016cf --- /dev/null +++ b/tests/reference/runtime-with_break-a7ff7d8.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-with_break-a7ff7d8", + "cmd": "lpython {infile}", + "infile": "tests/for_else/with_break.py", + "infile_hash": "008d9850b479a3c0d21703c9bd98b147d5b12baab71ec9a72f90e74a", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-with_break-a7ff7d8.stdout", + "stdout_hash": "51d17b7c777114691588f549eb084256fb6fc05c641289d486bf8367", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-with_break-a7ff7d8.stdout b/tests/reference/runtime-with_break-a7ff7d8.stdout new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/tests/reference/runtime-with_break-a7ff7d8.stdout @@ -0,0 +1 @@ +0 diff --git a/tests/tests.toml b/tests/tests.toml index 206fce824a..94b41195cc 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -1058,3 +1058,19 @@ run_with_dbg = true [[test]] filename = "runtime_errors/test_raise_01.py" run_with_dbg = true + +[[test]] +filename = "for_else/with_break.py" +run = true + +[[test]] +filename = "for_else/no_break.py" +run = true + +[[test]] +filename = "for_else/break_in_if.py" +run = true + +[[test]] +filename = "for_else/nested_loop.py" +run = true