Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support struct initialization with named arguments #1813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ RUN(NAME structs_20 LABELS cpython llvm c
EXTRAFILES structs_20b.c)
RUN(NAME structs_21 LABELS cpython llvm c)
RUN(NAME structs_22 LABELS cpython llvm c)
RUN(NAME structs_23 LABELS cpython llvm c)
RUN(NAME structs_24 LABELS cpython llvm c)
RUN(NAME sizeof_01 LABELS llvm c
EXTRAFILES sizeof_01b.c)
RUN(NAME sizeof_02 LABELS cpython llvm c)
Expand Down
30 changes: 30 additions & 0 deletions integration_tests/structs_23.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from lpython import dataclass, i32, u64, f64

@dataclass
class A:
a: i32
b: i32

@dataclass
class B:
a: u64
b: f64

def main0():
s: A = A(b=-24, a=6)
print(s.a)
print(s.b)

assert s.a == 6
assert s.b == -24

def main1():
s: B = B(u64(22), b=3.14)
print(s.a)
print(s.b)

assert s.a == u64(22)
assert abs(s.b - 3.14) <= 1e-12

main0()
main1()
18 changes: 18 additions & 0 deletions integration_tests/structs_24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from lpython import dataclass, i32, f64, u64
from numpy import array

@dataclass
class Foo:
x: i32
y: i32

def main0() -> None:
foos: Foo[2] = array([Foo(y=2, x=1), Foo(x=3, y=4)])
print(foos[0].x, foos[0].y, foos[1].x, foos[1].y)

assert foos[0].x == 1
assert foos[0].y == 2
assert foos[1].x == 3
assert foos[1].y == 4

main0()
100 changes: 85 additions & 15 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,63 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
return true;
}

int64_t find_argument_position_from_name(ASR::StructType_t* orig_struct, std::string arg_name) {
int64_t arg_position = -1;
for( size_t i = 0; i < orig_struct->n_members; i++ ) {
std::string original_arg_name = std::string(orig_struct->m_members[i]);
if( original_arg_name == arg_name ) {
return i;
}
}
return arg_position;
}

void visit_expr_list(AST::expr_t** pos_args, size_t n_pos_args,
AST::keyword_t* kwargs, size_t n_kwargs,
Vec<ASR::call_arg_t>& call_args_vec,
ASR::StructType_t* orig_struct, const Location &loc) {
LCOMPILERS_ASSERT(call_args_vec.reserve_called);

// Fill the whole call_args_vec with nullptr
// This is for error handling later on.
for( size_t i = 0; i < n_pos_args + n_kwargs; i++ ) {
ASR::call_arg_t call_arg;
Location loc;
loc.first = loc.last = 1;
call_arg.m_value = nullptr;
call_arg.loc = loc;
call_args_vec.push_back(al, call_arg);
}

// Now handle positional arguments in the following loop
for( size_t i = 0; i < n_pos_args; i++ ) {
this->visit_expr(*pos_args[i]);
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
call_args_vec.p[i].loc = expr->base.loc;
call_args_vec.p[i].m_value = expr;
}

// Now handle keyword arguments in the following loop
for( size_t i = 0; i < n_kwargs; i++ ) {
this->visit_expr(*kwargs[i].m_value);
ASR::expr_t* expr = ASRUtils::EXPR(tmp);
std::string arg_name = std::string(kwargs[i].m_arg);
int64_t arg_pos = find_argument_position_from_name(orig_struct, arg_name);
if( arg_pos == -1 ) {
throw SemanticError("Member '" + arg_name + "' not found in struct", kwargs[i].loc);
} else if (arg_pos >= (int64_t)call_args_vec.size()) {
throw SemanticError("Not enough arguments to " + std::string(orig_struct->m_name)
+ "(), expected " + std::to_string(orig_struct->n_members), loc);
}
if( call_args_vec[arg_pos].m_value != nullptr ) {
throw SemanticError(std::string(orig_struct->m_name) + "() got multiple values for argument '"
+ arg_name + "'", kwargs[i].loc);
}
call_args_vec.p[arg_pos].loc = expr->base.loc;
call_args_vec.p[arg_pos].m_value = expr;
}
}

