diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index fb78d7a6fa..e89e12ab65 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -755,6 +755,8 @@ RUN(NAME test_platform LABELS cpython llvm c) RUN(NAME test_vars_01 LABELS cpython llvm) RUN(NAME test_version LABELS cpython llvm) RUN(NAME logical_binop1 LABELS cpython llvm) +RUN(NAME test_logical_compare LABELS cpython llvm) +RUN(NAME test_logical_assignment LABELS cpython llvm) RUN(NAME vec_01 LABELS cpython llvm c NOFAST) RUN(NAME test_str_comparison LABELS cpython llvm c wasm) RUN(NAME test_bit_length LABELS cpython llvm c) diff --git a/integration_tests/test_logical_assignment.py b/integration_tests/test_logical_assignment.py new file mode 100644 index 0000000000..152aa0c822 --- /dev/null +++ b/integration_tests/test_logical_assignment.py @@ -0,0 +1,22 @@ +from lpython import i32, f64 + + +def test_logical_assignment(): + # Can be uncommented after fixing the segfault + # _LPYTHON: str = "LPython" + # s_var: str = "" or _LPYTHON + # assert s_var == "LPython" + # print(s_var) + + _MAX_VAL: i32 = 100 + i_var: i32 = 0 and 100 + assert i_var == 0 + print(i_var) + + _PI: f64 = 3.14 + f_var: f64 = 2.0 * _PI or _PI**2.0 + assert f_var == 6.28 + print(f_var) + + +test_logical_assignment() diff --git a/integration_tests/test_logical_compare.py b/integration_tests/test_logical_compare.py new file mode 100644 index 0000000000..497718a13e --- /dev/null +++ b/integration_tests/test_logical_compare.py @@ -0,0 +1,130 @@ +from lpython import i32, f64 + + +def test_logical_compare_literal(): + # Integers + print(1 or 3) + assert (1 or 3) == 1 + + print(1 and 3) + assert (1 and 3) == 3 + + print(2 or 3 or 5 or 6) + assert (2 or 3 or 5 or 6) == 2 + + print(1 and 3 or 2 and 4) + assert (1 and 3 or 2 and 4) == 3 + + print(1 or 3 and 0 or 4) + assert (1 or 3 and 0 or 4) == 1 + + print(1 and 3 or 2 and 0) + assert (1 and 3 or 2 and 0) == 3 + + print(1 and 0 or 3 and 4) + assert (1 and 0 or 3 and 4) == 4 + + # Floating-point numbers + print(1.33 or 6.67) + assert (1.33 or 6.67) == 1.33 + + print(1.33 and 6.67) + assert (1.33 and 6.67) == 6.67 + + print(1.33 or 6.67 and 3.33 or 0.0) + assert (1.33 or 6.67 and 3.33 or 0.0) == 1.33 + + print(1.33 and 6.67 or 3.33 and 0.0) + assert (1.33 and 6.67 or 3.33 and 0.0) == 6.67 + + print(1.33 and 0.0 and 3.33 and 6.67) + assert (1.33 and 0.0 and 3.33 and 6.67) == 0.0 + + # Strings + print("a" or "b") + assert ("a" or "b") == "a" + + print("abc" or "b") + assert ("abc" or "b") == "abc" + + print("a" and "b") + assert ("a" and "b") == "b" + + print("a" or "b" and "c" or "d") + assert ("a" or "b" and "c" or "d") == "a" + + print("" or " ") + assert ("" or " ") == " " + + print("" and " " or "a" and "b" and "c") + assert ("" and " " or "a" and "b" and "c") == "c" + + print("" and " " and "a" and "b" and "c") + assert ("" and " " and "a" and "b" and "c") == "" + + +def test_logical_compare_variable(): + # Integers + i_a: i32 = 1 + i_b: i32 = 3 + + print(i_a and i_b) + assert (i_a and i_b) == 3 + + print(i_a or i_b or 2 or 4) + assert (i_a or i_b or 2 or 4) == 1 + + print(i_a and i_b or 2 and 4) + assert (i_a and i_b or 2 and 4) == 3 + + print(i_a or i_b and 0 or 4) + assert (i_a or i_b and 0 or 4) == i_a + + print(i_a and i_b or 2 and 0) + assert (i_a and i_b or 2 and 0) == i_b + + print(i_a and 0 or i_b and 4) + assert (i_a and 0 or i_b and 4) == 4 + + print(i_a + i_b or 0 - 4) + assert (i_a + i_b or 0 - 4) == 4 + + # Floating-point numbers + f_a: f64 = 1.67 + f_b: f64 = 3.33 + + print(f_a // f_b and f_a - f_b) + assert (f_a // f_b and f_a - f_b) == 0.0 + + print(f_a**3.0 or 3.0**f_a) + assert (f_a**3.0 or 3.0**f_a) == 4.657462999999999 + + print(f_a - 3.0 and f_a + 3.0 or f_b - 3.0 and f_b + 3.0) + assert (f_a - 3.0 and f_a + 3.0 or f_b - 3.0 and f_b + 3.0) == 4.67 + + # Can be uncommented after fixing the segfault + # Strings + # s_a: str = "a" + # s_b: str = "b" + + # print(s_a or s_b) + # assert (s_a or s_b) == s_a + + # print(s_a and s_b) + # assert (s_a and s_b) == s_b + + # print(s_a + s_b or s_b + s_a) + # assert (s_a + s_b or s_b + s_a) == "ab" + + # print(s_a[0] or s_b[-1]) + # assert (s_a[0] or s_b[-1]) == "a" + + # print(s_a[0] and s_b[-1]) + # assert (s_a[0] and s_b[-1]) == "b" + + # print(s_a + s_b or s_b + s_a + s_a[0] and s_b[-1]) + # assert (s_a + s_b or s_b + s_a + s_a[0] and s_b[-1]) == "ab" + + +test_logical_compare_literal() +test_logical_compare_variable() diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index e8f45844ff..1390106be4 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -3326,29 +3326,116 @@ class CommonVisitor : public AST::BaseVisitor { x.base.base.loc); } } - LCOMPILERS_ASSERT( - ASRUtils::check_equal_type(ASRUtils::expr_type(lhs), ASRUtils::expr_type(rhs))); + ASR::ttype_t *left_operand_type = ASRUtils::expr_type(lhs); + ASR::ttype_t *right_operand_type = ASRUtils::expr_type(rhs); + ASR::expr_t *value = nullptr; - ASR::ttype_t *dest_type = ASRUtils::expr_type(lhs); + ASR::ttype_t *dest_type = left_operand_type; + if (!ASRUtils::check_equal_type(left_operand_type, right_operand_type)) { + throw SemanticError("Type mismatch: '" + ASRUtils::type_to_str_python(left_operand_type) + + "' and '" + ASRUtils::type_to_str_python(right_operand_type) + + "'. Both operands must be of the same type.", x.base.base.loc); + } + // Reference: https://docs.python.org/3/reference/expressions.html#boolean-operations if (ASRUtils::expr_value(lhs) != nullptr && ASRUtils::expr_value(rhs) != nullptr) { - - LCOMPILERS_ASSERT(ASR::is_a(*dest_type)); - bool left_value = ASR::down_cast( - ASRUtils::expr_value(lhs))->m_value; - bool right_value = ASR::down_cast( - ASRUtils::expr_value(rhs))->m_value; - bool result; - switch (op) { - case (ASR::logicalbinopType::And): { result = left_value && right_value; break; } - case (ASR::logicalbinopType::Or): { result = left_value || right_value; break; } - default : { - throw SemanticError("Boolean operator type not supported", - x.base.base.loc); + switch (dest_type->type) { + case ASR::ttypeType::Logical: { + bool left_value = ASR::down_cast( + ASRUtils::expr_value(lhs))->m_value; + bool right_value = ASR::down_cast( + ASRUtils::expr_value(rhs))->m_value; + bool result; + switch (op) { + case (ASR::logicalbinopType::And): { result = left_value && right_value; break; } + case (ASR::logicalbinopType::Or): { result = left_value || right_value; break; } + default : { + throw SemanticError("Boolean operator type not supported", + x.base.base.loc); + } + } + value = ASRUtils::EXPR(ASR::make_LogicalConstant_t( + al, x.base.base.loc, result, dest_type)); + break; + } + case ASR::ttypeType::Integer: { + int64_t left_value = ASR::down_cast( + ASRUtils::expr_value(lhs))->m_n; + int64_t right_value = ASR::down_cast( + ASRUtils::expr_value(rhs))->m_n; + int64_t result; + switch (op) { + case (ASR::logicalbinopType::And): { + result = left_value == 0 ? left_value : right_value; + break; + } + case (ASR::logicalbinopType::Or): { + result = left_value != 0 ? left_value : right_value; + break; + } + default : { + throw SemanticError("Boolean operator type not supported", + x.base.base.loc); + } + } + value = ASRUtils::EXPR(ASR::make_IntegerConstant_t( + al, x.base.base.loc, result, dest_type)); + break; + } + case ASR::ttypeType::Real: { + double left_value = ASR::down_cast( + ASRUtils::expr_value(lhs))->m_r; + double right_value = ASR::down_cast( + ASRUtils::expr_value(rhs))->m_r; + double result; + switch (op) { + case (ASR::logicalbinopType::And): { + result = left_value == 0 ? left_value : right_value; + break; + } + case (ASR::logicalbinopType::Or): { + result = left_value != 0 ? left_value : right_value; + break; + } + default : { + throw SemanticError("Boolean operator type not supported", + x.base.base.loc); + } + } + value = ASRUtils::EXPR(ASR::make_RealConstant_t( + al, x.base.base.loc, result, dest_type)); + break; } + case ASR::ttypeType::Character: { + char* left_value = ASR::down_cast( + ASRUtils::expr_value(lhs))->m_s; + char* right_value = ASR::down_cast( + ASRUtils::expr_value(rhs))->m_s; + char* result; + switch (op) { + case (ASR::logicalbinopType::And): { + result = std::strcmp(left_value, "") == 0 ? left_value : right_value; + break; + } + case (ASR::logicalbinopType::Or): { + result = std::strcmp(left_value, "") != 0 ? left_value : right_value; + break; + } + default : { + throw SemanticError("Boolean operator type not supported", + x.base.base.loc); + } + } + value = ASRUtils::EXPR(ASR::make_StringConstant_t( + al, x.base.base.loc, result, dest_type)); + break; + } + + default: + throw SemanticError("Boolean operation not supported on objects of type '" + + ASRUtils::type_to_str_python(dest_type) + "'", + x.base.base.loc); } - value = ASR::down_cast(ASR::make_LogicalConstant_t( - al, x.base.base.loc, result, dest_type)); } tmp = ASR::make_LogicalBinOp_t(al, x.base.base.loc, lhs, op, rhs, dest_type, value); } @@ -7586,7 +7673,6 @@ we will have to use something else. } Vec args_; args_.reserve(al, x.n_args); visit_expr_list(x.m_args, x.n_args, args_); - if (x.n_args > 0 && ASRUtils::is_array(ASRUtils::expr_type(args_[0])) && imported_functions[call_name] == "math" ) { throw SemanticError("Function '" + call_name + "' does not accept vector values", diff --git a/tests/errors/test_logical_compare_01.py b/tests/errors/test_logical_compare_01.py new file mode 100644 index 0000000000..83ef1f0a8d --- /dev/null +++ b/tests/errors/test_logical_compare_01.py @@ -0,0 +1,5 @@ +def f(): + print("hello" or 10) + + +f() diff --git a/tests/reference/asr-test_logical_compare_01-5db0e2e.json b/tests/reference/asr-test_logical_compare_01-5db0e2e.json new file mode 100644 index 0000000000..7030df8c0e --- /dev/null +++ b/tests/reference/asr-test_logical_compare_01-5db0e2e.json @@ -0,0 +1,13 @@ +{ + "basename": "asr-test_logical_compare_01-5db0e2e", + "cmd": "lpython --show-asr --no-color {infile} -o {outfile}", + "infile": "tests/errors/test_logical_compare_01.py", + "infile_hash": "467dc216d8ce90f4b3a1ec06610cea226ae96152763cfa42d5ab8f33", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "asr-test_logical_compare_01-5db0e2e.stderr", + "stderr_hash": "d10cac68687315b5d29828e0acb5170f44bd91dd30784f8bd4943bb0", + "returncode": 2 +} \ No newline at end of file diff --git a/tests/reference/asr-test_logical_compare_01-5db0e2e.stderr b/tests/reference/asr-test_logical_compare_01-5db0e2e.stderr new file mode 100644 index 0000000000..c1e876782c --- /dev/null +++ b/tests/reference/asr-test_logical_compare_01-5db0e2e.stderr @@ -0,0 +1,5 @@ +semantic error: Type mismatch: 'str' and 'i32'. Both operands must be of the same type. + --> tests/errors/test_logical_compare_01.py:2:11 + | +2 | print("hello" or 10) + | ^^^^^^^^^^^^^ diff --git a/tests/tests.toml b/tests/tests.toml index 2691050814..0ea59119d2 100644 --- a/tests/tests.toml +++ b/tests/tests.toml @@ -1259,6 +1259,10 @@ asr = true filename = "errors/loop_03.py" asr = true +[[test]] +filename = "errors/test_logical_compare_01.py" +asr = true + [[test]] filename = "errors/bindc_01.py" asr = true