diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index cef7ab8bf1..d681b3037f 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -446,6 +446,7 @@ RUN(NAME expr_18 FAIL LABELS cpython llvm c) RUN(NAME expr_19 LABELS cpython llvm c) RUN(NAME expr_20 LABELS cpython llvm c) RUN(NAME expr_21 LABELS cpython llvm c) +RUN(NAME expr_22 LABELS cpython llvm c) RUN(NAME expr_01u LABELS cpython llvm c NOFAST) RUN(NAME expr_02u LABELS cpython llvm c NOFAST) @@ -640,6 +641,7 @@ RUN(NAME structs_31 LABELS cpython llvm c) RUN(NAME structs_32 LABELS cpython llvm c) RUN(NAME structs_33 LABELS cpython llvm c) RUN(NAME structs_34 LABELS cpython llvm c) +RUN(NAME structs_35 LABELS cpython llvm c) RUN(NAME symbolics_01 LABELS cpython_sym c_sym) RUN(NAME symbolics_02 LABELS cpython_sym c_sym) diff --git a/integration_tests/expr_22.py b/integration_tests/expr_22.py new file mode 100644 index 0000000000..d59fee6941 --- /dev/null +++ b/integration_tests/expr_22.py @@ -0,0 +1,22 @@ +from lpython import i32 + +def f(): + x: i32 = 2 + y: i32 = 1 + z: i32 = 1 + t: i32 = 1 + assert x > y == z + assert not (x == y == z) + assert y == z == t != x + assert x > y == z >= t + t = 0 + assert x > y == z >= t + t = 4 + assert not (x > y == z >= t) + assert t > x > y == z + assert 3 > 2 >= 0 <= 6 + assert t > y < x + assert not (2 == 3 > 4) + + +f() diff --git a/integration_tests/structs_35.py b/integration_tests/structs_35.py new file mode 100644 index 0000000000..ed1ce302a2 --- /dev/null +++ b/integration_tests/structs_35.py @@ -0,0 +1,32 @@ +from lpython import (i8, i32, i64, f32, f64, + dataclass + ) +from numpy import (empty, + int8, + ) + +# test issue 2131 + +@dataclass +class Foo: + a : i8[4] = empty(4, dtype=int8) + dim : i32 = 4 + +def trinary_majority(x : Foo, y : Foo, z : Foo) -> Foo: + foo : Foo = Foo() + + assert foo.dim == x.dim == y.dim == z.dim + + return foo + + +t1 : Foo = Foo() +t1.a = empty(4, dtype=int8) + +t2 : Foo = Foo() +t2.a = empty(4, dtype=int8) + +t3 : Foo = Foo() +t3.a = empty(4, dtype=int8) + +r1 : Foo = trinary_majority(t1, t2, t3) diff --git a/src/libasr/codegen/asr_to_c.cpp b/src/libasr/codegen/asr_to_c.cpp index f25d565c46..89b2c95d34 100644 --- a/src/libasr/codegen/asr_to_c.cpp +++ b/src/libasr/codegen/asr_to_c.cpp @@ -583,6 +583,8 @@ class ASRToCVisitor : public BaseCCPPVisitor std::string unit_src = ""; indentation_level = 0; indentation_spaces = 4; + SymbolTable* current_scope_copy = current_scope; + current_scope = global_scope; c_ds_api->set_indentation(indentation_level, indentation_spaces); c_ds_api->set_global_scope(global_scope); c_utils_functions->set_indentation(indentation_level, indentation_spaces); @@ -760,6 +762,7 @@ R"( out_file.close(); } } + current_scope = current_scope_copy; } void visit_Module(const ASR::Module_t &x) { @@ -768,7 +771,8 @@ R"( } else { intrinsic_module = false; } - + SymbolTable *current_scope_copy = current_scope; + current_scope = x.m_symtab; std::string unit_src = ""; for (auto &item : x.m_symtab->get_scope()) { if (ASR::is_a(*item.second)) { @@ -813,13 +817,15 @@ R"( } src = unit_src; intrinsic_module = false; + current_scope = current_scope_copy; } void visit_Program(const ASR::Program_t &x) { // Topologically sort all program functions // and then define them in the right order std::vector func_order = ASRUtils::determine_function_definition_order(x.m_symtab); - + SymbolTable *current_scope_copy = current_scope; + current_scope = x.m_symtab; // Generate code for nested subroutines and functions first: std::string contains; for (auto &item : func_order) { @@ -898,6 +904,7 @@ R"( // Initialise Numpy + decl + body + indent1 + "return 0;\n}\n"; indentation_level -= 2; + current_scope = current_scope_copy; } template diff --git a/src/libasr/codegen/asr_to_c_cpp.h b/src/libasr/codegen/asr_to_c_cpp.h index ef718db8f6..022e1fde48 100644 --- a/src/libasr/codegen/asr_to_c_cpp.h +++ b/src/libasr/codegen/asr_to_c_cpp.h @@ -193,6 +193,8 @@ class BaseCCPPVisitor : public ASR::BaseVisitor void visit_TranslationUnit(const ASR::TranslationUnit_t &x) { global_scope = x.m_global_scope; + SymbolTable* current_scope_copy = current_scope; + current_scope = global_scope; // All loose statements must be converted to a function, so the items // must be empty: LCOMPILERS_ASSERT(x.n_items == 0); @@ -255,6 +257,7 @@ R"(#include } src = unit_src; + current_scope = current_scope_copy; } std::string check_tmp_buffer() { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ac31a64003..f96e9213cb 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6053,54 +6053,71 @@ class BodyVisitor : public CommonVisitor { body.size()); } - void visit_Compare(const AST::Compare_t &x) { - this->visit_expr(*x.m_left); + void compare_helper(const Location &loc, AST::expr_t *m_left, AST::expr_t *m_right, ASR::cmpopType asr_op) { + this->visit_expr(*m_left); ASR::expr_t *left = ASRUtils::EXPR(tmp); - if (x.n_comparators > 1) { - diag.add(diag::Diagnostic( - "Only one comparison operator is supported for now", - diag::Level::Error, diag::Stage::Semantic, { - diag::Label("multiple comparison operators", - {x.m_comparators[0]->base.loc}) - }) - ); - throw SemanticAbort(); - } - this->visit_expr(*x.m_comparators[0]); + this->visit_expr(*m_right); ASR::expr_t *right = ASRUtils::EXPR(tmp); - - ASR::cmpopType asr_op; - switch (x.m_ops) { - case (AST::cmpopType::Eq): { asr_op = ASR::cmpopType::Eq; break; } - case (AST::cmpopType::Gt): { asr_op = ASR::cmpopType::Gt; break; } - case (AST::cmpopType::GtE): { asr_op = ASR::cmpopType::GtE; break; } - case (AST::cmpopType::Lt): { asr_op = ASR::cmpopType::Lt; break; } - case (AST::cmpopType::LtE): { asr_op = ASR::cmpopType::LtE; break; } - case (AST::cmpopType::NotEq): { asr_op = ASR::cmpopType::NotEq; break; } - default: { - throw SemanticError("Comparison operator not implemented", - x.base.base.loc); - } - } - + ASR::ttype_t *type = ASRUtils::TYPE( + ASR::make_Logical_t(al, loc, 4)); + ASR::expr_t *value = nullptr; ASR::ttype_t *left_type = ASRUtils::expr_type(left); ASR::ttype_t *right_type = ASRUtils::expr_type(right); + ASR::ttype_t *dest_type = left_type; + ASR::expr_t *overloaded = nullptr; if( ASR::is_a(*left_type) ) { left_type = ASRUtils::get_contained_type(left_type); } if( ASR::is_a(*right_type) ) { right_type = ASRUtils::get_contained_type(right_type); } - ASR::expr_t *overloaded = nullptr; if (!ASRUtils::is_logical(*left_type) || !ASRUtils::is_logical(*right_type)) { cast_helper(left, right, false); } - left_type = ASRUtils::expr_type(left); - right_type = ASRUtils::expr_type(right); - ASR::ttype_t *dest_type = left_type; if (!ASRUtils::check_equal_type(left_type, right_type)) { + if (AST::is_a(*m_left)) { + // handle chained comparisons + LCOMPILERS_ASSERT(ASRUtils::is_logical(*left_type)); + AST::Compare_t *lc = AST::down_cast(m_left); + compare_helper(loc, lc->m_comparators[0], m_right, asr_op); + right = ASRUtils::EXPR(tmp); + right_type = ASRUtils::expr_type(right); + LCOMPILERS_ASSERT(ASRUtils::is_logical(*right_type)); + if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { + bool left_value = ASR::down_cast( + ASRUtils::expr_value(left))->m_value; + bool right_value = ASR::down_cast( + ASRUtils::expr_value(right))->m_value; + bool result = left_value && right_value; + value = ASR::down_cast(ASR::make_LogicalConstant_t( + al, loc, result, type)); + } + tmp = ASR::make_LogicalBinOp_t(al, loc, left, + ASR::logicalbinopType::And, right, type, value); + return; + } else if (AST::is_a(*m_right)) { + // handle chained comparisons + LCOMPILERS_ASSERT(ASRUtils::is_logical(*right_type)); + AST::Compare_t *rc = AST::down_cast(m_right); + compare_helper(loc, m_left, rc->m_left, asr_op); + left = ASRUtils::EXPR(tmp); + left_type = ASRUtils::expr_type(left); + LCOMPILERS_ASSERT(ASRUtils::is_logical(*left_type)); + if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { + bool left_value = ASR::down_cast( + ASRUtils::expr_value(left))->m_value; + bool right_value = ASR::down_cast( + ASRUtils::expr_value(right))->m_value; + bool result = left_value && right_value; + value = ASR::down_cast(ASR::make_LogicalConstant_t( + al, loc, result, type)); + } + tmp = ASR::make_LogicalBinOp_t(al, loc, left, + ASR::logicalbinopType::And, right, type, value); + return; + } std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); diag.add(diag::Diagnostic( @@ -6112,30 +6129,25 @@ class BodyVisitor : public CommonVisitor { ); throw SemanticAbort(); } - ASR::ttype_t *type = ASRUtils::TYPE( - ASR::make_Logical_t(al, x.base.base.loc, 4)); - ASR::expr_t *value = nullptr; - if( ASR::is_a(*dest_type) || ASR::is_a(*dest_type) ) { dest_type = ASRUtils::get_contained_type(dest_type); } - if (ASRUtils::is_array(dest_type)) { ASR::dimension_t* m_dims = nullptr; int n_dims = ASRUtils::extract_dimensions_from_ttype(dest_type, m_dims); int array_size = ASRUtils::get_fixed_size_of_array(m_dims, n_dims); if (array_size == -1) { - throw SemanticError("The truth value of an array is ambiguous. Use a.any() or a.all()", x.base.base.loc); + throw SemanticError("The truth value of an array is ambiguous. Use a.any() or a.all()", loc); } else if (array_size != 1) { - throw SemanticError("The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()", x.base.base.loc); + throw SemanticError("The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()", loc); } else { Vec argsL, argsR; argsL.reserve(al, 1); argsR.reserve(al, 1); for (int i = 0; i < n_dims; i++) { ASR::array_index_t aiL, aiR; - ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, x.base.base.loc, 4)); - ASR::expr_t* const_zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, x.base.base.loc, 0, int_type)); + ASR::ttype_t *int_type = ASRUtils::TYPE(ASR::make_Integer_t(al, loc, 4)); + ASR::expr_t* const_zero = ASRUtils::EXPR(ASR::make_IntegerConstant_t(al, loc, 0, int_type)); aiL.m_right = aiR.m_right = const_zero; aiL.m_left = aiR.m_left = nullptr; aiL.m_step = aiR.m_step = nullptr; @@ -6145,8 +6157,10 @@ class BodyVisitor : public CommonVisitor { argsR.push_back(al, aiR); } dest_type = ASRUtils::type_get_past_array(dest_type); - left = ASRUtils::EXPR(make_ArrayItem_t(al, left->base.loc, left, argsL.p, argsL.n, dest_type, ASR::arraystorageType::RowMajor, nullptr)); - right = ASRUtils::EXPR(make_ArrayItem_t(al, right->base.loc, right, argsR.p, argsR.n, dest_type, ASR::arraystorageType::RowMajor, nullptr)); + left = ASRUtils::EXPR(make_ArrayItem_t(al, left->base.loc, + left, argsL.p, argsL.n, dest_type, ASR::arraystorageType::RowMajor, nullptr)); + right = ASRUtils::EXPR(make_ArrayItem_t(al, right->base.loc, + right, argsR.p, argsR.n, dest_type, ASR::arraystorageType::RowMajor, nullptr)); } } @@ -6166,13 +6180,13 @@ class BodyVisitor : public CommonVisitor { case (ASR::cmpopType::NotEq): { result = left_value != right_value; break; } default: { throw SemanticError("Comparison operator not implemented", - x.base.base.loc); + loc); } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_IntegerCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_IntegerCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASRUtils::is_unsigned_integer(*dest_type)) { if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { int64_t left_value = -1; @@ -6189,13 +6203,13 @@ class BodyVisitor : public CommonVisitor { case (ASR::cmpopType::NotEq): { result = left_value != right_value; break; } default: { throw SemanticError("Comparison operator not implemented", - x.base.base.loc); + loc); } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_UnsignedIntegerCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_UnsignedIntegerCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASRUtils::is_real(*dest_type)) { if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { @@ -6213,14 +6227,14 @@ class BodyVisitor : public CommonVisitor { case (ASR::cmpopType::NotEq): { result = left_value != right_value; break; } default: { throw SemanticError("Comparison operator not implemented", - x.base.base.loc); + loc); } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_RealCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_RealCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASRUtils::is_complex(*dest_type)) { @@ -6246,14 +6260,14 @@ class BodyVisitor : public CommonVisitor { default: { throw SemanticError("'" + ASRUtils::cmpop_to_str(asr_op) + "' comparison is not supported between complex numbers", - x.base.base.loc); + loc); } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_ComplexCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_ComplexCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASRUtils::is_logical(*dest_type)) { @@ -6272,14 +6286,14 @@ class BodyVisitor : public CommonVisitor { case (ASR::cmpopType::NotEq): { result = left_value != right_value; break; } default: { throw SemanticError("Comparison operator not implemented", - x.base.base.loc); + loc); } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_LogicalCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_LogicalCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASRUtils::is_character(*dest_type)) { @@ -6318,50 +6332,77 @@ class BodyVisitor : public CommonVisitor { break; } default: { - throw SemanticError("ICE: Unknown compare operator", x.base.base.loc); // should never happen + throw SemanticError("ICE: Unknown compare operator", loc); // should never happen } } value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, type)); + al, loc, result, type)); } - tmp = ASR::make_StringCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_StringCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq && asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE && asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) { throw SemanticError("Only ==, !=, <, <=, >, >= operators " "are supported for Tuples", - x.base.base.loc); + loc); } - tmp = ASR::make_TupleCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_TupleCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq && asr_op != ASR::cmpopType::Lt && asr_op != ASR::cmpopType::LtE && asr_op != ASR::cmpopType::Gt && asr_op != ASR::cmpopType::GtE) { throw SemanticError("Only ==, !=, <, <=, >, >= operators " "are supported for Lists", - x.base.base.loc); + loc); } - tmp = ASR::make_ListCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_ListCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { if (asr_op != ASR::cmpopType::Eq && asr_op != ASR::cmpopType::NotEq) { throw SemanticError("Only Equal and Not-equal operators are supported for CPtr", - x.base.base.loc); + loc); } - tmp = ASR::make_CPtrCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_CPtrCompare_t(al, loc, left, asr_op, right, type, value); } else if (ASR::is_a(*dest_type)) { - tmp = ASR::make_SymbolicCompare_t(al, x.base.base.loc, left, asr_op, right, type, value); + tmp = ASR::make_SymbolicCompare_t(al, loc, left, asr_op, right, type, value); } else { throw SemanticError("Compare not supported for type: " + ASRUtils::type_to_str_python(dest_type), - x.base.base.loc); + loc); } if (overloaded != nullptr) { - tmp = ASR::make_OverloadedCompare_t(al, x.base.base.loc, left, asr_op, right, type, + tmp = ASR::make_OverloadedCompare_t(al, loc, left, asr_op, right, type, value, overloaded); } } + void visit_Compare(const AST::Compare_t &x) { + if (x.n_comparators > 1) { + diag.add(diag::Diagnostic( + "Only one comparison operator is supported for now", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("multiple comparison operators", + {x.m_comparators[0]->base.loc}) + }) + ); + throw SemanticAbort(); + } + + ASR::cmpopType asr_op; + switch (x.m_ops) { + case (AST::cmpopType::Eq): { asr_op = ASR::cmpopType::Eq; break; } + case (AST::cmpopType::Gt): { asr_op = ASR::cmpopType::Gt; break; } + case (AST::cmpopType::GtE): { asr_op = ASR::cmpopType::GtE; break; } + case (AST::cmpopType::Lt): { asr_op = ASR::cmpopType::Lt; break; } + case (AST::cmpopType::LtE): { asr_op = ASR::cmpopType::LtE; break; } + case (AST::cmpopType::NotEq): { asr_op = ASR::cmpopType::NotEq; break; } + default: { + throw SemanticError("Comparison operator not implemented", + x.base.base.loc); + } + } + compare_helper(x.base.base.loc, x.m_left, x.m_comparators[0], asr_op); + } void visit_ConstantEllipsis(const AST::ConstantEllipsis_t &/*x*/) { tmp = nullptr;