diff --git a/grammar/Python.asdl b/grammar/Python.asdl index a5ca1c672e..ade97a49a0 100644 --- a/grammar/Python.asdl +++ b/grammar/Python.asdl @@ -73,6 +73,7 @@ module LPython -- need sequences for compare to distinguish between -- x < 4 < 3 and (x < 4) < 3 | Compare(expr left, cmpop ops, expr* comparators) + | Membership(expr left, membershipop op, expr right) | Call(expr func, expr* args, keyword* keywords) | FormattedValue(expr value, int conversion, expr? format_spec) | JoinedStr(expr* values) @@ -110,7 +111,9 @@ module LPython unaryop = Invert | Not | UAdd | USub - cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn + cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot + + membershipop = In | NotIn comprehension = (expr target, expr iter, expr* ifs, int is_async) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 8d70900cdf..ea416e764b 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -600,6 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x RUN(NAME test_import_06 LABELS cpython llvm llvm_jit) RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c) RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST) +RUN(NAME test_membership_01 LABELS cpython llvm) RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c) diff --git a/integration_tests/test_membership_01.py b/integration_tests/test_membership_01.py new file mode 100644 index 0000000000..1fab47cda0 --- /dev/null +++ b/integration_tests/test_membership_01.py @@ -0,0 +1,36 @@ +def test_int_dict(): + a: dict[i32, i32] = {1:2, 2:3, 3:4, 4:5} + i: i32 + assert (1 in a) + assert (6 not in a) + i = 4 + assert (i in a) + +def test_str_dict(): + a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'} + i: str + assert ('a' in a) + assert ('d' not in a) + i = 'c' + assert (i in a) + +def test_int_set(): + a: set[i32] = {1, 2, 3, 4} + i: i32 + assert (1 in a) + assert (6 not in a) + i = 4 + assert (i in a) + +def test_str_set(): + a: set[str] = {'a', 'b', 'c', 'e', 'f'} + i: str + assert ('a' in a) + # assert ('d' not in a) + i = 'c' + assert (i in a) + +test_int_dict() +test_str_dict() +test_int_set() +test_str_set() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 578e31692c..679c43ea98 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -118,12 +118,14 @@ expr | ListConcat(expr left, expr right, ttype type, expr? value) | ListCompare(expr left, cmpop op, expr right, ttype type, expr? value) | ListCount(expr arg, expr ele, ttype type, expr? value) + | ListContains(expr left, expr right, ttype type, expr? value) | SetConstant(expr* elements, ttype type) | SetLen(expr arg, ttype type, expr? value) | TupleConstant(expr* elements, ttype type) | TupleLen(expr arg, ttype type, expr value) | TupleCompare(expr left, cmpop op, expr right, ttype type, expr? value) | TupleConcat(expr left, expr right, ttype type, expr? value) + | TupleContains(expr left, expr right, ttype type, expr? value) | StringConstant(string s, ttype type) | StringConcat(expr left, expr right, ttype type, expr? value) | StringRepeat(expr left, expr right, ttype type, expr? value) @@ -131,6 +133,7 @@ expr | StringItem(expr arg, expr idx, ttype type, expr? value) | StringSection(expr arg, expr? start, expr? end, expr? step, ttype type, expr? value) | StringCompare(expr left, cmpop op, expr right, ttype type, expr? value) + | StringContains(expr left, expr right, ttype type, expr? value) | StringOrd(expr arg, ttype type, expr? value) | StringChr(expr arg, ttype type, expr? value) | StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value) @@ -176,6 +179,8 @@ expr | ListRepeat(expr left, expr right, ttype type, expr? value) | DictPop(expr a, expr key, ttype type, expr? value) | SetPop(expr a, ttype type, expr? value) + | SetContains(expr left, expr right, ttype type, expr? value) + | DictContains(expr left, expr right, ttype type, expr? value) | IntegerBitLen(expr a, ttype type, expr? value) | Ichar(expr arg, ttype type, expr? value) | Iachar(expr arg, ttype type, expr? value) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index bd267d88d5..61e54152aa 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1637,6 +1637,51 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_DictContains(const ASR::DictContains_t &x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return; + } + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_right); + llvm::Value *right = tmp; + ASR::Dict_t *dict_type = ASR::down_cast( + ASRUtils::expr_type(x.m_right)); + ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type); + this->visit_expr(*x.m_left); + llvm::Value *left = tmp; + ptr_loads = ptr_loads_copy; + llvm::Value *capacity = LLVM::CreateLoad(*builder, + llvm_utils->dict_api->get_pointer_to_capacity(right)); + llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module); + + tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true); + } + + void visit_SetContains(const ASR::SetContains_t &x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return; + } + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_right); + llvm::Value *right = tmp; + ASR::ttype_t *el_type = ASRUtils::expr_type(x.m_left); + ptr_loads = !LLVM::is_llvm_struct(el_type); + this->visit_expr(*x.m_left); + llvm::Value *left = tmp; + ptr_loads = ptr_loads_copy; + llvm::Value *capacity = LLVM::CreateLoad(*builder, + llvm_utils->set_api->get_pointer_to_capacity(right)); + llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module); + + tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true); + } + void visit_DictLen(const ASR::DictLen_t& x) { if (x.m_value) { this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 7ac11b9e31..7a317e04c1 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -3177,7 +3177,7 @@ namespace LCompilers { llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -3187,6 +3187,8 @@ namespace LCompilers { llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, pos, false, module, LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + if (check_if_exists) + return is_key_matching; llvm_utils->create_if_else(is_key_matching, [&]() { }, [&]() { @@ -3245,7 +3247,7 @@ namespace LCompilers { llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) { /** * C++ equivalent: @@ -3287,6 +3289,9 @@ namespace LCompilers { llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr); builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); builder->SetInsertPoint(thenBB); { @@ -3304,6 +3309,9 @@ namespace LCompilers { llvm_utils->create_if_else(is_key_matching, [=]() { LLVM::CreateStore(*builder, key_hash, pos_ptr); }, [&]() { + if (check_if_exists) { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr); + } else { std::string message = "The dict does not contain the specified key"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -3312,7 +3320,7 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - }); + }}); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -3321,11 +3329,24 @@ namespace LCompilers { module, key_asr_type, true); } llvm_utils->start_new_block(mergeBB); - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); - // Check if the actual key is present or not - llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr); + llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::AllocaInst *is_key_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + + llvm_utils->create_if_else(flag, [&](){ + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_key_matching_ptr); + }, [&](){ + // Check if the actual element is present or not + LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, pos, false, module, - LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type), is_key_matching_ptr); + }); + + llvm::Value *is_key_matching = LLVM::CreateLoad(*builder, is_key_matching_ptr); + + if (check_if_exists) { + return is_key_matching; + } llvm_utils->create_if_else(is_key_matching, [&]() { }, [&]() { @@ -3471,7 +3492,7 @@ namespace LCompilers { llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists) { /** * C++ equivalent: * @@ -3506,6 +3527,10 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); + if (check_if_exists) { + return does_kv_exists; + } + llvm_utils->create_if_else(does_kv_exists, [&]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); @@ -4358,6 +4383,7 @@ namespace LCompilers { // end llvm_utils->start_new_block(loopend); } + llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, @@ -6393,9 +6419,9 @@ namespace LCompilers { el_asr_type, name2memidx); } - void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( + llvm::Value* LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) { /** * C++ equivalent: @@ -6423,18 +6449,22 @@ namespace LCompilers { */ get_builder0() + pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); llvm::Value* el_list = get_el_list(set); llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + std::string s = check_if_exists ? "qq" : "pp"; + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then"+s, fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"+s); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"+s); llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el_mask, el_hash)); llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr); builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB); builder->SetInsertPoint(thenBB); { @@ -6447,6 +6477,9 @@ namespace LCompilers { llvm_utils->create_if_else(is_el_matching, [=]() { LLVM::CreateStore(*builder, el_hash, pos_ptr); }, [&]() { + if (check_if_exists) { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr); + } else { if (throw_key_error) { std::string message = "The set does not contain the specified element"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); @@ -6457,7 +6490,7 @@ namespace LCompilers { llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); } - }); + }}); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -6466,11 +6499,25 @@ namespace LCompilers { module, el_asr_type, true); } llvm_utils->start_new_block(mergeBB); - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr); + llvm::AllocaInst *is_el_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + + llvm_utils->create_if_else(flag, [&](){ + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_el_matching_ptr); + }, [&](){ // Check if the actual element is present or not - llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, - llvm_utils->list_api->read_item(el_list, pos, false, module, - LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* item = llvm_utils->list_api->read_item(el_list, pos, false, module, + LLVM::is_llvm_struct(el_asr_type)) ; + llvm::Value *iseq =llvm_utils->is_equal_by_value(el, + item, module, el_asr_type) ; + LLVM::CreateStore(*builder, iseq, is_el_matching_ptr); + }); + + llvm::Value *is_el_matching = LLVM::CreateLoad(*builder, is_el_matching_ptr); + if (check_if_exists) { + return is_el_matching; + } llvm_utils->create_if_else(is_el_matching, []() {}, [&]() { if (throw_key_error) { @@ -6484,11 +6531,13 @@ namespace LCompilers { exit(context, module, *builder, exit_code); } }); + + return nullptr; } - void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( + llvm::Value* LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) { /** * C++ equivalent: * @@ -6515,6 +6564,10 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); + if (check_if_exists) { + return does_el_exist; + } + llvm_utils->create_if_else(does_el_exist, []() {}, [&]() { if (throw_key_error) { std::string message = "The set does not contain the specified element"; @@ -6527,6 +6580,8 @@ namespace LCompilers { exit(context, module, *builder, exit_code); } }); + + return nullptr; } void LLVMSetLinearProbing::remove_item( diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 869aef52e7..0ea2644e96 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -589,7 +589,7 @@ namespace LCompilers { virtual llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false) = 0; virtual llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, @@ -644,6 +644,7 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); + virtual void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -700,7 +701,7 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -738,6 +739,7 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Module& module, @@ -772,13 +774,14 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value); + virtual ~LLVMDictOptimizedLinearProbing(); }; @@ -849,7 +852,7 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -886,6 +889,7 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Module& module, @@ -965,9 +969,9 @@ namespace LCompilers { std::map>& name2memidx); virtual - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0; + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false) = 0; virtual void remove_item( @@ -1036,9 +1040,9 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false); void remove_item( llvm::Value* set, llvm::Value* el, @@ -1117,9 +1121,9 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false); void remove_item( llvm::Value* set, llvm::Value* el, diff --git a/src/lpython/parser/parser.yy b/src/lpython/parser/parser.yy index 7d773b962f..6658c3eac7 100644 --- a/src/lpython/parser/parser.yy +++ b/src/lpython/parser/parser.yy @@ -1228,8 +1228,9 @@ expr | expr ">=" expr { $$ = COMPARE($1, GtE, $3, @$); } | expr "is" expr { $$ = COMPARE($1, Is, $3, @$); } | expr "is not" expr { $$ = COMPARE($1, IsNot, $3, @$); } - | expr "in" expr { $$ = COMPARE($1, In, $3, @$); } - | expr "not in" expr { $$ = COMPARE($1, NotIn, $3, @$); } + + | expr "in" expr { $$ = MEMBERSHIP($1, In, $3, @$); } + | expr "not in" expr { $$ = MEMBERSHIP($1, NotIn, $3, @$); } | expr "and" expr { $$ = BOOLOP($1, And, $3, @$); } | expr "or" expr { $$ = BOOLOP($1, Or, $3, @$); } diff --git a/src/lpython/parser/semantics.h b/src/lpython/parser/semantics.h index 9a41278783..7fd17cc566 100644 --- a/src/lpython/parser/semantics.h +++ b/src/lpython/parser/semantics.h @@ -719,6 +719,8 @@ static inline ast_t* BOOLOP_01(Allocator &al, Location &loc, #define UNARY(x, op, l) make_UnaryOp_t(p.m_a, l, unaryopType::op, EXPR(x)) #define COMPARE(x, op, y, l) make_Compare_t(p.m_a, l, \ EXPR(x), cmpopType::op, EXPRS(A2LIST(p.m_a, y)), 1) +#define MEMBERSHIP(x, op, y, l) make_Membership_t(p.m_a, l, \ + EXPR(x), membershipopType::op, EXPR(y)) static inline ast_t* concat_string(Allocator &al, Location &l, expr_t *string, std::string str, expr_t *string_literal) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 1364763135..f7067e3e82 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6575,6 +6575,118 @@ class BodyVisitor : public CommonVisitor { } } + void visit_Membership(const AST::Membership_t &x) { + this->visit_expr(*x.m_left); + ASR::expr_t *left = ASRUtils::EXPR(tmp); + this->visit_expr(*x.m_right); + ASR::expr_t *right = ASRUtils::EXPR(tmp); + + ASR::ttype_t *left_type = ASRUtils::expr_type(left); + ASR::ttype_t *right_type = ASRUtils::expr_type(right); + + ASR::expr_t *value = nullptr; + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Logical_t( + al, x.base.base.loc, 4)); + if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_ListContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASRUtils::is_character(*right_type)) { + if (!ASRUtils::check_equal_type(left_type, right_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { + char* left_value = ASR::down_cast( + ASRUtils::expr_value(left))->m_s; + char* right_value = ASR::down_cast( + ASRUtils::expr_value(right))->m_s; + std::string left_str = std::string(left_value); + std::string right_str = std::string(right_value); + + bool result = right_str.find(left_str) != std::string::npos; + + value = ASR::down_cast(ASR::make_LogicalConstant_t( + al, x.base.base.loc, result, type)); + } + tmp = make_StringContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_TupleContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_SetContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_DictContains_t(al, x.base.base.loc, left, right, type, value); + } else { + throw SemanticError("Membership operator is only defined for strings, lists, tuples, sets and dictionaries.", x.base.base.loc); + } + + if (x.m_op == AST::membershipopType::NotIn) { + tmp = ASR::make_LogicalNot_t(al, x.base.base.loc, ASRUtils::EXPR(tmp), type, nullptr); + } + } + void visit_ConstantEllipsis(const AST::ConstantEllipsis_t &/*x*/) { tmp = nullptr; } diff --git a/tests/reference/ast_new-comprehension1-69cf2af.json b/tests/reference/ast_new-comprehension1-69cf2af.json index 1e1b460b96..5bda7d0179 100644 --- a/tests/reference/ast_new-comprehension1-69cf2af.json +++ b/tests/reference/ast_new-comprehension1-69cf2af.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-comprehension1-69cf2af.stdout", - "stdout_hash": "dd4d6e66646c90be9ebc7070964a2f42ca21d5c782bfddbf89ce854b", + "stdout_hash": "93c8b1b23bf7419338573fda46fd07fc907c0637e0985124bd9f49b1", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-comprehension1-69cf2af.stdout b/tests/reference/ast_new-comprehension1-69cf2af.stdout index 83f9d88428..6506a37763 100644 --- a/tests/reference/ast_new-comprehension1-69cf2af.stdout +++ b/tests/reference/ast_new-comprehension1-69cf2af.stdout @@ -360,13 +360,13 @@ ) [(BoolOp And - [(Compare + [(Membership (Name i Load ) NotIn - [(List + (List [(ConstantInt 3 () @@ -380,18 +380,18 @@ () )] Load - )] + ) ) - (Compare + (Membership (Name i Load ) In - [(Name + (Name list3 Load - )] + ) )] )] 0)] @@ -641,16 +641,16 @@ )] [] ) - [(Compare + [(Membership (Name i Load ) NotIn - [(Name + (Name axis Load - )] + ) )] 0)] )] diff --git a/tests/reference/ast_new-conditional_expr1-07ccb9e.json b/tests/reference/ast_new-conditional_expr1-07ccb9e.json index e90a4839bd..c3a1c95270 100644 --- a/tests/reference/ast_new-conditional_expr1-07ccb9e.json +++ b/tests/reference/ast_new-conditional_expr1-07ccb9e.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-conditional_expr1-07ccb9e.stdout", - "stdout_hash": "92adfc3fb76aa117fdee246478837474332ec5de543e164920e3ec40", + "stdout_hash": "dfedb3fe94d880e8827e7569eabc8d1f0e975060db35d4b736e1361d", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout b/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout index 74739c7294..2d53752fa7 100644 --- a/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout +++ b/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout @@ -327,16 +327,16 @@ (Expr (Call (IfExp - (Compare + (Membership (Name tktype Load ) In - [(Name + (Name whentrue Load - )] + ) ) (Attribute (Name @@ -890,16 +890,16 @@ Load ) (IfExp - (Compare + (Membership (Name start Load ) In - [(Name + (Name labels Load - )] + ) ) (ConstantStr ":" diff --git a/tests/reference/ast_new-for2-af08901.json b/tests/reference/ast_new-for2-af08901.json index ff9c17f689..6e65b70d3a 100644 --- a/tests/reference/ast_new-for2-af08901.json +++ b/tests/reference/ast_new-for2-af08901.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-for2-af08901.stdout", - "stdout_hash": "ac6e50517c5d609747b66c75e15bfa69ada7f0f41ebeb943da9b3167", + "stdout_hash": "40d6e5ac6ca4865a1b3b257fb4c7f4b2df3b6d8f52e7f38d66e72487", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-for2-af08901.stdout b/tests/reference/ast_new-for2-af08901.stdout index c495f51677..9b75c2b12e 100644 --- a/tests/reference/ast_new-for2-af08901.stdout +++ b/tests/reference/ast_new-for2-af08901.stdout @@ -169,16 +169,16 @@ i Store ) - (Compare + (Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) [(Pass)] [] @@ -194,16 +194,16 @@ Load ) [(If - (Compare + (Membership (Name item Load ) In - [(Name + (Name list2 Load - )] + ) ) [(Pass)] [] @@ -216,39 +216,39 @@ Or [(BoolOp And - [(Compare + [(Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) - (Compare + (Membership (Name b Load ) NotIn - [(Name + (Name list2 Load - )] + ) )] ) - (Compare + (Membership (Name c Load ) In - [(Name + (Name list3 Load - )] + ) )] ) [(Pass)] diff --git a/tests/reference/ast_new-if2-c3b6022.json b/tests/reference/ast_new-if2-c3b6022.json index f9c4d553f4..d154a2684e 100644 --- a/tests/reference/ast_new-if2-c3b6022.json +++ b/tests/reference/ast_new-if2-c3b6022.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-if2-c3b6022.stdout", - "stdout_hash": "cef89f96f75c68381a475911818e03cbcb78bff27d91b5d356fc667b", + "stdout_hash": "f87ec76a617cdbffb26b6f30b0acfdec3fde29a027ae6bcc1bf03a14", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-if2-c3b6022.stdout b/tests/reference/ast_new-if2-c3b6022.stdout index 584a5f9094..69bc755dd7 100644 --- a/tests/reference/ast_new-if2-c3b6022.stdout +++ b/tests/reference/ast_new-if2-c3b6022.stdout @@ -131,13 +131,13 @@ () ) (If - (Compare + (Membership (Name a Load ) NotIn - [(List + (List [(ConstantInt 1 () @@ -147,20 +147,20 @@ () )] Load - )] + ) ) [(Pass)] [] ) (If - (Compare - (Compare + (Membership + (Membership (Name a Load ) NotIn - [(List + (List [(ConstantInt 1 () @@ -170,10 +170,10 @@ () )] Load - )] + ) ) NotIn - [(List + (List [(ConstantBool .true. () @@ -183,19 +183,19 @@ () )] Load - )] + ) ) [(Pass)] [] ) (If - (Compare + (Membership (Name field Load ) In - [(List + (List [(ConstantStr "vararg" () @@ -205,7 +205,7 @@ () )] Load - )] + ) ) [(If (Compare @@ -224,16 +224,16 @@ [] ) (If - (Compare + (Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) [(Pass)] [] diff --git a/tests/reference/ast_new-statements1-e081093.json b/tests/reference/ast_new-statements1-e081093.json index 5676cb70c4..4615757975 100644 --- a/tests/reference/ast_new-statements1-e081093.json +++ b/tests/reference/ast_new-statements1-e081093.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-statements1-e081093.stdout", - "stdout_hash": "9425fb51c6f0e2ed284e0ba59bb2efee1a86541d77150d20c02fd5fc", + "stdout_hash": "bc316e311b5cc06fc517c2f40759673385f44af66b32bb5f85e0867a", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-statements1-e081093.stdout b/tests/reference/ast_new-statements1-e081093.stdout index 421e1c8067..adac7b7c1b 100644 --- a/tests/reference/ast_new-statements1-e081093.stdout +++ b/tests/reference/ast_new-statements1-e081093.stdout @@ -1015,26 +1015,26 @@ ) ) (Expr - (Compare + (Membership (ConstantStr "hello" () ) In - [(Name + (Name x Load - )] + ) ) ) (Expr - (Compare + (Membership (ConstantStr "a" () ) In - [(Call + (Call (Attribute (Name a @@ -1045,20 +1045,20 @@ ) [] [] - )] + ) ) ) (Expr - (Compare + (Membership (ConstantStr "lo" () ) In - [(ConstantStr + (ConstantStr "hello" () - )] + ) ) ) (Expr @@ -1460,7 +1460,7 @@ bool Load ) - (Compare + (Membership (List [(Name x @@ -1469,13 +1469,13 @@ Load ) NotIn - [(List + (List [(Name y Load )] Load - )] + ) ) 1 ) @@ -1496,7 +1496,7 @@ output Store )] - (Compare + (Membership (List [(Name x @@ -1505,13 +1505,13 @@ Load ) NotIn - [(List + (List [(Name y Load )] Load - )] + ) ) () ) @@ -1561,7 +1561,7 @@ [] []) [(Return - (Compare + (Membership (List [(Name a @@ -1570,13 +1570,13 @@ Load ) In - [(List + (List [(Name b Load )] Load - )] + ) ) )] [] @@ -1614,7 +1614,7 @@ output Store )] - (Compare + (Membership (List [(Name a @@ -1623,13 +1623,13 @@ Load ) NotIn - [(List + (List [(Name b Load )] Load - )] + ) ) () ) @@ -1662,7 +1662,7 @@ output Store )] - (Compare + (Membership (List [(Name a @@ -1671,13 +1671,13 @@ Load ) NotIn - [(List + (List [(Name b Load )] Load - )] + ) ) () ) diff --git a/tests/reference/ast_new-statements2-c4cdc5f.json b/tests/reference/ast_new-statements2-c4cdc5f.json index efb47d87e7..2d579649cd 100644 --- a/tests/reference/ast_new-statements2-c4cdc5f.json +++ b/tests/reference/ast_new-statements2-c4cdc5f.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-statements2-c4cdc5f.stdout", - "stdout_hash": "d79c678d3b5de63e5d424a2015595bfc3a686fc5c7ba0802aed6f3af", + "stdout_hash": "5df7c032836575768db845fd1aba55609d5691833e3439d5c077ebae", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-statements2-c4cdc5f.stdout b/tests/reference/ast_new-statements2-c4cdc5f.stdout index c18d65316e..49de84c2a4 100644 --- a/tests/reference/ast_new-statements2-c4cdc5f.stdout +++ b/tests/reference/ast_new-statements2-c4cdc5f.stdout @@ -232,7 +232,7 @@ ) ) (If - (Compare + (Membership (Subscript (Subscript (Name @@ -256,7 +256,7 @@ Load ) In - [(List + (List [(ConstantStr "" () @@ -266,7 +266,7 @@ () )] Load - )] + ) ) [(Pass)] [] @@ -387,16 +387,16 @@ ) ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -406,16 +406,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -425,16 +425,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -444,16 +444,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis