diff --git a/integration_tests/test_dict_14.py b/integration_tests/test_dict_14.py new file mode 100644 index 0000000000..4fe91c687f --- /dev/null +++ b/integration_tests/test_dict_14.py @@ -0,0 +1,65 @@ +from lpython import i32 + +def test_dict(): + d_i32: dict[i32, i32] = {5: 1, 5: 2} + d_str: dict[str, i32] = {'a': 1, 'a': 2} + l_str_1: list[str] = [] + l_str_2: list[str] = [] + l_i32_1: list[i32] = [] + l_i32_2: list[i32] = [] + i: i32 + s: str + + assert len(d_i32) == 1 + d_i32.pop(5) + assert len(d_i32) == 0 + + assert len(d_str) == 1 + d_str.pop('a') + assert len(d_str) == 0 + + d_str = {'a': 2, 'a': 2, 'b': 2, 'c': 3, 'a': 5} + assert len(d_str) == 3 + d_str.pop('a') + assert len(d_str) == 2 + d_str.pop('b') + assert len(d_str) == 1 + + d_str['a'] = 20 + assert len(d_str) == 2 + d_str.pop('c') + assert len(d_str) == 1 + + l_str_1 = d_str.keys() + for s in l_str_1: + l_str_2.append(s) + assert l_str_2 == ['a'] + l_i32_1 = d_str.values() + for i in l_i32_1: + l_i32_2.append(i) + assert l_i32_2 == [20] + + d_i32 = {5: 2, 5: 2, 6: 2, 7: 3, 5: 5} + assert len(d_i32) == 3 + d_i32.pop(5) + assert len(d_i32) == 2 + d_i32.pop(6) + assert len(d_i32) == 1 + + d_i32[6] = 30 + assert len(d_i32) == 2 + d_i32.pop(7) + assert len(d_i32) == 1 + + l_i32_1 = d_i32.keys() + l_i32_2.clear() + for i in l_i32_1: + l_i32_2.append(i) + assert l_i32_2 == [6] + l_i32_1 = d_i32.values() + l_i32_2.clear() + for i in l_i32_1: + l_i32_2.append(i) + assert l_i32_2 == [30] + +test_dict() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 76b9a40291..de3e53d272 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1993,7 +1993,7 @@ namespace LCompilers { void LLVMDictSeparateChaining::dict_init(std::string key_type_code, std::string value_type_code, llvm::Value* dict, llvm::Module* module, size_t initial_capacity) { - llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, initial_capacity + 1)); + llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, initial_capacity)); llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(dict); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 1)), rehash_flag_ptr); dict_init_given_initial_capacity(key_type_code, value_type_code, dict, module, llvm_capacity); @@ -2637,17 +2637,30 @@ namespace LCompilers { /** * C++ equivalent: * - * is_key_matching = 1; + * chain_itr_prev = nullptr; * - * while( chain_itr != nullptr && is_key_matching ) { - * break_signal = key != kv_key; - * is_key_matching = break_signal; // 1 means not matching - * if( break_signal ) { - * chain_itr = next_kv_struct; + * ll_exists = key_mask_value == 1; + * if( ll_exists ) { + * chain_itr = ll_head; + * } + * else { + * chain_itr = nullptr; + * } + * is_key_matching = 0; + * + * while( chain_itr != nullptr && !is_key_matching ) { + * is_key_matching = (key == kv_struct_key); + * if( !is_key_matching ) { + * // update for next iteration + * chain_itr_prev = chain_itr; + * chain_itr = next_kv_struct; // (*chain_itr)[2] * } * } * + * // now, chain_itr either points to kv or is nullptr + * */ + get_builder0() chain_itr = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); chain_itr_prev = builder0.CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); @@ -2655,12 +2668,19 @@ namespace LCompilers { LLVM::CreateStore(*builder, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr_prev); - llvm::Value* kv_ll_i8 = builder->CreateBitCast(key_value_pair_linked_list, llvm::Type::getInt8PtrTy(context)); - LLVM::CreateStore(*builder, kv_ll_i8, chain_itr); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)); + llvm_utils->create_if_else(builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), [&]() { + llvm::Value* kv_ll_i8 = builder->CreateBitCast(key_value_pair_linked_list, + llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, kv_ll_i8, chain_itr); + }, [&]() { + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr); + }); LLVM::CreateStore(*builder, - builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(1, 0)), is_key_matching_var ); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); @@ -2674,7 +2694,8 @@ namespace LCompilers { LLVM::CreateLoad(*builder, chain_itr), llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) ); - cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, is_key_matching_var)); + cond = builder->CreateAnd(cond, builder->CreateNot(LLVM::CreateLoad( + *builder, is_key_matching_var))); builder->CreateCondBr(cond, loopbody, loopend); } @@ -2682,27 +2703,24 @@ namespace LCompilers { llvm_utils->start_new_block(loopbody); { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); - LLVM::CreateStore(*builder, kv_struct_i8, chain_itr_prev); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo()); - llvm::Value* kv_key = llvm_utils->create_gep(kv_struct, 0); + llvm::Value* kv_struct_key = llvm_utils->create_gep(kv_struct, 0); if( !LLVM::is_llvm_struct(key_asr_type) ) { - kv_key = LLVM::CreateLoad(*builder, kv_key); + kv_struct_key = LLVM::CreateLoad(*builder, kv_struct_key); } - llvm::Value* break_signal = llvm_utils->is_equal_by_value(key, kv_key, module, key_asr_type); - break_signal = builder->CreateNot(break_signal); - LLVM::CreateStore(*builder, break_signal, is_key_matching_var); - llvm_utils->create_if_else(break_signal, [&]() { + LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(key, kv_struct_key, + module, key_asr_type), is_key_matching_var); + llvm_utils->create_if_else(builder->CreateNot(LLVM::CreateLoad(*builder, is_key_matching_var)), [&]() { + LLVM::CreateStore(*builder, kv_struct_i8, chain_itr_prev); llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2)); LLVM::CreateStore(*builder, next_kv_struct, chain_itr); - }, [=]() { - }); + }, []() {}); } builder->CreateBr(loophead); // end llvm_utils->start_new_block(loopend); - } void LLVMDict::resolve_collision_for_write( @@ -2743,6 +2761,26 @@ namespace LCompilers { llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * resolve_collision(); // modifies pos + + * key_list[pos] = key; + * value_list[pos] = value; + + * key_mask_value = key_mask[pos]; + * is_slot_empty = key_mask_value == 0 || key_mask_value == 3; + * occupancy += is_slot_empty; + + * linear_prob_happened = (key_hash != pos) || (key_mask[key_hash] == 2); + * set_max_2 = linear_prob_happened ? 2 : 1; + * key_mask[key_hash] = set_max_2; + * key_mask[pos] = set_max_2; + * + */ + llvm::Value* key_list = get_key_list(dict); llvm::Value* value_list = get_value_list(dict); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -2758,6 +2796,8 @@ namespace LCompilers { llvm_utils->create_ptr_gep(key_mask, pos)); llvm::Value* is_slot_empty = builder->CreateICmpEQ(key_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + is_slot_empty = builder->CreateOr(is_slot_empty, builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3)))); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); is_slot_empty = builder->CreateZExt(is_slot_empty, llvm::Type::getInt32Ty(context)); llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); @@ -2784,6 +2824,40 @@ namespace LCompilers { llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * kv_linked_list = key_value_pairs[key_hash]; + * resolve_collision(key); // modifies chain_itr + * do_insert = chain_itr == nullptr; + * + * if( do_insert ) { + * if( chain_itr_prev != nullptr ) { + * new_kv_struct = malloc(kv_struct_size); + * new_kv_struct[0] = key; + * new_kv_struct[1] = value; + * new_kv_struct[2] = nullptr; + * chain_itr_prev[2] = new_kv_struct; + * } + * else { + * kv_linked_list[0] = key; + * kv_linked_list[1] = value; + * kv_linked_list[2] = nullptr; + * } + * occupancy += 1; + * } + * else { + * kv_struct[0] = key; + * kv_struct[1] = value; + * } + * + * buckets_filled_delta = key_mask[key_hash] == 0; + * buckets_filled += buckets_filled_delta; + * key_mask[key_hash] = 1; + * + */ + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); @@ -2792,6 +2866,7 @@ namespace LCompilers { this->resolve_collision(capacity, key_hash, key, key_value_pair_linked_list, kv_struct_type, key_mask, *module, key_asr_type); llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + llvm::Function *fn = builder->GetInsertBlock()->getParent(); llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); @@ -2799,21 +2874,38 @@ namespace LCompilers { llvm::Value* do_insert = builder->CreateICmpEQ(kv_struct_i8, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); builder->CreateCondBr(do_insert, thenBB, elseBB); + builder->SetInsertPoint(thenBB); { - llvm::DataLayout data_layout(module); - size_t kv_struct_size = data_layout.getTypeAllocSize(kv_struct_type); - llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), kv_struct_size); - llvm::Value* new_kv_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); - llvm::Value* new_kv_struct = builder->CreateBitCast(new_kv_struct_i8, kv_struct_type->getPointerTo()); - llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, module, name2memidx); - llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, module, name2memidx); - LLVM::CreateStore(*builder, - llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), - llvm_utils->create_gep(new_kv_struct, 2)); - llvm::Value* kv_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); - llvm::Value* kv_struct_prev = builder->CreateBitCast(kv_struct_prev_i8, kv_struct_type->getPointerTo()); - LLVM::CreateStore(*builder, new_kv_struct_i8, llvm_utils->create_gep(kv_struct_prev, 2)); + llvm_utils->create_if_else(builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr_prev), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), [&]() { + llvm::DataLayout data_layout(module); + size_t kv_struct_size = data_layout.getTypeAllocSize(kv_struct_type); + llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), kv_struct_size); + llvm::Value* new_kv_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + llvm::Value* new_kv_struct = builder->CreateBitCast(new_kv_struct_i8, kv_struct_type->getPointerTo()); + llvm_utils->deepcopy(key, llvm_utils->create_gep(new_kv_struct, 0), key_asr_type, module, name2memidx); + llvm_utils->deepcopy(value, llvm_utils->create_gep(new_kv_struct, 1), value_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(new_kv_struct, 2)); + llvm::Value* kv_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* kv_struct_prev = builder->CreateBitCast(kv_struct_prev_i8, kv_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, new_kv_struct_i8, llvm_utils->create_gep(kv_struct_prev, 2)); + }, [&]() { + llvm_utils->deepcopy(key, llvm_utils->create_gep(key_value_pair_linked_list, 0), key_asr_type, module, name2memidx); + llvm_utils->deepcopy(value, llvm_utils->create_gep(key_value_pair_linked_list, 1), value_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(key_value_pair_linked_list, 2)); + }); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateAdd(occupancy, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -2823,12 +2915,7 @@ namespace LCompilers { llvm_utils->deepcopy(value, llvm_utils->create_gep(kv_struct, 1), value_asr_type, module, name2memidx); } llvm_utils->start_new_block(mergeBB); - llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); llvm::Value* buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(dict); - llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); - occupancy = builder->CreateAdd(occupancy, - llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); - LLVM::CreateStore(*builder, occupancy, occupancy_ptr); llvm::Value* key_mask_value_ptr = llvm_utils->create_ptr_gep(key_mask, key_hash); llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, key_mask_value_ptr); llvm::Value* buckets_filled_delta = builder->CreateICmpEQ(key_mask_value, @@ -3143,14 +3230,12 @@ namespace LCompilers { ASRUtils::get_type_code(value_asr_type) ); llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; - llvm::Value* tmp_value_ptr_local = nullptr; get_builder0() tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr); - tmp_value_ptr_local = tmp_value_ptr; llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); - LLVM::CreateStore(*builder, value, tmp_value_ptr_local); + LLVM::CreateStore(*builder, value, tmp_value_ptr); return tmp_value_ptr; } @@ -3158,6 +3243,17 @@ namespace LCompilers { llvm::Value* dict, llvm::Value* key_hash, llvm::Value* key, llvm::Module& module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) { + /** + * C++ equivalent: + * + * resolve_collision(key); // modified chain_itr + * does_kv_exist = key_mask[key_hash] == 1 && chain_itr != nullptr; + * if( !does_key_exist ) { + * exit(1); // KeyError + * } + * + */ + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); llvm::Value* key_value_pair_linked_list = llvm_utils->create_ptr_gep(key_value_pairs, key_hash); @@ -3170,10 +3266,8 @@ namespace LCompilers { ASRUtils::get_type_code(value_asr_type) ); llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; - llvm::Value* tmp_value_ptr_local = nullptr; get_builder0() tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr); - tmp_value_ptr_local = tmp_value_ptr; llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, @@ -3183,11 +3277,11 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); - llvm_utils->create_if_else(does_kv_exists, [=]() { + llvm_utils->create_if_else(does_kv_exists, [&]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); - LLVM::CreateStore(*builder, value, tmp_value_ptr_local); + LLVM::CreateStore(*builder, value, tmp_value_ptr); }, [&]() { std::string message = "The dict does not contain the specified key"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); @@ -3217,10 +3311,8 @@ namespace LCompilers { ASRUtils::get_type_code(value_asr_type) ); llvm::Type* value_type = std::get<2>(typecode2dicttype[llvm_key]).second; - llvm::Value* tmp_value_ptr_local = nullptr; get_builder0() tmp_value_ptr = builder0.CreateAlloca(value_type, nullptr); - tmp_value_ptr_local = tmp_value_ptr; llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(key_mask, key_hash)); llvm::Value* does_kv_exists = builder->CreateICmpEQ(key_mask_value, @@ -3230,13 +3322,13 @@ namespace LCompilers { llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) ); - llvm_utils->create_if_else(does_kv_exists, [=]() { + llvm_utils->create_if_else(does_kv_exists, [&]() { llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo()); llvm::Value* value = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 1)); - LLVM::CreateStore(*builder, value, tmp_value_ptr_local); + LLVM::CreateStore(*builder, value, tmp_value_ptr); }, [&]() { - LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr_local); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, def_value), tmp_value_ptr); }); return tmp_value_ptr; } @@ -3619,25 +3711,32 @@ namespace LCompilers { llvm::Value* dict, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * // this condition will be true with 0 buckets_filled too + * rehash_condition = rehash_flag && (occupancy >= 2 * buckets_filled); + * if( rehash_condition ) { + * rehash(); + * } + * + */ + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(dict)); llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(dict)); llvm::Value* rehash_condition = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(dict)); - rehash_condition = builder->CreateAnd(rehash_condition, builder->CreateICmpNE(buckets_filled, - llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)))); - occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); - buckets_filled = builder->CreateSIToFP(buckets_filled, llvm::Type::getFloatTy(context)); - llvm::Value* avg_ll_length = builder->CreateFDiv(occupancy, buckets_filled); - llvm::Value* avg_ll_length_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), - llvm::APFloat((float) 2.0)); + llvm::Value* buckets_filled_times_2 = builder->CreateMul(buckets_filled, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 2))); rehash_condition = builder->CreateAnd(rehash_condition, - builder->CreateFCmpOGE(avg_ll_length, avg_ll_length_threshold)); + builder->CreateICmpSGE(occupancy, buckets_filled_times_2)); llvm_utils->create_if_else(rehash_condition, [&]() { rehash(dict, module, key_asr_type, value_asr_type, name2memidx); }, [=]() { }); } - void LLVMDict::write_item(llvm::Value* dict, llvm::Value* key, + void LLVMDictInterface::write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, std::map>& name2memidx) { @@ -3646,20 +3745,12 @@ namespace LCompilers { llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); this->resolve_collision_for_write(dict, key_hash, key, value, module, key_asr_type, value_asr_type, name2memidx); + // A second rehash ensures that the threshold is not breached at any point. + // It can be shown mathematically that rehashing twice would only occur for small dictionaries, + // for example, for threshold set in linear probing, it occurs only when len(dict) <= 2 rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx); } - void LLVMDictSeparateChaining::write_item(llvm::Value* dict, llvm::Value* key, - llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, - std::map>& name2memidx) { - rehash_all_at_once_if_needed(dict, module, key_asr_type, value_asr_type, name2memidx); - llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); - llvm::Value* key_hash = get_key_hash(current_capacity, key, key_asr_type, *module); - this->resolve_collision_for_write(dict, key_hash, key, value, module, - key_asr_type, value_asr_type, name2memidx); - } - llvm::Value* LLVMDict::read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer) { @@ -3739,9 +3830,17 @@ namespace LCompilers { llvm::Value* LLVMDict::pop_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) { + /** + * C++ equivalent: + * + * resolve_collision_for_read_with_bound_check(key); // modifies pos + * key_mask[pos] = 3; // tombstone marker + * occupancy -= 1; + */ + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); - llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + llvm::Value* value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module, dict_type->m_key_type, dict_type->m_value_type); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); @@ -3773,9 +3872,35 @@ namespace LCompilers { llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool get_pointer) { + /** + * C++ equivalent: + * + * // modifies chain_itr and chain_itr_prev + * resolve_collision_for_read_with_bound_check(key); + * + * if(chain_itr_prev != nullptr) { + * chain_itr_prev[2] = chain_itr[2]; // next + * } + * else { + * // head of linked list removed + * if( chain_itr[2] == nullptr ) { + * // this linked list is now empty + * key_mask[key_hash] = 0; + * num_buckets_filled--; + * } + * else { + * // not empty yet + * key_value_pairs[key_hash] = chain_itr[2]; + * } + * } + * + * occupancy--; + * + */ + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_hash = get_key_hash(current_capacity, key, dict_type->m_key_type, module); - llvm::Value* value_ptr = this->resolve_collision_for_read(dict, key_hash, key, module, + llvm::Value* value_ptr = this->resolve_collision_for_read_with_bound_check(dict, key_hash, key, module, dict_type->m_key_type, dict_type->m_value_type); std::pair llvm_key = std::make_pair( ASRUtils::get_type_code(dict_type->m_key_type), @@ -3785,40 +3910,35 @@ namespace LCompilers { value_ptr = builder->CreateBitCast(value_ptr, value_type->getPointerTo()); llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev); llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr); + llvm::Type* kv_struct_type = get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type); + found = builder->CreateBitCast(found, kv_struct_type->getPointerTo()); + llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 2)); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); - - builder->CreateCondBr( - builder->CreateICmpNE(prev, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), - thenBB, elseBB - ); - builder->SetInsertPoint(thenBB); - { - llvm::Type* kv_struct_type = get_key_value_pair_type(dict_type->m_key_type, dict_type->m_value_type); - found = builder->CreateBitCast(found, kv_struct_type->getPointerTo()); - llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 2)); + llvm_utils->create_if_else(builder->CreateICmpNE(prev, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), [&]() { prev = builder->CreateBitCast(prev, kv_struct_type->getPointerTo()); LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 2)); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - { - llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); - LLVM::CreateStore( - *builder, - llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)), - llvm_utils->create_ptr_gep(key_mask, key_hash) - ); - llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(dict); - llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr); - num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get( - llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); - LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr); - } - llvm_utils->start_new_block(mergeBB); + }, [&]() { + llvm_utils->create_if_else(builder->CreateICmpEQ(found_next, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), [&]() { + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + LLVM::CreateStore( + *builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)), + llvm_utils->create_ptr_gep(key_mask, key_hash) + ); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(dict); + llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr); + num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr); + }, [&]() { + found_next = builder->CreateBitCast(found_next, kv_struct_type->getPointerTo()); + llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, found_next), + llvm_utils->create_ptr_gep(key_value_pairs, key_hash)); + }); + }); llvm::Value* occupancy_ptr = get_pointer_to_occupancy(dict); llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 7efa781430..1a77e57d47 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -588,7 +588,7 @@ namespace LCompilers { void write_item(llvm::Value* dict, llvm::Value* key, llvm::Value* value, llvm::Module* module, ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, - std::map>& name2memidx) = 0; + std::map>& name2memidx); virtual llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, @@ -693,11 +693,6 @@ namespace LCompilers { ASR::ttype_t* value_asr_type, std::map>& name2memidx); - void write_item(llvm::Value* dict, llvm::Value* key, - llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, - std::map>& name2memidx); - llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* key_asr_type, bool enable_bounds_checking, bool get_pointer=false); @@ -847,11 +842,6 @@ namespace LCompilers { ASR::ttype_t* value_asr_type, std::map>& name2memidx); - void write_item(llvm::Value* dict, llvm::Value* key, - llvm::Value* value, llvm::Module* module, - ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, - std::map>& name2memidx); - llvm::Value* read_item(llvm::Value* dict, llvm::Value* key, llvm::Module& module, ASR::Dict_t* dict_type, bool enable_bounds_checking, bool get_pointer=false); diff --git a/tests/errors/test_dict15.py b/tests/errors/test_dict15.py new file mode 100644 index 0000000000..7818753833 --- /dev/null +++ b/tests/errors/test_dict15.py @@ -0,0 +1,8 @@ +from lpython import i32 + +def test_dict_pop(): + d: dict[i32, i32] = {1: 2} + d.pop(1) + d.pop(1) + +test_dict_pop() diff --git a/tests/errors/test_dict16.py b/tests/errors/test_dict16.py new file mode 100644 index 0000000000..51a19e33f9 --- /dev/null +++ b/tests/errors/test_dict16.py @@ -0,0 +1,8 @@ +from lpython import i32 + +def test_dict_pop(): + d: dict[str, i32] = {'a': 2} + d.pop('a') + d.pop('a') + +test_dict_pop() diff --git a/tests/reference/runtime-test_dict15-6f3af0d.json b/tests/reference/runtime-test_dict15-6f3af0d.json new file mode 100644 index 0000000000..5bf5c80a4b --- /dev/null +++ b/tests/reference/runtime-test_dict15-6f3af0d.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-test_dict15-6f3af0d", + "cmd": "lpython {infile}", + "infile": "tests/errors/test_dict15.py", + "infile_hash": "6a0e507b9a9cf659cb433abbdc3435b4c63a6079eadcd7d2c765def1", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "runtime-test_dict15-6f3af0d.stderr", + "stderr_hash": "cb46ef04db0862506d688ebe8830a50afaaead9b0d29b0c007dd149a", + "returncode": 1 +} \ No newline at end of file diff --git a/tests/reference/runtime-test_dict15-6f3af0d.stderr b/tests/reference/runtime-test_dict15-6f3af0d.stderr new file mode 100644 index 0000000000..e8c90e4e1d --- /dev/null +++ b/tests/reference/runtime-test_dict15-6f3af0d.stderr @@ -0,0 +1 @@ +KeyError: The dict does not contain the specified key diff --git a/tests/reference/runtime-test_dict16-c5a958d.json b/tests/reference/runtime-test_dict16-c5a958d.json new file mode 100644 index 0000000000..471c82d252 --- /dev/null +++ b/tests/reference/runtime-test_dict16-c5a958d.json @@ -0,0 +1,13 @@ +{ + "basename": "runtime-test_dict16-c5a958d", + "cmd": "lpython {infile}", + "infile": "tests/errors/test_dict16.py", + "infile_hash": "7b00cfd7f6eac8338897bd99e5d953605f16927ee0f27683146b0182", + "outfile": null, + "outfile_hash": null, + "stdout": null, + "stdout_hash": null, + "stderr": "runtime-test_dict16-c5a958d.stderr", + "stderr_hash": "cb46ef04db0862506d688ebe8830a50afaaead9b0d29b0c007dd149a", + "returncode": 1 +} \ No newline at end of file diff --git a/tests/reference/runtime-test_dict16-c5a958d.stderr b/tests/reference/runtime-test_dict16-c5a958d.stderr new file mode 100644 index 0000000000..e8c90e4e1d --- /dev/null +++ b/tests/reference/runtime-test_dict16-c5a958d.stderr @@ -0,0 +1 @@ +KeyError: The dict does not contain the specified key