void visit_expr_list_with_cast(ASR::expr_t** m_args, size_t n_args,
Vec<ASR::call_arg_t>& call_args_vec,
Vec<ASR::call_arg_t>& args,
Expand Down Expand Up @@ -1195,7 +1252,17 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
}
} else if(ASR::is_a<ASR::StructType_t>(*s)) {
ASR::StructType_t* StructType = ASR::down_cast<ASR::StructType_t>(s);
for( size_t i = 0; i < std::min(args.size(), StructType->n_members); i++ ) {
if (n_kwargs > 0) {
args.reserve(al, n_pos_args + n_kwargs);
visit_expr_list(pos_args, n_pos_args, kwargs, n_kwargs,
args, StructType, loc);
}

if (args.size() > 0 && args.size() != StructType->n_members) {
throw SemanticError("StructConstructor arguments do not match the number of struct members", loc);
}

for( size_t i = 0; i < args.size(); i++ ) {
std::string member_name = StructType->m_members[i];
ASR::Variable_t* member_var = ASR::down_cast<ASR::Variable_t>(
StructType->m_symtab->resolve_symbol(member_name));
Expand Down Expand Up @@ -6599,6 +6666,14 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type);
}

void parse_args(const AST::Call_t &x, Vec<ASR::call_arg_t> &args) {
// Keyword arguments handled in make_call_helper()
if( x.n_keywords == 0 ) {
args.reserve(al, x.n_args);
visit_expr_list(x.m_args, x.n_args, args);
}
}

void visit_Call(const AST::Call_t &x) {
std::string call_name = "";
Vec<ASR::call_arg_t> args;
Expand All @@ -6612,14 +6687,9 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = nullptr;
return ;
}
// Keyword arguments handled in make_call_helper
#define parse_args() if( x.n_keywords == 0 ) { \
args.reserve(al, x.n_args); \
visit_expr_list(x.m_args, x.n_args, args); \
} \
Comment on lines -6615 to -6619
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, I was overlooking the first line if( x.n_keywords == 0 ) { \. Hence, I refactored this into a function, which I am hoping is more readable.


if (AST::is_a<AST::Attribute_t>(*x.m_func)) {
parse_args()
parse_args(x, args);
AST::Attribute_t *at = AST::down_cast<AST::Attribute_t>(x.m_func);
if (AST::is_a<AST::Name_t>(*at->m_value)) {
AST::Name_t *n = AST::down_cast<AST::Name_t>(at->m_value);
Expand Down Expand Up @@ -6788,7 +6858,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
// This will all be removed once we port it to intrinsic functions
// Intrinsic functions
if (call_name == "size") {
parse_args();
parse_args(x, args);;
if( args.size() < 1 || args.size() > 2 ) {
throw SemanticError("array accepts only 1 (arr) or 2 (arr, axis) arguments, got " +
std::to_string(args.size()) + " arguments instead.",
Expand Down Expand Up @@ -6820,7 +6890,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = nullptr;
return;
} else if (call_name == "callable") {
parse_args()
parse_args(x, args);
if (args.size() != 1) {
throw SemanticError(call_name + "() takes exactly one argument (" +
std::to_string(args.size()) + " given)", x.base.base.loc);
Expand All @@ -6836,13 +6906,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_LogicalConstant_t(al, x.base.base.loc, result, type);
return;
} else if( call_name == "pointer" ) {
parse_args()
parse_args(x, args);
ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Pointer_t(al, x.base.base.loc,
ASRUtils::expr_type(args[0].m_value)));
tmp = ASR::make_GetPointer_t(al, x.base.base.loc, args[0].m_value, type, nullptr);
return ;
} else if( call_name == "array" ) {
parse_args()
parse_args(x, args);
if( args.size() != 1 ) {
throw SemanticError("array accepts only 1 argument for now, got " +
std::to_string(args.size()) + " arguments instead.",
Expand All @@ -6862,7 +6932,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}
return;
} else if( call_name == "deepcopy" ) {
parse_args()
parse_args(x, args);
if( args.size() != 1 ) {
throw SemanticError("deepcopy only accepts one argument, found " +
std::to_string(args.size()) + " instead.",
Expand Down Expand Up @@ -6921,7 +6991,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
call_name == "c32" ||
call_name == "c64"
) {
parse_args()
parse_args(x, args);
ASR::ttype_t* target_type = nullptr;
if( call_name == "i8" ) {
target_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 1, nullptr, 0));
Expand Down Expand Up @@ -6953,7 +7023,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = (ASR::asr_t*) arg;
return ;
} else if (intrinsic_node_handler.is_present(call_name)) {
parse_args()
parse_args(x, args);
tmp = intrinsic_node_handler.get_intrinsic_node(call_name, al,
x.base.base.loc, args);
return;
Expand All @@ -6965,7 +7035,7 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
} // end of "comment"
}

parse_args()
parse_args(x, args);
tmp = make_call_helper(al, s, current_scope, args, call_name, x.base.base.loc,
false, x.m_args, x.n_args, x.m_keywords, x.n_keywords);
}
Expand Down
10 changes: 10 additions & 0 deletions tests/errors/structs_03.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from lpython import i32, dataclass

@dataclass
class S:
x: i32

def main0():
s: S = S(y=2)

main0()
11 changes: 11 additions & 0 deletions tests/errors/structs_04.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lpython import i32, dataclass

@dataclass
class S:
x: i32
y: i32

def main0():
s: S = S(24, x=2)

main0()
11 changes: 11 additions & 0 deletions tests/errors/structs_05.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lpython import i32, dataclass

@dataclass
class S:
x: i32
y: i32

def main0():
s: S = S(2)

main0()
11 changes: 11 additions & 0 deletions tests/errors/structs_06.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lpython import i32, dataclass

@dataclass
class S:
x: i32
y: i32

def main0():
s: S = S(2, 3, 4, 5)

main0()
11 changes: 11 additions & 0 deletions tests/errors/structs_07.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from lpython import i32, dataclass

@dataclass
class S:
x: i32
y: i32

def main0():
s: S = S(y=2)

main0()
13 changes: 13 additions & 0 deletions tests/reference/asr-structs_03-754fb64.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-structs_03-754fb64",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/structs_03.py",
"infile_hash": "19180d0a7a22141e74e61452cc6cc185f1dd1c4f4315446450ce98db",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_03-754fb64.stderr",
"stderr_hash": "c6410f9948863d922cb0a0cd36613c529ad45fdf556d393d36e2df07",
"returncode": 2
}
5 changes: 5 additions & 0 deletions tests/reference/asr-structs_03-754fb64.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
semantic error: Member 'y' not found in struct
--> tests/errors/structs_03.py:8:14
|
8 | s: S = S(y=2)
| ^^^
13 changes: 13 additions & 0 deletions tests/reference/asr-structs_04-7b864bc.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-structs_04-7b864bc",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/structs_04.py",
"infile_hash": "5951c49d2d7f143bbe3d67b982770ceb6d709939eb2d5ed544888f16",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_04-7b864bc.stderr",
"stderr_hash": "e4e04a1a30ae38b6587c4c3ad12a7e83839c63938c025a3884f62ef8",
"returncode": 2
}
5 changes: 5 additions & 0 deletions tests/reference/asr-structs_04-7b864bc.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
semantic error: S() got multiple values for argument 'x'
--> tests/errors/structs_04.py:9:18
|
9 | s: S = S(24, x=2)
| ^^^
13 changes: 13 additions & 0 deletions tests/reference/asr-structs_05-a89315d.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-structs_05-a89315d",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/structs_05.py",
"infile_hash": "3b94e692a074b226736f068daf39c876f113277a73468bd21c01d3cc",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_05-a89315d.stderr",
"stderr_hash": "227decb39171becb34a42cbdd93d96bcdd4d8c9dc5151706a74d7074",
"returncode": 2
}
5 changes: 5 additions & 0 deletions tests/reference/asr-structs_05-a89315d.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
semantic error: StructConstructor arguments do not match the number of struct members
--> tests/errors/structs_05.py:9:12
|
9 | s: S = S(2)
| ^^^^
13 changes: 13 additions & 0 deletions tests/reference/asr-structs_06-6e14537.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"basename": "asr-structs_06-6e14537",
"cmd": "lpython --show-asr --no-color {infile} -o {outfile}",
"infile": "tests/errors/structs_06.py",
"infile_hash": "9f4273c5fb4469837f65003255dcdca067c5c17735d0642757fd069c",
"outfile": null,
"outfile_hash": null,
"stdout": null,
"stdout_hash": null,
"stderr": "asr-structs_06-6e14537.stderr",
"stderr_hash": "21e94af3d6a631d4871d9bc2a86200c3c3c3b661964a079105721dde",
"returncode": 2
}
5 changes: 5 additions & 0 deletions tests/reference/asr-structs_06-6e14537.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
semantic error: StructConstructor arguments do not match the number of struct members
--> tests/errors/structs_06.py:9:12
|
9 | s: S = S(2, 3, 4, 5)
| ^^^^^^^^^^^^^
Loading