diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 580b9bbb9f..04be6abf67 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -685,6 +685,7 @@ RUN(NAME structs_31 LABELS cpython llvm c) RUN(NAME structs_32 LABELS cpython llvm c) RUN(NAME structs_33 LABELS cpython llvm c) RUN(NAME structs_34 LABELS cpython llvm c) +RUN(NAME structs_35 LABELS cpython llvm) RUN(NAME symbolics_01 LABELS cpython_sym c_sym llvm_sym NOFAST) RUN(NAME symbolics_02 LABELS cpython_sym c_sym llvm_sym NOFAST) diff --git a/integration_tests/structs_35.py b/integration_tests/structs_35.py new file mode 100644 index 0000000000..4bdb499d75 --- /dev/null +++ b/integration_tests/structs_35.py @@ -0,0 +1,26 @@ +from lpython import dataclass, field, i32 +from numpy import array + +@dataclass +class X: + a: i32 = 123 + b: bool = True + c: list[i32] = field(default_factory=lambda: [1, 2, 3]) + d: i32[3] = field(default_factory=lambda: array([4, 5, 6])) + e: i32 = field(default=-5) + +def main0(): + x: X = X() + print(x) + assert x.a == 123 + assert x.b == True + assert x.c[0] == 1 + assert x.d[1] == 5 + assert x.e == -5 + x.c[0] = 3 + x.d[0] = 3 + print(x) + assert x.c[0] == 3 + assert x.d[0] == 3 + +main0() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 0ce768f9d2..c7376cb0dc 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -7802,6 +7802,35 @@ class BodyVisitor : public CommonVisitor { tmp = ASR::make_SizeOfType_t(al, x.base.base.loc, arg_type, size_type, nullptr); return ; + } else if( call_name == "field" ) { + if (x.n_args != 0) { + throw SemanticError("'field' expects only keyword arguments", x.base.base.loc); + } + + if (x.n_keywords != 1) { + throw SemanticError("'field' expects one keyword argument", x.base.base.loc); + } + + args.reserve(al, 1); + visit_expr_list(x.m_args, x.n_args, args); + + if( std::string(x.m_keywords[0].m_arg) != "default_factory" && std::string(x.m_keywords[0].m_arg) != "default" ) { + throw SemanticError("Unrecognised keyword argument, " + + std::string(x.m_keywords[0].m_arg), x.base.base.loc); + } + + if ( std::string(x.m_keywords[0].m_arg) == "default_factory") { + if (!AST::is_a(*x.m_keywords[0].m_value)) { + throw SemanticError("Only lambda functions currently supported as default_factory value", x.base.base.loc); + } + + AST::Lambda_t* lambda_fn = AST::down_cast(x.m_keywords[0].m_value); + this->visit_expr(*lambda_fn->m_body); + } else { + // field has default argument provided + this->visit_expr(*x.m_keywords[0].m_value); + } + return ; } else if( call_name == "f64" || call_name == "f32" || diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 3ac9811c8c..028c731266 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -2,7 +2,7 @@ import os import ctypes import platform -from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass +from dataclasses import dataclass as py_dataclass, is_dataclass as py_is_dataclass, field import functools @@ -11,7 +11,7 @@ "overload", "ccall", "TypeVar", "pointer", "c_p_pointer", "Pointer", "p_c_pointer", "vectorize", "inline", "Union", "static", "packed", "Const", "sizeof", "ccallable", "ccallback", "Callable", - "Allocatable", "In", "Out", "InOut", "dataclass", "S"] + "Allocatable", "In", "Out", "InOut", "dataclass", "field", "S"] # data-types