diff --git a/src/libasr/pass/global_stmts.cpp b/src/libasr/pass/global_stmts.cpp index 7fd9310d1a..4324350c25 100644 --- a/src/libasr/pass/global_stmts.cpp +++ b/src/libasr/pass/global_stmts.cpp @@ -105,57 +105,31 @@ void pass_wrap_global_stmts(Allocator &al, } if (return_var) { - // The last item was an expression, create a function returning it - // The last defined `return_var` is the actual return value ASR::down_cast2(return_var)->m_intent = ASRUtils::intent_return_var; + } - - ASR::asr_t *fn = ASRUtils::make_Function_t_util( - al, loc, - /* a_symtab */ fn_scope, - /* a_name */ fn_name, - nullptr, 0, - /* a_args */ nullptr, - /* n_args */ 0, - /* a_body */ body.p, - /* n_body */ body.size(), - /* a_return_var */ return_var_ref, - ASR::abiType::BindC, - ASR::Public, ASR::Implementation, - nullptr, - false, false, false, false, false, - nullptr, 0, nullptr, 0, - false, false, false); - std::string sym_name = fn_name; - if (unit.m_global_scope->get_symbol(sym_name) != nullptr) { - throw LCompilersException("Function already defined"); - } - unit.m_global_scope->add_symbol(sym_name, down_cast(fn)); - } else { - // The last item was a statement, create a subroutine (returning - // nothing) - ASR::asr_t *fn = ASRUtils::make_Function_t_util( - al, loc, - /* a_symtab */ fn_scope, - /* a_name */ fn_name, - nullptr, 0, - /* a_args */ nullptr, - /* n_args */ 0, - /* a_body */ body.p, - /* n_body */ body.size(), - nullptr, - ASR::abiType::Source, - ASR::Public, ASR::Implementation, nullptr, - false, false, false, false, false, - nullptr, 0, nullptr, 0, - false, false, false); - std::string sym_name = fn_name; - if (unit.m_global_scope->get_symbol(sym_name) != nullptr) { - throw LCompilersException("Function already defined"); - } - unit.m_global_scope->add_symbol(sym_name, down_cast(fn)); + ASR::asr_t *fn = ASRUtils::make_Function_t_util( + al, loc, + /* a_symtab */ fn_scope, + /* a_name */ fn_name, + nullptr, 0, + /* a_args */ nullptr, + /* n_args */ 0, + /* a_body */ body.p, + /* n_body */ body.size(), + /* a_return_var */ (return_var ? return_var_ref : nullptr), + (return_var ? ASR::abiType::BindC : ASR::abiType::Source), + ASR::Public, ASR::Implementation, + nullptr, + false, false, false, false, false, + nullptr, 0, nullptr, 0, + false, false, false); + std::string sym_name = fn_name; + if (unit.m_global_scope->get_symbol(sym_name) != nullptr) { + throw LCompilersException("Function already defined"); } + unit.m_global_scope->add_symbol(sym_name, down_cast(fn)); unit.m_items = nullptr; unit.n_items = 0; PassUtils::UpdateDependenciesVisitor v(al); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 802a236a06..f718609247 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -547,6 +547,39 @@ ASR::symbol_t* import_from_module(Allocator &al, ASR::Module_t *m, SymbolTable * return nullptr; } +// Here, we call the global_initializer & global_statements to +// initialize and execute the global symbols +void get_calls_to_global_init_and_stmts(Allocator &al, const Location &loc, SymbolTable* scope, + ASR::Module_t* mod, std::vector &tmp_vec) { + + std::string mod_name = mod->m_name; + std::string g_func_name = mod_name + "@global_initializer"; + ASR::symbol_t *g_func = mod->m_symtab->get_symbol("global_initializer"); + if (g_func && !scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_initializer"), + ASR::accessType::Public)); + scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, loc, + es, g_func, nullptr, 0, nullptr, nullptr, false)); + } + + g_func_name = mod_name + "@global_statements"; + g_func = mod->m_symtab->get_symbol("global_statements"); + if (g_func && !scope->get_symbol(g_func_name)) { + ASR::symbol_t *es = ASR::down_cast( + ASR::make_ExternalSymbol_t(al, mod->base.base.loc, + scope, s2c(al, g_func_name), g_func, + s2c(al, mod_name), nullptr, 0, s2c(al, "global_statements"), + ASR::accessType::Public)); + scope->add_symbol(g_func_name, es); + tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, loc, + es, g_func, nullptr, 0, nullptr, nullptr, false)); + } +} + template class CommonVisitor : public AST::BaseVisitor { public: @@ -568,9 +601,6 @@ class CommonVisitor : public AST::BaseVisitor { Allocator &al; LocationManager &lm; SymbolTable *current_scope; - // The current_module contains the current module that is being visited; - // this is used to append to the module dependencies if needed - ASR::Module_t *current_module = nullptr; SetChar current_module_dependencies; // True for the main module, false for every other one // The main module is stored directly in TranslationUnit, other modules are Modules @@ -585,6 +615,10 @@ class CommonVisitor : public AST::BaseVisitor { std::map &ast_overload; std::string parent_dir; std::vector import_paths; + /* + current_body exists only for Functions, For, If (& its Else part), While. + current_body does not exist for Modules, ClassDef/Structs. + */ Vec *current_body; ASR::ttype_t* ann_assign_target_type; AST::expr_t* assign_ast_target; @@ -668,24 +702,7 @@ class CommonVisitor : public AST::BaseVisitor { v = current_scope->get_symbol(sym); } - // Now we need to add the module `m` with the intrinsic function - // into the current module dependencies - if (current_module) { - // We are in body visitor, the module is already constructed - // and available as current_module. - // Add the module `m` to current module dependencies - SetChar vec; - vec.from_pointer_n_copy(al, current_module->m_dependencies, - current_module->n_dependencies); - vec.push_back(al, m->m_name); - current_module->m_dependencies = vec.p; - current_module->n_dependencies = vec.size(); - } else { - // We are in the symtab visitor or body visitor and we are - // constructing a module, so current_module is not available yet - // (the current_module_dependencies is not used in body visitor) - current_module_dependencies.push_back(al, m->m_name); - } + current_module_dependencies.push_back(al, m->m_name); return v; } @@ -4322,12 +4339,12 @@ class SymbolTableVisitor : public CommonVisitor { } } else { bool is_pure = false, is_module = false; - + // This checks for internal function defintions as well. for (size_t i = 0; i < x.n_body; i++) { visit_stmt(*x.m_body[i]); } - + tmp = ASRUtils::make_Function_t_util( al, x.base.base.loc, /* a_symtab */ current_scope, @@ -4453,7 +4470,6 @@ class SymbolTableVisitor : public CommonVisitor { ASR::symbol_t *t = nullptr; // current_scope->parent->resolve_symbol(msym); if (!t) { std::string rl_path = get_runtime_library_dir(); - SymbolTable *st = current_scope; std::vector paths; for (auto &path:import_paths) { paths.push_back(path); @@ -4461,12 +4477,9 @@ class SymbolTableVisitor : public CommonVisitor { paths.push_back(rl_path); paths.push_back(parent_dir); - if (!main_module) { - st = st->parent; - } bool lpython, enum_py, copy, sympy; set_module_symbol(msym, paths); - t = (ASR::symbol_t*)(load_module(al, st, + t = (ASR::symbol_t*)(load_module(al, global_scope, msym, x.base.base.loc, diag, lm, false, paths, lpython, enum_py, copy, sympy, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }, allow_implicit_casting)); @@ -4530,18 +4543,14 @@ class SymbolTableVisitor : public CommonVisitor { } paths.push_back(rl_path); paths.push_back(parent_dir); - SymbolTable *st = current_scope; std::vector mods; for (size_t i=0; iparent; - } for (auto &mod_sym : mods) { bool lpython, enum_py, copy, sympy; set_module_symbol(mod_sym, paths); - t = (ASR::symbol_t*)(load_module(al, st, + t = (ASR::symbol_t*)(load_module(al, global_scope, mod_sym, x.base.base.loc, diag, lm, false, paths, lpython, enum_py, copy, sympy, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }, allow_implicit_casting)); @@ -4738,12 +4747,12 @@ class BodyVisitor : public CommonVisitor { ASR::TranslationUnit_t *unit = ASR::down_cast2(asr); current_scope = unit->m_global_scope; LCOMPILERS_ASSERT(current_scope != nullptr); - ASR::symbol_t* main_module_sym = current_scope->get_symbol(module_name); + ASR::symbol_t* module_sym = nullptr; ASR::Module_t* mod = nullptr; - if( main_module_sym ) { - mod = ASR::down_cast(main_module_sym); - } + if (!main_module) { + module_sym = current_scope->get_symbol(module_name); + mod = ASR::down_cast(module_sym); current_scope = mod->m_symtab; LCOMPILERS_ASSERT(current_scope != nullptr); } @@ -4765,72 +4774,69 @@ class BodyVisitor : public CommonVisitor { tmp_vec.clear(); } } + if( mod ) { for( size_t i = 0; i < mod->n_dependencies; i++ ) { current_module_dependencies.push_back(al, mod->m_dependencies[i]); } mod->m_dependencies = current_module_dependencies.p; - mod->n_dependencies = current_module_dependencies.size(); - } - - if (global_init.n > 0 && main_module_sym) { - // unit->m_items is used and set to nullptr in the - // `pass_wrap_global_stmts_into_function` pass - unit->m_items = global_init.p; - unit->n_items = global_init.size(); - std::string func_name = "global_initializer"; - LCompilers::PassOptions pass_options; - pass_options.run_fun = func_name; - pass_wrap_global_stmts(al, *unit, pass_options); - - ASR::Module_t *mod = ASR::down_cast(main_module_sym); - ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); - if (f_sym) { - // Add the `global_initilaizer` function into the `__main__` - // module and later call this function to initialize the - // global variables like list, ... - ASR::Function_t *f = ASR::down_cast(f_sym); - f->m_symtab->parent = mod->m_symtab; - mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); - // Erase the function in TranslationUnit - unit->m_global_scope->erase_symbol(func_name); - } - global_init.p = nullptr; - global_init.n = 0; - } - - if (global_init.n > 0) { - // copy all the item's from `items` (global_statements) - // into `global_init` - for (auto &i: items) { - global_init.push_back(al, i); + mod->n_dependencies = current_module_dependencies.n; + + if (global_init.n > 0) { + // unit->m_items is used and set to nullptr in the + // `pass_wrap_global_stmts_into_function` pass + unit->m_items = global_init.p; + unit->n_items = global_init.size(); + std::string func_name = "global_initializer"; + LCompilers::PassOptions pass_options; + pass_options.run_fun = func_name; + pass_wrap_global_stmts(al, *unit, pass_options); + + ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); + if (f_sym) { + // Add the `global_initilaizer` function into the + // module and later call this function to initialize the + // global variables like list, ... + ASR::Function_t *f = ASR::down_cast(f_sym); + f->m_symtab->parent = mod->m_symtab; + mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); + // Erase the function in TranslationUnit + unit->m_global_scope->erase_symbol(func_name); + } + global_init.p = nullptr; + global_init.n = 0; + } + + if (items.n > 0) { + unit->m_items = items.p; + unit->n_items = items.size(); + std::string func_name = "global_statements"; + // Wrap all the global statements into a Function + LCompilers::PassOptions pass_options; + pass_options.run_fun = func_name; + pass_wrap_global_stmts(al, *unit, pass_options); + + ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); + if (f_sym) { + // Add the `global_statements` function into the + // module and later call this function to execute the + // global_statements + ASR::Function_t *f = ASR::down_cast(f_sym); + f->m_symtab->parent = mod->m_symtab; + mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); + // Erase the function in TranslationUnit + unit->m_global_scope->erase_symbol(func_name); + } + items.p = nullptr; + items.n = 0; } - unit->m_items = global_init.p; - unit->n_items = global_init.size(); } else { - unit->m_items = items.p; - unit->n_items = items.size(); - } - - if (items.n > 0 && main_module_sym) { - std::string func_name = "global_statements"; - // Wrap all the global statements into a Function - LCompilers::PassOptions pass_options; - pass_options.run_fun = func_name; - pass_wrap_global_stmts(al, *unit, pass_options); - - ASR::Module_t *mod = ASR::down_cast(main_module_sym); - ASR::symbol_t *f_sym = unit->m_global_scope->get_symbol(func_name); - if (f_sym) { - // Add the `global_statements` function into the `__main__` - // module and later call this function to execute the - // global_statements - ASR::Function_t *f = ASR::down_cast(f_sym); - f->m_symtab->parent = mod->m_symtab; - mod->m_symtab->add_symbol(func_name, (ASR::symbol_t *) f); - // Erase the function in TranslationUnit - unit->m_global_scope->erase_symbol(func_name); + // It is main_module + for (auto item:items) { + global_init.push_back(al, item); } + unit->m_items = global_init.p; + unit->n_items = global_init.size(); } tmp = asr; @@ -4886,47 +4892,18 @@ class BodyVisitor : public CommonVisitor { void visit_Import(const AST::Import_t &x) { // All the modules are imported in the SymbolTable visitor - // Here, we call the global_initializer & global_statements to - // initialize and execute the global symbols for (size_t i = 0; i < x.n_names; i++) { std::string mod_name = x.m_names[i].m_name; ASR::symbol_t *mod_sym = current_scope->resolve_symbol(mod_name); if (mod_sym) { ASR::Module_t *mod = ASR::down_cast(mod_sym); - - std::string g_func_name = mod_name + "@global_initializer"; - ASR::symbol_t *g_func = mod->m_symtab->get_symbol("global_initializer"); - if (g_func && !current_scope->get_symbol(g_func_name)) { - ASR::symbol_t *es = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, mod->base.base.loc, - current_scope, s2c(al, g_func_name), g_func, - s2c(al, mod_name), nullptr, 0, s2c(al, "global_initializer"), - ASR::accessType::Public)); - current_scope->add_symbol(g_func_name, es); - tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, x.base.base.loc, - es, g_func, nullptr, 0, nullptr, nullptr, false)); - } - - g_func_name = mod_name + "@global_statements"; - g_func = mod->m_symtab->get_symbol("global_statements"); - if (g_func && !current_scope->get_symbol(g_func_name)) { - ASR::symbol_t *es = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, mod->base.base.loc, - current_scope, s2c(al, g_func_name), g_func, - s2c(al, mod_name), nullptr, 0, s2c(al, "global_statements"), - ASR::accessType::Public)); - current_scope->add_symbol(g_func_name, es); - tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, x.base.base.loc, - es, g_func, nullptr, 0, nullptr, nullptr, false)); - } + get_calls_to_global_init_and_stmts(al, x.base.base.loc, current_scope, mod, tmp_vec); } } } void visit_ImportFrom(const AST::ImportFrom_t &x) { // Handled by SymbolTableVisitor already - // Here, we call the global_initializer & global_statements to - // initialize and execute the global symbols std::string mod_name = x.m_module; for (size_t i = 0; i < x.n_names; i++) { imported_functions[x.m_names[i].m_name] = mod_name; @@ -4934,32 +4911,7 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t *mod_sym = current_scope->resolve_symbol(mod_name); if (mod_sym) { ASR::Module_t *mod = ASR::down_cast(mod_sym); - - std::string g_func_name = mod_name + "@global_initializer"; - ASR::symbol_t *g_func = mod->m_symtab->get_symbol("global_initializer"); - if (g_func && !current_scope->get_symbol(g_func_name)) { - ASR::symbol_t *es = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, mod->base.base.loc, - current_scope, s2c(al, g_func_name), g_func, - s2c(al, mod_name), nullptr, 0, s2c(al, "global_initializer"), - ASR::accessType::Public)); - current_scope->add_symbol(g_func_name, es); - tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, x.base.base.loc, - es, g_func, nullptr, 0, nullptr, nullptr, false)); - } - - g_func_name = mod_name + "@global_statements"; - g_func = mod->m_symtab->get_symbol("global_statements"); - if (g_func && !current_scope->get_symbol(g_func_name)) { - ASR::symbol_t *es = ASR::down_cast( - ASR::make_ExternalSymbol_t(al, mod->base.base.loc, - current_scope, s2c(al, g_func_name), g_func, - s2c(al, mod_name), nullptr, 0, s2c(al, "global_statements"), - ASR::accessType::Public)); - current_scope->add_symbol(g_func_name, es); - tmp_vec.push_back(ASRUtils::make_SubroutineCall_t_util(al, x.base.base.loc, - es, g_func, nullptr, 0, nullptr, nullptr, false)); - } + get_calls_to_global_init_and_stmts(al, x.base.base.loc, current_scope, mod, tmp_vec); } tmp = nullptr; } diff --git a/src/lpython/semantics/python_ast_to_asr.h b/src/lpython/semantics/python_ast_to_asr.h index 34557bbb95..8270846c32 100644 --- a/src/lpython/semantics/python_ast_to_asr.h +++ b/src/lpython/semantics/python_ast_to_asr.h @@ -8,7 +8,7 @@ namespace LCompilers::LPython { Result python_ast_to_asr(Allocator &al, LocationManager &lm, SymbolTable* symtab, LPython::AST::ast_t &ast, diag::Diagnostics &diagnostics, CompilerOptions &compiler_options, - bool main_module, std::string ext_mod_name, std::string file_path, bool allow_implicit_casting=false); + bool main_module, std::string module_name, std::string file_path, bool allow_implicit_casting=false); int save_pyc_files(const ASR::TranslationUnit_t &u, std::string infile);