From e493b4777037015a59b68790731b414496e57d79 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 18:20:08 +0800 Subject: [PATCH 01/13] Add new ASR node: ForElse --- src/libasr/ASR.asdl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 34af5be3cd..c0c8cdede0 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -168,6 +168,7 @@ stmt | ImplicitDeallocate(expr* vars) | DoConcurrentLoop(do_loop_head head, stmt* body) | DoLoop(identifier? name, do_loop_head head, stmt* body) + | ForElse(do_loop_head head, stmt* body, stmt* orelse) | ErrorStop(expr? code) | Exit(identifier? stmt_name) | ForAllSingle(do_loop_head head, stmt assign_stmt) From fa7d5289272dcbda2572fd4b8877f0c324e8ebe9 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 18:21:49 +0800 Subject: [PATCH 02/13] Convert For loop with non-empty orelse to ForElse --- src/lpython/semantics/python_ast_to_asr.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 6f33d8090a..a2751f37e5 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -4698,6 +4698,12 @@ class BodyVisitor : public CommonVisitor { if (parallel) { tmp = ASR::make_DoConcurrentLoop_t(al, x.base.base.loc, head, body.p, body.size()); + } else if ( x.n_orelse > 0 ) { + Vec orelse; + orelse.reserve(al, x.n_orelse); + transform_stmts(orelse, x.n_orelse, x.m_orelse); + tmp = ASR::make_ForElse_t(al, x.base.base.loc, head, + body.p, body.size(), orelse.p, orelse.size()); } else { tmp = ASR::make_DoLoop_t(al, x.base.base.loc, nullptr, head, body.p, body.size()); From ee3974e83be7333539239d92681495a5745977dc Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 18:23:35 +0800 Subject: [PATCH 03/13] Add pass for_else to convert ForElse to DoLoop --- src/libasr/CMakeLists.txt | 1 + src/libasr/pass/for_else.cpp | 126 +++++++++++++++++++++++++++++++++ src/libasr/pass/for_else.h | 15 ++++ src/libasr/pass/pass_manager.h | 3 + src/libasr/pass/pass_utils.h | 6 ++ 5 files changed, 151 insertions(+) create mode 100644 src/libasr/pass/for_else.cpp create mode 100644 src/libasr/pass/for_else.h diff --git a/src/libasr/CMakeLists.txt b/src/libasr/CMakeLists.txt index e3439f899c..c73733c005 100644 --- a/src/libasr/CMakeLists.txt +++ b/src/libasr/CMakeLists.txt @@ -29,6 +29,7 @@ set(SRC pass/nested_vars.cpp pass/param_to_const.cpp + pass/for_else.cpp pass/do_loops.cpp pass/for_all.cpp pass/global_stmts.cpp diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp new file mode 100644 index 0000000000..c719497de1 --- /dev/null +++ b/src/libasr/pass/for_else.cpp @@ -0,0 +1,126 @@ +#include "for_else.h" +#include "libasr/asr_scopes.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace LCompilers { + +class ExitVisitor : public ASR::StatementWalkVisitor { +public: + ASR::expr_t* flag_expr; + + ExitVisitor(Allocator &al, ASR::expr_t* flag_expr) : StatementWalkVisitor(al) { + this->flag_expr = flag_expr; + } + + void visit_Exit(const ASR::Exit_t &x) { + std::cerr << "Break!" << std::endl; + + // Vec result; + // result.reserve(al, 1); + + // Location loc = x.base.base.loc; + // ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); + // ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); + // ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); + // result.push_back(al, assign_stmt); + + // pass_result = result; + + auto scope = current_scope->get_scope(); + for (auto it = scope.begin(); it != scope.end(); it++) { + std::cerr << "First: " << it->first << ", second: " << it->second << std::endl; + } + + } +}; + +class ForElseVisitor : public ASR::StatementWalkVisitor +{ +public: + ForElseVisitor(Allocator &al) : StatementWalkVisitor(al) { + counter = 0; + } + + int counter; + + void visit_Exit(const ASR::Exit_t &x) { + std::cerr << "!!!! Break!" << std::endl; + } + + void visit_ForElse(const ASR::ForElse_t &x) { + Location loc = x.base.base.loc; + + Vec result; + result.reserve(al, 1); + + // create a boolean flag and set it to false + ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); + ASR::expr_t* true_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true, bool_type)); + ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); + + auto target_scope = current_scope; // al.make_new(current_scope); + + Str s; + s.from_str_view(std::string("_no_break_") + std::to_string(counter)); + counter++; + + ASR::symbol_t* flag_symbol = LCompilers::ASR::down_cast( + ASR::make_Variable_t( + al, loc, target_scope, + s.c_str(al), nullptr, 0, ASR::intentType::Local, nullptr, nullptr, + ASR::storage_typeType::Default, bool_type, + ASR::abiType::Source, ASR::Public, + ASR::presenceType::Required, false)); + ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol)); + target_scope->add_symbol(s.c_str(al), flag_symbol); + + ASR::stmt_t* assign_stmt = ASRUtils::STMT( + ASR::make_Assignment_t(al, loc, flag_expr, true_expr, nullptr)); + result.push_back(al, assign_stmt); + + Vec body; + body.reserve(al, x.n_body); + + for (size_t i = 0; i < x.n_body; i++) { + ASR::stmt_t *stmt = x.m_body[i]; + + if (stmt->type == ASR::stmtType::Exit) { + ASR::stmt_t* assign_stmt = ASRUtils::STMT( + ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); + result.push_back(al, assign_stmt); + } + + body.push_back(al, stmt); + } + + // convert head and body to DoLoop + ASR::stmt_t *stmt = ASRUtils::STMT( + ASR::make_DoLoop_t(al, loc, x.m_head, body.p, body.size()) + ); + result.push_back(al, stmt); + + // add an If block that executes the orelse statements when the flag is true + result.push_back( + al, ASRUtils::STMT( + ASR::make_If_t(al, loc, flag_expr, x.m_orelse, x.n_orelse, nullptr, 0))); + + pass_result = result; + } +}; + +void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& /*pass_options*/) { + ForElseVisitor v(al); + v.visit_TranslationUnit(unit); + // ExitVisitor v2(al, nullptr); + // v2.visit_TranslationUnit(unit); +} + +} // namespace LCompilers diff --git a/src/libasr/pass/for_else.h b/src/libasr/pass/for_else.h new file mode 100644 index 0000000000..189f2882cb --- /dev/null +++ b/src/libasr/pass/for_else.h @@ -0,0 +1,15 @@ +#ifndef FORELSE_H +#define FORELSE_H + +#include +#include + +namespace LCompilers { + + void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, + const LCompilers::PassOptions& pass_options); + +} // namespace LCompilers + + +#endif diff --git a/src/libasr/pass/pass_manager.h b/src/libasr/pass/pass_manager.h index cdb8915ba7..1560945da2 100644 --- a/src/libasr/pass/pass_manager.h +++ b/src/libasr/pass/pass_manager.h @@ -14,6 +14,7 @@ #include #endif +#include #include #include #include @@ -64,6 +65,7 @@ namespace LCompilers { std::vector _user_defined_passes; std::vector _skip_passes, _c_skip_passes; std::map _passes_db = { + {"for_else", &pass_replace_forelse}, {"do_loops", &pass_replace_do_loops}, {"global_stmts", &pass_wrap_global_stmts_into_function}, {"implied_do_loops", &pass_replace_implied_do_loops}, @@ -200,6 +202,7 @@ namespace LCompilers { "print_arr", "print_list_tuple", "array_dim_intrinsics_update", + "for_else", "do_loops", "forall", "select_case", diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 71842acb76..33b6eb318e 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -220,6 +220,12 @@ namespace LCompilers { this->current_scope = current_scope_copy; } + void visit_ForElse(const ASR::ForElse_t& x) { + self().visit_do_loop_head(x.m_head); + ASR::ForElse_t& xx = const_cast(x); + transform_stmts(xx.m_body, xx.n_body); + transform_stmts(xx.m_orelse, xx.n_orelse); + } }; template From c4d5a63cf14b0c5aa452e657e2c246fb2311b1b8 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 18:59:33 +0800 Subject: [PATCH 04/13] Add some tests for ForElse for_else/break_in_if.py currently does not pass. --- tests/for_else/break_in_if.py | 10 ++++++++++ tests/for_else/no_break.py | 8 ++++++++ tests/for_else/with_break.py | 9 +++++++++ tests/reference/runtime-no_break-1e0d019.json | 13 +++++++++++++ tests/reference/runtime-no_break-1e0d019.stdout | 5 +++++ tests/reference/runtime-with_break-a7ff7d8.json | 13 +++++++++++++ tests/reference/runtime-with_break-a7ff7d8.stdout | 1 + tests/tests.toml | 12 ++++++++++++ 8 files changed, 71 insertions(+) create mode 100644 tests/for_else/break_in_if.py create mode 100644 tests/for_else/no_break.py create mode 100644 tests/for_else/with_break.py create mode 100644 tests/reference/runtime-no_break-1e0d019.json create mode 100644 tests/reference/runtime-no_break-1e0d019.stdout create mode 100644 tests/reference/runtime-with_break-a7ff7d8.json create mode 100644 tests/reference/runtime-with_break-a7ff7d8.stdout diff --git a/tests/for_else/break_in_if.py b/tests/for_else/break_in_if.py new file mode 100644 index 0000000000..0a7799d5ae --- /dev/null +++ b/tests/for_else/break_in_if.py @@ -0,0 +1,10 @@ +def break_in_if(): + i: i32 + for i in range(4): + print(i) + if i == 2: + break + else: + print(10) + +break_in_if() diff --git a/tests/for_else/no_break.py b/tests/for_else/no_break.py new file mode 100644 index 0000000000..66db0701eb --- /dev/null +++ b/tests/for_else/no_break.py @@ -0,0 +1,8 @@ +def no_break(): + i: i32 + for i in range(4): + print(i) + else: + print(10) + +no_break() diff --git a/tests/for_else/with_break.py b/tests/for_else/with_break.py new file mode 100644 index 0000000000..ad0e15ddf2 --- /dev/null +++ b/tests/for_else/with_break.py @@ -0,0 +1,9 @@ +def with_break(): + i: i32 + for i in range(4): + print(i) + break + else: + print(10) + +with_break() diff --git a/tests/reference/runtime-no_break-1e0d019.json b/tests/reference/runtime-no_break-1e0d019.json new file mode 100644 index 0000000000..bb6392c30b --- /dev/null +++ b/tests/reference/runtime-no_break-1e0d019.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-no_break-1e0d019", + "cmd": "lpython {infile}", + "infile": "tests/for_else/no_break.py", + "infile_hash": "c9c058789cbf4ee7e18ee93f4f307556887c633b227c121b6f5297f5", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-no_break-1e0d019.stdout", + "stdout_hash": "7b61b2e1808c9472fa08a21dd6e1c0e656cbe912c2dfecddc22c3db5", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-no_break-1e0d019.stdout b/tests/reference/runtime-no_break-1e0d019.stdout new file mode 100644 index 0000000000..32cd0b7857 --- /dev/null +++ b/tests/reference/runtime-no_break-1e0d019.stdout @@ -0,0 +1,5 @@ +0 +1 +2 +3 +10 diff --git a/tests/reference/runtime-with_break-a7ff7d8.json b/tests/reference/runtime-with_break-a7ff7d8.json new file mode 100644 index 0000000000..38886016cf --- /dev/null +++ b/tests/reference/runtime-with_break-a7ff7d8.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-with_break-a7ff7d8", + "cmd": "lpython {infile}", + "infile": "tests/for_else/with_break.py", + "infile_hash": "008d9850b479a3c0d21703c9bd98b147d5b12baab71ec9a72f90e74a", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-with_break-a7ff7d8.stdout", + "stdout_hash": "51d17b7c777114691588f549eb084256fb6fc05c641289d486bf8367", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-with_break-a7ff7d8.stdout b/tests/reference/runtime-with_break-a7ff7d8.stdout new file mode 100644 index 0000000000..573541ac97 --- /dev/null +++ b/tests/reference/runtime-with_break-a7ff7d8.stdout @@ -0,0 +1 @@ +0 diff --git a/tests/tests.toml b/tests/tests.toml index 206fce824a..8f7043332c 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -1058,3 +1058,15 @@ run_with_dbg = true [[test]] filename = "runtime_errors/test_raise_01.py" run_with_dbg = true + +[[test]] +filename = "for_else/with_break.py" +run = true + +[[test]] +filename = "for_else/no_break.py" +run = true + +# [[test]] +# filename = "for_else/break_in_if.py" +# run = true From 5b9186ce38808d3c3d3603d99c8a5d1367447c61 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 21:49:49 +0800 Subject: [PATCH 05/13] Add test: for_else/break_in_if.py --- tests/reference/runtime-break_in_if-e70c15c.json | 13 +++++++++++++ tests/reference/runtime-break_in_if-e70c15c.stdout | 3 +++ tests/tests.toml | 6 +++--- 3 files changed, 19 insertions(+), 3 deletions(-) create mode 100644 tests/reference/runtime-break_in_if-e70c15c.json create mode 100644 tests/reference/runtime-break_in_if-e70c15c.stdout diff --git a/tests/reference/runtime-break_in_if-e70c15c.json b/tests/reference/runtime-break_in_if-e70c15c.json new file mode 100644 index 0000000000..8808d44788 --- /dev/null +++ b/tests/reference/runtime-break_in_if-e70c15c.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-break_in_if-e70c15c", + "cmd": "lpython {infile}", + "infile": "tests/for_else/break_in_if.py", + "infile_hash": "48b6f27dc6d584ec9765483a80f3f0a705bbbdae1afddf31bab19305", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-break_in_if-e70c15c.stdout", + "stdout_hash": "32abcb5e52daed49078b31f00c8096aca4e2bcef16f75512188bc328", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-break_in_if-e70c15c.stdout b/tests/reference/runtime-break_in_if-e70c15c.stdout new file mode 100644 index 0000000000..4539bbf2d2 --- /dev/null +++ b/tests/reference/runtime-break_in_if-e70c15c.stdout @@ -0,0 +1,3 @@ +0 +1 +2 diff --git a/tests/tests.toml b/tests/tests.toml index 8f7043332c..be6882cbdc 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -1067,6 +1067,6 @@ run = true filename = "for_else/no_break.py" run = true -# [[test]] -# filename = "for_else/break_in_if.py" -# run = true +[[test]] +filename = "for_else/break_in_if.py" +run = true From 5e7dbc38c092c17334a56fed5d85e8767e6e3d28 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Thu, 30 Mar 2023 21:52:01 +0800 Subject: [PATCH 06/13] Basic support for if blocks in loops --- src/libasr/pass/for_else.cpp | 177 +++++++++++++++++++++++++++-------- 1 file changed, 138 insertions(+), 39 deletions(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index c719497de1..67002da863 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -9,35 +9,150 @@ #include #include +#include + +// the current code passes break_in_if.py + +/* FIXME: + +for ... + for ... + break # does not affect the orelse flag in outer loop +else + ... + + */ + namespace LCompilers { +using ASR::is_a; +using ASR::down_cast; +using ASR::stmtType; + +std::map doLoopFlagMap; + class ExitVisitor : public ASR::StatementWalkVisitor { public: - ASR::expr_t* flag_expr; + std::stack doLoopStack; + + ExitVisitor(Allocator &al) : StatementWalkVisitor(al) { + } - ExitVisitor(Allocator &al, ASR::expr_t* flag_expr) : StatementWalkVisitor(al) { - this->flag_expr = flag_expr; + void visit_DoLoop(const ASR::DoLoop_t &x) { + ASR::stmt_t *doLoopStmt = (ASR::stmt_t*)(&x); + // std::cerr << doLoopStmt << " -- " << doLoopFlagMap[doLoopStmt] << std::endl; + + ASR::DoLoop_t& xx = const_cast(x); + + if (doLoopFlagMap.find(doLoopStmt) != doLoopFlagMap.end()) + doLoopStack.push(doLoopStmt); + + this->transform_stmts(xx.m_body, xx.n_body); + + if (doLoopFlagMap.find(doLoopStmt) != doLoopFlagMap.end()) + doLoopStack.pop(); } void visit_Exit(const ASR::Exit_t &x) { - std::cerr << "Break!" << std::endl; + if (doLoopStack.empty()) + return; - // Vec result; - // result.reserve(al, 1); + // std::cerr << "Break! inside " << doLoopStack.top() << std::endl; - // Location loc = x.base.base.loc; - // ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); - // ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); - // ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); - // result.push_back(al, assign_stmt); + Vec result; + result.reserve(al, 1); - // pass_result = result; + Location loc = x.base.base.loc; + ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); + ASR::symbol_t* flag_symbol = doLoopFlagMap[doLoopStack.top()]; + ASR::expr_t* flag_expr = ASRUtils::EXPR(ASR::make_Var_t(al, loc, flag_symbol)); + ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); + ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); + result.push_back(al, assign_stmt); + result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc))); - auto scope = current_scope->get_scope(); - for (auto it = scope.begin(); it != scope.end(); it++) { - std::cerr << "First: " << it->first << ", second: " << it->second << std::endl; + pass_result = result; + + return; + + auto current = current_scope; + while (current != nullptr) { + std::cerr << "Scope " << current << std::endl; + auto scope = current->get_scope(); + for (auto it = scope.begin(); it != scope.end(); it++) { + if (is_a(*it->second)) { + std::cerr << " Variable: " << it->first << std::endl; + } else if (is_a(*it->second)) { + std::cerr << " Program: " << it->first << std::endl; + } else if (is_a(*it->second)) { + std::cerr << " Block: " << it->first << std::endl; + ASR::Block_t *block = down_cast(it->second); + std::cerr << " body " << block->n_body << std::endl; + for (size_t i = 0; i < block->n_body; i++) { + std::cerr << " " << getStmtType(block->m_body[i]->type) << std::endl; + } + } else if (is_a(*it->second)) { + std::cerr << " Function: " << it->first << std::endl; + ASR::Function_t *block = down_cast(it->second); + for (size_t i = 0; i < block->n_body; i++) { + std::cerr << " " << getStmtType(block->m_body[i]->type) << std::endl; + } + } else { + std::cerr << " First: " << it->first << ", second: " << it->second->type << std::endl; + } + } + current = current->parent; } + } + std::string getStmtType(stmtType t) { + switch (t) { + case stmtType::Allocate: return "Allocate"; + case stmtType::Assign: return "Assign"; + case stmtType::Assignment: return "Assignment"; + case stmtType::Associate: return "Associate"; + case stmtType::Cycle: return "Cycle"; + case stmtType::ExplicitDeallocate: return "ExplicitDeallocate"; + case stmtType::ImplicitDeallocate: return "ImplicitDeallocate"; + case stmtType::DoConcurrentLoop: return "DoConcurrentLoop"; + case stmtType::DoLoop: return "DoLoop"; + case stmtType::ForElse: return "ForElse"; + case stmtType::ErrorStop: return "ErrorStop"; + case stmtType::Exit: return "Exit"; + case stmtType::ForAllSingle: return "ForAllSingle"; + case stmtType::GoTo: return "GoTo"; + case stmtType::GoToTarget: return "GoToTarget"; + case stmtType::If: return "If"; + case stmtType::IfArithmetic: return "IfArithmetic"; + case stmtType::Print: return "Print"; + case stmtType::FileOpen: return "FileOpen"; + case stmtType::FileClose: return "FileClose"; + case stmtType::FileRead: return "FileRead"; + case stmtType::FileBackspace: return "FileBackspace"; + case stmtType::FileRewind: return "FileRewind"; + case stmtType::FileInquire: return "FileInquire"; + case stmtType::FileWrite: return "FileWrite"; + case stmtType::Return: return "Return"; + case stmtType::Select: return "Select"; + case stmtType::Stop: return "Stop"; + case stmtType::Assert: return "Assert"; + case stmtType::SubroutineCall: return "SubroutineCall"; + case stmtType::Where: return "Where"; + case stmtType::WhileLoop: return "WhileLoop"; + case stmtType::Nullify: return "Nullify"; + case stmtType::Flush: return "Flush"; + case stmtType::ListAppend: return "ListAppend"; + case stmtType::AssociateBlockCall: return "AssociateBlockCall"; + case stmtType::SelectType: return "SelectType"; + case stmtType::CPtrToPointer: return "CPtrToPointer"; + case stmtType::BlockCall: return "BlockCall"; + case stmtType::SetInsert: return "SetInsert"; + case stmtType::SetRemove: return "SetRemove"; + case stmtType::ListInsert: return "ListInsert"; + case stmtType::ListRemove: return "ListRemove"; + case stmtType::ListClear: return "ListClear"; + case stmtType::DictInsert: return "DictInsert"; + } } }; @@ -50,10 +165,6 @@ class ForElseVisitor : public ASR::StatementWalkVisitor int counter; - void visit_Exit(const ASR::Exit_t &x) { - std::cerr << "!!!! Break!" << std::endl; - } - void visit_ForElse(const ASR::ForElse_t &x) { Location loc = x.base.base.loc; @@ -85,26 +196,14 @@ class ForElseVisitor : public ASR::StatementWalkVisitor ASR::make_Assignment_t(al, loc, flag_expr, true_expr, nullptr)); result.push_back(al, assign_stmt); - Vec body; - body.reserve(al, x.n_body); - - for (size_t i = 0; i < x.n_body; i++) { - ASR::stmt_t *stmt = x.m_body[i]; - - if (stmt->type == ASR::stmtType::Exit) { - ASR::stmt_t* assign_stmt = ASRUtils::STMT( - ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); - result.push_back(al, assign_stmt); - } - - body.push_back(al, stmt); - } - // convert head and body to DoLoop - ASR::stmt_t *stmt = ASRUtils::STMT( - ASR::make_DoLoop_t(al, loc, x.m_head, body.p, body.size()) + ASR::stmt_t *doLoopStmt = ASRUtils::STMT( + ASR::make_DoLoop_t(al, loc, x.m_head, x.m_body, x.n_body) ); - result.push_back(al, stmt); + result.push_back(al, doLoopStmt); + + doLoopFlagMap[doLoopStmt] = flag_symbol; + // std::cerr << doLoopStmt << " -> " << flag_expr << std::endl; // add an If block that executes the orelse statements when the flag is true result.push_back( @@ -119,8 +218,8 @@ void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, const LCompilers::PassOptions& /*pass_options*/) { ForElseVisitor v(al); v.visit_TranslationUnit(unit); - // ExitVisitor v2(al, nullptr); - // v2.visit_TranslationUnit(unit); + ExitVisitor v2(al); + v2.visit_TranslationUnit(unit); } } // namespace LCompilers From 10f60a129565b3dcc70d10950233bcd1b73dd002 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:25:11 +0800 Subject: [PATCH 07/13] Fix problem with nested loop --- src/libasr/pass/for_else.cpp | 25 ++++++------------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index 67002da863..bf8cd4e50b 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -11,18 +11,6 @@ #include -// the current code passes break_in_if.py - -/* FIXME: - -for ... - for ... - break # does not affect the orelse flag in outer loop -else - ... - - */ - namespace LCompilers { using ASR::is_a; @@ -42,19 +30,18 @@ class ExitVisitor : public ASR::StatementWalkVisitor { ASR::stmt_t *doLoopStmt = (ASR::stmt_t*)(&x); // std::cerr << doLoopStmt << " -- " << doLoopFlagMap[doLoopStmt] << std::endl; - ASR::DoLoop_t& xx = const_cast(x); - - if (doLoopFlagMap.find(doLoopStmt) != doLoopFlagMap.end()) - doLoopStack.push(doLoopStmt); + doLoopStack.push(doLoopStmt); + ASR::DoLoop_t& xx = const_cast(x); this->transform_stmts(xx.m_body, xx.n_body); - if (doLoopFlagMap.find(doLoopStmt) != doLoopFlagMap.end()) - doLoopStack.pop(); + doLoopStack.pop(); } void visit_Exit(const ASR::Exit_t &x) { - if (doLoopStack.empty()) + if (doLoopStack.empty() || + // the current loop is not originally a ForElse loop + doLoopFlagMap.find(doLoopStack.top()) == doLoopFlagMap.end()) return; // std::cerr << "Break! inside " << doLoopStack.top() << std::endl; From ff2af2e1f8aa50d6f3bfd72864439a05ac92ad2e Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:26:53 +0800 Subject: [PATCH 08/13] Update comment --- src/libasr/pass/for_else.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index bf8cd4e50b..8e3784007a 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -158,7 +158,7 @@ class ForElseVisitor : public ASR::StatementWalkVisitor Vec result; result.reserve(al, 1); - // create a boolean flag and set it to false + // create a boolean flag and set it to true ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); ASR::expr_t* true_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true, bool_type)); ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); @@ -189,6 +189,7 @@ class ForElseVisitor : public ASR::StatementWalkVisitor ); result.push_back(al, doLoopStmt); + // this DoLoop corresponds to the current flag doLoopFlagMap[doLoopStmt] = flag_symbol; // std::cerr << doLoopStmt << " -> " << flag_expr << std::endl; From e26e703d60be0e9176c8b45aadc8faaaaae4a8b7 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:26:59 +0800 Subject: [PATCH 09/13] Remove unused code --- src/libasr/pass/for_else.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index 8e3784007a..0657548ce1 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -161,7 +161,6 @@ class ForElseVisitor : public ASR::StatementWalkVisitor // create a boolean flag and set it to true ASR::ttype_t* bool_type = ASRUtils::TYPE(ASR::make_Logical_t(al, loc, 4, nullptr, 0)); ASR::expr_t* true_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, true, bool_type)); - ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); auto target_scope = current_scope; // al.make_new(current_scope); From f67409beae1285d0b3a2ab14399d6da3ed2fe132 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:35:41 +0800 Subject: [PATCH 10/13] Make 'doLoopFlagMap' instance variable --- src/libasr/pass/for_else.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index 0657548ce1..284cc96d77 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -17,13 +17,13 @@ using ASR::is_a; using ASR::down_cast; using ASR::stmtType; -std::map doLoopFlagMap; - class ExitVisitor : public ASR::StatementWalkVisitor { public: std::stack doLoopStack; + std::map &doLoopFlagMap; - ExitVisitor(Allocator &al) : StatementWalkVisitor(al) { + ExitVisitor(Allocator &al, std::map &doLoopFlagMap) + : StatementWalkVisitor(al), doLoopFlagMap(doLoopFlagMap) { } void visit_DoLoop(const ASR::DoLoop_t &x) { @@ -150,6 +150,8 @@ class ForElseVisitor : public ASR::StatementWalkVisitor counter = 0; } + std::map doLoopFlagMap; + int counter; void visit_ForElse(const ASR::ForElse_t &x) { @@ -205,7 +207,7 @@ void pass_replace_forelse(Allocator &al, ASR::TranslationUnit_t &unit, const LCompilers::PassOptions& /*pass_options*/) { ForElseVisitor v(al); v.visit_TranslationUnit(unit); - ExitVisitor v2(al); + ExitVisitor v2(al, v.doLoopFlagMap); v2.visit_TranslationUnit(unit); } From ad437e746da3581ae03f34911ad21b6c29c6d7e2 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:42:42 +0800 Subject: [PATCH 11/13] Remove debug code --- src/libasr/pass/for_else.cpp | 84 ------------------------------------ 1 file changed, 84 deletions(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index 284cc96d77..02db470cc9 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -28,7 +28,6 @@ class ExitVisitor : public ASR::StatementWalkVisitor { void visit_DoLoop(const ASR::DoLoop_t &x) { ASR::stmt_t *doLoopStmt = (ASR::stmt_t*)(&x); - // std::cerr << doLoopStmt << " -- " << doLoopFlagMap[doLoopStmt] << std::endl; doLoopStack.push(doLoopStmt); @@ -44,8 +43,6 @@ class ExitVisitor : public ASR::StatementWalkVisitor { doLoopFlagMap.find(doLoopStack.top()) == doLoopFlagMap.end()) return; - // std::cerr << "Break! inside " << doLoopStack.top() << std::endl; - Vec result; result.reserve(al, 1); @@ -59,87 +56,6 @@ class ExitVisitor : public ASR::StatementWalkVisitor { result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc))); pass_result = result; - - return; - - auto current = current_scope; - while (current != nullptr) { - std::cerr << "Scope " << current << std::endl; - auto scope = current->get_scope(); - for (auto it = scope.begin(); it != scope.end(); it++) { - if (is_a(*it->second)) { - std::cerr << " Variable: " << it->first << std::endl; - } else if (is_a(*it->second)) { - std::cerr << " Program: " << it->first << std::endl; - } else if (is_a(*it->second)) { - std::cerr << " Block: " << it->first << std::endl; - ASR::Block_t *block = down_cast(it->second); - std::cerr << " body " << block->n_body << std::endl; - for (size_t i = 0; i < block->n_body; i++) { - std::cerr << " " << getStmtType(block->m_body[i]->type) << std::endl; - } - } else if (is_a(*it->second)) { - std::cerr << " Function: " << it->first << std::endl; - ASR::Function_t *block = down_cast(it->second); - for (size_t i = 0; i < block->n_body; i++) { - std::cerr << " " << getStmtType(block->m_body[i]->type) << std::endl; - } - } else { - std::cerr << " First: " << it->first << ", second: " << it->second->type << std::endl; - } - } - current = current->parent; - } - } - - std::string getStmtType(stmtType t) { - switch (t) { - case stmtType::Allocate: return "Allocate"; - case stmtType::Assign: return "Assign"; - case stmtType::Assignment: return "Assignment"; - case stmtType::Associate: return "Associate"; - case stmtType::Cycle: return "Cycle"; - case stmtType::ExplicitDeallocate: return "ExplicitDeallocate"; - case stmtType::ImplicitDeallocate: return "ImplicitDeallocate"; - case stmtType::DoConcurrentLoop: return "DoConcurrentLoop"; - case stmtType::DoLoop: return "DoLoop"; - case stmtType::ForElse: return "ForElse"; - case stmtType::ErrorStop: return "ErrorStop"; - case stmtType::Exit: return "Exit"; - case stmtType::ForAllSingle: return "ForAllSingle"; - case stmtType::GoTo: return "GoTo"; - case stmtType::GoToTarget: return "GoToTarget"; - case stmtType::If: return "If"; - case stmtType::IfArithmetic: return "IfArithmetic"; - case stmtType::Print: return "Print"; - case stmtType::FileOpen: return "FileOpen"; - case stmtType::FileClose: return "FileClose"; - case stmtType::FileRead: return "FileRead"; - case stmtType::FileBackspace: return "FileBackspace"; - case stmtType::FileRewind: return "FileRewind"; - case stmtType::FileInquire: return "FileInquire"; - case stmtType::FileWrite: return "FileWrite"; - case stmtType::Return: return "Return"; - case stmtType::Select: return "Select"; - case stmtType::Stop: return "Stop"; - case stmtType::Assert: return "Assert"; - case stmtType::SubroutineCall: return "SubroutineCall"; - case stmtType::Where: return "Where"; - case stmtType::WhileLoop: return "WhileLoop"; - case stmtType::Nullify: return "Nullify"; - case stmtType::Flush: return "Flush"; - case stmtType::ListAppend: return "ListAppend"; - case stmtType::AssociateBlockCall: return "AssociateBlockCall"; - case stmtType::SelectType: return "SelectType"; - case stmtType::CPtrToPointer: return "CPtrToPointer"; - case stmtType::BlockCall: return "BlockCall"; - case stmtType::SetInsert: return "SetInsert"; - case stmtType::SetRemove: return "SetRemove"; - case stmtType::ListInsert: return "ListInsert"; - case stmtType::ListRemove: return "ListRemove"; - case stmtType::ListClear: return "ListClear"; - case stmtType::DictInsert: return "DictInsert"; - } } }; From f51b9b752c35a0430253631ff1c1d99256d1d29a Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 17:43:54 +0800 Subject: [PATCH 12/13] Add test for nested loop --- tests/for_else/nested_loop.py | 12 ++++++++++++ tests/reference/runtime-nested_loop-6a5a431.json | 13 +++++++++++++ tests/reference/runtime-nested_loop-6a5a431.stdout | 9 +++++++++ tests/tests.toml | 4 ++++ 4 files changed, 38 insertions(+) create mode 100644 tests/for_else/nested_loop.py create mode 100644 tests/reference/runtime-nested_loop-6a5a431.json create mode 100644 tests/reference/runtime-nested_loop-6a5a431.stdout diff --git a/tests/for_else/nested_loop.py b/tests/for_else/nested_loop.py new file mode 100644 index 0000000000..ade5968637 --- /dev/null +++ b/tests/for_else/nested_loop.py @@ -0,0 +1,12 @@ +def nested_loop(): + i: i32 + j: i32 + for i in range(4): + print("outer: " + str(i)) + for j in range(10, 20): + print(" inner: " + str(j)) + break + else: + print("no break in outer loop") + +nested_loop() diff --git a/tests/reference/runtime-nested_loop-6a5a431.json b/tests/reference/runtime-nested_loop-6a5a431.json new file mode 100644 index 0000000000..5f83825c4d --- /dev/null +++ b/tests/reference/runtime-nested_loop-6a5a431.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-nested_loop-6a5a431", + "cmd": "lpython {infile}", + "infile": "tests/for_else/nested_loop.py", + "infile_hash": "41866cefd6bbcd693f2f7a3e2fb229fa5500e0fcb41ae864a5902709", + "outfile": null, + "outfile_hash": null, + "stdout": "runtime-nested_loop-6a5a431.stdout", + "stdout_hash": "8bfba961db77bc0c4a2f4cc1332a96d27325a39768c79454eabd871e", + "stderr": null, + "stderr_hash": null, + "returncode": 0 +} \ No newline at end of file diff --git a/tests/reference/runtime-nested_loop-6a5a431.stdout b/tests/reference/runtime-nested_loop-6a5a431.stdout new file mode 100644 index 0000000000..b8da464e3a --- /dev/null +++ b/tests/reference/runtime-nested_loop-6a5a431.stdout @@ -0,0 +1,9 @@ +outer: 0 + inner: 10 +outer: 1 + inner: 10 +outer: 2 + inner: 10 +outer: 3 + inner: 10 +no break in outer loop diff --git a/tests/tests.toml b/tests/tests.toml index be6882cbdc..94b41195cc 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -1070,3 +1070,7 @@ run = true [[test]] filename = "for_else/break_in_if.py" run = true + +[[test]] +filename = "for_else/nested_loop.py" +run = true From 8a24109ceaa8078859719fd0e58dd659e250f465 Mon Sep 17 00:00:00 2001 From: thebesttv Date: Fri, 14 Apr 2023 18:04:38 +0800 Subject: [PATCH 13/13] Update according to synced libasr This fix is due to commit 35a8d865bb960acd7397252607e2e5cd64a4bccd. --- src/libasr/pass/for_else.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libasr/pass/for_else.cpp b/src/libasr/pass/for_else.cpp index 02db470cc9..04a1ecfc74 100644 --- a/src/libasr/pass/for_else.cpp +++ b/src/libasr/pass/for_else.cpp @@ -53,7 +53,7 @@ class ExitVisitor : public ASR::StatementWalkVisitor { ASR::expr_t* false_expr = ASRUtils::EXPR(ASR::make_LogicalConstant_t(al, loc, false, bool_type)); ASR::stmt_t* assign_stmt = ASRUtils::STMT(ASR::make_Assignment_t(al, loc, flag_expr, false_expr, nullptr)); result.push_back(al, assign_stmt); - result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc))); + result.push_back(al, ASRUtils::STMT(ASR::make_Exit_t(al, loc, nullptr))); pass_result = result; } @@ -102,7 +102,7 @@ class ForElseVisitor : public ASR::StatementWalkVisitor // convert head and body to DoLoop ASR::stmt_t *doLoopStmt = ASRUtils::STMT( - ASR::make_DoLoop_t(al, loc, x.m_head, x.m_body, x.n_body) + ASR::make_DoLoop_t(al, loc, nullptr, x.m_head, x.m_body, x.n_body) ); result.push_back(al, doLoopStmt);