Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Prev Previous commit
Next Next commit
use generic procedure
  • Loading branch information
Smit-create authored and certik committed Mar 9, 2022
commit 32bede6832bb6d849f913e320d94ec6ee608e377
115 changes: 64 additions & 51 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,13 @@ ASR::Module_t* load_module(Allocator &al, SymbolTable *symtab,
}

template <typename T>
bool argument_types_match(const Vec<ASR::ttype_t*> &args,
bool argument_types_match(const Vec<ASR::expr_t*> &args,
const T &sub) {
if (args.size() <= sub.n_args) {
size_t i;
for (i = 0; i < args.size(); i++) {
ASR::Variable_t *v = LFortran::ASRUtils::EXPR2VAR(sub.m_args[i]);
ASR::ttype_t *arg1 = args[i];
ASR::ttype_t *arg1 = ASRUtils::expr_type(args[i]);
ASR::ttype_t *arg2 = v->m_type;
if (!ASRUtils::check_equal_type(arg1, arg2)) {
return false;
Expand All @@ -164,7 +164,7 @@ bool argument_types_match(const Vec<ASR::ttype_t*> &args,
}
}

bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::ttype_t*> &args,
bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::expr_t*> &args,
const Location& loc, const std::function<void (const std::string &, const Location &)> err) {
bool result = false;
if (ASR::is_a<ASR::Subroutine_t>(*proc)) {
Expand All @@ -185,8 +185,7 @@ bool select_func_subrout(const ASR::symbol_t* proc, const Vec<ASR::ttype_t*> &ar
return result;
}

std::map<std::string, std::vector<std::string>> overload_definitons;

std::map<int, ASR::symbol_t*> ast_overload;
template <class Derived>
class CommonVisitor : public AST::BaseVisitor<Derived> {
public:
Expand All @@ -204,7 +203,7 @@ class CommonVisitor : public AST::BaseVisitor<Derived> {
// The main module is stored directly in TranslationUnit, other modules are Modules
bool main_module;
PythonIntrinsicProcedures intrinsic_procedures;
std::map<std::string, std::vector<std::string>> overload_defs;
std::map<std::string, Vec<ASR::symbol_t* >> overload_defs;

CommonVisitor(Allocator &al, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module)
Expand Down Expand Up @@ -571,7 +570,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
for (size_t i=0; i<x.n_body; i++) {
visit_stmt(*x.m_body[i]);
}

if (!overload_defs.empty()) {
create_GenericProcedure(x.base.base.loc);
}
global_scope = nullptr;
tmp = tmp0;
}
Expand Down Expand Up @@ -642,6 +643,9 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
std::string overload_number;
if (overload_defs.find(sym_name) == overload_defs.end()){
overload_number = "0";
Vec<ASR::symbol_t *> v;
v.reserve(al, 1);
overload_defs[sym_name] = v;
} else {
overload_number = std::to_string(overload_defs[sym_name].size());
}
Expand Down Expand Up @@ -700,10 +704,22 @@ class SymbolTableVisitor : public CommonVisitor<SymbolTableVisitor> {
s_access, deftype, bindc_name,
is_pure, is_module);
}
parent_scope->scope[sym_name] = ASR::down_cast<ASR::symbol_t>(tmp);
ASR::symbol_t * t = ASR::down_cast<ASR::symbol_t>(tmp);
parent_scope->scope[sym_name] = t;
current_scope = parent_scope;
if (overload) {
overload_defs[x.m_name].push_back(sym_name);
overload_defs[x.m_name].push_back(al, t);
ast_overload[(int64_t)&x] = t;
}
}

void create_GenericProcedure(const Location &loc) {
for(auto &p: overload_defs) {
std::string def_name = p.first;
tmp = ASR::make_GenericProcedure_t(al, loc, current_scope, s2c(al, def_name),
p.second.p, p.second.size(), ASR::accessType::Public);
ASR::symbol_t *t = ASR::down_cast<ASR::symbol_t>(tmp);
current_scope->scope[def_name] = t;
}
}

Expand Down Expand Up @@ -801,7 +817,6 @@ Result<ASR::asr_t*> symbol_table_visitor(Allocator &al, const AST::Module_t &ast
SymbolTableVisitor v(al, nullptr, diagnostics, main_module);
try {
v.visit_Module(ast);
overload_definitons = v.overload_defs;
} catch (const SemanticError &e) {
Error error;
diagnostics.diagnostics.push_back(e.d);
Expand Down Expand Up @@ -882,44 +897,24 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
v.n_body = body.size();
}

ASR::symbol_t* overloaddef_find_helper(std::string func_name, Vec<ASR::ttype_t*> args,
const Location &loc) {
for(auto &t: overload_defs[func_name]) {
SymbolTable *symtab = current_scope;
while (symtab!= nullptr && symtab->scope.find(t) == symtab->scope.end()) {
symtab = symtab->parent;
}
LFORTRAN_ASSERT(symtab != nullptr);
ASR::symbol_t *st = symtab->scope[t];
bool ok = select_func_subrout(st, args, loc,
[&](const std::string &msg, const Location &l) { throw SemanticError(msg, l); });
if (ok) {
return st;
}
}
return nullptr;
}

void visit_FunctionDef(const AST::FunctionDef_t &x) {
SymbolTable *old_scope = current_scope;
ASR::symbol_t *t = nullptr;
if (overload_defs.find(x.m_name) != overload_defs.end()) {
Vec<ASR::ttype_t *> args;
args.reserve(al, x.m_args.n_args);
for (size_t i=0; i<x.m_args.n_args; i++) {
ASR::ttype_t *arg_type = ast_expr_to_asr_type(x.base.base.loc,
*x.m_args.m_args[i].m_annotation);
args.push_back(al, arg_type);
}
t = overloaddef_find_helper(x.m_name, args, x.base.base.loc);
} else {
t = current_scope->scope[x.m_name];
}
ASR::symbol_t *t = t = current_scope->scope[x.m_name];
if (ASR::is_a<ASR::Subroutine_t>(*t)) {
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(t));
} else if (ASR::is_a<ASR::Function_t>(*t)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(t);
handle_fn(x, *f);
} else if (ASR::is_a<ASR::GenericProcedure_t>(*t)) {
ASR::symbol_t *s = ast_overload[(int64_t)&x];
if (ASR::is_a<ASR::Subroutine_t>(*s)) {
handle_fn(x, *ASR::down_cast<ASR::Subroutine_t>(s));
} else if (ASR::is_a<ASR::Function_t>(*s)) {
ASR::Function_t *f = ASR::down_cast<ASR::Function_t>(s);
handle_fn(x, *f);
} else {
LFORTRAN_ASSERT(false);
}
} else {
LFORTRAN_ASSERT(false);
}
Expand Down Expand Up @@ -2211,15 +2206,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
x.base.base.loc);
}

ASR::symbol_t *s = current_scope->resolve_symbol(call_name);

if (!s && overload_defs.find(call_name)!=overload_defs.end()) {
Vec<ASR::ttype_t*> args_type;
args_type.reserve(al, x.n_args);
for(size_t i=0; i<x.n_args; i++) {
args_type.push_back(al, ASRUtils::expr_type(args[i]));
}
s = overloaddef_find_helper(call_name, args_type, x.base.base.loc);
ASR::symbol_t *s = current_scope->resolve_symbol(call_name), *s_generic = nullptr;
if (s->type == ASR::symbolType::GenericProcedure){
ASR::GenericProcedure_t *p = ASR::down_cast<ASR::GenericProcedure_t>(s);
int idx = select_generic_procedure(args, *p, x.base.base.loc);
// Create ExternalSymbol for procedures in different modules.
s_generic = s;
s = p->m_procs[idx];
}

if (!s) {
Expand Down Expand Up @@ -2366,6 +2359,27 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
x.base.base.loc);
}
}
int select_generic_procedure(const Vec<ASR::expr_t*> &args,
const ASR::GenericProcedure_t &p, Location loc) {
for (size_t i=0; i < p.n_procs; i++) {

if( ASR::is_a<ASR::ClassProcedure_t>(*p.m_procs[i]) ) {
ASR::ClassProcedure_t *clss_fn
= ASR::down_cast<ASR::ClassProcedure_t>(p.m_procs[i]);
const ASR::symbol_t *proc = ASRUtils::symbol_get_past_external(clss_fn->m_proc);
if( select_func_subrout(proc, args, loc,
[&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); })
){
return i;
}
} else {
if( select_func_subrout(p.m_procs[i], args, loc, [&](const std::string &msg, const Location &loc) { throw SemanticError(msg, loc); }) ) {
return i;
}
}
}
throw SemanticError("Arguments do not match for any generic procedure", loc);
}

void visit_ImportFrom(const AST::ImportFrom_t &/*x*/) {
// Handled by SymbolTableVisitor already
Expand All @@ -2380,7 +2394,6 @@ Result<ASR::TranslationUnit_t*> body_visitor(Allocator &al,
{
BodyVisitor b(al, unit, diagnostics, main_module);
try {
b.overload_defs = overload_definitons;
b.visit_Module(ast);
} catch (const SemanticError &e) {
Error error;
Expand Down