From 4c970db0e7dcbd488c0be4bf4c4246d00f402731 Mon Sep 17 00:00:00 2001 From: advik Date: Tue, 25 Jun 2024 00:41:38 +0530 Subject: [PATCH 1/4] Add clear method to dictionary and set --- src/libasr/ASR.asdl | 2 + src/libasr/codegen/asr_to_llvm.cpp | 23 ++++++++ src/libasr/codegen/llvm_utils.cpp | 56 +++++++++++++++++++ src/libasr/codegen/llvm_utils.h | 15 +++++ src/lpython/semantics/python_attribute_eval.h | 44 ++++++++++++++- 5 files changed, 139 insertions(+), 1 deletion(-) diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 1d4cf6ee7d..584f8f3f29 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -74,6 +74,8 @@ stmt | ListRemove(expr a, expr ele) | ListClear(expr a) | DictInsert(expr a, expr key, expr value) + | DictClear(expr a) + | SetClear(expr a) | Expr(expr expression) expr diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4921c46f75..9076ca556d 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1637,6 +1637,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_DictClear(const ASR::DictClear_t& x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pdict = tmp; + ptr_loads = ptr_loads_copy; + ASR::Dict_t* dict_type = ASR::down_cast(ASRUtils::expr_type(x.m_a)); + + llvm_utils->dict_api->dict_clear(pdict, module.get(), dict_type->m_key_type, dict_type->m_value_type); + } + + void visit_SetClear(const ASR::SetClear_t& x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pset = tmp; + ptr_loads = ptr_loads_copy; + ASR::Set_t *set_type = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); + + llvm_utils->set_api->set_clear(pset, module.get(), set_type->m_type); + } + void visit_DictContains(const ASR::DictContains_t &x) { if (x.m_value) { this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 9f20c98f78..b4b04cb7a6 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1,3 +1,5 @@ +#include "llvm_utils.h" +#include #include #include #include @@ -4384,6 +4386,25 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMDict::dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm_utils->list_api->free_data(key_list, *module); + llvm_utils->list_api->free_data(value_list, *module); + LLVM::lfortran_free(context, *module, *builder, key_mask); + + std::string key_type_code = ASRUtils::get_type_code(key_asr_type); + std::string value_type_code = ASRUtils::get_type_code(value_asr_type); + dict_init(key_type_code, value_type_code, dict, module, 0); + } + + void LLVMDictSeparateChaining::dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) { + dict_init(ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type), dict, module, 0); + } llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, @@ -6880,6 +6901,41 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMSetLinearProbing::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) { + get_builder0(); + llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); + llvm::Value* capacity_ptr = get_pointer_to_capacity(set); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + LLVM::CreateStore(*builder, llvm_zero, occupancy_ptr); + LLVM::CreateStore(*builder, llvm_zero, capacity_ptr); + + llvm::Value* el_list = get_el_list(set); + llvm::DataLayout data_layout(module); + size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); + llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), + llvm::APInt(32, mask_size)); + llvm::Value* new_el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_zero, + llvm_mask_size); + std::string el_type_code = ASRUtils::get_type_code(el_asr_type); + //llvm::Type* el_llvm_type = std::get<2>(typecode2settype[el_type_code]); + //int32_t el_type_size = std::get<1>(typecode2settype[el_type_code]); + + //llvm::Value* new_el_list = builder0.CreateAlloca(llvm_utils->list_api->get_list_type(el_llvm_type, + //el_type_code, el_type_size), nullptr); + llvm_utils->list_api->list_init(el_type_code, el_list, *module, llvm_zero, llvm_zero); + + llvm_utils->list_api->free_data(el_list, *module); + LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); + //LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_el_list), el_list); + LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); + } + + void LLVMSetSeparateChaining::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) { + LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + set_init_given_initial_capacity(ASRUtils::get_type_code(el_asr_type), set, module, llvm_zero); + } + llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 8e24438100..cff7b9f982 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -1,6 +1,7 @@ #ifndef LFORTRAN_LLVM_UTILS_H #define LFORTRAN_LLVM_UTILS_H +#include #include #include @@ -644,6 +645,9 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); + virtual + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) = 0; virtual void get_elements_list(llvm::Value* dict, @@ -739,6 +743,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -889,6 +895,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -987,6 +995,9 @@ namespace LCompilers { virtual llvm::Value* len(llvm::Value* set); + virtual + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type) = 0; + virtual bool is_set_present(); @@ -1053,6 +1064,8 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type); + ~LLVMSetLinearProbing(); }; @@ -1134,6 +1147,8 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type); + ~LLVMSetSeparateChaining(); }; diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index f8926a3eb8..ac9ad26a9c 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -34,10 +34,12 @@ struct AttributeHandler { {"set@add", &eval_set_add}, {"set@remove", &eval_set_remove}, {"set@discard", &eval_set_discard}, + {"set@clear", &eval_set_clear}, {"dict@get", &eval_dict_get}, {"dict@pop", &eval_dict_pop}, {"dict@keys", &eval_dict_keys}, - {"dict@values", &eval_dict_values} + {"dict@values", &eval_dict_values}, + {"dict@clear", &eval_dict_clear} }; modify_attr_set = {"list@append", "list@remove", @@ -356,6 +358,26 @@ struct AttributeHandler { return create_function(al, loc, args_with_set, diag); } + static ASR::asr_t* eval_set_clear(ASR::expr_t *s, Allocator &al, + const Location &loc, Vec &args, diag::Diagnostics & diag) { + if (ASRUtils::is_const(s)) { + throw SemanticError("cannot clear elements from a const set", loc); + } + if (args.size() != 0) { + diag.add(diag::Diagnostic( + "Incorrect number of arguments in 'clear', it accepts no argument", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("incorrect number of arguments in clear (found: " + + std::to_string(args.size()) + ", expected: 0)", + {loc}) + }) + ); + throw SemanticAbort(); + } + + return make_SetClear_t(al, loc, s); + } + static ASR::asr_t* eval_dict_get(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { ASR::expr_t *def = nullptr; @@ -448,6 +470,26 @@ struct AttributeHandler { return create_function(al, loc, args_with_dict, diag); } + static ASR::asr_t* eval_dict_clear(ASR::expr_t *s, Allocator &al, + const Location &loc, Vec &args, diag::Diagnostics & diag) { + if (ASRUtils::is_const(s)) { + throw SemanticError("cannot clear elements from a const dict", loc); + } + if (args.size() != 0) { + diag.add(diag::Diagnostic( + "Incorrect number of arguments in 'clear', it accepts no argument", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("incorrect number of arguments in clear (found: " + + std::to_string(args.size()) + ", expected: 0)", + {loc}) + }) + ); + throw SemanticAbort(); + } + + return make_DictClear_t(al, loc, s); + } + static ASR::asr_t* eval_symbolic_diff(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { Vec args_with_list; From 0eccb251ab0f82c87cffe71805c6916a33d3fffc Mon Sep 17 00:00:00 2001 From: advik Date: Tue, 25 Jun 2024 12:38:50 +0530 Subject: [PATCH 2/4] Added tests --- integration_tests/CMakeLists.txt | 2 ++ integration_tests/test_dict_clear.py | 20 ++++++++++++++++++++ integration_tests/test_set_clear.py | 21 +++++++++++++++++++++ src/libasr/codegen/llvm_utils.cpp | 10 +++++----- 4 files changed, 48 insertions(+), 5 deletions(-) create mode 100644 integration_tests/test_dict_clear.py create mode 100644 integration_tests/test_set_clear.py diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index ea416e764b..6bde8f5d6c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -584,10 +584,12 @@ RUN(NAME test_dict_bool LABELS cpython llvm llvm_jit) RUN(NAME test_dict_increment LABELS cpython llvm llvm_jit) RUN(NAME test_dict_keys_values LABELS cpython llvm llvm_jit) RUN(NAME test_dict_nested1 LABELS cpython llvm llvm_jit) +RUN(NAME test_dict_clear LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm llvm_jit) RUN(NAME test_set_add LABELS cpython llvm llvm_jit) RUN(NAME test_set_remove LABELS cpython llvm llvm_jit) RUN(NAME test_set_discard LABELS cpython llvm llvm_jit) +RUN(NAME test_set_clear LABELS cpython llvm) RUN(NAME test_global_set LABELS cpython llvm llvm_jit) RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c) RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_clear.py b/integration_tests/test_dict_clear.py new file mode 100644 index 0000000000..67a0bd6109 --- /dev/null +++ b/integration_tests/test_dict_clear.py @@ -0,0 +1,20 @@ +def test_clear(): + a: dict[i32, i32] = {1:1, 2:2} + + a.clear() + a[3] = 3 + + assert len(a) == 1 + assert a.keys() == [3] + assert a.values() == [3] + + b: dict[str, str] = {'a':'a', 'b':'b'} + + b.clear() + b['c'] = 'c' + + assert len(b) == 1 + assert b.keys() == ['c'] + assert b.values() == ['c'] + +test_clear() diff --git a/integration_tests/test_set_clear.py b/integration_tests/test_set_clear.py new file mode 100644 index 0000000000..7d55ceeb5c --- /dev/null +++ b/integration_tests/test_set_clear.py @@ -0,0 +1,21 @@ +def test_clear(): + a: set[i32] = {1, 2} + + a.clear() + a.add(3) + + assert len(a) == 1 + # a.remove(3) + # assert len(a) == 0 + + b: set[str] = {'a', 'b'} + + b.clear() + b.add('c') + + assert len(b) == 1 + # b.remove('c') + # assert len(b) == 0 + + +test_clear() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index b4b04cb7a6..bef5a31f70 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -6917,16 +6917,16 @@ namespace LCompilers { llvm::Value* new_el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_zero, llvm_mask_size); std::string el_type_code = ASRUtils::get_type_code(el_asr_type); - //llvm::Type* el_llvm_type = std::get<2>(typecode2settype[el_type_code]); - //int32_t el_type_size = std::get<1>(typecode2settype[el_type_code]); + llvm::Type* el_llvm_type = std::get<2>(typecode2settype[el_type_code]); + int32_t el_type_size = std::get<1>(typecode2settype[el_type_code]); - //llvm::Value* new_el_list = builder0.CreateAlloca(llvm_utils->list_api->get_list_type(el_llvm_type, - //el_type_code, el_type_size), nullptr); + llvm::Value* new_el_list = builder0.CreateAlloca(llvm_utils->list_api->get_list_type(el_llvm_type, + el_type_code, el_type_size), nullptr); llvm_utils->list_api->list_init(el_type_code, el_list, *module, llvm_zero, llvm_zero); llvm_utils->list_api->free_data(el_list, *module); LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); - //LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_el_list), el_list); + LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_el_list), el_list); LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); } From 02238b72d8ea8297607b0acbc349ad30deef45e0 Mon Sep 17 00:00:00 2001 From: advik Date: Tue, 25 Jun 2024 17:57:42 +0530 Subject: [PATCH 3/4] Update valid tests --- integration_tests/test_dict_clear.py | 6 ++---- integration_tests/test_set_clear.py | 8 ++++---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/integration_tests/test_dict_clear.py b/integration_tests/test_dict_clear.py index 67a0bd6109..eccfea0aa6 100644 --- a/integration_tests/test_dict_clear.py +++ b/integration_tests/test_dict_clear.py @@ -5,8 +5,7 @@ def test_clear(): a[3] = 3 assert len(a) == 1 - assert a.keys() == [3] - assert a.values() == [3] + assert 3 in a b: dict[str, str] = {'a':'a', 'b':'b'} @@ -14,7 +13,6 @@ def test_clear(): b['c'] = 'c' assert len(b) == 1 - assert b.keys() == ['c'] - assert b.values() == ['c'] + assert 'c' in b test_clear() diff --git a/integration_tests/test_set_clear.py b/integration_tests/test_set_clear.py index 7d55ceeb5c..47776a7e07 100644 --- a/integration_tests/test_set_clear.py +++ b/integration_tests/test_set_clear.py @@ -2,18 +2,18 @@ def test_clear(): a: set[i32] = {1, 2} a.clear() - a.add(3) + # a.add(3) - assert len(a) == 1 + assert len(a) == 0 # a.remove(3) # assert len(a) == 0 b: set[str] = {'a', 'b'} b.clear() - b.add('c') + # b.add('c') - assert len(b) == 1 + assert len(b) == 0 # b.remove('c') # assert len(b) == 0 From 6db1860d12fc9f6a5894c58de368ead73e6885bc Mon Sep 17 00:00:00 2001 From: advik Date: Wed, 26 Jun 2024 00:02:59 +0530 Subject: [PATCH 4/4] Fix bugs and refactor code --- integration_tests/test_set_clear.py | 12 ++++++------ src/libasr/codegen/llvm_utils.cpp | 23 ++--------------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/integration_tests/test_set_clear.py b/integration_tests/test_set_clear.py index 47776a7e07..871e2c2bf7 100644 --- a/integration_tests/test_set_clear.py +++ b/integration_tests/test_set_clear.py @@ -2,20 +2,20 @@ def test_clear(): a: set[i32] = {1, 2} a.clear() - # a.add(3) + a.add(3) + assert len(a) == 1 + a.remove(3) assert len(a) == 0 - # a.remove(3) - # assert len(a) == 0 b: set[str] = {'a', 'b'} b.clear() - # b.add('c') + b.add('c') + assert len(b) == 1 + b.remove('c') assert len(b) == 0 - # b.remove('c') - # assert len(b) == 0 test_clear() diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index bef5a31f70..2bcb3e355e 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -6902,32 +6902,13 @@ namespace LCompilers { } void LLVMSetLinearProbing::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) { - get_builder0(); - llvm::Value* occupancy_ptr = get_pointer_to_occupancy(set); - llvm::Value* capacity_ptr = get_pointer_to_capacity(set); - llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); - LLVM::CreateStore(*builder, llvm_zero, occupancy_ptr); - LLVM::CreateStore(*builder, llvm_zero, capacity_ptr); llvm::Value* el_list = get_el_list(set); - llvm::DataLayout data_layout(module); - size_t mask_size = data_layout.getTypeAllocSize(llvm::Type::getInt8Ty(context)); - llvm::Value* llvm_mask_size = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), - llvm::APInt(32, mask_size)); - llvm::Value* new_el_mask = LLVM::lfortran_calloc(context, *module, *builder, llvm_zero, - llvm_mask_size); - std::string el_type_code = ASRUtils::get_type_code(el_asr_type); - llvm::Type* el_llvm_type = std::get<2>(typecode2settype[el_type_code]); - int32_t el_type_size = std::get<1>(typecode2settype[el_type_code]); - - llvm::Value* new_el_list = builder0.CreateAlloca(llvm_utils->list_api->get_list_type(el_llvm_type, - el_type_code, el_type_size), nullptr); - llvm_utils->list_api->list_init(el_type_code, el_list, *module, llvm_zero, llvm_zero); llvm_utils->list_api->free_data(el_list, *module); LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); - LLVM::CreateStore(*builder, LLVM::CreateLoad(*builder, new_el_list), el_list); - LLVM::CreateStore(*builder, new_el_mask, get_pointer_to_mask(set)); + + set_init(ASRUtils::get_type_code(el_asr_type), set, module, 0); } void LLVMSetSeparateChaining::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) {