Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support defining Union inside a dataclass #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ RUN(NAME structs_13 LABELS llvm c
EXTRAFILES structs_13b.c)
RUN(NAME structs_14 LABELS cpython llvm c)
RUN(NAME structs_15 LABELS cpython llvm c)
RUN(NAME structs_16 LABELS cpython llvm c)
RUN(NAME structs_17 LABELS cpython llvm c)
RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
RUN(NAME enum_01 LABELS cpython llvm c)
Expand Down
19 changes: 19 additions & 0 deletions integration_tests/structs_16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from ltypes import i32, i64, dataclass, union, Union

@dataclass
class A:
@union
class B(Union):
x: i32
y: i64
b: B
c: i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect! While at this, given the new StructStaticMember, can you also please add a test for:

@dataclass
class A:
    @dataclass
    class B(Union):
        x: i32
        y: i64
    b: B
    c: i32

Essentially a "struct in struct"? We will need it too sooner or later, so now is the time to make sure it works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please let me know if there is anything else to do here.


def test_ordering():
bd: A.B = A.B()
bd.x = 1
ad: A = A(bd, 2)
assert ad.b.x == 1
assert ad.c == 2

test_ordering()
31 changes: 31 additions & 0 deletions integration_tests/structs_17.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from ltypes import i32, f32, f64, dataclass

@dataclass
class B:
z: i32
@dataclass
class C:
cz: f32
bc: C

@dataclass
class A:
y: f32
x: i32
b: B


def f(a: A):
print(a.x)
print(a.y)
print(a.b.z)

def g():
x: A = A(f32(3.25), 3, B(71, B.C(f32(4.0))))
f(x)
assert x.x == 3
assert f64(x.y) == 3.25
assert x.b.z == 71
assert f64(x.b.bc.cz) == 4.0

g()
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ expr

