From 9ad9c8bc381c34bb200d97c6bb025e7420a64751 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sun, 25 Jun 2023 15:22:01 +0200 Subject: [PATCH 1/4] add dict keys --- src/libasr/asr_utils.h | 12 ++++- src/libasr/codegen/asr_to_llvm.cpp | 18 +++++++ src/libasr/pass/intrinsic_function_registry.h | 53 +++++++++++++++++++ src/lpython/semantics/python_attribute_eval.h | 17 +++++- 4 files changed, 98 insertions(+), 2 deletions(-) diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 796ceafccf..cab4e49141 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -197,7 +197,7 @@ static inline ASR::abiType symbol_abi(const ASR::symbol_t *f) return ASR::abiType::Source; } -static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) { +static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type, int overload=0) { switch( asr_type->type ) { case ASR::ttypeType::List: { return ASR::down_cast(asr_type)->m_type; @@ -205,6 +205,16 @@ static inline ASR::ttype_t* get_contained_type(ASR::ttype_t* asr_type) { case ASR::ttypeType::Set: { return ASR::down_cast(asr_type)->m_type; } + case ASR::ttypeType::Dict: { + switch( overload ) { + case 0: + return ASR::down_cast(asr_type)->m_key_type; + case 1: + return ASR::down_cast(asr_type)->m_value_type; + default: + return asr_type; + } + } case ASR::ttypeType::Enum: { ASR::Enum_t* enum_asr = ASR::down_cast(asr_type); ASR::EnumType_t* enum_type = ASR::down_cast(enum_asr->m_enum_type); diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 7329e37c40..713eeeaa11 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -2072,6 +2072,20 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx); } + void generate_DictKeys(ASR::expr_t* m_arg) { + ASR::Dict_t* dict_type = ASR::down_cast( + ASRUtils::expr_type(m_arg)); + + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*m_arg); + llvm::Value* pdict = tmp; + + set_dict_api(dict_type); + ptr_loads = ptr_loads_copy; + tmp = llvm_utils->dict_api->get_key_list(pdict); + } + void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) { switch (static_cast(x.m_intrinsic_id)) { case ASRUtils::IntrinsicFunctions::ListIndex: { @@ -2115,6 +2129,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } break; } + case ASRUtils::IntrinsicFunctions::DictKeys: { + generate_DictKeys(x.m_args[0]); + break; + } case ASRUtils::IntrinsicFunctions::Exp: { switch (x.m_overload_id) { case 0: { diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index d82a8c253e..aa6b36aac7 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -64,6 +64,7 @@ enum class IntrinsicFunctions : int64_t { Partition, ListReverse, ListPop, + DictKeys, SymbolicSymbol, SymbolicAdd, SymbolicSub, @@ -1139,6 +1140,52 @@ static inline ASR::asr_t* create_ListPop(Allocator& al, const Location& loc, } // namespace ListPop +namespace DictKeys { + +static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, "Call to dict.keys must have no argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "Argument to dict.keys must be of dict type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASRUtils::check_equal_type( + ASRUtils::get_contained_type(x.m_type), + ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 0)), + "Return type of dict.keys must be of list of dict key element type", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_dict_keys(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO: To be implemented for DictConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_DictKeys(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 1) { + err("Call to dict.keys must have no argument", loc); + } + + ASR::expr_t* dict_expr = args[0]; + ASR::ttype_t *type = ASRUtils::expr_type(dict_expr); + ASR::ttype_t *dict_keys_type = ASR::down_cast(type)->m_key_type; + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_dict_keys(al, loc, arg_values); + ASR::ttype_t *to_type = List(dict_keys_type); + return ASR::make_IntrinsicFunction_t(al, loc, + static_cast(ASRUtils::IntrinsicFunctions::DictKeys), + args.p, args.size(), 0, to_type, compile_time_value); +} + +} // namespace DictKeys + namespace Any { static inline void verify_array(ASR::expr_t* array, ASR::ttype_t* return_type, @@ -2212,6 +2259,8 @@ namespace IntrinsicFunctionRegistry { {nullptr, &ListPop::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::ListReverse), {nullptr, &ListReverse::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::DictKeys), + {nullptr, &DictKeys::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), {nullptr, &SymbolicSymbol::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2266,6 +2315,8 @@ namespace IntrinsicFunctionRegistry { "list.reverse"}, {static_cast(ASRUtils::IntrinsicFunctions::ListPop), "list.pop"}, + {static_cast(ASRUtils::IntrinsicFunctions::DictKeys), + "dict.keys"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), "Symbol"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2311,6 +2362,7 @@ namespace IntrinsicFunctionRegistry { {"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}}, {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, + {"dict.keys", {&DictKeys::create_DictKeys, &DictKeys::eval_dict_keys}}, {"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}}, {"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}}, {"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}}, @@ -2425,6 +2477,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(Partition) INTRINSIC_NAME_CASE(ListReverse) INTRINSIC_NAME_CASE(ListPop) + INTRINSIC_NAME_CASE(DictKeys) INTRINSIC_NAME_CASE(SymbolicSymbol) INTRINSIC_NAME_CASE(SymbolicAdd) INTRINSIC_NAME_CASE(SymbolicSub) diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 5ce7834bb1..ceb639a5c2 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -34,7 +34,8 @@ struct AttributeHandler { {"set@add", &eval_set_add}, {"set@remove", &eval_set_remove}, {"dict@get", &eval_dict_get}, - {"dict@pop", &eval_dict_pop} + {"dict@pop", &eval_dict_pop}, + {"dict@keys", &eval_dict_keys} }; modify_attr_set = {"list@append", "list@remove", @@ -388,6 +389,20 @@ struct AttributeHandler { return make_DictPop_t(al, loc, s, args[0], value_type, nullptr); } + static ASR::asr_t* eval_dict_keys(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_dict; + args_with_dict.reserve(al, args.size() + 1); + args_with_dict.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_dict.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("dict.keys"); + return create_function(al, loc, args_with_dict, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From f28eaf06b676fb0a5500be20415e36913dbeb555 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sat, 8 Jul 2023 11:00:46 +0530 Subject: [PATCH 2/4] fix dict keys and values for LP --- integration_tests/CMakeLists.txt | 1 + integration_tests/test_dict_keys_values.py | 43 +++++++++ src/libasr/codegen/asr_to_llvm.cpp | 43 ++++++++- src/libasr/codegen/llvm_utils.cpp | 89 +++++++++++++++++++ src/libasr/codegen/llvm_utils.h | 16 ++++ src/libasr/pass/intrinsic_function_registry.h | 57 +++++++++++- src/lpython/semantics/python_attribute_eval.h | 17 +++- 7 files changed, 260 insertions(+), 6 deletions(-) create mode 100644 integration_tests/test_dict_keys_values.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 3d69c73c32..0006531218 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -472,6 +472,7 @@ RUN(NAME test_dict_12 LABELS cpython llvm c) RUN(NAME test_dict_13 LABELS cpython llvm c) RUN(NAME test_dict_bool LABELS cpython llvm) RUN(NAME test_dict_increment LABELS cpython llvm) +RUN(NAME test_dict_keys_values LABELS cpython llvm) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) RUN(NAME modules_02 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_keys_values.py b/integration_tests/test_dict_keys_values.py new file mode 100644 index 0000000000..a50a3c64f6 --- /dev/null +++ b/integration_tests/test_dict_keys_values.py @@ -0,0 +1,43 @@ +from lpython import i32, f64 + +def test_dict_keys_values(): + d1: dict[i32, i32] = {} + d2: dict[tuple[i32, i32], tuple[i32, tuple[str, f64]]] = {} + k1: list[i32] + k2: list[tuple[i32, i32]] + v1: list[i32] + v2: list[tuple[i32, tuple[str, f64]]] + i: i32 + j: i32 + key_count: i32 + s: str + + for i in range(105, 115): + d1[i] = i + 1 + k1 = d1.keys() + v1 = d1.values() + assert len(k1) == 10 + for i in range(105, 115): + key_count = 0 + for j in range(len(k1)): + if k1[j] == i: + key_count += 1 + assert v1[j] == d1[i] + assert key_count == 1 + + s = 'a' + for i in range(10): + d2[(i, i + 1)] = (i, (s, f64(i * i))) + s += 'a' + k2 = d2.keys() + v2 = d2.values() + assert len(k2) == 10 + for i in range(10): + key_count = 0 + for j in range(len(k2)): + if k2[j] == (i, i + 1): + key_count += 1 + assert v2[j] == d2[k2[j]] + assert key_count == 1 + +test_dict_keys_values() \ No newline at end of file diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 713eeeaa11..487e6b0b1b 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -2072,9 +2072,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx); } - void generate_DictKeys(ASR::expr_t* m_arg) { + void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value, const Location &loc) { ASR::Dict_t* dict_type = ASR::down_cast( ASRUtils::expr_type(m_arg)); + ASR::ttype_t* el_type = key_or_value == 0 ? + dict_type->m_key_type : dict_type->m_value_type; int64_t ptr_loads_copy = ptr_loads; ptr_loads = 0; @@ -2082,8 +2084,39 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm::Value* pdict = tmp; set_dict_api(dict_type); + if(llvm_utils->dict_api == dict_api_sc.get()) { + throw CodeGenError("dict.keys and dict.values are only implemented " + "for linear probing for now", loc); + } ptr_loads = ptr_loads_copy; - tmp = llvm_utils->dict_api->get_key_list(pdict); + + bool is_array_type_local = false, is_malloc_array_type_local = false; + bool is_list_local = false; + ASR::dimension_t* m_dims_local = nullptr; + int n_dims_local = -1, a_kind_local = -1; + llvm::Type* llvm_el_type = get_type_from_ttype_t(el_type, + nullptr, + ASR::storage_typeType::Default, is_array_type_local, + is_malloc_array_type_local, is_list_local, m_dims_local, + n_dims_local, a_kind_local); + std::string type_code = ASRUtils::get_type_code(el_type); + int32_t type_size = -1; + if( ASR::is_a(*el_type) || + LLVM::is_llvm_struct(el_type) || + ASR::is_a(*el_type) ) { + llvm::DataLayout data_layout(module.get()); + type_size = data_layout.getTypeAllocSize(llvm_el_type); + } else { + type_size = ASRUtils::extract_kind_from_ttype_t(el_type); + } + llvm::Type* el_list_type = list_api->get_list_type(llvm_el_type, type_code, type_size); + llvm::Value* el_list = builder->CreateAlloca(el_list_type, nullptr, key_or_value == 0 ? + "keys_list" : "values_list"); + list_api->list_init(type_code, el_list, *module, 0, 0); + + llvm_utils->dict_api->get_elements_list(pdict, el_list, el_type, *module, + name2memidx, key_or_value); + tmp = el_list; } void visit_IntrinsicFunction(const ASR::IntrinsicFunction_t& x) { @@ -2130,7 +2163,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor break; } case ASRUtils::IntrinsicFunctions::DictKeys: { - generate_DictKeys(x.m_args[0]); + generate_DictElems(x.m_args[0], 0, x.base.base.loc); + break; + } + case ASRUtils::IntrinsicFunctions::DictValues: { + generate_DictElems(x.m_args[0], 1, x.base.base.loc); break; } case ASRUtils::IntrinsicFunctions::Exp: { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 5b3f897a2b..072f72a504 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -341,6 +341,12 @@ namespace LCompilers { list_api->list_deepcopy(src, dest, list_type, module, name2memidx); break ; } + case ASR::ttypeType::Dict: { + ASR::Dict_t* dict_type = ASR::down_cast(asr_type); + // set dict api here? + dict_api->dict_deepcopy(src, dest, dict_type, module, name2memidx); + break ; + } case ASR::ttypeType::Struct: { ASR::Struct_t* struct_t = ASR::down_cast(asr_type); ASR::StructType_t* struct_type_t = ASR::down_cast( @@ -2469,6 +2475,89 @@ namespace LCompilers { return LLVM::CreateLoad(*builder, value_ptr); } + void LLVMDict::get_elements_list(llvm::Value* dict, + llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + std::map>& name2memidx, + bool key_or_value) { + + /** + * C++ equivalent: + * + * idx = 0; + * + * while( capacity > idx ) { + * el = key_or_value_list[idx]; + * key_mask_value = key_mask[idx]; + * + * is_key_skip = key_mask_value == 3; // tombstone + * is_key_set = key_mask_value != 0; + * add_el = is_key_set && !is_key_skip; + * if( add_el ) { + * elements_list.append(el); + * } + * + * idx++; + * } + * + */ + + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict); + if( !are_iterators_set ) { + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + } + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); + + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT(capacity, LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, idx)); + llvm::Value* is_key_skip = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 3))); + llvm::Value* is_key_set = builder->CreateICmpNE(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 0))); + + llvm::Value* add_el = builder->CreateAnd(is_key_set, + builder->CreateNot(is_key_skip)); + llvm_utils->create_if_else(add_el, [&]() { + llvm::Value* el = llvm_utils->list_api->read_item(el_list, idx, + false, module, LLVM::is_llvm_struct(el_asr_type)); + llvm_utils->list_api->append(elements_list, el, + el_asr_type, &module, name2memidx); + }, [=]() { + }); + + idx = builder->CreateAdd(idx, llvm::ConstantInt::get( + llvm::Type::getInt32Ty(context), llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, idx, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } + + void LLVMDictSeparateChaining::get_elements_list(llvm::Value* /*dict*/, + llvm::Value* /*elements_list*/, ASR::ttype_t* /*el_asr_type*/, llvm::Module& /*module*/, + std::map>& /*name2memidx*/, + bool /*key_or_value*/) {} + llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, llvm::Module& module, bool get_pointer) { diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index ffcfef16c6..f249bd37af 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -463,6 +463,12 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); + virtual + void get_elements_list(llvm::Value* dict, + llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + std::map>& name2memidx, + bool key_or_value) = 0; + virtual ~LLVMDictInterface() = 0; }; @@ -555,6 +561,11 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void get_elements_list(llvm::Value* dict, + llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + std::map>& name2memidx, + bool key_or_value); + virtual ~LLVMDict(); }; @@ -702,6 +713,11 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void get_elements_list(llvm::Value* dict, + llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + std::map>& name2memidx, + bool key_or_value); + virtual ~LLVMDictSeparateChaining(); }; diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index aa6b36aac7..e16f84d19f 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -65,6 +65,7 @@ enum class IntrinsicFunctions : int64_t { ListReverse, ListPop, DictKeys, + DictValues, SymbolicSymbol, SymbolicAdd, SymbolicSub, @@ -1148,8 +1149,8 @@ static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnost ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), "Argument to dict.keys must be of dict type", x.base.base.loc, diagnostics); - ASRUtils::require_impl(ASRUtils::check_equal_type( - ASRUtils::get_contained_type(x.m_type), + ASRUtils::require_impl(ASR::is_a(*x.m_type) && + ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type), ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 0)), "Return type of dict.keys must be of list of dict key element type", x.base.base.loc, diagnostics); @@ -1186,6 +1187,52 @@ static inline ASR::asr_t* create_DictKeys(Allocator& al, const Location& loc, } // namespace DictKeys +namespace DictValues { + +static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 1, "Call to dict.values must have no argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "Argument to dict.values must be of dict type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*x.m_type) && + ASRUtils::check_equal_type(ASRUtils::get_contained_type(x.m_type), + ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]), 1)), + "Return type of dict.values must be of list of dict value element type", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_dict_values(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO: To be implemented for DictConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_DictValues(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 1) { + err("Call to dict.values must have no argument", loc); + } + + ASR::expr_t* dict_expr = args[0]; + ASR::ttype_t *type = ASRUtils::expr_type(dict_expr); + ASR::ttype_t *dict_values_type = ASR::down_cast(type)->m_value_type; + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_dict_values(al, loc, arg_values); + ASR::ttype_t *to_type = List(dict_values_type); + return ASR::make_IntrinsicFunction_t(al, loc, + static_cast(ASRUtils::IntrinsicFunctions::DictValues), + args.p, args.size(), 0, to_type, compile_time_value); +} + +} // namespace DictValues + namespace Any { static inline void verify_array(ASR::expr_t* array, ASR::ttype_t* return_type, @@ -2261,6 +2308,8 @@ namespace IntrinsicFunctionRegistry { {nullptr, &ListReverse::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::DictKeys), {nullptr, &DictKeys::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::DictValues), + {nullptr, &DictValues::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), {nullptr, &SymbolicSymbol::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2317,6 +2366,8 @@ namespace IntrinsicFunctionRegistry { "list.pop"}, {static_cast(ASRUtils::IntrinsicFunctions::DictKeys), "dict.keys"}, + {static_cast(ASRUtils::IntrinsicFunctions::DictValues), + "dict.values"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicSymbol), "Symbol"}, {static_cast(ASRUtils::IntrinsicFunctions::SymbolicAdd), @@ -2363,6 +2414,7 @@ namespace IntrinsicFunctionRegistry { {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, {"dict.keys", {&DictKeys::create_DictKeys, &DictKeys::eval_dict_keys}}, + {"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}}, {"Symbol", {&SymbolicSymbol::create_SymbolicSymbol, &SymbolicSymbol::eval_SymbolicSymbol}}, {"SymbolicAdd", {&SymbolicAdd::create_SymbolicAdd, &SymbolicAdd::eval_SymbolicAdd}}, {"SymbolicSub", {&SymbolicSub::create_SymbolicSub, &SymbolicSub::eval_SymbolicSub}}, @@ -2478,6 +2530,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(ListReverse) INTRINSIC_NAME_CASE(ListPop) INTRINSIC_NAME_CASE(DictKeys) + INTRINSIC_NAME_CASE(DictValues) INTRINSIC_NAME_CASE(SymbolicSymbol) INTRINSIC_NAME_CASE(SymbolicAdd) INTRINSIC_NAME_CASE(SymbolicSub) diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index ceb639a5c2..9b7aafcf45 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -35,7 +35,8 @@ struct AttributeHandler { {"set@remove", &eval_set_remove}, {"dict@get", &eval_dict_get}, {"dict@pop", &eval_dict_pop}, - {"dict@keys", &eval_dict_keys} + {"dict@keys", &eval_dict_keys}, + {"dict@values", &eval_dict_values} }; modify_attr_set = {"list@append", "list@remove", @@ -403,6 +404,20 @@ struct AttributeHandler { { throw SemanticError(msg, loc); }); } + static ASR::asr_t* eval_dict_values(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &/*diag*/) { + Vec args_with_dict; + args_with_dict.reserve(al, args.size() + 1); + args_with_dict.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_dict.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("dict.values"); + return create_function(al, loc, args_with_dict, [&](const std::string &msg, const Location &loc) + { throw SemanticError(msg, loc); }); + } + }; // AttributeHandler } // namespace LCompilers::LPython From 493c859d3b2e66a2dfd29a24e2875b9049ea3c26 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sat, 8 Jul 2023 11:19:26 +0530 Subject: [PATCH 3/4] modify test, incorrect --- integration_tests/test_dict_keys_values.py | 31 +++++++--------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/integration_tests/test_dict_keys_values.py b/integration_tests/test_dict_keys_values.py index a50a3c64f6..2c7c35ed68 100644 --- a/integration_tests/test_dict_keys_values.py +++ b/integration_tests/test_dict_keys_values.py @@ -2,42 +2,29 @@ def test_dict_keys_values(): d1: dict[i32, i32] = {} - d2: dict[tuple[i32, i32], tuple[i32, tuple[str, f64]]] = {} k1: list[i32] - k2: list[tuple[i32, i32]] + k1_copy: list[i32] = [] v1: list[i32] - v2: list[tuple[i32, tuple[str, f64]]] + v1_copy: list[i32] = [] i: i32 j: i32 key_count: i32 - s: str for i in range(105, 115): d1[i] = i + 1 k1 = d1.keys() + for i in k1: + k1_copy.append(i) v1 = d1.values() + for i in v1: + v1_copy.append(i) assert len(k1) == 10 for i in range(105, 115): key_count = 0 for j in range(len(k1)): - if k1[j] == i: + if k1_copy[j] == i: key_count += 1 - assert v1[j] == d1[i] + assert v1_copy[j] == d1[i] assert key_count == 1 - s = 'a' - for i in range(10): - d2[(i, i + 1)] = (i, (s, f64(i * i))) - s += 'a' - k2 = d2.keys() - v2 = d2.values() - assert len(k2) == 10 - for i in range(10): - key_count = 0 - for j in range(len(k2)): - if k2[j] == (i, i + 1): - key_count += 1 - assert v2[j] == d2[k2[j]] - assert key_count == 1 - -test_dict_keys_values() \ No newline at end of file +test_dict_keys_values() From 24e4c25bdf2a8ba8a66b04be30e63843eef9d4c1 Mon Sep 17 00:00:00 2001 From: kabra1110 Date: Sat, 22 Jul 2023 07:43:25 +0530 Subject: [PATCH 4/4] add for SC --- integration_tests/test_dict_keys_values.py | 24 +++++ src/libasr/codegen/asr_to_llvm.cpp | 15 ++- src/libasr/codegen/llvm_utils.cpp | 109 +++++++++++++++++++-- src/libasr/codegen/llvm_utils.h | 9 +- 4 files changed, 135 insertions(+), 22 deletions(-) diff --git a/integration_tests/test_dict_keys_values.py b/integration_tests/test_dict_keys_values.py index 2c7c35ed68..e3c28b72d6 100644 --- a/integration_tests/test_dict_keys_values.py +++ b/integration_tests/test_dict_keys_values.py @@ -8,6 +8,7 @@ def test_dict_keys_values(): v1_copy: list[i32] = [] i: i32 j: i32 + s: str key_count: i32 for i in range(105, 115): @@ -27,4 +28,27 @@ def test_dict_keys_values(): assert v1_copy[j] == d1[i] assert key_count == 1 + d2: dict[str, str] = {} + k2: list[str] + k2_copy: list[str] = [] + v2: list[str] + v2_copy: list[str] = [] + + for i in range(105, 115): + d2[str(i)] = str(i + 1) + k2 = d2.keys() + for s in k2: + k2_copy.append(s) + v2 = d2.values() + for s in v2: + v2_copy.append(s) + assert len(k2) == 10 + for i in range(105, 115): + key_count = 0 + for j in range(len(k2)): + if k2_copy[j] == str(i): + key_count += 1 + assert v2_copy[j] == d2[str(i)] + assert key_count == 1 + test_dict_keys_values() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 3b4d7e20a0..25edbebafc 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1684,7 +1684,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx); } - void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value, const Location &loc) { + void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value) { ASR::Dict_t* dict_type = ASR::down_cast( ASRUtils::expr_type(m_arg)); ASR::ttype_t* el_type = key_or_value == 0 ? @@ -1695,11 +1695,6 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor this->visit_expr(*m_arg); llvm::Value* pdict = tmp; - llvm_utils->set_dict_api(dict_type); - if(llvm_utils->dict_api == dict_api_sc.get()) { - throw CodeGenError("dict.keys and dict.values are only implemented " - "for linear probing for now", loc); - } ptr_loads = ptr_loads_copy; bool is_array_type_local = false, is_malloc_array_type_local = false; @@ -1725,7 +1720,9 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor "keys_list" : "values_list"); list_api->list_init(type_code, el_list, *module, 0, 0); - llvm_utils->dict_api->get_elements_list(pdict, el_list, el_type, *module, + llvm_utils->set_dict_api(dict_type); + llvm_utils->dict_api->get_elements_list(pdict, el_list, dict_type->m_key_type, + dict_type->m_value_type, *module, name2memidx, key_or_value); tmp = el_list; } @@ -1802,11 +1799,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor break; } case ASRUtils::IntrinsicFunctions::DictKeys: { - generate_DictElems(x.m_args[0], 0, x.base.base.loc); + generate_DictElems(x.m_args[0], 0); break; } case ASRUtils::IntrinsicFunctions::DictValues: { - generate_DictElems(x.m_args[0], 1, x.base.base.loc); + generate_DictElems(x.m_args[0], 1); break; } case ASRUtils::IntrinsicFunctions::SetAdd: { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 173b17b36b..cfa7710f9e 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -3872,34 +3872,38 @@ namespace LCompilers { } void LLVMDict::get_elements_list(llvm::Value* dict, - llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + llvm::Value* elements_list, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, llvm::Module& module, std::map>& name2memidx, bool key_or_value) { /** * C++ equivalent: - * + * + * // key_or_value = 0 for keys, 1 for values + * * idx = 0; - * + * * while( capacity > idx ) { * el = key_or_value_list[idx]; * key_mask_value = key_mask[idx]; - * + * * is_key_skip = key_mask_value == 3; // tombstone * is_key_set = key_mask_value != 0; * add_el = is_key_set && !is_key_skip; * if( add_el ) { * elements_list.append(el); * } - * + * * idx++; * } - * + * */ llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); llvm::Value* el_list = key_or_value == 0 ? get_key_list(dict) : get_value_list(dict); + ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type; if( !are_iterators_set ) { idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); } @@ -3949,10 +3953,95 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } - void LLVMDictSeparateChaining::get_elements_list(llvm::Value* /*dict*/, - llvm::Value* /*elements_list*/, ASR::ttype_t* /*el_asr_type*/, llvm::Module& /*module*/, - std::map>& /*name2memidx*/, - bool /*key_or_value*/) {} + void LLVMDictSeparateChaining::get_elements_list(llvm::Value* dict, + llvm::Value* elements_list, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, llvm::Module& module, + std::map>& name2memidx, + bool key_or_value) { + if( !are_iterators_set ) { + idx_ptr = builder->CreateAlloca(llvm::Type::getInt32Ty(context), nullptr); + chain_itr = builder->CreateAlloca(llvm::Type::getInt8PtrTy(context), nullptr); + } + LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, 0)), idx_ptr); + + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(dict)); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm::Value* key_value_pairs = LLVM::CreateLoad(*builder, get_pointer_to_key_value_pairs(dict)); + llvm::Type* kv_pair_type = get_key_value_pair_type(key_asr_type, value_asr_type); + ASR::ttype_t* el_asr_type = key_or_value == 0 ? key_asr_type : value_asr_type; + llvm::BasicBlock *loophead = llvm::BasicBlock::Create(context, "loop.head"); + llvm::BasicBlock *loopbody = llvm::BasicBlock::Create(context, "loop.body"); + llvm::BasicBlock *loopend = llvm::BasicBlock::Create(context, "loop.end"); + + // head + llvm_utils->start_new_block(loophead); + { + llvm::Value *cond = builder->CreateICmpSGT( + capacity, + LLVM::CreateLoad(*builder, idx_ptr)); + builder->CreateCondBr(cond, loopbody, loopend); + } + + // body + llvm_utils->start_new_block(loopbody); + { + llvm::Value* idx = LLVM::CreateLoad(*builder, idx_ptr); + llvm::Value* key_mask_value = LLVM::CreateLoad(*builder, + llvm_utils->create_ptr_gep(key_mask, idx)); + llvm::Value* is_key_set = builder->CreateICmpEQ(key_mask_value, + llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1))); + + llvm_utils->create_if_else(is_key_set, [&]() { + llvm::Value* dict_i = llvm_utils->create_ptr_gep(key_value_pairs, idx); + llvm::Value* kv_ll_i8 = builder->CreateBitCast(dict_i, llvm::Type::getInt8PtrTy(context)); + LLVM::CreateStore(*builder, kv_ll_i8, chain_itr); + + llvm::BasicBlock *loop2head = llvm::BasicBlock::Create(context, "loop2.head"); + llvm::BasicBlock *loop2body = llvm::BasicBlock::Create(context, "loop2.body"); + llvm::BasicBlock *loop2end = llvm::BasicBlock::Create(context, "loop2.end"); + + // head + llvm_utils->start_new_block(loop2head); + { + llvm::Value *cond = builder->CreateICmpNE( + LLVM::CreateLoad(*builder, chain_itr), + llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)) + ); + builder->CreateCondBr(cond, loop2body, loop2end); + } + + // body + llvm_utils->start_new_block(loop2body); + { + llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr); + llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_pair_type->getPointerTo()); + llvm::Value* kv_el = llvm_utils->create_gep(kv_struct, key_or_value); + if( !LLVM::is_llvm_struct(el_asr_type) ) { + kv_el = LLVM::CreateLoad(*builder, kv_el); + } + llvm_utils->list_api->append(elements_list, kv_el, + el_asr_type, &module, name2memidx); + llvm::Value* next_kv_struct = LLVM::CreateLoad(*builder, llvm_utils->create_gep(kv_struct, 2)); + LLVM::CreateStore(*builder, next_kv_struct, chain_itr); + } + + builder->CreateBr(loop2head); + + // end + llvm_utils->start_new_block(loop2end); + }, [=]() { + }); + llvm::Value* tmp = builder->CreateAdd(idx, + llvm::ConstantInt::get(context, llvm::APInt(32, 1))); + LLVM::CreateStore(*builder, tmp, idx_ptr); + } + + builder->CreateBr(loophead); + + // end + llvm_utils->start_new_block(loopend); + } llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 8b38e495fe..26d2d48822 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -623,7 +623,8 @@ namespace LCompilers { virtual void get_elements_list(llvm::Value* dict, - llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + llvm::Value* elements_list, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, llvm::Module& module, std::map>& name2memidx, bool key_or_value) = 0; @@ -720,7 +721,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); void get_elements_list(llvm::Value* dict, - llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + llvm::Value* elements_list, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, llvm::Module& module, std::map>& name2memidx, bool key_or_value); @@ -872,7 +874,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); void get_elements_list(llvm::Value* dict, - llvm::Value* elements_list, ASR::ttype_t* el_asr_type, llvm::Module& module, + llvm::Value* elements_list, ASR::ttype_t* key_asr_type, + ASR::ttype_t* value_asr_type, llvm::Module& module, std::map>& name2memidx, bool key_or_value);