From 5f9fd7c4d22fd803464bcfb414188f410e258af6 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Wed, 19 Jul 2023 21:57:29 +0530 Subject: [PATCH 1/5] changes_benchmark --- src/libasr/codegen/asr_to_llvm.cpp | 2 ++ src/libasr/codegen/llvm_utils.cpp | 27 ++++++++++----------------- src/libasr/codegen/llvm_utils.h | 2 ++ 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 8648ecb027..271bb7362d 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -249,6 +249,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor void create_loop(char *name, Cond condition, Body loop_body) { dict_api_lp->set_iterators(); dict_api_sc->set_iterators(); + set_api->set_iterators(); std::string loop_name; if (name) { @@ -288,6 +289,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor start_new_block(loopend); dict_api_lp->reset_iterators(); dict_api_sc->reset_iterators(); + set_api->reset_iterators(); } void get_type_debug_info(ASR::ttype_t* t, std::string &type_name, diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 470cdacbbd..4d47d77e3d 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -5337,10 +5337,8 @@ namespace LCompilers { /** * C++ equivalent: * - * occupancy += 1; - * load_factor = occupancy / capacity; - * load_factor_threshold = 0.6; - * rehash_condition = (capacity == 0) || (load_factor >= load_factor_threshold); + * // this condition will be true with 0 capacity too + * rehash_condition = 5 * occupancy >= 3 * capacity; * if( rehash_condition ) { * rehash(); * } @@ -5349,21 +5347,16 @@ namespace LCompilers { llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value* rehash_condition = builder->CreateICmpEQ(capacity, - llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0))); - occupancy = builder->CreateAdd(occupancy, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, 1))); - occupancy = builder->CreateSIToFP(occupancy, llvm::Type::getFloatTy(context)); - capacity = builder->CreateSIToFP(capacity, llvm::Type::getFloatTy(context)); - llvm::Value* load_factor = builder->CreateFDiv(occupancy, capacity); // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor - llvm::Value* load_factor_threshold = llvm::ConstantFP::get(llvm::Type::getFloatTy(context), - llvm::APFloat((float) 0.6)); - rehash_condition = builder->CreateOr(rehash_condition, builder->CreateFCmpOGE(load_factor, load_factor_threshold)); - llvm_utils->create_if_else(rehash_condition, [&]() { + // occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity + llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 5))); + llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 3))); + llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5, + capacity_times_3), [&]() { rehash(set, module, el_asr_type, name2memidx); - }, [=]() { - }); + }, []() {}); } void LLVMSetLinearProbing::write_item( diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 38f8c5fe12..050b0a1c09 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -17,6 +17,8 @@ # define FIXED_VECTOR_TYPE llvm::VectorType #endif +#define PERTURB_SHIFT 5 + namespace LCompilers { // Platform dependent fast unique hash: From b35bf9e4e6f705c6bc9f2eab221d515655f431f9 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Fri, 21 Jul 2023 15:28:56 +0530 Subject: [PATCH 2/5] init separate chaining --- src/libasr/codegen/asr_to_llvm.cpp | 44 +- src/libasr/codegen/llvm_utils.cpp | 932 ++++++++++++++++++++++++++++- src/libasr/codegen/llvm_utils.h | 109 +++- 3 files changed, 1030 insertions(+), 55 deletions(-) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 271bb7362d..973d8b6ce7 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -175,7 +175,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor std::unique_ptr tuple_api; std::unique_ptr dict_api_lp; std::unique_ptr dict_api_sc; - std::unique_ptr set_api; // linear probing + std::unique_ptr set_api_lp; + std::unique_ptr set_api_sc; std::unique_ptr arr_descr; ASRToLLVMVisitor(Allocator &al, llvm::LLVMContext &context, std::string infile, @@ -200,7 +201,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tuple_api(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_lp(std::make_unique(context, llvm_utils.get(), builder.get())), dict_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), - set_api(std::make_unique(context, llvm_utils.get(), builder.get())), + set_api_lp(std::make_unique(context, llvm_utils.get(), builder.get())), + set_api_sc(std::make_unique(context, llvm_utils.get(), builder.get())), arr_descr(LLVMArrUtils::Descriptor::get_descriptor(context, builder.get(), llvm_utils.get(), LLVMArrUtils::DESCR_TYPE::_SimpleCMODescriptor)) @@ -208,10 +210,12 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->tuple_api = tuple_api.get(); llvm_utils->list_api = list_api.get(); llvm_utils->dict_api = nullptr; - llvm_utils->set_api = set_api.get(); + llvm_utils->set_api = nullptr; llvm_utils->arr_api = arr_descr.get(); llvm_utils->dict_api_lp = dict_api_lp.get(); llvm_utils->dict_api_sc = dict_api_sc.get(); + llvm_utils->set_api_lp = set_api_lp.get(); + llvm_utils->set_api_sc = set_api_sc.get(); } llvm::Value* CreateLoad(llvm::Value *x) { @@ -249,7 +253,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor void create_loop(char *name, Cond condition, Body loop_body) { dict_api_lp->set_iterators(); dict_api_sc->set_iterators(); - set_api->set_iterators(); + set_api_lp->set_iterators(); + set_api_sc->set_iterators(); std::string loop_name; if (name) { @@ -289,7 +294,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor start_new_block(loopend); dict_api_lp->reset_iterators(); dict_api_sc->reset_iterators(); - set_api->reset_iterators(); + set_api_lp->reset_iterators(); + set_api_sc->reset_iterators(); } void get_type_debug_info(ASR::ttype_t* t, std::string &type_name, @@ -1158,12 +1164,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Type* const_set_type = llvm_utils->get_set_type(x.m_type, module.get()); llvm::Value* const_set = builder->CreateAlloca(const_set_type, nullptr, "const_set"); ASR::Set_t* x_set = ASR::down_cast(x.m_type); + llvm_utils->set_set_api(x_set); std::string el_type_code = ASRUtils::get_type_code(x_set->m_type); llvm_utils->set_api->set_init(el_type_code, const_set, module.get(), x.n_elements); int64_t ptr_loads_el = !LLVM::is_llvm_struct(x_set->m_type); int64_t ptr_loads_copy = ptr_loads; + ptr_loads = ptr_loads_el; for( size_t i = 0; i < x.n_elements; i++ ) { - ptr_loads = ptr_loads_el; visit_expr_wrapper(x.m_elements[i], true); llvm::Value* element = tmp; llvm_utils->set_api->write_item(const_set, element, module.get(), @@ -1522,6 +1529,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*x.m_arg); ptr_loads = ptr_loads_copy; llvm::Value* pset = tmp; + ASR::Set_t* x_set = ASR::down_cast(ASRUtils::expr_type(x.m_arg)); + llvm_utils->set_set_api(x_set); tmp = llvm_utils->set_api->len(pset); } @@ -1687,6 +1696,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } void generate_SetAdd(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::Set_t* set_type = ASR::down_cast( + ASRUtils::expr_type(m_arg)); ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -1697,10 +1708,13 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(m_ele, true); ptr_loads = ptr_loads_copy; llvm::Value *el = tmp; - set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); + llvm_utils->set_set_api(set_type); + llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); } void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + ASR::Set_t* set_type = ASR::down_cast( + ASRUtils::expr_type(m_arg)); ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -1711,7 +1725,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr_wrapper(m_ele, true); ptr_loads = ptr_loads_copy; llvm::Value *el = tmp; - set_api->remove_item(pset, el, *module, asr_el_type); + llvm_utils->set_set_api(set_type); + llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type); } void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) { @@ -2729,6 +2744,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_dict_present_copy_sc = dict_api_sc->is_dict_present(); dict_api_lp->set_is_dict_present(false); dict_api_sc->set_is_dict_present(false); + bool is_set_present_copy_lp = set_api_lp->is_set_present(); + bool is_set_present_copy_sc = set_api_sc->is_set_present(); + set_api_lp->set_is_set_present(false); + set_api_sc->set_is_set_present(false); llvm_goto_targets.clear(); // Generate code for nested subroutines and functions first: for (auto &item : x.m_symtab->get_scope()) { @@ -2788,6 +2807,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor builder->CreateRet(ret_val2); dict_api_lp->set_is_dict_present(is_dict_present_copy_lp); dict_api_sc->set_is_dict_present(is_dict_present_copy_sc); + set_api_lp->set_is_set_present(is_set_present_copy_lp); + set_api_sc->set_is_set_present(is_set_present_copy_sc); // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); @@ -3277,6 +3298,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor bool is_dict_present_copy_sc = dict_api_sc->is_dict_present(); dict_api_lp->set_is_dict_present(false); dict_api_sc->set_is_dict_present(false); + bool is_set_present_copy_lp = set_api_lp->is_set_present(); + bool is_set_present_copy_sc = set_api_sc->is_set_present(); + set_api_lp->set_is_set_present(false); + set_api_sc->set_is_set_present(false); llvm_goto_targets.clear(); instantiate_function(x); if (ASRUtils::get_FunctionType(x)->m_deftype == ASR::deftypeType::Interface) { @@ -3289,6 +3314,8 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor parent_function = nullptr; dict_api_lp->set_is_dict_present(is_dict_present_copy_lp); dict_api_sc->set_is_dict_present(is_dict_present_copy_sc); + set_api_lp->set_is_set_present(is_set_present_copy_lp); + set_api_sc->set_is_set_present(is_set_present_copy_sc); // Finalize the debug info. if (compiler_options.emit_debug_info) DBuilder->finalize(); @@ -4141,6 +4168,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* target_set = tmp; ptr_loads = ptr_loads_copy; ASR::Set_t* value_set_type = ASR::down_cast(asr_value_type); + llvm_utils->set_set_api(value_set_type); llvm_utils->set_api->set_deepcopy(value_set, target_set, value_set_type, module.get(), name2memidx); return ; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 4d47d77e3d..e6b8022847 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -121,7 +121,8 @@ namespace LCompilers { name2dertype(name2dertype_), name2dercontext(name2dercontext_), struct_type_stack(struct_type_stack_), dertype2parent(dertype2parent_), name2memidx(name2memidx_), arr_arg_type_cache(arr_arg_type_cache_), fname2arg_type(fname2arg_type_), - dict_api_lp(nullptr), dict_api_sc(nullptr), compiler_options(compiler_options_) { + dict_api_lp(nullptr), dict_api_sc(nullptr), + set_api_lp(nullptr), set_api_sc(nullptr), compiler_options(compiler_options_) { std::vector els_4 = { llvm::Type::getFloatTy(context), llvm::Type::getFloatTy(context)}; @@ -600,6 +601,7 @@ namespace LCompilers { local_a_kind, module); int32_t el_type_size = get_type_size(asr_set->m_type, el_llvm_type, local_a_kind, module); std::string el_type_code = ASRUtils::get_type_code(asr_set->m_type); + set_set_api(asr_set); return set_api->get_set_type(el_type_code, el_type_size, el_llvm_type); } @@ -851,6 +853,7 @@ namespace LCompilers { is_list, m_dims, n_dims, a_kind, module, m_abi); int32_t el_type_size = get_type_size(asr_set->m_type, el_llvm_type, a_kind, module); + set_set_api(asr_set); type = set_api->get_set_type(el_type_code, el_type_size, el_llvm_type)->getPointerTo(); break; } @@ -874,6 +877,14 @@ namespace LCompilers { } } + void LLVMUtils::set_set_api(ASR::Set_t* set_type) { + if( ASR::is_a(*set_type->m_type) ) { + set_api = set_api_sc; + } else { + set_api = set_api_lp; + } + } + std::vector LLVMUtils::convert_args(const ASR::Function_t& x, llvm::Module* module) { std::vector args; for (size_t i=0; im_type, el_llvm_type, local_a_kind, module); + set_set_api(asr_set); + return_type = set_api->get_set_type(el_type_code, el_type_size, el_llvm_type); break; } @@ -2676,6 +2689,20 @@ namespace LCompilers { llvm::Value* key, llvm::Value* key_value_pair_linked_list, llvm::Type* kv_pair_type, llvm::Value* key_mask, llvm::Module& module, ASR::ttype_t* key_asr_type) { + /** + * C++ equivalent: + * + * is_key_matching = 1; + * + * while( chain_itr != nullptr && is_key_matching ) { + * break_signal = key != kv_key; + * is_key_matching = break_signal; // 1 means not matching + * if( break_signal ) { + * chain_itr = next_kv_struct; + * } + * } + * + */ if( !are_iterators_set ) { chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); chain_itr_prev = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); @@ -4820,16 +4847,36 @@ namespace LCompilers { context(context_), llvm_utils(std::move(llvm_utils_)), builder(std::move(builder_)), - pos_ptr(nullptr), are_iterators_set(false), + pos_ptr(nullptr), is_el_matching_var(nullptr), + idx_ptr(nullptr), hash_iter(nullptr), + hash_value(nullptr), polynomial_powers(nullptr), + chain_itr(nullptr), chain_itr_prev(nullptr), + old_capacity(nullptr), old_elems(nullptr), + old_el_mask(nullptr), are_iterators_set(false), is_set_present_(false) { } + bool LLVMSetInterface::is_set_present() { + return is_set_present_; + } + + void LLVMSetInterface::set_is_set_present(bool value) { + is_set_present_ = value; + } + LLVMSetLinearProbing::LLVMSetLinearProbing(llvm::LLVMContext& context_, LLVMUtils* llvm_utils_, llvm::IRBuilder<>* builder_): LLVMSetInterface(context_, llvm_utils_, builder_) { } + LLVMSetSeparateChaining::LLVMSetSeparateChaining( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils_, + llvm::IRBuilder<>* builder_): + LLVMSetInterface(context_, llvm_utils_, builder_) { + } + LLVMSetInterface::~LLVMSetInterface() { typecode2settype.clear(); } @@ -4837,6 +4884,9 @@ namespace LCompilers { LLVMSetLinearProbing::~LLVMSetLinearProbing() { } + LLVMSetSeparateChaining::~LLVMSetSeparateChaining() { + } + llvm::Value* LLVMSetLinearProbing::get_pointer_to_occupancy(llvm::Value* set) { return llvm_utils->create_gep(set, 0); } @@ -4854,6 +4904,34 @@ namespace LCompilers { return llvm_utils->create_gep(set, 2); } + llvm::Value* LLVMSetSeparateChaining::get_el_list(llvm::Value* /*set*/) { + return nullptr; + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_occupancy(llvm::Value* set) { + return llvm_utils->create_gep(set, 0); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_number_of_filled_buckets(llvm::Value* set) { + return llvm_utils->create_gep(set, 1); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_capacity(llvm::Value* set) { + return llvm_utils->create_gep(set, 2); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_elems(llvm::Value* set) { + return llvm_utils->create_gep(set, 3); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_mask(llvm::Value* set) { + return llvm_utils->create_gep(set, 4); + } + + llvm::Value* LLVMSetSeparateChaining::get_pointer_to_rehash_flag(llvm::Value* set) { + return llvm_utils->create_gep(set, 5); + } + llvm::Type* LLVMSetLinearProbing::get_set_type(std::string type_code, int32_t type_size, llvm::Type* el_type) { is_set_present_ = true; @@ -4871,6 +4949,27 @@ namespace LCompilers { return set_desc; } + llvm::Type* LLVMSetSeparateChaining::get_set_type( + std::string el_type_code, int32_t el_type_size, llvm::Type* el_type) { + is_set_present_ = true; + if( typecode2settype.find(el_type_code) != typecode2settype.end() ) { + return std::get<0>(typecode2settype[el_type_code]); + } + + std::vector el_vec = {el_type, llvm::Type::getInt8PtrTy(context)}; + llvm::Type* elstruct = llvm::StructType::create(context, el_vec, "el"); + std::vector set_type_vec = {llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context), + llvm::Type::getInt32Ty(context), + elstruct->getPointerTo(), + llvm::Type::getInt8PtrTy(context), + llvm::Type::getInt1Ty(context)}; + llvm::Type* set_desc = llvm::StructType::create(context, set_type_vec, "set"); + typecode2settype[el_type_code] = std::make_tuple(set_desc, el_type_size, el_type); + typecode2elstruct[el_type_code] = elstruct; + return set_desc; + } + void LLVMSetLinearProbing::set_init(std::string type_code, llvm::Value* set, llvm::Module* module, size_t initial_capacity) { llvm::Value* n_ptr = get_pointer_to_occupancy(set); @@ -4890,6 +4989,57 @@ namespace LCompilers { LLVM::CreateStore(*builder, el_mask, get_pointer_to_mask(set)); } + void LLVMSetSeparateChaining::set_init( + std::string el_type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity) { + llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, initial_capacity + 1)); + llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(set); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), + llvm::APInt(1, 1)), rehash_flag_ptr); + set_init_given_initial_capacity(el_type_code, set, module, llvm_capacity); + } + + void LLVMSetSeparateChaining::set_init_given_initial_capacity( + std::string el_type_code, llvm::Value* set, + llvm::Module* module, llvm::Value* llvm_capacity) { + llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(set); + llvm::Value* rehash_flag = LLVM::CreateLoad(*builder, rehash_flag_ptr); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + LLVM::CreateStore(*builder, llvm_zero, occupancy_ptr); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + LLVM::CreateStore(*builder, llvm_zero, num_buckets_filled_ptr); + + llvm::DataLayout data_layout(module); + llvm::Type* el_type = typecode2elstruct[el_type_code]; + size_t el_type_size = data_layout.getTypeAllocSize(el_type); + llvm::Value* llvm_el_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, el_type_size)); + llvm::Value* malloc_size = builder->CreateMul(llvm_capacity, llvm_el_size); + llvm::Value* el_ptr = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + rehash_flag = builder->CreateAnd(rehash_flag, + builder->CreateICmpNE(el_ptr, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + el_ptr = builder->CreateBitCast(el_ptr, el_type->getPointerTo()); + LLVM::CreateStore(*builder, el_ptr, get_pointer_to_elems(set)); + + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_capacity, + llvm_mask_size); + rehash_flag = builder->CreateAnd(rehash_flag, + builder->CreateICmpNE(el_mask, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + LLVM::CreateStore(*builder, el_mask, get_pointer_to_mask(set)); + + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + LLVM::CreateStore(*builder, llvm_capacity, capacity_ptr); + LLVM::CreateStore(*builder, rehash_flag, rehash_flag_ptr); + } + void LLVMSetInterface::set_iterators() { if( are_iterators_set || !is_set_present_ ) { return ; @@ -4914,6 +5064,39 @@ namespace LCompilers { polynomial_powers = builder->CreateAlloca(llvm::Type::getInt64Ty(context), nullptr, "p_pow"); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt64Ty(context), llvm::APInt(64, 1)), polynomial_powers); + chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr); + chain_itr_prev = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr_prev); + old_capacity = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), old_capacity); + old_occupancy = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), old_occupancy); + old_number_of_buckets_filled = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), old_number_of_buckets_filled); + old_elems = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), old_elems); + old_el_mask = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), old_el_mask); + src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), src_itr); + dest_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), dest_itr); + next_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), next_ptr); + copy_itr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), copy_itr); are_iterators_set = true; } @@ -4925,6 +5108,17 @@ namespace LCompilers { hash_iter = nullptr; hash_value = nullptr; polynomial_powers = nullptr; + chain_itr = nullptr; + chain_itr_prev = nullptr; + old_capacity = nullptr; + old_occupancy = nullptr; + old_number_of_buckets_filled = nullptr; + old_elems = nullptr; + old_el_mask = nullptr; + src_itr = nullptr; + dest_itr = nullptr; + next_ptr = nullptr; + copy_itr = nullptr; are_iterators_set = false; } @@ -5149,6 +5343,84 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMSetSeparateChaining::resolve_collision( + llvm::Value* el_hash, llvm::Value* el, llvm::Value* el_linked_list, + llvm::Type* el_struct_type, llvm::Value* el_mask, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * is_el_matching = 1; + * + * while( chain_itr != nullptr && is_el_matching ) { + * break_signal = el != el_struct_el; + * is_el_matching = break_signal; // 1 means not matching + * if( break_signal ) { + * chain_itr = next_el_struct; + * } + * } + * + * // now, chain_itr either points to element or is nullptr + * + */ + if( !are_iterators_set ) { + chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + chain_itr_prev = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + is_el_matching_var = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr); + } + + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr_prev); + llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, el_ll_i8, chain_itr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + LLVM::CreateStore(*builder, + builder->CreateICmpEQ(el_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), + is_el_matching_var + ); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, is_el_matching_var)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + LLVM::CreateStore(*builder, el_struct_i8, chain_itr_prev); + llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo()); + llvm::Value* el_struct_el = llvm_utils->create_gep(el_struct, 0); + if( !LLVM::is_llvm_struct(el_asr_type) ) { + el_struct_el = LLVM::CreateLoad(*builder, el_struct_el); + } + llvm::Value* break_signal = llvm_utils->is_equal_by_value(el, el_struct_el, module, el_asr_type); + break_signal = builder->CreateNot(break_signal); + LLVM::CreateStore(*builder, break_signal, is_el_matching_var); + llvm_utils->create_if_else(break_signal, [&]() { + llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1)); + LLVM::CreateStore(*builder, next_el_struct, chain_itr); + }, []() { + }); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + + } + void LLVMSetLinearProbing::resolve_collision_for_write( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, @@ -5203,6 +5475,91 @@ namespace LCompilers { LLVM::CreateStore(*builder, set_max_2, llvm_utils->create_ptr_gep(el_mask, pos)); } + void LLVMSetSeparateChaining::resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * resolve_collision(el); // modifies chain_itr + * do_insert = chain_itr == nullptr; + * + * if( do_insert ) { + * new_el_struct = malloc(el_struct_size); + * new_el_struct[0] = el; + * new_el_struct[1] = nullptr; + * chain_itr_prev[1] = new_el_struct; + * } + * else { + * el_struct[0] = el; + * } + * + * buckets_filled_delta = el_mask[el_hash] == 0; + * buckets_filled += buckets_filled_delta; + * el_mask[el_hash] = 1; + * + */ + + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + this->resolve_collision(el_hash, el, el_linked_list, el_struct_type, + el_mask, *module, el_asr_type); + llvm::Value* el_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* do_insert = builder->CreateICmpEQ(el_struct_i8, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); + builder->CreateCondBr(do_insert, thenBB, elseBB); + + builder->SetInsertPoint(thenBB); + { + llvm::DataLayout data_layout(module); + size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); + llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), el_struct_size); + llvm::Value* new_el_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + llvm::Value* new_el_struct = builder->CreateBitCast(new_el_struct_i8, el_struct_type->getPointerTo()); + llvm_utils->deepcopy(el, llvm_utils->create_gep(new_el_struct, 0), el_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(new_el_struct, 1)); + llvm::Value* el_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* el_struct_prev = builder->CreateBitCast(el_struct_prev_i8, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, new_el_struct_i8, llvm_utils->create_gep(el_struct_prev, 1)); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + llvm::Value* el_struct = builder->CreateBitCast(el_struct_i8, el_struct_type->getPointerTo()); + llvm_utils->deepcopy(el, llvm_utils->create_gep(el_struct, 0), el_asr_type, module, name2memidx); + } + llvm_utils->start_new_block(mergeBB); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateAdd(occupancy, + builder->CreateZExt(do_insert, llvm::Type::getInt32Ty(context))); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + llvm::Value* el_mask_value_ptr = llvm_utils->create_ptr_gep(el_mask, el_hash); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, el_mask_value_ptr); + llvm::Value* buckets_filled_delta = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, buckets_filled_ptr); + buckets_filled = builder->CreateAdd( + buckets_filled, + builder->CreateZExt(buckets_filled_delta, llvm::Type::getInt32Ty(context)) + ); + LLVM::CreateStore(*builder, buckets_filled, buckets_filled_ptr); + LLVM::CreateStore(*builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)), + el_mask_value_ptr); + } + void LLVMSetLinearProbing::rehash( llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) { @@ -5330,42 +5687,263 @@ namespace LCompilers { LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); } - void LLVMSetLinearProbing::rehash_all_at_once_if_needed( + void LLVMSetSeparateChaining::rehash( llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx) { - /** * C++ equivalent: * - * // this condition will be true with 0 capacity too - * rehash_condition = 5 * occupancy >= 3 * capacity; - * if( rehash_condition ) { - * rehash(); + * capacity = 3 * capacity + 1; + * + * if( rehash_flag ) { + * while( old_capacity > idx ) { + * if( el_mask[el_hash] == 1 ) { + * write_el_linked_list(old_elems_value[idx]); + * } + * idx++; + * } + * } + * else { + * // set to old values * } * */ - llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); - llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor - // occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity - llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get( - llvm::Type::getInt32Ty(context), llvm::APInt(32, 5))); - llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get( - llvm::Type::getInt32Ty(context), llvm::APInt(32, 3))); - llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5, - capacity_times_3), [&]() { - rehash(set, module, el_asr_type, name2memidx); - }, []() {}); - } - - void LLVMSetLinearProbing::write_item( - llvm::Value* set, llvm::Value* el, - llvm::Module* module, ASR::ttype_t* el_asr_type, - std::map>& name2memidx) { - rehash_all_at_once_if_needed(set, module, el_asr_type, name2memidx); - llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); - llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); + if( !are_iterators_set ) { + old_capacity = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_occupancy = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_number_of_buckets_filled = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + old_elems = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + old_el_mask = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + } + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* number_of_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* old_capacity_value = LLVM::CreateLoad(*builder, capacity_ptr); + LLVM::CreateStore(*builder, old_capacity_value, old_capacity); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, occupancy_ptr), + old_occupancy + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, number_of_buckets_filled_ptr), + old_number_of_buckets_filled + ); + llvm::Value* old_el_mask_value = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + llvm::Value* old_elems_value = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + old_elems_value = builder->CreateBitCast(old_elems_value, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, old_el_mask_value, old_el_mask); + LLVM::CreateStore(*builder, old_elems_value, old_elems); + + llvm::Value* capacity = builder->CreateMul(old_capacity_value, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 3))); + capacity = builder->CreateAdd(capacity, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + set_init_given_initial_capacity(ASRUtils::get_type_code(el_asr_type), + set, module, capacity); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB_rehash = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB_rehash = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB_rehash = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* rehash_flag = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(set)); + builder->CreateCondBr(rehash_flag, thenBB_rehash, elseBB_rehash); + + builder->SetInsertPoint(thenBB_rehash); + old_elems_value = LLVM::CreateLoad(*builder, old_elems); + old_elems_value = builder->CreateBitCast(old_elems_value, + typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]->getPointerTo()); + old_el_mask_value = LLVM::CreateLoad(*builder, old_el_mask); + old_capacity_value = LLVM::CreateLoad(*builder, old_capacity); + capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)), idx_ptr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + old_capacity_value, + LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* itr = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(old_el_mask_value, itr)); + llvm::Value* is_el_set = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + + llvm_utils->create_if_else(is_el_set, [&]() { + llvm::Value* srci = llvm_utils->create_ptr_gep(old_elems_value, itr); + write_el_linked_list(srci, set, capacity, el_asr_type, module, name2memidx); + }, [=]() { + }); + llvm::Value* tmp = builder->CreateAdd( + itr, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + builder->CreateBr(mergeBB_rehash); + llvm_utils->start_new_block(elseBB_rehash); + { + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_capacity), + get_pointer_to_capacity(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_occupancy), + get_pointer_to_occupancy(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_number_of_buckets_filled), + get_pointer_to_number_of_filled_buckets(set) + ); + LLVM::CreateStore(*builder, + builder->CreateBitCast( + LLVM::CreateLoad(*builder, old_elems), + typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]->getPointerTo() + ), + get_pointer_to_elems(set) + ); + LLVM::CreateStore(*builder, + LLVM::CreateLoad(*builder, old_el_mask), + get_pointer_to_mask(set) + ); + } + llvm_utils->start_new_block(mergeBB_rehash); + } + + void LLVMSetSeparateChaining::write_el_linked_list( + llvm::Value* el_ll, llvm::Value* set, llvm::Value* capacity, + ASR::ttype_t* m_el_type, llvm::Module* module, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * while( src_itr != nullptr ) { + * resolve_collision_for_write(el_struct[0]); + * src_itr = el_struct[1]; + * } + * + */ + + if( !are_iterators_set ) { + src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + } + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(m_el_type)]->getPointerTo(); + LLVM::CreateStore(*builder, + builder->CreateBitCast(el_ll, llvm::Type::getInt8PtrTy(context)), + src_itr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, src_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* curr_src = builder->CreateBitCast(LLVM::CreateLoad(*builder, src_itr), + el_struct_type); + llvm::Value* src_el_ptr = llvm_utils->create_gep(curr_src, 0); + llvm::Value* src_el = src_el_ptr; + if( !LLVM::is_llvm_struct(m_el_type) ) { + src_el = LLVM::CreateLoad(*builder, src_el_ptr); + } + llvm::Value* el_hash = get_el_hash(capacity, src_el, m_el_type, *module); + resolve_collision_for_write( + set, el_hash, src_el, module, + m_el_type, name2memidx); + + llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 1)); + LLVM::CreateStore(*builder, src_next_ptr, src_itr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + void LLVMSetLinearProbing::rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + + /** + * C++ equivalent: + * + * // this condition will be true with 0 capacity too + * rehash_condition = 5 * occupancy >= 3 * capacity; + * if( rehash_condition ) { + * rehash(); + * } + * + */ + + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + // Threshold hash is chosen from https://en.wikipedia.org/wiki/Hash_table#Load_factor + // occupancy / capacity >= 0.6 is same as 5 * occupancy >= 3 * capacity + llvm::Value* occupancy_times_5 = builder->CreateMul(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 5))); + llvm::Value* capacity_times_3 = builder->CreateMul(capacity, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 3))); + llvm_utils->create_if_else(builder->CreateICmpSGE(occupancy_times_5, + capacity_times_3), [&]() { + rehash(set, module, el_asr_type, name2memidx); + }, []() {}); + } + + void LLVMSetSeparateChaining::rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * rehash_condition = rehash_flag && occupancy >= 2 * buckets_filled; + * if( rehash_condition ) { + * rehash(); + * } + * + */ + llvm::Value* occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); + llvm::Value* buckets_filled = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(set)); + llvm::Value* rehash_condition = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(set)); + llvm::Value* buckets_filled_times_2 = builder->CreateMul(buckets_filled, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 2))); + rehash_condition = builder->CreateAnd(rehash_condition, + builder->CreateICmpSGE(occupancy, buckets_filled_times_2)); + llvm_utils->create_if_else(rehash_condition, [&]() { + rehash(set, module, el_asr_type, name2memidx); + }, []() {}); + } + + void LLVMSetInterface::write_item( + llvm::Value* set, llvm::Value* el, + llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx) { + rehash_all_at_once_if_needed(set, module, el_asr_type, name2memidx); + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, *module); this->resolve_collision_for_write(set, el_hash, el, module, el_asr_type, name2memidx); } @@ -5448,8 +6026,48 @@ namespace LCompilers { llvm_utils->list_api->read_item(el_list, pos, false, module, LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); - llvm_utils->create_if_else(is_el_matching, [&]() { - }, [&]() { + llvm_utils->create_if_else(is_el_matching, []() {}, [&]() { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + }); + } + + void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * resolve_collision(el); // modified chain_itr + * does_el_exist = el_mask[el_hash] == 1 && chain_itr != nullptr; + * if( !does_el_exist ) { + * exit(1); // KeyError + * } + * + */ + llvm::Value* elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(set)); + llvm::Value* el_linked_list = llvm_utils->create_ptr_gep(elems, el_hash); + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + std::string el_type_code = ASRUtils::get_type_code(el_asr_type); + llvm::Type* el_struct_type = typecode2elstruct[el_type_code]; + this->resolve_collision(el_hash, el, el_linked_list, + el_struct_type, el_mask, module, el_asr_type); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm::Value* does_el_exist = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + does_el_exist = builder->CreateAnd(does_el_exist, + builder->CreateICmpNE(LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))) + ); + + llvm_utils->create_if_else(does_el_exist, []() {}, [&]() { std::string message = "The set does not contain the specified element"; llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); @@ -5487,6 +6105,75 @@ namespace LCompilers { LLVM::CreateStore(*builder, occupancy, occupancy_ptr); } + void LLVMSetSeparateChaining::remove_item( + llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type) { + /** + * C++ equivalent: + * + * // modifies chain_itr and chain_itr_prev + * resolve_collision_for_read_with_bound_check(el); + * + * if(chain_itr_prev != nullptr) { + * chain_itr_prev[1] = chain_itr[1]; // next + * } + * else { + * // this linked list is now empty + * el_mask[el_hash] = 0; + * num_buckets_filled--; + * } + * + * occupancy--; + * + */ + + llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); + llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module); + this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type); + llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + + builder->CreateCondBr( + builder->CreateICmpNE(prev, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), + thenBB, elseBB + ); + builder->SetInsertPoint(thenBB); + { + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(el_asr_type)]; + found = builder->CreateBitCast(found, el_struct_type->getPointerTo()); + llvm::Value* found_next = LLVM::CreateLoad(*builder, llvm_utils->create_gep(found, 1)); + prev = builder->CreateBitCast(prev, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, found_next, llvm_utils->create_gep(prev, 1)); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); + LLVM::CreateStore( + *builder, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0)), + llvm_utils->create_ptr_gep(el_mask, el_hash) + ); + llvm::Value* num_buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); + llvm::Value* num_buckets_filled = LLVM::CreateLoad(*builder, num_buckets_filled_ptr); + num_buckets_filled = builder->CreateSub(num_buckets_filled, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, num_buckets_filled, num_buckets_filled_ptr); + } + llvm_utils->start_new_block(mergeBB); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateSub(occupancy, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); + } + void LLVMSetLinearProbing::set_deepcopy( llvm::Value* src, llvm::Value* dest, ASR::Set_t* set_type, llvm::Module* module, @@ -5516,7 +6203,188 @@ namespace LCompilers { LLVM::CreateStore(*builder, dest_el_mask, dest_el_mask_ptr); } - llvm::Value* LLVMSetLinearProbing::len(llvm::Value* set) { + void LLVMSetSeparateChaining::set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) { + llvm::Value* src_occupancy = LLVM::CreateLoad(*builder, get_pointer_to_occupancy(src)); + llvm::Value* src_filled_buckets = LLVM::CreateLoad(*builder, get_pointer_to_number_of_filled_buckets(src)); + llvm::Value* src_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(src)); + llvm::Value* src_el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(src)); + llvm::Value* src_rehash_flag = LLVM::CreateLoad(*builder, get_pointer_to_rehash_flag(src)); + LLVM::CreateStore(*builder, src_occupancy, get_pointer_to_occupancy(dest)); + LLVM::CreateStore(*builder, src_filled_buckets, get_pointer_to_number_of_filled_buckets(dest)); + LLVM::CreateStore(*builder, src_capacity, get_pointer_to_capacity(dest)); + LLVM::CreateStore(*builder, src_rehash_flag, get_pointer_to_rehash_flag(dest)); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* malloc_size = builder->CreateMul(src_capacity, llvm_mask_size); + llvm::Value* dest_el_mask = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + LLVM::CreateStore(*builder, dest_el_mask, get_pointer_to_mask(dest)); + + malloc_size = builder->CreateSub(src_occupancy, src_filled_buckets); + malloc_size = builder->CreateAdd(src_capacity, malloc_size); + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(set_type->m_type)]; + size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); + llvm::Value* llvm_el_struct_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, el_struct_size)); + malloc_size = builder->CreateMul(malloc_size, llvm_el_struct_size); + llvm::Value* dest_elems = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + dest_elems = builder->CreateBitCast(dest_elems, el_struct_type->getPointerTo()); + if( !are_iterators_set ) { + copy_itr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + next_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + LLVM::CreateStore(*builder, llvm_zero, copy_itr); + LLVM::CreateStore(*builder, src_capacity, next_ptr); + + llvm::Value* src_elems = LLVM::CreateLoad(*builder, get_pointer_to_elems(src)); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + src_capacity, + LLVM::CreateLoad(*builder, copy_itr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* itr = LLVM::CreateLoad(*builder, copy_itr); + llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(src_el_mask, itr)); + LLVM::CreateStore(*builder, el_mask_value, + llvm_utils->create_ptr_gep(dest_el_mask, itr)); + llvm::Value* is_el_set = builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + + llvm_utils->create_if_else(is_el_set, [&]() { + llvm::Value* srci = llvm_utils->create_ptr_gep(src_elems, itr); + llvm::Value* desti = llvm_utils->create_ptr_gep(dest_elems, itr); + deepcopy_el_linked_list(srci, desti, dest_elems, + set_type, module, name2memidx); + }, []() {}); + llvm::Value* tmp = builder->CreateAdd( + itr, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, copy_itr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + LLVM::CreateStore(*builder, dest_elems, get_pointer_to_elems(dest)); + } + + void LLVMSetSeparateChaining::deepcopy_el_linked_list( + llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_elems, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx) { + /** + * C++ equivalent: + * + * // memory allocation done before calling this function + * + * while( src_itr != nullptr ) { + * deepcopy(src_el, dest_el_ptr); + * src_itr = src_itr_next; + * if( src_next_exists ) { + * *next_ptr = *next_ptr + 1; + * } + * else { + * curr_dest_next_ptr = nullptr; + * } + * } + * + */ + + if( !are_iterators_set ) { + src_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + dest_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + } + llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(set_type->m_type)]->getPointerTo(); + LLVM::CreateStore(*builder, + builder->CreateBitCast(srci, llvm::Type::getInt8PtrTy(context)), + src_itr); + LLVM::CreateStore(*builder, + builder->CreateBitCast(desti, llvm::Type::getInt8PtrTy(context)), + dest_itr); + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, src_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* curr_src = builder->CreateBitCast(LLVM::CreateLoad(*builder, src_itr), + el_struct_type); + llvm::Value* curr_dest = builder->CreateBitCast(LLVM::CreateLoad(*builder, dest_itr), + el_struct_type); + llvm::Value* src_el_ptr = llvm_utils->create_gep(curr_src, 0); + llvm::Value *src_el = src_el_ptr; + if( !LLVM::is_llvm_struct(set_type->m_type) ) { + src_el = LLVM::CreateLoad(*builder, src_el_ptr); + } + llvm::Value* dest_el_ptr = llvm_utils->create_gep(curr_dest, 0); + llvm_utils->deepcopy(src_el, dest_el_ptr, set_type->m_type, module, name2memidx); + + llvm::Value* src_next_ptr = LLVM::CreateLoad(*builder, llvm_utils->create_gep(curr_src, 1)); + llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 1); + LLVM::CreateStore(*builder, src_next_ptr, src_itr); + + llvm::Function *fn = builder->GetInsertBlock()->getParent(); + llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); + llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); + llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); + llvm::Value* src_next_exists = builder->CreateICmpNE(src_next_ptr, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); + builder->CreateCondBr(src_next_exists, thenBB, elseBB); + builder->SetInsertPoint(thenBB); + { + llvm::Value* next_idx = LLVM::CreateLoad(*builder, next_ptr); + llvm::Value* dest_next_ptr = llvm_utils->create_ptr_gep(dest_elems, next_idx); + dest_next_ptr = builder->CreateBitCast(dest_next_ptr, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, dest_next_ptr, curr_dest_next_ptr); + LLVM::CreateStore(*builder, dest_next_ptr, dest_itr); + next_idx = builder->CreateAdd(next_idx, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, next_idx, next_ptr); + } + builder->CreateBr(mergeBB); + llvm_utils->start_new_block(elseBB); + { + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + curr_dest_next_ptr + ); + } + llvm_utils->start_new_block(mergeBB); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 050b0a1c09..f2174dc9fd 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -17,8 +17,6 @@ # define FIXED_VECTOR_TYPE llvm::VectorType #endif -#define PERTURB_SHIFT 5 - namespace LCompilers { // Platform dependent fast unique hash: @@ -202,6 +200,8 @@ namespace LCompilers { LLVMDictInterface* dict_api_lp; LLVMDictInterface* dict_api_sc; + LLVMSetInterface* set_api_lp; + LLVMSetInterface* set_api_sc; CompilerOptions &compiler_options; @@ -300,6 +300,8 @@ namespace LCompilers { void set_dict_api(ASR::Dict_t* dict_type); + void set_set_api(ASR::Set_t* set_type); + void deepcopy(llvm::Value* src, llvm::Value* dest, ASR::ttype_t* asr_type, llvm::Module* module, std::map>& name2memidx); @@ -876,6 +878,10 @@ namespace LCompilers { llvm::AllocaInst *pos_ptr, *is_el_matching_var; llvm::AllocaInst *idx_ptr, *hash_iter, *hash_value; llvm::AllocaInst *polynomial_powers; + llvm::AllocaInst *chain_itr, *chain_itr_prev; + llvm::AllocaInst *old_capacity, *old_elems, *old_el_mask; + llvm::AllocaInst *old_occupancy, *old_number_of_buckets_filled; + llvm::AllocaInst *src_itr, *dest_itr, *next_ptr, *copy_itr; bool are_iterators_set; std::map> typecode2settype; @@ -915,13 +921,6 @@ namespace LCompilers { llvm::Value* get_el_hash(llvm::Value* capacity, llvm::Value* el, ASR::ttype_t* el_asr_type, llvm::Module& module); - virtual - void resolve_collision( - llvm::Value* capacity, llvm::Value* el_hash, - llvm::Value* el, llvm::Value* el_list, - llvm::Value* el_mask, llvm::Module& module, - ASR::ttype_t* el_asr_type, bool for_read=false) = 0; - virtual void resolve_collision_for_write( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, @@ -942,7 +941,7 @@ namespace LCompilers { void write_item( llvm::Value* set, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, - std::map>& name2memidx) = 0; + std::map>& name2memidx); virtual void resolve_collision_for_read_with_bound_check( @@ -961,7 +960,13 @@ namespace LCompilers { std::map>& name2memidx) = 0; virtual - llvm::Value* len(llvm::Value* set) = 0; + llvm::Value* len(llvm::Value* set); + + virtual + bool is_set_present(); + + virtual + void set_is_set_present(bool value); virtual ~LLVMSetInterface() = 0; @@ -1010,11 +1015,87 @@ namespace LCompilers { llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); - void write_item( + void resolve_collision_for_read_with_bound_check( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void remove_item( llvm::Value* set, llvm::Value* el, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void set_deepcopy( + llvm::Value* src, llvm::Value* dest, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx); + + ~LLVMSetLinearProbing(); + }; + + class LLVMSetSeparateChaining: public LLVMSetInterface { + + protected: + + std::map typecode2elstruct; + + llvm::Value* get_pointer_to_number_of_filled_buckets(llvm::Value* set); + + llvm::Value* get_pointer_to_elems(llvm::Value* set); + + llvm::Value* get_pointer_to_rehash_flag(llvm::Value* set); + + void set_init_given_initial_capacity(std::string el_type_code, + llvm::Value* set, llvm::Module* module, llvm::Value* initial_capacity); + + void resolve_collision( + llvm::Value* el_hash, llvm::Value* el, llvm::Value* el_linked_list, + llvm::Type* el_struct_type, llvm::Value* el_mask, + llvm::Module& module, ASR::ttype_t* el_asr_type); + + void write_el_linked_list( + llvm::Value* el_ll, llvm::Value* set, llvm::Value* capacity, + ASR::ttype_t* m_el_type, llvm::Module* module, + std::map>& name2memidx); + + void deepcopy_el_linked_list( + llvm::Value* srci, llvm::Value* desti, llvm::Value* dest_elems, + ASR::Set_t* set_type, llvm::Module* module, + std::map>& name2memidx); + + public: + + LLVMSetSeparateChaining( + llvm::LLVMContext& context_, + LLVMUtils* llvm_utils, + llvm::IRBuilder<>* builder); + + llvm::Type* get_set_type( + std::string type_code, + int32_t type_size, llvm::Type* el_type); + + void set_init(std::string type_code, llvm::Value* set, + llvm::Module* module, size_t initial_capacity); + + llvm::Value* get_el_list(llvm::Value* set); + + llvm::Value* get_pointer_to_occupancy(llvm::Value* set); + + llvm::Value* get_pointer_to_capacity(llvm::Value* set); + + llvm::Value* get_pointer_to_mask(llvm::Value* set); + + void resolve_collision_for_write( + llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module* module, ASR::ttype_t* el_asr_type, std::map>& name2memidx); + void rehash( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + + void rehash_all_at_once_if_needed( + llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type, + std::map>& name2memidx); + void resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, llvm::Module& module, ASR::ttype_t* el_asr_type); @@ -1028,9 +1109,7 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); - llvm::Value* len(llvm::Value* set); - - ~LLVMSetLinearProbing(); + ~LLVMSetSeparateChaining(); }; } // namespace LCompilers From b049ada4f18738ef8414d24cf8ba26494c68ee49 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Fri, 21 Jul 2023 21:10:23 +0530 Subject: [PATCH 3/5] use only LP for now --- src/libasr/codegen/llvm_utils.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index e6b8022847..309e102632 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -877,12 +877,13 @@ namespace LCompilers { } } - void LLVMUtils::set_set_api(ASR::Set_t* set_type) { - if( ASR::is_a(*set_type->m_type) ) { - set_api = set_api_sc; - } else { - set_api = set_api_lp; - } + void LLVMUtils::set_set_api(ASR::Set_t* /*set_type*/) { + // if( ASR::is_a(*set_type->m_type) ) { + // set_api = set_api_sc; + // } else { + // set_api = set_api_lp; + // } + set_api = set_api_lp; } std::vector LLVMUtils::convert_args(const ASR::Function_t& x, llvm::Module* module) { From 1e858978a262654e82f2de2c941108406cf9644f Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sat, 29 Jul 2023 16:00:17 +0530 Subject: [PATCH 4/5] fix separate chaining bugs --- integration_tests/test_set_len.py | 6 +- src/libasr/codegen/llvm_utils.cpp | 143 +++++++++++++++++------------- 2 files changed, 87 insertions(+), 62 deletions(-) diff --git a/integration_tests/test_set_len.py b/integration_tests/test_set_len.py index 33d252a0fe..8e66064dd3 100644 --- a/integration_tests/test_set_len.py +++ b/integration_tests/test_set_len.py @@ -3,6 +3,8 @@ def test_set(): s: set[i32] s = {1, 2, 22, 2, -1, 1} - assert len(s) == 4 + s2: set[str] + s2 = {'a', 'b', 'cd', 'b', 'abc', 'a'} + assert len(s2) == 4 -test_set() \ No newline at end of file +test_set() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 309e102632..2cb36212ed 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -877,13 +877,13 @@ namespace LCompilers { } } - void LLVMUtils::set_set_api(ASR::Set_t* /*set_type*/) { - // if( ASR::is_a(*set_type->m_type) ) { - // set_api = set_api_sc; - // } else { - // set_api = set_api_lp; - // } - set_api = set_api_lp; + void LLVMUtils::set_set_api(ASR::Set_t* set_type) { + if( ASR::is_a(*set_type->m_type) ) { + set_api = set_api_sc; + } else { + set_api = set_api_lp; + } + // set_api = set_api_lp; } std::vector LLVMUtils::convert_args(const ASR::Function_t& x, llvm::Module* module) { @@ -4994,7 +4994,7 @@ namespace LCompilers { std::string el_type_code, llvm::Value* set, llvm::Module* module, size_t initial_capacity) { llvm::Value* llvm_capacity = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, initial_capacity + 1)); + llvm::APInt(32, initial_capacity)); llvm::Value* rehash_flag_ptr = get_pointer_to_rehash_flag(set); LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), llvm::APInt(1, 1)), rehash_flag_ptr); @@ -5351,13 +5351,20 @@ namespace LCompilers { /** * C++ equivalent: * - * is_el_matching = 1; + * ll_exists = el_mask_value == 1; + * if( ll_exists ) { + * chain_itr = ll_head; + * } + * else { + * chain_itr = nullptr; + * } + * is_el_matching = 0; * - * while( chain_itr != nullptr && is_el_matching ) { - * break_signal = el != el_struct_el; - * is_el_matching = break_signal; // 1 means not matching - * if( break_signal ) { - * chain_itr = next_el_struct; + * while( chain_itr != nullptr && !is_el_matching ) { + * chain_itr_prev = chain_itr; + * is_el_matching = (el == el_struct_el); + * if( !is_el_matching ) { + * chain_itr = next_el_struct; // (*chain_itr)[1] * } * } * @@ -5372,12 +5379,18 @@ namespace LCompilers { LLVM::CreateStore(*builder, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr_prev); - llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context)); - LLVM::CreateStore(*builder, el_ll_i8, chain_itr); llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, llvm_utils->create_ptr_gep(el_mask, el_hash)); + llvm_utils->create_if_else(builder->CreateICmpEQ(el_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), [&]() { + llvm::Value* el_ll_i8 = builder->CreateBitCast(el_linked_list, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, el_ll_i8, chain_itr); + }, [&]() { + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), chain_itr); + }); LLVM::CreateStore(*builder, - builder->CreateICmpEQ(el_mask_value, llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))), + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(1, 0)), is_el_matching_var ); llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); @@ -5391,7 +5404,8 @@ namespace LCompilers { LLVM::CreateLoad(*builder, chain_itr), llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) ); - cond = builder->CreateAnd(cond, LLVM::CreateLoad(*builder, is_el_matching_var)); + cond = builder->CreateAnd(cond, builder->CreateNot(LLVM::CreateLoad( + *builder, is_el_matching_var))); builder->CreateCondBr(cond, loopbody, loopend); } @@ -5405,14 +5419,12 @@ namespace LCompilers { if( !LLVM::is_llvm_struct(el_asr_type) ) { el_struct_el = LLVM::CreateLoad(*builder, el_struct_el); } - llvm::Value* break_signal = llvm_utils->is_equal_by_value(el, el_struct_el, module, el_asr_type); - break_signal = builder->CreateNot(break_signal); - LLVM::CreateStore(*builder, break_signal, is_el_matching_var); - llvm_utils->create_if_else(break_signal, [&]() { + LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(el, el_struct_el, + module, el_asr_type), is_el_matching_var); + llvm_utils->create_if_else(builder->CreateNot(LLVM::CreateLoad(*builder, is_el_matching_var)), [&]() { llvm::Value* next_el_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(el_struct, 1)); LLVM::CreateStore(*builder, next_el_struct, chain_itr); - }, []() { - }); + }, []() {}); } builder->CreateBr(loophead); @@ -5483,14 +5495,22 @@ namespace LCompilers { /** * C++ equivalent: * + * el_linked_list = elems[el_hash]; * resolve_collision(el); // modifies chain_itr * do_insert = chain_itr == nullptr; * * if( do_insert ) { - * new_el_struct = malloc(el_struct_size); - * new_el_struct[0] = el; - * new_el_struct[1] = nullptr; - * chain_itr_prev[1] = new_el_struct; + * if( chain_itr_prev != nullptr ) { + * new_el_struct = malloc(el_struct_size); + * new_el_struct[0] = el; + * new_el_struct[1] = nullptr; + * chain_itr_prev[1] = new_el_struct; + * } + * else { + * el_linked_list[0] = el; + * el_linked_list[1] = nullptr; + * } + * occupancy += 1; * } * else { * el_struct[0] = el; @@ -5520,18 +5540,33 @@ namespace LCompilers { builder->SetInsertPoint(thenBB); { - llvm::DataLayout data_layout(module); - size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); - llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), el_struct_size); - llvm::Value* new_el_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); - llvm::Value* new_el_struct = builder->CreateBitCast(new_el_struct_i8, el_struct_type->getPointerTo()); - llvm_utils->deepcopy(el, llvm_utils->create_gep(new_el_struct, 0), el_asr_type, module, name2memidx); - LLVM::CreateStore(*builder, - llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), - llvm_utils->create_gep(new_el_struct, 1)); - llvm::Value* el_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); - llvm::Value* el_struct_prev = builder->CreateBitCast(el_struct_prev_i8, el_struct_type->getPointerTo()); - LLVM::CreateStore(*builder, new_el_struct_i8, llvm_utils->create_gep(el_struct_prev, 1)); + llvm_utils->create_if_else(builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr_prev), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))), [&]() { + llvm::DataLayout data_layout(module); + size_t el_struct_size = data_layout.getTypeAllocSize(el_struct_type); + llvm::Value* malloc_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), el_struct_size); + llvm::Value* new_el_struct_i8 = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); + llvm::Value* new_el_struct = builder->CreateBitCast(new_el_struct_i8, el_struct_type->getPointerTo()); + llvm_utils->deepcopy(el, llvm_utils->create_gep(new_el_struct, 0), el_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(new_el_struct, 1)); + llvm::Value* el_struct_prev_i8 = LLVM::CreateLoad(*builder, chain_itr_prev); + llvm::Value* el_struct_prev = builder->CreateBitCast(el_struct_prev_i8, el_struct_type->getPointerTo()); + LLVM::CreateStore(*builder, new_el_struct_i8, llvm_utils->create_gep(el_struct_prev, 1)); + }, [&]() { + llvm_utils->deepcopy(el, llvm_utils->create_gep(el_linked_list, 0), el_asr_type, module, name2memidx); + LLVM::CreateStore(*builder, + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), + llvm_utils->create_gep(el_linked_list, 1)); + }); + + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); + occupancy = builder->CreateAdd(occupancy, + llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 1)); + LLVM::CreateStore(*builder, occupancy, occupancy_ptr); } builder->CreateBr(mergeBB); llvm_utils->start_new_block(elseBB); @@ -5540,12 +5575,7 @@ namespace LCompilers { llvm_utils->deepcopy(el, llvm_utils->create_gep(el_struct, 0), el_asr_type, module, name2memidx); } llvm_utils->start_new_block(mergeBB); - llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); llvm::Value* buckets_filled_ptr = get_pointer_to_number_of_filled_buckets(set); - llvm::Value* occupancy = LLVM::CreateLoad(*builder, occupancy_ptr); - occupancy = builder->CreateAdd(occupancy, - builder->CreateZExt(do_insert, llvm::Type::getInt32Ty(context))); - LLVM::CreateStore(*builder, occupancy, occupancy_ptr); llvm::Value* el_mask_value_ptr = llvm_utils->create_ptr_gep(el_mask, el_hash); llvm::Value* el_mask_value = LLVM::CreateLoad(*builder, el_mask_value_ptr); llvm::Value* buckets_filled_delta = builder->CreateICmpEQ(el_mask_value, @@ -6225,6 +6255,7 @@ namespace LCompilers { llvm::Value* dest_el_mask = LLVM::lfortran_malloc(context, *module, *builder, malloc_size); LLVM::CreateStore(*builder, dest_el_mask, get_pointer_to_mask(dest)); + // number of elements to be copied = capacity + (occupancy - filled_buckets) malloc_size = builder->CreateSub(src_occupancy, src_filled_buckets); malloc_size = builder->CreateAdd(src_capacity, malloc_size); llvm::Type* el_struct_type = typecode2elstruct[ASRUtils::get_type_code(set_type->m_type)]; @@ -6295,13 +6326,15 @@ namespace LCompilers { * // memory allocation done before calling this function * * while( src_itr != nullptr ) { - * deepcopy(src_el, dest_el_ptr); + * deepcopy(src_el, curr_dest_ptr); * src_itr = src_itr_next; * if( src_next_exists ) { * *next_ptr = *next_ptr + 1; + * curr_dest[1] = &dest_elems[*next_ptr]; + * curr_dest = *curr_dest[1]; * } * else { - * curr_dest_next_ptr = nullptr; + * curr_dest[1] = nullptr; * } * } * @@ -6350,15 +6383,9 @@ namespace LCompilers { llvm::Value* curr_dest_next_ptr = llvm_utils->create_gep(curr_dest, 1); LLVM::CreateStore(*builder, src_next_ptr, src_itr); - llvm::Function *fn = builder->GetInsertBlock()->getParent(); - llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn); - llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"); - llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"); llvm::Value* src_next_exists = builder->CreateICmpNE(src_next_ptr, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context))); - builder->CreateCondBr(src_next_exists, thenBB, elseBB); - builder->SetInsertPoint(thenBB); - { + llvm_utils->create_if_else(src_next_exists, [&]() { llvm::Value* next_idx = LLVM::CreateLoad(*builder, next_ptr); llvm::Value* dest_next_ptr = llvm_utils->create_ptr_gep(dest_elems, next_idx); dest_next_ptr = builder->CreateBitCast(dest_next_ptr, llvm::Type::getInt8PtrTy(context)); @@ -6367,16 +6394,12 @@ namespace LCompilers { next_idx = builder->CreateAdd(next_idx, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); LLVM::CreateStore(*builder, next_idx, next_ptr); - } - builder->CreateBr(mergeBB); - llvm_utils->start_new_block(elseBB); - { + }, [&]() { LLVM::CreateStore(*builder, llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)), curr_dest_next_ptr ); - } - llvm_utils->start_new_block(mergeBB); + }); } builder->CreateBr(loophead); From 94aa2d0b7cb8852eef8a419555f9bf4263909469 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Mon, 31 Jul 2023 17:41:45 +0530 Subject: [PATCH 5/5] use only LP --- src/libasr/codegen/llvm_utils.cpp | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index f02cd6e356..d9963e702f 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -876,13 +876,11 @@ namespace LCompilers { } } - void LLVMUtils::set_set_api(ASR::Set_t* set_type) { - if( ASR::is_a(*set_type->m_type) ) { - set_api = set_api_sc; - } else { - set_api = set_api_lp; - } - // set_api = set_api_lp; + void LLVMUtils::set_set_api(ASR::Set_t* /*set_type*/) { + // As per benchmarks, separate chaining + // does not provide significant gains over + // linear probing. + set_api = set_api_lp; } std::vector LLVMUtils::convert_args(const ASR::Function_t& x, llvm::Module* module) {