diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index e31e899223..103a954244 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -353,6 +353,8 @@ RUN(NAME structs_13 LABELS llvm c EXTRAFILES structs_13b.c) RUN(NAME structs_14 LABELS cpython llvm c) RUN(NAME structs_15 LABELS cpython llvm c) +RUN(NAME structs_16 LABELS cpython llvm c) +RUN(NAME structs_17 LABELS cpython llvm c) RUN(NAME sizeof_01 LABELS llvm c EXTRAFILES sizeof_01b.c) RUN(NAME enum_01 LABELS cpython llvm c) diff --git a/integration_tests/structs_16.py b/integration_tests/structs_16.py new file mode 100644 index 0000000000..1cb49e8e3f --- /dev/null +++ b/integration_tests/structs_16.py @@ -0,0 +1,19 @@ +from ltypes import i32, i64, dataclass, union, Union + +@dataclass +class A: + @union + class B(Union): + x: i32 + y: i64 + b: B + c: i32 + +def test_ordering(): + bd: A.B = A.B() + bd.x = 1 + ad: A = A(bd, 2) + assert ad.b.x == 1 + assert ad.c == 2 + +test_ordering() diff --git a/integration_tests/structs_17.py b/integration_tests/structs_17.py new file mode 100644 index 0000000000..10d9717451 --- /dev/null +++ b/integration_tests/structs_17.py @@ -0,0 +1,31 @@ +from ltypes import i32, f32, f64, dataclass + +@dataclass +class B: + z: i32 + @dataclass + class C: + cz: f32 + bc: C + +@dataclass +class A: + y: f32 + x: i32 + b: B + + +def f(a: A): + print(a.x) + print(a.y) + print(a.b.z) + +def g(): + x: A = A(f32(3.25), 3, B(71, B.C(f32(4.0)))) + f(x) + assert x.x == 3 + assert f64(x.y) == 3.25 + assert x.b.z == 71 + assert f64(x.b.bc.cz) == 4.0 + +g() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 634f5fa98d..99e7bb8c03 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -288,6 +288,7 @@ expr | BitCast(expr source, expr mold, expr? size, ttype type, expr? value) | StructInstanceMember(expr v, symbol m, ttype type, expr? value) + | StructStaticMember(expr v, symbol m, ttype type, expr? value) | EnumMember(expr v, symbol m, ttype type, expr? value) | UnionRef(expr v, symbol m, ttype type, expr? value) | EnumName(expr v, ttype enum_type, ttype type, expr? value) diff --git a/src/libasr/asr_utils.cpp b/src/libasr/asr_utils.cpp index 22f0a2ebea..d1b94a4b3d 100644 --- a/src/libasr/asr_utils.cpp +++ b/src/libasr/asr_utils.cpp @@ -306,7 +306,7 @@ ASR::TranslationUnit_t* find_and_load_module(Allocator &al, const std::string &m ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc, ASR::asr_t* v_var, ASR::symbol_t* member, SymbolTable* current_scope) { - ASR::Variable_t* member_variable = ((ASR::Variable_t*)(&(member->base))); + ASR::Variable_t* member_variable = ASR::down_cast(member); ASR::ttype_t* member_type = member_variable->m_type; switch( member_type->type ) { case ASR::ttypeType::Struct: { @@ -317,23 +317,27 @@ ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc, ASR::symbol_t* der_ext; char* module_name = (char*)"~nullptr"; ASR::symbol_t* m_external = der->m_derived_type; - if( m_external->type == ASR::symbolType::ExternalSymbol ) { - ASR::ExternalSymbol_t* m_ext = (ASR::ExternalSymbol_t*)(&(m_external->base)); + if( ASR::is_a(*m_external) ) { + ASR::ExternalSymbol_t* m_ext = ASR::down_cast(m_external); m_external = m_ext->m_external; module_name = m_ext->m_module_name; + } else if( ASR::is_a(*m_external) ) { + ASR::symbol_t* asr_owner = ASRUtils::get_asr_owner(m_external); + if( ASR::is_a(*asr_owner) ) { + module_name = ASRUtils::symbol_name(asr_owner); + } } - Str mangled_name; - mangled_name.from_str(al, "1_" + + std::string mangled_name = current_scope->get_unique_name( std::string(module_name) + "_" + std::string(der_type_name)); - char* mangled_name_char = mangled_name.c_str(al); - if( current_scope->get_symbol(mangled_name.str()) == nullptr ) { + char* mangled_name_char = s2c(al, mangled_name); + if( current_scope->get_symbol(mangled_name) == nullptr ) { bool make_new_ext_sym = true; ASR::symbol_t* der_tmp = nullptr; if( current_scope->get_symbol(std::string(der_type_name)) != nullptr ) { der_tmp = current_scope->get_symbol(std::string(der_type_name)); - if( der_tmp->type == ASR::symbolType::ExternalSymbol ) { - ASR::ExternalSymbol_t* der_ext_tmp = (ASR::ExternalSymbol_t*)(&(der_tmp->base)); + if( ASR::is_a(*der_tmp) ) { + ASR::ExternalSymbol_t* der_ext_tmp = ASR::down_cast(der_tmp); if( der_ext_tmp->m_external == m_external ) { make_new_ext_sym = false; } @@ -342,15 +346,17 @@ ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc, } } if( make_new_ext_sym ) { - der_ext = (ASR::symbol_t*)ASR::make_ExternalSymbol_t(al, loc, current_scope, mangled_name_char, m_external, - module_name, nullptr, 0, s2c(al, der_type_name), ASR::accessType::Public); - current_scope->add_symbol(mangled_name.str(), der_ext); + der_ext = ASR::down_cast(ASR::make_ExternalSymbol_t( + al, loc, current_scope, mangled_name_char, m_external, + module_name, nullptr, 0, s2c(al, der_type_name), + ASR::accessType::Public)); + current_scope->add_symbol(mangled_name, der_ext); } else { LFORTRAN_ASSERT(der_tmp != nullptr); der_ext = der_tmp; } } else { - der_ext = current_scope->get_symbol(mangled_name.str()); + der_ext = current_scope->get_symbol(mangled_name); } ASR::asr_t* der_new = ASR::make_Struct_t(al, loc, der_ext, der->m_dims, der->n_dims); member_type = ASRUtils::TYPE(der_new); diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index be273e553b..c5653d8e46 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -451,6 +451,15 @@ static inline ASR::Module_t *get_sym_module(const ASR::symbol_t *sym) { return nullptr; } +// Returns the ASR owner of the symbol +static inline ASR::symbol_t *get_asr_owner(const ASR::symbol_t *sym) { + const SymbolTable *s = symbol_parent_symtab(sym); + if( !ASR::is_a(*s->asr_owner) ) { + return nullptr; + } + return ASR::down_cast(s->asr_owner); +} + // Returns the Module_t the symbol is in, or nullptr if not in a module // or no asr_owner yet static inline ASR::Module_t *get_sym_module0(const ASR::symbol_t *sym) { diff --git a/src/libasr/asr_verify.cpp b/src/libasr/asr_verify.cpp index 57e2366cd9..3fe01476ee 100644 --- a/src/libasr/asr_verify.cpp +++ b/src/libasr/asr_verify.cpp @@ -331,7 +331,9 @@ class VerifyVisitor : public BaseWalkVisitor for (auto &a : x.m_symtab->get_scope()) { this->visit_symbol(*a.second); if( ASR::is_a(*a.second) || - ASR::is_a(*a.second) ) { + ASR::is_a(*a.second) || + ASR::is_a(*a.second) || + ASR::is_a(*a.second) ) { continue ; } ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(a.second)); @@ -346,7 +348,8 @@ class VerifyVisitor : public BaseWalkVisitor ASR::symbol_t* sym = ASR::down_cast(var_type)->m_union_type; aggregate_type_name = ASRUtils::symbol_name(sym); } - if( aggregate_type_name ) { + if( aggregate_type_name && + !current_symtab->get_symbol(std::string(aggregate_type_name)) ) { struct_dependencies.push_back(std::string(aggregate_type_name)); require(present(x.m_dependencies, x.n_dependencies, std::string(aggregate_type_name)), std::string(x.m_name) + " depends on " + std::string(aggregate_type_name) @@ -481,17 +484,34 @@ class VerifyVisitor : public BaseWalkVisitor require(std::string(x.m_original_name) == std::string(orig_name), "ExternalSymbol::m_original_name must match external->m_name"); ASR::Module_t *m = ASRUtils::get_sym_module(x.m_external); - require(m, - "ExternalSymbol::m_external is not in a module"); - require(std::string(x.m_module_name) == std::string(m->m_name), + ASR::StructType_t* sm = nullptr; + bool is_valid_owner = false; + is_valid_owner = m != nullptr; + std::string asr_owner_name = ""; + if( !is_valid_owner ) { + ASR::symbol_t* asr_owner_sym = ASRUtils::get_asr_owner(x.m_external); + is_valid_owner = ASR::is_a(*asr_owner_sym); + sm = ASR::down_cast(asr_owner_sym); + asr_owner_name = sm->m_name; + } else { + asr_owner_name = m->m_name; + } + require(is_valid_owner, + "ExternalSymbol::m_external is not in a module or struct type"); + require(std::string(x.m_module_name) == asr_owner_name, "ExternalSymbol::m_module_name `" + std::string(x.m_module_name) - + "` must match external's module name `" + std::string(m->m_name) + "`"); - ASR::symbol_t *s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names); + + "` must match external's module name `" + asr_owner_name + "`"); + ASR::symbol_t *s = nullptr; + if( m ) { + s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names); + } else if( sm ) { + s = sm->m_symtab->resolve_symbol(std::string(x.m_original_name)); + } require(s != nullptr, "ExternalSymbol::m_original_name ('" + std::string(x.m_original_name) + "') + scope_names not found in a module '" - + std::string(m->m_name) + "'"); + + asr_owner_name + "'"); require(s == x.m_external, "ExternalSymbol::m_name + scope_names found but not equal to m_external"); } diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index d30b913ae8..ecd754e56d 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -139,6 +139,10 @@ class ASRToCVisitor : public BaseCCPPVisitor void allocate_array_members_of_struct(ASR::StructType_t* der_type_t, std::string& sub, std::string indent, std::string name) { for( auto itr: der_type_t->m_symtab->get_scope() ) { + if( ASR::is_a(*itr.second) || + ASR::is_a(*itr.second) ) { + continue ; + } ASR::ttype_t* mem_type = ASRUtils::symbol_type(itr.second); if( ASRUtils::is_character(*mem_type) ) { sub += indent + name + "->" + itr.first + " = (char*) malloc(40 * sizeof(char));\n"; @@ -418,7 +422,8 @@ class ASRToCVisitor : public BaseCCPPVisitor } else if (ASR::is_a(*v_m_type)) { std::string indent(indentation_level*indentation_spaces, ' '); ASR::Union_t *t = ASR::down_cast(v_m_type); - std::string der_type_name = ASRUtils::symbol_name(t->m_union_type); + std::string der_type_name = ASRUtils::symbol_name( + ASRUtils::symbol_get_past_external(t->m_union_type)); if( is_array ) { bool is_fixed_size = true; dims = convert_dims_c(t->n_dims, t->m_dims, v_m_type, is_fixed_size, true); @@ -706,11 +711,25 @@ R"( } template - void visit_AggregateTypeUtil(const T& x, std::string c_type_name) { + void visit_AggregateTypeUtil(const T& x, std::string c_type_name, + std::string& src_dest) { + std::string body = ""; + int indendation_level_copy = indentation_level; + for( auto itr: x.m_symtab->get_scope() ) { + if( ASR::is_a(*itr.second) ) { + visit_AggregateTypeUtil(*ASR::down_cast(itr.second), + "union", src_dest); + } else if( ASR::is_a(*itr.second) ) { + std::string struct_c_type_name = get_StructCTypeName( + *ASR::down_cast(itr.second)); + visit_AggregateTypeUtil(*ASR::down_cast(itr.second), + struct_c_type_name, src_dest); + } + } + indentation_level = indendation_level_copy; std::string indent(indentation_level*indentation_spaces, ' '); indentation_level += 1; std::string open_struct = indent + c_type_name + " " + std::string(x.m_name) + " {\n"; - std::string body = ""; indent.push_back(' '); for( size_t i = 0; i < x.n_members; i++ ) { ASR::symbol_t* member = x.m_symtab->get_symbol(x.m_members[i]); @@ -725,11 +744,10 @@ R"( } indentation_level -= 1; std::string end_struct = "};\n\n"; - array_types_decls += open_struct + body + end_struct; + src_dest += open_struct + body + end_struct; } - void visit_StructType(const ASR::StructType_t& x) { - src = ""; + std::string get_StructCTypeName(const ASR::StructType_t& x) { std::string c_type_name = "struct"; if( x.m_is_packed ) { std::string attr_args = "(packed"; @@ -745,12 +763,18 @@ R"( attr_args += ")"; c_type_name += " __attribute__(" + attr_args + ")"; } - visit_AggregateTypeUtil(x, c_type_name); + return c_type_name; + } + + void visit_StructType(const ASR::StructType_t& x) { + src = ""; + std::string c_type_name = get_StructCTypeName(x); + visit_AggregateTypeUtil(x, c_type_name, array_types_decls); src = ""; } void visit_UnionType(const ASR::UnionType_t& x) { - visit_AggregateTypeUtil(x, "union"); + visit_AggregateTypeUtil(x, "union", array_types_decls); } void visit_EnumType(const ASR::EnumType_t& x) { diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 60421f5174..5822d32033 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -732,6 +732,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } const std::map& scope = der_type->m_symtab->get_scope(); for( auto itr = scope.begin(); itr != scope.end(); itr++ ) { + if( ASR::is_a(*itr->second) || + ASR::is_a(*itr->second) ) { + continue ; + } ASR::Variable_t* member = ASR::down_cast(itr->second); llvm::Type* llvm_mem_type = get_type_from_ttype_t_util(member->m_type, member->m_abi); member_types.push_back(llvm_mem_type); @@ -2665,7 +2669,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::string struct_type_name = struct_type_t->m_name; for( auto item: struct_type_t->m_symtab->get_scope() ) { if( ASR::is_a(*item.second) || - ASR::is_a(*item.second) ) { + ASR::is_a(*item.second) || + ASR::is_a(*item.second) || + ASR::is_a(*item.second) ) { continue ; } ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second); @@ -5171,7 +5177,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } case ASR::ttypeType::Union: { ASR::Union_t* der = ASR::down_cast(x->m_type); - ASR::UnionType_t* der_type = ASR::down_cast(der->m_union_type); + ASR::UnionType_t* der_type = ASR::down_cast( + ASRUtils::symbol_get_past_external(der->m_union_type)); der_type_name = std::string(der_type->m_name); uint32_t h = get_hash((ASR::asr_t*)x); if( llvm_symtab.find(h) != llvm_symtab.end() ) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index f6ab880f4e..95a56dff7c 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -1163,8 +1163,8 @@ class CommonVisitor : public AST::BaseVisitor { } args_new.p[i] = arg_new_i; } - ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, s, nullptr, 0)); - return ASR::make_StructTypeConstructor_t(al, loc, s, args_new.p, args_new.size(), der_type, nullptr); + ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, stemp, nullptr, 0)); + return ASR::make_StructTypeConstructor_t(al, loc, stemp, args_new.p, args_new.size(), der_type, nullptr); } else if( ASR::is_a(*s) ) { Vec args_new; args_new.reserve(al, args.size()); @@ -1179,14 +1179,14 @@ class CommonVisitor : public AST::BaseVisitor { args_new.p[i] = arg_new_i; } ASR::ttype_t* der_type = ASRUtils::TYPE(ASR::make_Enum_t(al, loc, s, nullptr, 0)); - return ASR::make_EnumTypeConstructor_t(al, loc, s, args_new.p, args_new.size(), der_type, nullptr); + return ASR::make_EnumTypeConstructor_t(al, loc, stemp, args_new.p, args_new.size(), der_type, nullptr); } else if( ASR::is_a(*s) ) { if( args.size() != 0 ) { throw SemanticError("Union constructors do not accept any argument as of now.", loc); } - ASR::ttype_t* union_ = ASRUtils::TYPE(ASR::make_Union_t(al, loc, s, nullptr, 0)); - return ASR::make_UnionTypeConstructor_t(al, loc, s, nullptr, 0, union_, nullptr); + ASR::ttype_t* union_ = ASRUtils::TYPE(ASR::make_Union_t(al, loc, stemp, nullptr, 0)); + return ASR::make_UnionTypeConstructor_t(al, loc, stemp, nullptr, 0, union_, nullptr); } else { throw SemanticError("Unsupported call type for " + call_name, loc); } @@ -1543,6 +1543,47 @@ class CommonVisitor : public AST::BaseVisitor { fill_dims_for_asr_type(dims, value, loc); } } + } else if (AST::is_a(annotation)) { + AST::Attribute_t* attr_annotation = AST::down_cast(&annotation); + LFORTRAN_ASSERT(AST::is_a(*attr_annotation->m_value)); + std::string value = AST::down_cast(attr_annotation->m_value)->m_id; + ASR::symbol_t *t = current_scope->resolve_symbol(value); + + if (!t) { + throw SemanticError("'" + value + "' is not defined in the scope", + attr_annotation->base.base.loc); + } + LFORTRAN_ASSERT(ASR::is_a(*t)); + ASR::StructType_t* struct_type = ASR::down_cast(t); + std::string struct_var_name = struct_type->m_name; + std::string struct_member_name = attr_annotation->m_attr; + ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name); + if( !struct_member ) { + throw SemanticError(struct_member_name + " not present in " + + struct_var_name + " dataclass.", + attr_annotation->base.base.loc); + } + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == struct_member && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + attr_annotation->base.base.loc, current_scope, s2c(al, import_name), + struct_member,s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + return ASRUtils::TYPE(ASR::make_Union_t(al, attr_annotation->base.base.loc, import_struct_member, nullptr, 0)); } else { throw SemanticError("Only Name, Subscript, and Call supported for now in annotation of annotated assignment.", loc); @@ -2330,6 +2371,10 @@ class CommonVisitor : public AST::BaseVisitor { bool is_enum_scope=false, ASR::abiType abi=ASR::abiType::Source) { int64_t prev_value = 1; for( size_t i = 0; i < x.n_body; i++ ) { + if( AST::is_a(*x.m_body[i]) ) { + visit_ClassDef(*AST::down_cast(x.m_body[i])); + continue; + } LFORTRAN_ASSERT(AST::is_a(*x.m_body[i])); AST::AnnAssign_t* ann_assign = AST::down_cast(x.m_body[i]); LFORTRAN_ASSERT(AST::is_a(*ann_assign->m_target)); @@ -2382,7 +2427,8 @@ class CommonVisitor : public AST::BaseVisitor { aggregate_type_name = ASRUtils::symbol_name( ASR::down_cast(var_type)->m_union_type); } - if( aggregate_type_name ) { + if( aggregate_type_name && + !current_scope->get_symbol(std::string(aggregate_type_name)) ) { struct_dependencies.push_back(al, aggregate_type_name); } member_names.push_back(al, n->m_id); @@ -4529,8 +4575,41 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name); LFORTRAN_ASSERT(ASR::is_a(*member_sym)); ASR::Variable_t* member_var = ASR::down_cast(member_sym); + ASR::ttype_t* member_var_type = member_var->m_type; + if( ASR::is_a(*member_var->m_type) ) { + ASR::Struct_t* member_var_struct_t = ASR::down_cast(member_var->m_type); + if( !ASR::is_a(*member_var_struct_t->m_derived_type) ) { + ASR::StructType_t* struct_type = ASR::down_cast(member_var_struct_t->m_derived_type); + ASR::symbol_t* struct_type_asr_owner = ASRUtils::get_asr_owner(member_var_struct_t->m_derived_type); + if( struct_type_asr_owner && ASR::is_a(*struct_type_asr_owner) ) { + std::string struct_var_name = ASR::down_cast(struct_type_asr_owner)->m_name; + std::string struct_member_name = struct_type->m_name; + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == member_var_struct_t->m_derived_type && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + loc, current_scope, s2c(al, import_name), + member_var_struct_t->m_derived_type, s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + member_var_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, import_struct_member, nullptr, 0)); + } + } + } tmp = ASR::make_StructInstanceMember_t(al, loc, e, member_sym, - member_var->m_type, nullptr); + member_var_type, nullptr); } else if(ASR::is_a(*type)) { if( std::string(attr_char) == "value" ) { ASR::Enum_t* enum_ = ASR::down_cast(type); @@ -4628,8 +4707,41 @@ class BodyVisitor : public CommonVisitor { ASR::symbol_t* member_sym = der_type->m_symtab->resolve_symbol(member_name); LFORTRAN_ASSERT(ASR::is_a(*member_sym)); ASR::Variable_t* member_var = ASR::down_cast(member_sym); + ASR::ttype_t* member_var_type = member_var->m_type; + if( ASR::is_a(*member_var->m_type) ) { + ASR::Struct_t* member_var_struct_t = ASR::down_cast(member_var->m_type); + if( !ASR::is_a(*member_var_struct_t->m_derived_type) ) { + ASR::StructType_t* struct_type = ASR::down_cast(member_var_struct_t->m_derived_type); + ASR::symbol_t* struct_type_asr_owner = ASRUtils::get_asr_owner(member_var_struct_t->m_derived_type); + if( struct_type_asr_owner && ASR::is_a(*struct_type_asr_owner) ) { + std::string struct_var_name = ASR::down_cast(struct_type_asr_owner)->m_name; + std::string struct_member_name = struct_type->m_name; + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == member_var_struct_t->m_derived_type && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + loc, current_scope, s2c(al, import_name), + member_var_struct_t->m_derived_type, s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + member_var_type = ASRUtils::TYPE(ASR::make_Struct_t(al, loc, import_struct_member, nullptr, 0)); + } + } + } tmp = ASR::make_StructInstanceMember_t(al, loc, val, member_sym, - member_var->m_type, nullptr); + member_var_type, nullptr); } else if (ASR::is_a(*type)) { ASR::Enum_t* enum_ = ASR::down_cast(type); ASR::EnumType_t* enum_type = ASR::down_cast(enum_->m_enum_type); @@ -4740,6 +4852,25 @@ class BodyVisitor : public CommonVisitor { tmp = ASR::make_EnumValue_t(al, x.base.base.loc, enum_member_var, enum_t, enum_member_variable->m_type, ASRUtils::expr_value(enum_member_variable->m_symbolic_value)); + } else if (ASR::is_a(*t)) { + ASR::StructType_t* struct_type = ASR::down_cast(t); + ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(std::string(x.m_attr)); + if( !struct_member ) { + throw SemanticError(std::string(x.m_attr) + " not present in " + + std::string(struct_type->m_name) + " dataclass.", + x.base.base.loc); + } + if( ASR::is_a(*struct_member) ) { + ASR::Variable_t* struct_member_variable = ASR::down_cast(struct_member); + ASR::expr_t* struct_type_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, t)); + tmp = ASR::make_StructStaticMember_t(al, x.base.base.loc, + struct_type_var, struct_member, struct_member_variable->m_type, + nullptr); + } else if( ASR::is_a(*struct_member) ) { + ASR::expr_t* struct_type_var = ASRUtils::EXPR(ASR::make_Var_t(al, x.base.base.loc, t)); + ASR::ttype_t* union_type = ASRUtils::TYPE(ASR::make_Union_t(al, x.base.base.loc, struct_member, nullptr, 0)); + tmp = ASR::make_StructStaticMember_t(al, x.base.base.loc, struct_type_var, struct_member, union_type, nullptr); + } } else if (ASR::is_a(*t)) { ASR::Module_t *m = ASR::down_cast(t); std::string sym_name = value + "@" + x.m_attr; @@ -5747,11 +5878,44 @@ class BodyVisitor : public CommonVisitor { x.base.base.loc); } ASR::symbol_t *mt = symtab->get_symbol(mod_name); - ASR::Module_t *m = ASR::down_cast(mt); - call_name_store = ASRUtils::get_mangled_name(m, call_name_store); - st = import_from_module(al, m, current_scope, mod_name, - call_name, call_name_store, x.base.base.loc); - current_scope->add_symbol(call_name_store, st); + if( ASR::is_a(*mt) ) { + ASR::Module_t *m = ASR::down_cast(mt); + call_name_store = ASRUtils::get_mangled_name(m, call_name_store); + st = import_from_module(al, m, current_scope, mod_name, + call_name, call_name_store, x.base.base.loc); + current_scope->add_symbol(call_name_store, st); + } else if( ASR::is_a(*mt) ) { + ASR::StructType_t* struct_type = ASR::down_cast(mt); + std::string struct_var_name = struct_type->m_name; + std::string struct_member_name = call_name; + ASR::symbol_t* struct_member = struct_type->m_symtab->resolve_symbol(struct_member_name); + if( !struct_member ) { + throw SemanticError(struct_member_name + " not present in " + + struct_var_name + " dataclass.", + x.base.base.loc); + } + std::string import_name = struct_var_name + "_" + struct_member_name; + ASR::symbol_t* import_struct_member = current_scope->resolve_symbol(import_name); + bool import_from_struct = true; + if( import_struct_member ) { + if( ASR::is_a(*import_struct_member) ) { + ASR::ExternalSymbol_t* ext_sym = ASR::down_cast(import_struct_member); + if( ext_sym->m_external == struct_member && + std::string(ext_sym->m_module_name) == struct_var_name ) { + import_from_struct = false; + } + } + } + if( import_from_struct ) { + import_name = current_scope->get_unique_name(import_name); + import_struct_member = ASR::down_cast(ASR::make_ExternalSymbol_t(al, + x.base.base.loc, current_scope, s2c(al, import_name), + struct_member, s2c(al, struct_var_name), nullptr, 0, + s2c(al, struct_member_name), ASR::accessType::Public)); + current_scope->add_symbol(import_name, import_struct_member); + } + st = import_struct_member; + } } tmp = make_call_helper(al, st, current_scope, args, call_name, x.base.base.loc); return;