| BitCast(expr source, expr mold, expr? size, ttype type, expr? value)
| StructInstanceMember(expr v, symbol m, ttype type, expr? value)
| StructStaticMember(expr v, symbol m, ttype type, expr? value)
| EnumMember(expr v, symbol m, ttype type, expr? value)
| UnionRef(expr v, symbol m, ttype type, expr? value)
| EnumName(expr v, ttype enum_type, ttype type, expr? value)
Expand Down
32 changes: 19 additions & 13 deletions src/libasr/asr_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ ASR::TranslationUnit_t* find_and_load_module(Allocator &al, const std::string &m
ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc,
ASR::asr_t* v_var, ASR::symbol_t* member,
SymbolTable* current_scope) {
ASR::Variable_t* member_variable = ((ASR::Variable_t*)(&(member->base)));
ASR::Variable_t* member_variable = ASR::down_cast<ASR::Variable_t>(member);
ASR::ttype_t* member_type = member_variable->m_type;
switch( member_type->type ) {
case ASR::ttypeType::Struct: {
Expand All @@ -317,23 +317,27 @@ ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc,
ASR::symbol_t* der_ext;
char* module_name = (char*)"~nullptr";
ASR::symbol_t* m_external = der->m_derived_type;
if( m_external->type == ASR::symbolType::ExternalSymbol ) {
ASR::ExternalSymbol_t* m_ext = (ASR::ExternalSymbol_t*)(&(m_external->base));
if( ASR::is_a<ASR::ExternalSymbol_t>(*m_external) ) {
ASR::ExternalSymbol_t* m_ext = ASR::down_cast<ASR::ExternalSymbol_t>(m_external);
m_external = m_ext->m_external;
module_name = m_ext->m_module_name;
} else if( ASR::is_a<ASR::StructType_t>(*m_external) ) {
ASR::symbol_t* asr_owner = ASRUtils::get_asr_owner(m_external);
if( ASR::is_a<ASR::StructType_t>(*asr_owner) ) {
module_name = ASRUtils::symbol_name(asr_owner);
}
}
Str mangled_name;
mangled_name.from_str(al, "1_" +
std::string mangled_name = current_scope->get_unique_name(
std::string(module_name) + "_" +
std::string(der_type_name));
char* mangled_name_char = mangled_name.c_str(al);
if( current_scope->get_symbol(mangled_name.str()) == nullptr ) {
char* mangled_name_char = s2c(al, mangled_name);
if( current_scope->get_symbol(mangled_name) == nullptr ) {
bool make_new_ext_sym = true;
ASR::symbol_t* der_tmp = nullptr;
if( current_scope->get_symbol(std::string(der_type_name)) != nullptr ) {
der_tmp = current_scope->get_symbol(std::string(der_type_name));
if( der_tmp->type == ASR::symbolType::ExternalSymbol ) {
ASR::ExternalSymbol_t* der_ext_tmp = (ASR::ExternalSymbol_t*)(&(der_tmp->base));
if( ASR::is_a<ASR::ExternalSymbol_t>(*der_tmp) ) {
ASR::ExternalSymbol_t* der_ext_tmp = ASR::down_cast<ASR::ExternalSymbol_t>(der_tmp);
if( der_ext_tmp->m_external == m_external ) {
make_new_ext_sym = false;
}
Expand All @@ -342,15 +346,17 @@ ASR::asr_t* getStructInstanceMember_t(Allocator& al, const Location& loc,
}
}
if( make_new_ext_sym ) {
der_ext = (ASR::symbol_t*)ASR::make_ExternalSymbol_t(al, loc, current_scope, mangled_name_char, m_external,
module_name, nullptr, 0, s2c(al, der_type_name), ASR::accessType::Public);
current_scope->add_symbol(mangled_name.str(), der_ext);
der_ext = ASR::down_cast<ASR::symbol_t>(ASR::make_ExternalSymbol_t(
al, loc, current_scope, mangled_name_char, m_external,
module_name, nullptr, 0, s2c(al, der_type_name),
ASR::accessType::Public));
current_scope->add_symbol(mangled_name, der_ext);
} else {
LFORTRAN_ASSERT(der_tmp != nullptr);
der_ext = der_tmp;
}
} else {
der_ext = current_scope->get_symbol(mangled_name.str());
der_ext = current_scope->get_symbol(mangled_name);
}
ASR::asr_t* der_new = ASR::make_Struct_t(al, loc, der_ext, der->m_dims, der->n_dims);
member_type = ASRUtils::TYPE(der_new);
Expand Down
9 changes: 9 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,15 @@ static inline ASR::Module_t *get_sym_module(const ASR::symbol_t *sym) {
return nullptr;
}

// Returns the ASR owner of the symbol
static inline ASR::symbol_t *get_asr_owner(const ASR::symbol_t *sym) {
const SymbolTable *s = symbol_parent_symtab(sym);
if( !ASR::is_a<ASR::symbol_t>(*s->asr_owner) ) {
return nullptr;
}
return ASR::down_cast<ASR::symbol_t>(s->asr_owner);
}

// Returns the Module_t the symbol is in, or nullptr if not in a module
// or no asr_owner yet
static inline ASR::Module_t *get_sym_module0(const ASR::symbol_t *sym) {
Expand Down
36 changes: 28 additions & 8 deletions src/libasr/asr_verify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,9 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
for (auto &a : x.m_symtab->get_scope()) {
this->visit_symbol(*a.second);
if( ASR::is_a<ASR::ClassProcedure_t>(*a.second) ||
ASR::is_a<ASR::GenericProcedure_t>(*a.second) ) {
ASR::is_a<ASR::GenericProcedure_t>(*a.second) ||
ASR::is_a<ASR::StructType_t>(*a.second) ||
ASR::is_a<ASR::UnionType_t>(*a.second) ) {
continue ;
}
ASR::ttype_t* var_type = ASRUtils::type_get_past_pointer(ASRUtils::symbol_type(a.second));
Expand All @@ -346,7 +348,8 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
ASR::symbol_t* sym = ASR::down_cast<ASR::Union_t>(var_type)->m_union_type;
aggregate_type_name = ASRUtils::symbol_name(sym);
}
if( aggregate_type_name ) {
if( aggregate_type_name &&
!current_symtab->get_symbol(std::string(aggregate_type_name)) ) {
struct_dependencies.push_back(std::string(aggregate_type_name));
require(present(x.m_dependencies, x.n_dependencies, std::string(aggregate_type_name)),
std::string(x.m_name) + " depends on " + std::string(aggregate_type_name)
Expand Down Expand Up @@ -481,17 +484,34 @@ class VerifyVisitor : public BaseWalkVisitor<VerifyVisitor>
require(std::string(x.m_original_name) == std::string(orig_name),
"ExternalSymbol::m_original_name must match external->m_name");
ASR::Module_t *m = ASRUtils::get_sym_module(x.m_external);
require(m,
"ExternalSymbol::m_external is not in a module");
require(std::string(x.m_module_name) == std::string(m->m_name),
ASR::StructType_t* sm = nullptr;
bool is_valid_owner = false;
is_valid_owner = m != nullptr;
std::string asr_owner_name = "";
if( !is_valid_owner ) {
ASR::symbol_t* asr_owner_sym = ASRUtils::get_asr_owner(x.m_external);
is_valid_owner = ASR::is_a<ASR::StructType_t>(*asr_owner_sym);
sm = ASR::down_cast<ASR::StructType_t>(asr_owner_sym);
asr_owner_name = sm->m_name;
} else {
asr_owner_name = m->m_name;
}
require(is_valid_owner,
"ExternalSymbol::m_external is not in a module or struct type");
require(std::string(x.m_module_name) == asr_owner_name,
"ExternalSymbol::m_module_name `" + std::string(x.m_module_name)
+ "` must match external's module name `" + std::string(m->m_name) + "`");
ASR::symbol_t *s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names);
+ "` must match external's module name `" + asr_owner_name + "`");
ASR::symbol_t *s = nullptr;
if( m ) {
s = m->m_symtab->find_scoped_symbol(x.m_original_name, x.n_scope_names, x.m_scope_names);
} else if( sm ) {
s = sm->m_symtab->resolve_symbol(std::string(x.m_original_name));
}
require(s != nullptr,
"ExternalSymbol::m_original_name ('"
+ std::string(x.m_original_name)
+ "') + scope_names not found in a module '"
+ std::string(m->m_name) + "'");
+ asr_owner_name + "'");
require(s == x.m_external,
"ExternalSymbol::m_name + scope_names found but not equal to m_external");
}
Expand Down
40 changes: 32 additions & 8 deletions src/libasr/codegen/asr_to_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
void allocate_array_members_of_struct(ASR::StructType_t* der_type_t, std::string& sub,
std::string indent, std::string name) {
for( auto itr: der_type_t->m_symtab->get_scope() ) {
if( ASR::is_a<ASR::UnionType_t>(*itr.second) ||
ASR::is_a<ASR::StructType_t>(*itr.second) ) {
continue ;
}
ASR::ttype_t* mem_type = ASRUtils::symbol_type(itr.second);
if( ASRUtils::is_character(*mem_type) ) {
sub += indent + name + "->" + itr.first + " = (char*) malloc(40 * sizeof(char));\n";
Expand Down Expand Up @@ -418,7 +422,8 @@ class ASRToCVisitor : public BaseCCPPVisitor<ASRToCVisitor>
} else if (ASR::is_a<ASR::Union_t>(*v_m_type)) {
std::string indent(indentation_level*indentation_spaces, ' ');
ASR::Union_t *t = ASR::down_cast<ASR::Union_t>(v_m_type);
std::string der_type_name = ASRUtils::symbol_name(t->m_union_type);
std::string der_type_name = ASRUtils::symbol_name(
ASRUtils::symbol_get_past_external(t->m_union_type));
if( is_array ) {
bool is_fixed_size = true;
dims = convert_dims_c(t->n_dims, t->m_dims, v_m_type, is_fixed_size, true);
Expand Down Expand Up @@ -706,11 +711,25 @@ R"(
}

template <typename T>
void visit_AggregateTypeUtil(const T& x, std::string c_type_name) {
void visit_AggregateTypeUtil(const T& x, std::string c_type_name,
std::string& src_dest) {
std::string body = "";
int indendation_level_copy = indentation_level;
for( auto itr: x.m_symtab->get_scope() ) {
if( ASR::is_a<ASR::UnionType_t>(*itr.second) ) {
visit_AggregateTypeUtil(*ASR::down_cast<ASR::UnionType_t>(itr.second),
"union", src_dest);
} else if( ASR::is_a<ASR::StructType_t>(*itr.second) ) {
std::string struct_c_type_name = get_StructCTypeName(
*ASR::down_cast<ASR::StructType_t>(itr.second));
visit_AggregateTypeUtil(*ASR::down_cast<ASR::StructType_t>(itr.second),
struct_c_type_name, src_dest);
}
}
indentation_level = indendation_level_copy;
std::string indent(indentation_level*indentation_spaces, ' ');
indentation_level += 1;
std::string open_struct = indent + c_type_name + " " + std::string(x.m_name) + " {\n";
std::string body = "";
indent.push_back(' ');
for( size_t i = 0; i < x.n_members; i++ ) {
ASR::symbol_t* member = x.m_symtab->get_symbol(x.m_members[i]);
Expand All @@ -725,11 +744,10 @@ R"(
}
indentation_level -= 1;
std::string end_struct = "};\n\n";
array_types_decls += open_struct + body + end_struct;
src_dest += open_struct + body + end_struct;
}

void visit_StructType(const ASR::StructType_t& x) {
src = "";
std::string get_StructCTypeName(const ASR::StructType_t& x) {
std::string c_type_name = "struct";
if( x.m_is_packed ) {
std::string attr_args = "(packed";
Expand All @@ -745,12 +763,18 @@ R"(
attr_args += ")";
c_type_name += " __attribute__(" + attr_args + ")";
}
visit_AggregateTypeUtil(x, c_type_name);
return c_type_name;
}

void visit_StructType(const ASR::StructType_t& x) {
src = "";
std::string c_type_name = get_StructCTypeName(x);
visit_AggregateTypeUtil(x, c_type_name, array_types_decls);
src = "";
}

void visit_UnionType(const ASR::UnionType_t& x) {
visit_AggregateTypeUtil(x, "union");
visit_AggregateTypeUtil(x, "union", array_types_decls);
}

void visit_EnumType(const ASR::EnumType_t& x) {
Expand Down
11 changes: 9 additions & 2 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
const std::map<std::string, ASR::symbol_t*>& scope = der_type->m_symtab->get_scope();
for( auto itr = scope.begin(); itr != scope.end(); itr++ ) {
if( ASR::is_a<ASR::UnionType_t>(*itr->second) ||
ASR::is_a<ASR::StructType_t>(*itr->second) ) {
continue ;
}
ASR::Variable_t* member = ASR::down_cast<ASR::Variable_t>(itr->second);
llvm::Type* llvm_mem_type = get_type_from_ttype_t_util(member->m_type, member->m_abi);
member_types.push_back(llvm_mem_type);
Expand Down Expand Up @@ -2665,7 +2669,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
std::string struct_type_name = struct_type_t->m_name;
for( auto item: struct_type_t->m_symtab->get_scope() ) {
if( ASR::is_a<ASR::ClassProcedure_t>(*item.second) ||
ASR::is_a<ASR::GenericProcedure_t>(*item.second) ) {
ASR::is_a<ASR::GenericProcedure_t>(*item.second) ||
ASR::is_a<ASR::UnionType_t>(*item.second) ||
ASR::is_a<ASR::StructType_t>(*item.second) ) {
continue ;
}
ASR::ttype_t* symbol_type = ASRUtils::symbol_type(item.second);
Expand Down Expand Up @@ -5171,7 +5177,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
case ASR::ttypeType::Union: {
ASR::Union_t* der = ASR::down_cast<ASR::Union_t>(x->m_type);
ASR::UnionType_t* der_type = ASR::down_cast<ASR::UnionType_t>(der->m_union_type);
ASR::UnionType_t* der_type = ASR::down_cast<ASR::UnionType_t>(
ASRUtils::symbol_get_past_external(der->m_union_type));
der_type_name = std::string(der_type->m_name);
uint32_t h = get_hash((ASR::asr_t*)x);
if( llvm_symtab.find(h) != llvm_symtab.end() ) {
Expand Down
Loading