diff --git a/src/libasr/asdl_cpp.py b/src/libasr/asdl_cpp.py index 8eb94ef999..e9b66117b9 100644 --- a/src/libasr/asdl_cpp.py +++ b/src/libasr/asdl_cpp.py @@ -1303,6 +1303,8 @@ def visitField(self, field): self.emit(" self().replace_expr(x->m_%s[i]);"%(field.name), level) self.emit(" current_expr = current_expr_copy_%d;" % (self.current_expr_copy_variable_count), level) self.current_expr_copy_variable_count += 1 + elif field.type == "ttype": + self.emit(" self().replace_%s(x->m_%s[i]);" % (field.type, field.name), level) self.emit("}", level) else: if field.type != "symbol": diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index bcab5f631f..3880ce25d2 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -3247,10 +3247,12 @@ class ReplaceWithFunctionParamVisitor: public ASR::BaseExprReplacerm_v)); + ASRUtils::symbol_type(x->m_v), current_scope); *current_expr = ASRUtils::EXPR(ASR::make_FunctionParam_t( al, m_args[arg_idx]->base.loc, arg_idx, t_, nullptr)); } } - ASR::ttype_t* replace_args_with_FunctionParam(ASR::ttype_t* t) { + void replace_Struct(ASR::Struct_t *x) { + std::string derived_type_name = ASRUtils::symbol_name(x->m_derived_type); + ASR::symbol_t* derived_type_sym = current_scope->resolve_symbol(derived_type_name); + LCOMPILERS_ASSERT_MSG( derived_type_sym != nullptr, + "derived_type_sym cannot be nullptr"); + if (derived_type_sym != x->m_derived_type) { + x->m_derived_type = derived_type_sym; + } + } + + ASR::ttype_t* replace_args_with_FunctionParam(ASR::ttype_t* t, SymbolTable* current_scope) { + this->current_scope = current_scope; + ASRUtils::ExprStmtDuplicator duplicator(al); duplicator.allow_procedure_calls = true; @@ -3312,7 +3326,7 @@ inline ASR::asr_t* make_FunctionType_t_util(Allocator &al, ASR::expr_t* a_return_var, ASR::abiType a_abi, ASR::deftypeType a_deftype, char* a_bindc_name, bool a_elemental, bool a_pure, bool a_module, bool a_inline, bool a_static, - ASR::symbol_t** a_restrictions, size_t n_restrictions, bool a_is_restriction) { + ASR::symbol_t** a_restrictions, size_t n_restrictions, bool a_is_restriction, SymbolTable* current_scope) { Vec arg_types; arg_types.reserve(al, n_args); ReplaceWithFunctionParamVisitor replacer(al, a_args, n_args); @@ -3320,13 +3334,13 @@ inline ASR::asr_t* make_FunctionType_t_util(Allocator &al, // We need to substitute all direct argument variable references with // FunctionParam. ASR::ttype_t *t = replacer.replace_args_with_FunctionParam( - expr_type(a_args[i])); + expr_type(a_args[i]), current_scope); arg_types.push_back(al, t); } ASR::ttype_t* return_var_type = nullptr; if( a_return_var ) { return_var_type = replacer.replace_args_with_FunctionParam( - ASRUtils::expr_type(a_return_var)); + ASRUtils::expr_type(a_return_var), current_scope); } LCOMPILERS_ASSERT(arg_types.size() == n_args); @@ -3338,12 +3352,12 @@ inline ASR::asr_t* make_FunctionType_t_util(Allocator &al, } inline ASR::asr_t* make_FunctionType_t_util(Allocator &al, const Location &a_loc, - ASR::expr_t** a_args, size_t n_args, ASR::expr_t* a_return_var, ASR::FunctionType_t* ft) { + ASR::expr_t** a_args, size_t n_args, ASR::expr_t* a_return_var, ASR::FunctionType_t* ft, SymbolTable* current_scope) { return ASRUtils::make_FunctionType_t_util(al, a_loc, a_args, n_args, a_return_var, ft->m_abi, ft->m_deftype, ft->m_bindc_name, ft->m_elemental, ft->m_pure, ft->m_module, ft->m_inline, ft->m_static, ft->m_restrictions, - ft->n_restrictions, ft->m_is_restriction); + ft->n_restrictions, ft->m_is_restriction, current_scope); } inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc, @@ -3357,7 +3371,7 @@ inline ASR::asr_t* make_Function_t_util(Allocator& al, const Location& loc, ASR::ttype_t* func_type = ASRUtils::TYPE(ASRUtils::make_FunctionType_t_util( al, loc, a_args, n_args, m_return_var, m_abi, m_deftype, m_bindc_name, m_elemental, m_pure, m_module, m_inline, m_static, - m_restrictions, n_restrictions, m_is_restriction)); + m_restrictions, n_restrictions, m_is_restriction, m_symtab)); return ASR::make_Function_t( al, loc, m_symtab, m_name, func_type, m_dependencies, n_dependencies, a_args, n_args, m_body, n_body, m_return_var, m_access, m_deterministic, diff --git a/src/libasr/pass/pass_array_by_data.cpp b/src/libasr/pass/pass_array_by_data.cpp index eb53480c8e..a9dbcdcaa1 100644 --- a/src/libasr/pass/pass_array_by_data.cpp +++ b/src/libasr/pass/pass_array_by_data.cpp @@ -199,7 +199,7 @@ class PassArrayByDataProcedureVisitor : public PassUtils::PassVisitorm_function_signature = ASRUtils::TYPE(ASRUtils::make_FunctionType_t_util( - al, func_type->base.base.loc, new_args.p, new_args.size(), x->m_return_var, func_type)); + al, func_type->base.base.loc, new_args.p, new_args.size(), x->m_return_var, func_type, current_scope)); x->m_args = new_args.p; x->n_args = new_args.size(); } diff --git a/src/libasr/pass/pass_utils.h b/src/libasr/pass/pass_utils.h index 27c7edea7e..714f78a355 100644 --- a/src/libasr/pass/pass_utils.h +++ b/src/libasr/pass/pass_utils.h @@ -744,7 +744,7 @@ namespace LCompilers { for(auto &e: a_args) { ASRUtils::ReplaceWithFunctionParamVisitor replacer(al, x->m_args, x->n_args); arg_types.push_back(al, replacer.replace_args_with_FunctionParam( - ASRUtils::expr_type(e))); + ASRUtils::expr_type(e), x->m_symtab)); } s_func_type->m_arg_types = arg_types.p; s_func_type->n_arg_types = arg_types.n;