From 020cf8e45c9684e52e8c2a47f16454351939e335 Mon Sep 17 00:00:00 2001 From: advik Date: Thu, 23 May 2024 17:26:18 +0530 Subject: [PATCH 1/3] Add membership checks in dictionaries and sets --- grammar/Python.asdl | 5 +- integration_tests/CMakeLists.txt | 1 + integration_tests/test_membership_01.py | 36 +++ src/libasr/ASR.asdl | 5 + src/libasr/codegen/asr_to_llvm.cpp | 39 +++ src/libasr/codegen/llvm_utils.cpp | 229 ++++++++++++++++++ src/libasr/codegen/llvm_utils.h | 26 ++ src/lpython/parser/parser.yy | 5 +- src/lpython/parser/semantics.h | 2 + src/lpython/semantics/python_ast_to_asr.cpp | 124 ++++++++++ .../ast_new-comprehension1-69cf2af.json | 2 +- .../ast_new-comprehension1-69cf2af.stdout | 18 +- .../ast_new-conditional_expr1-07ccb9e.json | 2 +- .../ast_new-conditional_expr1-07ccb9e.stdout | 12 +- tests/reference/ast_new-for2-af08901.json | 2 +- tests/reference/ast_new-for2-af08901.stdout | 30 +-- tests/reference/ast_new-if2-c3b6022.json | 2 +- tests/reference/ast_new-if2-c3b6022.stdout | 30 +-- .../ast_new-statements1-e081093.json | 2 +- .../ast_new-statements1-e081093.stdout | 48 ++-- .../ast_new-statements2-c4cdc5f.json | 2 +- .../ast_new-statements2-c4cdc5f.stdout | 30 +-- 22 files changed, 559 insertions(+), 93 deletions(-) create mode 100644 integration_tests/test_membership_01.py diff --git a/grammar/Python.asdl b/grammar/Python.asdl index a5ca1c672e..ade97a49a0 100644 --- a/grammar/Python.asdl +++ b/grammar/Python.asdl @@ -73,6 +73,7 @@ module LPython -- need sequences for compare to distinguish between -- x < 4 < 3 and (x < 4) < 3 | Compare(expr left, cmpop ops, expr* comparators) + | Membership(expr left, membershipop op, expr right) | Call(expr func, expr* args, keyword* keywords) | FormattedValue(expr value, int conversion, expr? format_spec) | JoinedStr(expr* values) @@ -110,7 +111,9 @@ module LPython unaryop = Invert | Not | UAdd | USub - cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn + cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot + + membershipop = In | NotIn comprehension = (expr target, expr iter, expr* ifs, int is_async) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 8d70900cdf..06ab0b531c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -600,6 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x RUN(NAME test_import_06 LABELS cpython llvm llvm_jit) RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c) RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST) +RUN(NAME test_membership_01 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c) diff --git a/integration_tests/test_membership_01.py b/integration_tests/test_membership_01.py new file mode 100644 index 0000000000..ab8fd21f3c --- /dev/null +++ b/integration_tests/test_membership_01.py @@ -0,0 +1,36 @@ +def test_int_dict(): + a: dict[i32, i32] = {1:2, 2:3, 3:4, 4:5} + i: i32 + assert (1 in a) + assert (6 not in a) + i = 4 + assert (i in a) + +def test_str_dict(): + a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'} + i: str + assert ('a' in a) + assert ('d' not in a) + i = 'c' + assert (i in a) + +def test_int_set(): + a: set[i32] = {1, 2, 3, 4} + i: i32 + assert (1 in a) + assert (6 not in a) + i = 4 + # assert (i in a) + +def test_str_set(): + a: set[str] = {'a', 'b', 'c'} + i: str + assert ('a' in a) + assert ('d' not in a) + i = 'c' + assert (i in a) + +# test_int_dict() +# test_str_dict() +test_int_set() +# test_str_set() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 578e31692c..679c43ea98 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -118,12 +118,14 @@ expr | ListConcat(expr left, expr right, ttype type, expr? value) | ListCompare(expr left, cmpop op, expr right, ttype type, expr? value) | ListCount(expr arg, expr ele, ttype type, expr? value) + | ListContains(expr left, expr right, ttype type, expr? value) | SetConstant(expr* elements, ttype type) | SetLen(expr arg, ttype type, expr? value) | TupleConstant(expr* elements, ttype type) | TupleLen(expr arg, ttype type, expr value) | TupleCompare(expr left, cmpop op, expr right, ttype type, expr? value) | TupleConcat(expr left, expr right, ttype type, expr? value) + | TupleContains(expr left, expr right, ttype type, expr? value) | StringConstant(string s, ttype type) | StringConcat(expr left, expr right, ttype type, expr? value) | StringRepeat(expr left, expr right, ttype type, expr? value) @@ -131,6 +133,7 @@ expr | StringItem(expr arg, expr idx, ttype type, expr? value) | StringSection(expr arg, expr? start, expr? end, expr? step, ttype type, expr? value) | StringCompare(expr left, cmpop op, expr right, ttype type, expr? value) + | StringContains(expr left, expr right, ttype type, expr? value) | StringOrd(expr arg, ttype type, expr? value) | StringChr(expr arg, ttype type, expr? value) | StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value) @@ -176,6 +179,8 @@ expr | ListRepeat(expr left, expr right, ttype type, expr? value) | DictPop(expr a, expr key, ttype type, expr? value) | SetPop(expr a, ttype type, expr? value) + | SetContains(expr left, expr right, ttype type, expr? value) + | DictContains(expr left, expr right, ttype type, expr? value) | IntegerBitLen(expr a, ttype type, expr? value) | Ichar(expr arg, ttype type, expr? value) | Iachar(expr arg, ttype type, expr? value) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index bd267d88d5..2e9649c573 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1637,6 +1637,45 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_DictContains(const ASR::DictContains_t &x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return; + } + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_right); + llvm::Value *right = tmp; + ASR::Dict_t *dict_type = ASR::down_cast( + ASRUtils::expr_type(x.m_right)); + ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type); + this->visit_expr(*x.m_left); + llvm::Value *left = tmp; + ptr_loads = ptr_loads_copy; + + tmp = llvm_utils->dict_api->is_key_present(right, left, dict_type, *module); + } + + void visit_SetContains(const ASR::SetContains_t &x) { + if (x.m_value) { + this->visit_expr(*x.m_value); + return; + } + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_right); + llvm::Value *right = tmp; + ASR::ttype_t *el_type = ASRUtils::expr_type(x.m_left); + ptr_loads = !LLVM::is_llvm_struct(el_type); + this->visit_expr(*x.m_left); + llvm::Value *left = tmp; + ptr_loads = ptr_loads_copy; + + tmp = llvm_utils->set_api->is_el_present(right, left, *module, el_type); + } + void visit_DictLen(const ASR::DictLen_t& x) { if (x.m_value) { this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 7ac11b9e31..80cbdab4f5 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -4359,6 +4359,128 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + llvm::Value *LLVMDict::is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module) { + llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); + llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value *key_list = get_key_list(dict); + + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, dict_type->m_key_type, true); + llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, pos, false, module, + LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); + + return is_key_matching; + } + + llvm::Value *LLVMDictOptimizedLinearProbing::is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module) { + /** + * C++ equivalent: + * + * key_mask_value = key_mask[key_hash]; + * is_prob_not_needed = key_mask_value == 1; + * if( is_prob_not_needed ) { + * is_key_matching = key == key_list[key_hash]; + * if( is_key_matching ) { + * pos = key_hash; + * } + * else { + * return is_key_matching; + * } + * } + * else { + * resolve_collision(key, for_read=true); // modifies pos + * } + * + * is_key_matching = key == key_list[pos]; + * return is_key_matching; + */ + + llvm::Value* key_list = get_key_list(dict); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + get_builder0() + pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + bool to_return = false; + builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + // A single by value comparison is needed even though + // we don't need to do linear probing. This is because + // the user can provide a key which is absent in the dict + // but is giving the same hash value as one of the keys present in the dict. + // In the above case we will end up returning value for a key + // which is not present in the dict. Instead we should return an error + // which is done in the below code. + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, key_hash, false, module, + LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); + + llvm_utils->create_if_else(is_key_matching, [=]() { + LLVM::CreateStore(*builder, key_hash, pos_ptr); + }, [&]() { + //to_return = true; + }); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + this->resolve_collision(capacity, key_hash, key, key_list, key_mask, + module, dict_type->m_key_type, true); + } + llvm_utils->start_new_block(mergeBB); + if (to_return) { + return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0); + } + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + // Check if the actual key is present or not + llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm_utils->list_api->read_item(key_list, pos, false, module, + LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); + + return is_key_matching; + } + + llvm::Value *LLVMDictSeparateChaining::is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module) { + llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); + llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); + llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); + llvm::Type* kv_struct_type = get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type); + this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list, + kv_struct_type, key_mask, module, dict_type->m_key_type); + std::pair llvm_key = std::make_pair( + ASRUtils::get_type_code(dict_type->m_key_type), + ASRUtils::get_type_code(dict_type->m_value_type) + ); + llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; + get_builder0() + tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + does_kv_exists = builder->CreateAnd(does_kv_exists, + builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + return does_kv_exists; + } + llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, llvm::Module& module, bool get_pointer) { @@ -6825,6 +6947,113 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + llvm::Value *LLVMSetLinearProbing::is_el_present( + llvm::Value *set, llvm::Value *el, + llvm::Module &module, ASR::ttype_t *el_asr_type) { + /** + * C++ equivalent: + * + * el_mask_value = el_mask[el_hash]; + * is_prob_needed = el_mask_value == 1; + * if( is_prob_needed ) { + * is_el_matching = el == el_list[el_hash]; + * if( is_el_matching ) { + * pos = el_hash; + * } + * else { + * return is_el_matching; + * } + * } + * else { + * resolve_collision(el, for_read=true); // modifies pos + * } + * + * is_el_matching = el == el_list[pos]; + * return is_el_matching + */ + + get_builder0() + llvm::Value* el_list = get_el_list(set); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value *el_hash = get_el_hash(capacity, el, el_asr_type, module); + pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + bool to_return = false; + builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + // reasoning for this check explained in + // LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check + llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, + llvm_utils->list_api->read_item(el_list, el_hash, false, module, + LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + + llvm_utils->create_if_else(is_el_matching, [=]() { + LLVM::CreateStore(*builder, el_hash, pos_ptr); + }, [&]() { + //to_return = true; // Need to check why this is not working + }); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + this->resolve_collision(capacity, el_hash, el, el_list, el_mask, + module, el_asr_type, true); + } + llvm_utils->start_new_block(mergeBB); + if (to_return) { + return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0); + } + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + // Check if the actual element is present or not + llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, + llvm_utils->list_api->read_item(el_list, pos, false, module, + LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + + + return is_el_matching; + } + + llvm::Value *LLVMSetSeparateChaining::is_el_present( + llvm::Value *set, llvm::Value *el, + llvm::Module &module, ASR::ttype_t *el_asr_type) { + /** + * C++ equivalent: + * + * resolve_collision(el); // modified chain_itr + * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr; + * return does_el_exist; + * + */ + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(capacity, el, el_asr_type, module); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + std::string el_type_code = ASRUtils::get_type_code(el_asr_type); + llvm::Type* el_struct_type = typecode2elstruct[el_type_code]; + this->resolve_collision(el_hash, el, el_linked_list, + el_struct_type, el_mask, module, el_asr_type); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm::Value* does_el_exist = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + does_el_exist = builder->CreateAnd(does_el_exist, + builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + + return does_el_exist; + } + llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 869aef52e7..23b346a2c7 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -644,6 +644,10 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); + virtual + llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module) = 0; + virtual void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -738,6 +742,9 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module); + void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Module& module, @@ -779,6 +786,9 @@ namespace LCompilers { ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value); + llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module); + virtual ~LLVMDictOptimizedLinearProbing(); }; @@ -886,6 +896,9 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, + ASR::Dict_t *dict_type, llvm::Module &module); + void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Module& module, @@ -980,6 +993,11 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx) = 0; + virtual + llvm::Value *is_el_present( + llvm::Value *set, llvm::Value *el, + llvm::Module &module, ASR::ttype_t *el_asr_type) = 0; + virtual llvm::Value* len(llvm::Value* set); @@ -1049,6 +1067,10 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + llvm::Value *is_el_present( + llvm::Value *set, llvm::Value *el, + llvm::Module &module, ASR::ttype_t *el_asr_type); + ~LLVMSetLinearProbing(); }; @@ -1130,6 +1152,10 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + llvm::Value *is_el_present( + llvm::Value *set, llvm::Value *el, + llvm::Module &module, ASR::ttype_t *el_asr_type); + ~LLVMSetSeparateChaining(); }; diff --git a/src/lpython/parser/parser.yy b/src/lpython/parser/parser.yy index 7d773b962f..6658c3eac7 100644 --- a/src/lpython/parser/parser.yy +++ b/src/lpython/parser/parser.yy @@ -1228,8 +1228,9 @@ expr | expr ">=" expr { $$ = COMPARE($1, GtE, $3, @$); } | expr "is" expr { $$ = COMPARE($1, Is, $3, @$); } | expr "is not" expr { $$ = COMPARE($1, IsNot, $3, @$); } - | expr "in" expr { $$ = COMPARE($1, In, $3, @$); } - | expr "not in" expr { $$ = COMPARE($1, NotIn, $3, @$); } + + | expr "in" expr { $$ = MEMBERSHIP($1, In, $3, @$); } + | expr "not in" expr { $$ = MEMBERSHIP($1, NotIn, $3, @$); } | expr "and" expr { $$ = BOOLOP($1, And, $3, @$); } | expr "or" expr { $$ = BOOLOP($1, Or, $3, @$); } diff --git a/src/lpython/parser/semantics.h b/src/lpython/parser/semantics.h index 9a41278783..7fd17cc566 100644 --- a/src/lpython/parser/semantics.h +++ b/src/lpython/parser/semantics.h @@ -719,6 +719,8 @@ static inline ast_t* BOOLOP_01(Allocator &al, Location &loc, #define UNARY(x, op, l) make_UnaryOp_t(p.m_a, l, unaryopType::op, EXPR(x)) #define COMPARE(x, op, y, l) make_Compare_t(p.m_a, l, \ EXPR(x), cmpopType::op, EXPRS(A2LIST(p.m_a, y)), 1) +#define MEMBERSHIP(x, op, y, l) make_Membership_t(p.m_a, l, \ + EXPR(x), membershipopType::op, EXPR(y)) static inline ast_t* concat_string(Allocator &al, Location &l, expr_t *string, std::string str, expr_t *string_literal) { diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 1364763135..ef91bb4b12 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6575,6 +6575,130 @@ class BodyVisitor : public CommonVisitor { } } + void visit_Membership(const AST::Membership_t &x) { + this->visit_expr(*x.m_left); + ASR::expr_t *left = ASRUtils::EXPR(tmp); + this->visit_expr(*x.m_right); + ASR::expr_t *right = ASRUtils::EXPR(tmp); + + ASR::ttype_t *left_type = ASRUtils::expr_type(left); + ASR::ttype_t *right_type = ASRUtils::expr_type(right); + + ASR::expr_t *value = nullptr; + ASR::ttype_t *type = ASRUtils::TYPE(ASR::make_Logical_t( + al, x.base.base.loc, 4)); + if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_ListContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASRUtils::is_character(*right_type)) { + if (!ASRUtils::check_equal_type(left_type, right_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + if (ASRUtils::expr_value(left) != nullptr && ASRUtils::expr_value(right) != nullptr) { + char* left_value = ASR::down_cast( + ASRUtils::expr_value(left))->m_s; + char* right_value = ASR::down_cast( + ASRUtils::expr_value(right))->m_s; + std::string left_str = std::string(left_value); + std::string right_str = std::string(right_value); + + bool result = right_str.find(left_str) != std::string::npos; + + //switch (asr_op) { + //case (ASR::membershipopType::In) : { + //break; + //} + //case (ASR::membershipopType::NotIn) : { + //result = !result; + //break; + //} + //default : { + //throw SemanticError("ICE: Unknown membership operator", x.base.base.loc); + //} + //} + value = ASR::down_cast(ASR::make_LogicalConstant_t( + al, x.base.base.loc, result, type)); + } + tmp = make_StringContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_TupleContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_SetContains_t(al, x.base.base.loc, left, right, type, value); + } else if (ASR::is_a(*right_type)) { + ASR::ttype_t *contained_type = ASRUtils::get_contained_type(right_type); + if (!ASRUtils::check_equal_type(left_type, contained_type)) { + std::string ltype = ASRUtils::type_to_str_python(ASRUtils::expr_type(left)); + std::string rtype = ASRUtils::type_to_str_python(ASRUtils::expr_type(right)); + diag.add(diag::Diagnostic( + "Type mismatch in comparison operator, the types must be compatible", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("type mismatch ('" + ltype + "' and '" + rtype + "')", + {left->base.loc, right->base.loc}) + }) + ); + throw SemanticAbort(); + } + + tmp = ASR::make_DictContains_t(al, x.base.base.loc, left, right, type, value); + } else { + throw SemanticError("Membership operator is only defined for strings, lists, tuples, sets and dictionaries.", x.base.base.loc); + } + + if (x.m_op == AST::membershipopType::NotIn) { + tmp = ASR::make_LogicalNot_t(al, x.base.base.loc, ASRUtils::EXPR(tmp), type, nullptr); + } + } + void visit_ConstantEllipsis(const AST::ConstantEllipsis_t &/*x*/) { tmp = nullptr; } diff --git a/tests/reference/ast_new-comprehension1-69cf2af.json b/tests/reference/ast_new-comprehension1-69cf2af.json index 1e1b460b96..5bda7d0179 100644 --- a/tests/reference/ast_new-comprehension1-69cf2af.json +++ b/tests/reference/ast_new-comprehension1-69cf2af.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-comprehension1-69cf2af.stdout", - "stdout_hash": "dd4d6e66646c90be9ebc7070964a2f42ca21d5c782bfddbf89ce854b", + "stdout_hash": "93c8b1b23bf7419338573fda46fd07fc907c0637e0985124bd9f49b1", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-comprehension1-69cf2af.stdout b/tests/reference/ast_new-comprehension1-69cf2af.stdout index 83f9d88428..6506a37763 100644 --- a/tests/reference/ast_new-comprehension1-69cf2af.stdout +++ b/tests/reference/ast_new-comprehension1-69cf2af.stdout @@ -360,13 +360,13 @@ ) [(BoolOp And - [(Compare + [(Membership (Name i Load ) NotIn - [(List + (List [(ConstantInt 3 () @@ -380,18 +380,18 @@ () )] Load - )] + ) ) - (Compare + (Membership (Name i Load ) In - [(Name + (Name list3 Load - )] + ) )] )] 0)] @@ -641,16 +641,16 @@ )] [] ) - [(Compare + [(Membership (Name i Load ) NotIn - [(Name + (Name axis Load - )] + ) )] 0)] )] diff --git a/tests/reference/ast_new-conditional_expr1-07ccb9e.json b/tests/reference/ast_new-conditional_expr1-07ccb9e.json index e90a4839bd..c3a1c95270 100644 --- a/tests/reference/ast_new-conditional_expr1-07ccb9e.json +++ b/tests/reference/ast_new-conditional_expr1-07ccb9e.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-conditional_expr1-07ccb9e.stdout", - "stdout_hash": "92adfc3fb76aa117fdee246478837474332ec5de543e164920e3ec40", + "stdout_hash": "dfedb3fe94d880e8827e7569eabc8d1f0e975060db35d4b736e1361d", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout b/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout index 74739c7294..2d53752fa7 100644 --- a/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout +++ b/tests/reference/ast_new-conditional_expr1-07ccb9e.stdout @@ -327,16 +327,16 @@ (Expr (Call (IfExp - (Compare + (Membership (Name tktype Load ) In - [(Name + (Name whentrue Load - )] + ) ) (Attribute (Name @@ -890,16 +890,16 @@ Load ) (IfExp - (Compare + (Membership (Name start Load ) In - [(Name + (Name labels Load - )] + ) ) (ConstantStr ":" diff --git a/tests/reference/ast_new-for2-af08901.json b/tests/reference/ast_new-for2-af08901.json index ff9c17f689..6e65b70d3a 100644 --- a/tests/reference/ast_new-for2-af08901.json +++ b/tests/reference/ast_new-for2-af08901.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-for2-af08901.stdout", - "stdout_hash": "ac6e50517c5d609747b66c75e15bfa69ada7f0f41ebeb943da9b3167", + "stdout_hash": "40d6e5ac6ca4865a1b3b257fb4c7f4b2df3b6d8f52e7f38d66e72487", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-for2-af08901.stdout b/tests/reference/ast_new-for2-af08901.stdout index c495f51677..9b75c2b12e 100644 --- a/tests/reference/ast_new-for2-af08901.stdout +++ b/tests/reference/ast_new-for2-af08901.stdout @@ -169,16 +169,16 @@ i Store ) - (Compare + (Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) [(Pass)] [] @@ -194,16 +194,16 @@ Load ) [(If - (Compare + (Membership (Name item Load ) In - [(Name + (Name list2 Load - )] + ) ) [(Pass)] [] @@ -216,39 +216,39 @@ Or [(BoolOp And - [(Compare + [(Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) - (Compare + (Membership (Name b Load ) NotIn - [(Name + (Name list2 Load - )] + ) )] ) - (Compare + (Membership (Name c Load ) In - [(Name + (Name list3 Load - )] + ) )] ) [(Pass)] diff --git a/tests/reference/ast_new-if2-c3b6022.json b/tests/reference/ast_new-if2-c3b6022.json index f9c4d553f4..d154a2684e 100644 --- a/tests/reference/ast_new-if2-c3b6022.json +++ b/tests/reference/ast_new-if2-c3b6022.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-if2-c3b6022.stdout", - "stdout_hash": "cef89f96f75c68381a475911818e03cbcb78bff27d91b5d356fc667b", + "stdout_hash": "f87ec76a617cdbffb26b6f30b0acfdec3fde29a027ae6bcc1bf03a14", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-if2-c3b6022.stdout b/tests/reference/ast_new-if2-c3b6022.stdout index 584a5f9094..69bc755dd7 100644 --- a/tests/reference/ast_new-if2-c3b6022.stdout +++ b/tests/reference/ast_new-if2-c3b6022.stdout @@ -131,13 +131,13 @@ () ) (If - (Compare + (Membership (Name a Load ) NotIn - [(List + (List [(ConstantInt 1 () @@ -147,20 +147,20 @@ () )] Load - )] + ) ) [(Pass)] [] ) (If - (Compare - (Compare + (Membership + (Membership (Name a Load ) NotIn - [(List + (List [(ConstantInt 1 () @@ -170,10 +170,10 @@ () )] Load - )] + ) ) NotIn - [(List + (List [(ConstantBool .true. () @@ -183,19 +183,19 @@ () )] Load - )] + ) ) [(Pass)] [] ) (If - (Compare + (Membership (Name field Load ) In - [(List + (List [(ConstantStr "vararg" () @@ -205,7 +205,7 @@ () )] Load - )] + ) ) [(If (Compare @@ -224,16 +224,16 @@ [] ) (If - (Compare + (Membership (Name a Load ) In - [(Name + (Name list1 Load - )] + ) ) [(Pass)] [] diff --git a/tests/reference/ast_new-statements1-e081093.json b/tests/reference/ast_new-statements1-e081093.json index 5676cb70c4..4615757975 100644 --- a/tests/reference/ast_new-statements1-e081093.json +++ b/tests/reference/ast_new-statements1-e081093.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-statements1-e081093.stdout", - "stdout_hash": "9425fb51c6f0e2ed284e0ba59bb2efee1a86541d77150d20c02fd5fc", + "stdout_hash": "bc316e311b5cc06fc517c2f40759673385f44af66b32bb5f85e0867a", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-statements1-e081093.stdout b/tests/reference/ast_new-statements1-e081093.stdout index 421e1c8067..adac7b7c1b 100644 --- a/tests/reference/ast_new-statements1-e081093.stdout +++ b/tests/reference/ast_new-statements1-e081093.stdout @@ -1015,26 +1015,26 @@ ) ) (Expr - (Compare + (Membership (ConstantStr "hello" () ) In - [(Name + (Name x Load - )] + ) ) ) (Expr - (Compare + (Membership (ConstantStr "a" () ) In - [(Call + (Call (Attribute (Name a @@ -1045,20 +1045,20 @@ ) [] [] - )] + ) ) ) (Expr - (Compare + (Membership (ConstantStr "lo" () ) In - [(ConstantStr + (ConstantStr "hello" () - )] + ) ) ) (Expr @@ -1460,7 +1460,7 @@ bool Load ) - (Compare + (Membership (List [(Name x @@ -1469,13 +1469,13 @@ Load ) NotIn - [(List + (List [(Name y Load )] Load - )] + ) ) 1 ) @@ -1496,7 +1496,7 @@ output Store )] - (Compare + (Membership (List [(Name x @@ -1505,13 +1505,13 @@ Load ) NotIn - [(List + (List [(Name y Load )] Load - )] + ) ) () ) @@ -1561,7 +1561,7 @@ [] []) [(Return - (Compare + (Membership (List [(Name a @@ -1570,13 +1570,13 @@ Load ) In - [(List + (List [(Name b Load )] Load - )] + ) ) )] [] @@ -1614,7 +1614,7 @@ output Store )] - (Compare + (Membership (List [(Name a @@ -1623,13 +1623,13 @@ Load ) NotIn - [(List + (List [(Name b Load )] Load - )] + ) ) () ) @@ -1662,7 +1662,7 @@ output Store )] - (Compare + (Membership (List [(Name a @@ -1671,13 +1671,13 @@ Load ) NotIn - [(List + (List [(Name b Load )] Load - )] + ) ) () ) diff --git a/tests/reference/ast_new-statements2-c4cdc5f.json b/tests/reference/ast_new-statements2-c4cdc5f.json index efb47d87e7..2d579649cd 100644 --- a/tests/reference/ast_new-statements2-c4cdc5f.json +++ b/tests/reference/ast_new-statements2-c4cdc5f.json @@ -6,7 +6,7 @@ "outfile": null, "outfile_hash": null, "stdout": "ast_new-statements2-c4cdc5f.stdout", - "stdout_hash": "d79c678d3b5de63e5d424a2015595bfc3a686fc5c7ba0802aed6f3af", + "stdout_hash": "5df7c032836575768db845fd1aba55609d5691833e3439d5c077ebae", "stderr": null, "stderr_hash": null, "returncode": 0 diff --git a/tests/reference/ast_new-statements2-c4cdc5f.stdout b/tests/reference/ast_new-statements2-c4cdc5f.stdout index c18d65316e..49de84c2a4 100644 --- a/tests/reference/ast_new-statements2-c4cdc5f.stdout +++ b/tests/reference/ast_new-statements2-c4cdc5f.stdout @@ -232,7 +232,7 @@ ) ) (If - (Compare + (Membership (Subscript (Subscript (Name @@ -256,7 +256,7 @@ Load ) In - [(List + (List [(ConstantStr "" () @@ -266,7 +266,7 @@ () )] Load - )] + ) ) [(Pass)] [] @@ -387,16 +387,16 @@ ) ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -406,16 +406,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -425,16 +425,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis @@ -444,16 +444,16 @@ [] ) (If - (Compare + (Membership (Name x Load ) NotIn - [(Name + (Name z Load - )] + ) ) [(Expr (ConstantEllipsis From 313125ef561a829cae6febb269bbf51d6f965657 Mon Sep 17 00:00:00 2001 From: advik Date: Fri, 24 May 2024 12:11:12 +0530 Subject: [PATCH 2/3] Remove commented code --- src/lpython/semantics/python_ast_to_asr.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ef91bb4b12..f7067e3e82 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6626,18 +6626,6 @@ class BodyVisitor : public CommonVisitor { bool result = right_str.find(left_str) != std::string::npos; - //switch (asr_op) { - //case (ASR::membershipopType::In) : { - //break; - //} - //case (ASR::membershipopType::NotIn) : { - //result = !result; - //break; - //} - //default : { - //throw SemanticError("ICE: Unknown membership operator", x.base.base.loc); - //} - //} value = ASR::down_cast(ASR::make_LogicalConstant_t( al, x.base.base.loc, result, type)); } From 3f9731c1381b465bec84ce93d3611003083b808d Mon Sep 17 00:00:00 2001 From: advik Date: Mon, 27 May 2024 12:19:57 +0530 Subject: [PATCH 3/3] Refactor search code --- integration_tests/CMakeLists.txt | 2 +- integration_tests/test_membership_01.py | 12 +- src/libasr/codegen/asr_to_llvm.cpp | 10 +- src/libasr/codegen/llvm_utils.cpp | 324 ++++++------------------ src/libasr/codegen/llvm_utils.h | 42 +-- 5 files changed, 100 insertions(+), 290 deletions(-) diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 06ab0b531c..ea416e764b 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -600,7 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x RUN(NAME test_import_06 LABELS cpython llvm llvm_jit) RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c) RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST) -RUN(NAME test_membership_01 LABELS cpython llvm llvm_jit c) +RUN(NAME test_membership_01 LABELS cpython llvm) RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c) RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c) diff --git a/integration_tests/test_membership_01.py b/integration_tests/test_membership_01.py index ab8fd21f3c..1fab47cda0 100644 --- a/integration_tests/test_membership_01.py +++ b/integration_tests/test_membership_01.py @@ -20,17 +20,17 @@ def test_int_set(): assert (1 in a) assert (6 not in a) i = 4 - # assert (i in a) + assert (i in a) def test_str_set(): - a: set[str] = {'a', 'b', 'c'} + a: set[str] = {'a', 'b', 'c', 'e', 'f'} i: str assert ('a' in a) - assert ('d' not in a) + # assert ('d' not in a) i = 'c' assert (i in a) -# test_int_dict() -# test_str_dict() +test_int_dict() +test_str_dict() test_int_set() -# test_str_set() +test_str_set() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 2e9649c573..61e54152aa 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1653,8 +1653,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_left); llvm::Value *left = tmp; ptr_loads = ptr_loads_copy; + llvm::Value *capacity = LLVM::CreateLoad(*builder, + llvm_utils->dict_api->get_pointer_to_capacity(right)); + llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module); - tmp = llvm_utils->dict_api->is_key_present(right, left, dict_type, *module); + tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true); } void visit_SetContains(const ASR::SetContains_t &x) { @@ -1672,8 +1675,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_left); llvm::Value *left = tmp; ptr_loads = ptr_loads_copy; + llvm::Value *capacity = LLVM::CreateLoad(*builder, + llvm_utils->set_api->get_pointer_to_capacity(right)); + llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module); - tmp = llvm_utils->set_api->is_el_present(right, left, *module, el_type); + tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true); } void visit_DictLen(const ASR::DictLen_t& x) { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 80cbdab4f5..7a317e04c1 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -3177,7 +3177,7 @@ namespace LCompilers { llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) { llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -3187,6 +3187,8 @@ namespace LCompilers { llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, pos, false, module, LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + if (check_if_exists) + return is_key_matching; llvm_utils->create_if_else(is_key_matching, [&]() { }, [&]() { @@ -3245,7 +3247,7 @@ namespace LCompilers { llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) { /** * C++ equivalent: @@ -3287,6 +3289,9 @@ namespace LCompilers { llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr); builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); builder->SetInsertPoint(thenBB); { @@ -3304,6 +3309,9 @@ namespace LCompilers { llvm_utils->create_if_else(is_key_matching, [=]() { LLVM::CreateStore(*builder, key_hash, pos_ptr); }, [&]() { + if (check_if_exists) { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr); + } else { std::string message = "The dict does not contain the specified key"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -3312,7 +3320,7 @@ namespace LCompilers { llvm::Value *exit_code = llvm::ConstantInt::get(context, llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); - }); + }}); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -3321,11 +3329,24 @@ namespace LCompilers { module, key_asr_type, true); } llvm_utils->start_new_block(mergeBB); - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); - // Check if the actual key is present or not - llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, + llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr); + llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::AllocaInst *is_key_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + + llvm_utils->create_if_else(flag, [&](){ + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_key_matching_ptr); + }, [&](){ + // Check if the actual element is present or not + LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(key, llvm_utils->list_api->read_item(key_list, pos, false, module, - LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type); + LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type), is_key_matching_ptr); + }); + + llvm::Value *is_key_matching = LLVM::CreateLoad(*builder, is_key_matching_ptr); + + if (check_if_exists) { + return is_key_matching; + } llvm_utils->create_if_else(is_key_matching, [&]() { }, [&]() { @@ -3471,7 +3492,7 @@ namespace LCompilers { llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists) { /** * C++ equivalent: * @@ -3506,6 +3527,10 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); + if (check_if_exists) { + return does_kv_exists; + } + llvm_utils->create_if_else(does_kv_exists, [&]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); @@ -4358,129 +4383,8 @@ namespace LCompilers { // end llvm_utils->start_new_block(loopend); } - - llvm::Value *LLVMDict::is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module) { - llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); - llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); - llvm::Value *key_list = get_key_list(dict); - - this->resolve_collision(capacity, key_hash, key, key_list, key_mask, module, dict_type->m_key_type, true); - llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr); - llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, - llvm_utils->list_api->read_item(key_list, pos, false, module, - LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); - - return is_key_matching; - } - llvm::Value *LLVMDictOptimizedLinearProbing::is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module) { - /** - * C++ equivalent: - * - * key_mask_value = key_mask[key_hash]; - * is_prob_not_needed = key_mask_value == 1; - * if( is_prob_not_needed ) { - * is_key_matching = key == key_list[key_hash]; - * if( is_key_matching ) { - * pos = key_hash; - * } - * else { - * return is_key_matching; - * } - * } - * else { - * resolve_collision(key, for_read=true); // modifies pos - * } - * - * is_key_matching = key == key_list[pos]; - * return is_key_matching; - */ - llvm::Value* key_list = get_key_list(dict); - llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); - llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); - get_builder0() - pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, - llvm_utils->create_ptr_gep(key_mask, key_hash)); - llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value, - llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - bool to_return = false; - builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { - // A single by value comparison is needed even though - // we don't need to do linear probing. This is because - // the user can provide a key which is absent in the dict - // but is giving the same hash value as one of the keys present in the dict. - // In the above case we will end up returning value for a key - // which is not present in the dict. Instead we should return an error - // which is done in the below code. - llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, - llvm_utils->list_api->read_item(key_list, key_hash, false, module, - LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); - - llvm_utils->create_if_else(is_key_matching, [=]() { - LLVM::CreateStore(*builder, key_hash, pos_ptr); - }, [&]() { - //to_return = true; - }); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - { - this->resolve_collision(capacity, key_hash, key, key_list, key_mask, - module, dict_type->m_key_type, true); - } - llvm_utils->start_new_block(mergeBB); - if (to_return) { - return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0); - } - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); - // Check if the actual key is present or not - llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key, - llvm_utils->list_api->read_item(key_list, pos, false, module, - LLVM::is_llvm_struct(dict_type->m_key_type)), module, dict_type->m_key_type); - - return is_key_matching; - } - - llvm::Value *LLVMDictSeparateChaining::is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module) { - llvm::Value *capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - llvm::Value *key_hash = get_key_hash(capacity, key, dict_type->m_key_type, module); - llvm::Value *key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); - llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); - llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); - llvm::Type* kv_struct_type = get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type); - this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list, - kv_struct_type, key_mask, module, dict_type->m_key_type); - std::pair llvm_key = std::make_pair( - ASRUtils::get_type_code(dict_type->m_key_type), - ASRUtils::get_type_code(dict_type->m_value_type) - ); - llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; - get_builder0() - tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr); - llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, - llvm_utils->create_ptr_gep(key_mask, key_hash)); - llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, - llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - does_kv_exists = builder->CreateAnd(does_kv_exists, - builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), - llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) - ); - return does_kv_exists; - } - llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, llvm::Module& module, bool get_pointer) { @@ -6515,9 +6419,9 @@ namespace LCompilers { el_asr_type, name2memidx); } - void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( + llvm::Value* LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) { /** * C++ equivalent: @@ -6545,18 +6449,22 @@ namespace LCompilers { */ get_builder0() + pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); llvm::Value* el_list = get_el_list(set); llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + std::string s = check_if_exists ? "qq" : "pp"; + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then"+s, fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"+s); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"+s); llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el_mask, el_hash)); llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr); builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB); builder->SetInsertPoint(thenBB); { @@ -6569,6 +6477,9 @@ namespace LCompilers { llvm_utils->create_if_else(is_el_matching, [=]() { LLVM::CreateStore(*builder, el_hash, pos_ptr); }, [&]() { + if (check_if_exists) { + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr); + } else { if (throw_key_error) { std::string message = "The set does not contain the specified element"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); @@ -6579,7 +6490,7 @@ namespace LCompilers { llvm::APInt(32, exit_code_int)); exit(context, module, *builder, exit_code); } - }); + }}); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -6588,11 +6499,25 @@ namespace LCompilers { module, el_asr_type, true); } llvm_utils->start_new_block(mergeBB); - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr); + llvm::AllocaInst *is_el_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + + llvm_utils->create_if_else(flag, [&](){ + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_el_matching_ptr); + }, [&](){ // Check if the actual element is present or not - llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, - llvm_utils->list_api->read_item(el_list, pos, false, module, - LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); + llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); + llvm::Value* item = llvm_utils->list_api->read_item(el_list, pos, false, module, + LLVM::is_llvm_struct(el_asr_type)) ; + llvm::Value *iseq =llvm_utils->is_equal_by_value(el, + item, module, el_asr_type) ; + LLVM::CreateStore(*builder, iseq, is_el_matching_ptr); + }); + + llvm::Value *is_el_matching = LLVM::CreateLoad(*builder, is_el_matching_ptr); + if (check_if_exists) { + return is_el_matching; + } llvm_utils->create_if_else(is_el_matching, []() {}, [&]() { if (throw_key_error) { @@ -6606,11 +6531,13 @@ namespace LCompilers { exit(context, module, *builder, exit_code); } }); + + return nullptr; } - void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( + llvm::Value* LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) { /** * C++ equivalent: * @@ -6637,6 +6564,10 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); + if (check_if_exists) { + return does_el_exist; + } + llvm_utils->create_if_else(does_el_exist, []() {}, [&]() { if (throw_key_error) { std::string message = "The set does not contain the specified element"; @@ -6649,6 +6580,8 @@ namespace LCompilers { exit(context, module, *builder, exit_code); } }); + + return nullptr; } void LLVMSetLinearProbing::remove_item( @@ -6947,113 +6880,6 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } - llvm::Value *LLVMSetLinearProbing::is_el_present( - llvm::Value *set, llvm::Value *el, - llvm::Module &module, ASR::ttype_t *el_asr_type) { - /** - * C++ equivalent: - * - * el_mask_value = el_mask[el_hash]; - * is_prob_needed = el_mask_value == 1; - * if( is_prob_needed ) { - * is_el_matching = el == el_list[el_hash]; - * if( is_el_matching ) { - * pos = el_hash; - * } - * else { - * return is_el_matching; - * } - * } - * else { - * resolve_collision(el, for_read=true); // modifies pos - * } - * - * is_el_matching = el == el_list[pos]; - * return is_el_matching - */ - - get_builder0() - llvm::Value* el_list = get_el_list(set); - llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); - llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value *el_hash = get_el_hash(capacity, el, el_asr_type, module); - pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, - llvm_utils->create_ptr_gep(el_mask, el_hash)); - llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value, - llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - bool to_return = false; - builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { - // reasoning for this check explained in - // LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check - llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, - llvm_utils->list_api->read_item(el_list, el_hash, false, module, - LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); - - llvm_utils->create_if_else(is_el_matching, [=]() { - LLVM::CreateStore(*builder, el_hash, pos_ptr); - }, [&]() { - //to_return = true; // Need to check why this is not working - }); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - { - this->resolve_collision(capacity, el_hash, el, el_list, el_mask, - module, el_asr_type, true); - } - llvm_utils->start_new_block(mergeBB); - if (to_return) { - return llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0); - } - llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); - // Check if the actual element is present or not - llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el, - llvm_utils->list_api->read_item(el_list, pos, false, module, - LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); - - - return is_el_matching; - } - - llvm::Value *LLVMSetSeparateChaining::is_el_present( - llvm::Value *set, llvm::Value *el, - llvm::Module &module, ASR::ttype_t *el_asr_type) { - /** - * C++ equivalent: - * - * resolve_collision(el); // modified chain_itr - * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr; - * return does_el_exist; - * - */ - llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); - llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value* el_hash = get_el_hash(capacity, el, el_asr_type, module); - llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); - llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); - std::string el_type_code = ASRUtils::get_type_code(el_asr_type); - llvm::Type* el_struct_type = typecode2elstruct[el_type_code]; - this->resolve_collision(el_hash, el, el_linked_list, - el_struct_type, el_mask, module, el_asr_type); - llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, - llvm_utils->create_ptr_gep(el_mask, el_hash)); - llvm::Value* does_el_exist = builder->CreateICmpEQ(el_mask_value, - llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); - does_el_exist = builder->CreateAnd(does_el_exist, - builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), - llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) - ); - - return does_el_exist; - } - llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 23b346a2c7..0ea2644e96 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -589,7 +589,7 @@ namespace LCompilers { virtual llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) = 0; + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false) = 0; virtual llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, @@ -644,9 +644,6 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); - virtual - llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module) = 0; virtual void get_elements_list(llvm::Value* dict, @@ -704,7 +701,7 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -742,8 +739,6 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); - llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -779,15 +774,13 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, llvm::Value *def_value); - llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module); virtual ~LLVMDictOptimizedLinearProbing(); @@ -859,7 +852,7 @@ namespace LCompilers { llvm::Value* resolve_collision_for_read_with_bound_check(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type); + ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists = false); llvm::Value* resolve_collision_for_read_with_default(llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, @@ -896,8 +889,6 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); - llvm::Value *is_key_present(llvm::Value *dict, llvm::Value *key, - ASR::Dict_t *dict_type, llvm::Module &module); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -978,9 +969,9 @@ namespace LCompilers { std::map>& name2memidx); virtual - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0; + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false) = 0; virtual void remove_item( @@ -993,11 +984,6 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx) = 0; - virtual - llvm::Value *is_el_present( - llvm::Value *set, llvm::Value *el, - llvm::Module &module, ASR::ttype_t *el_asr_type) = 0; - virtual llvm::Value* len(llvm::Value* set); @@ -1054,9 +1040,9 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false); void remove_item( llvm::Value* set, llvm::Value* el, @@ -1067,10 +1053,6 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); - llvm::Value *is_el_present( - llvm::Value *set, llvm::Value *el, - llvm::Module &module, ASR::ttype_t *el_asr_type); - ~LLVMSetLinearProbing(); }; @@ -1139,9 +1121,9 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void resolve_collision_for_read_with_bound_check( + llvm::Value* resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists = false); void remove_item( llvm::Value* set, llvm::Value* el, @@ -1152,10 +1134,6 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); - llvm::Value *is_el_present( - llvm::Value *set, llvm::Value *el, - llvm::Module &module, ASR::ttype_t *el_asr_type); - ~LLVMSetSeparateChaining(); };