diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index ea416e764b..6bde8f5d6c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -584,10 +584,12 @@ RUN(NAME test_dict_bool LABELS cpython llvm llvm_jit) RUN(NAME test_dict_increment LABELS cpython llvm llvm_jit) RUN(NAME test_dict_keys_values LABELS cpython llvm llvm_jit) RUN(NAME test_dict_nested1 LABELS cpython llvm llvm_jit) +RUN(NAME test_dict_clear LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm llvm_jit) RUN(NAME test_set_add LABELS cpython llvm llvm_jit) RUN(NAME test_set_remove LABELS cpython llvm llvm_jit) RUN(NAME test_set_discard LABELS cpython llvm llvm_jit) +RUN(NAME test_set_clear LABELS cpython llvm) RUN(NAME test_global_set LABELS cpython llvm llvm_jit) RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c) RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_dict_clear.py b/integration_tests/test_dict_clear.py new file mode 100644 index 0000000000..eccfea0aa6 --- /dev/null +++ b/integration_tests/test_dict_clear.py @@ -0,0 +1,18 @@ +def test_clear(): + a: dict[i32, i32] = {1:1, 2:2} + + a.clear() + a[3] = 3 + + assert len(a) == 1 + assert 3 in a + + b: dict[str, str] = {'a':'a', 'b':'b'} + + b.clear() + b['c'] = 'c' + + assert len(b) == 1 + assert 'c' in b + +test_clear() diff --git a/integration_tests/test_set_clear.py b/integration_tests/test_set_clear.py new file mode 100644 index 0000000000..871e2c2bf7 --- /dev/null +++ b/integration_tests/test_set_clear.py @@ -0,0 +1,21 @@ +def test_clear(): + a: set[i32] = {1, 2} + + a.clear() + a.add(3) + + assert len(a) == 1 + a.remove(3) + assert len(a) == 0 + + b: set[str] = {'a', 'b'} + + b.clear() + b.add('c') + + assert len(b) == 1 + b.remove('c') + assert len(b) == 0 + + +test_clear() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index 1d4cf6ee7d..584f8f3f29 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -74,6 +74,8 @@ stmt | ListRemove(expr a, expr ele) | ListClear(expr a) | DictInsert(expr a, expr key, expr value) + | DictClear(expr a) + | SetClear(expr a) | Expr(expr expression) expr diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4921c46f75..9076ca556d 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1637,6 +1637,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } } + void visit_DictClear(const ASR::DictClear_t& x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pdict = tmp; + ptr_loads = ptr_loads_copy; + ASR::Dict_t* dict_type = ASR::down_cast(ASRUtils::expr_type(x.m_a)); + + llvm_utils->dict_api->dict_clear(pdict, module.get(), dict_type->m_key_type, dict_type->m_value_type); + } + + void visit_SetClear(const ASR::SetClear_t& x) { + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*x.m_a); + llvm::Value* pset = tmp; + ptr_loads = ptr_loads_copy; + ASR::Set_t *set_type = ASR::down_cast( + ASRUtils::expr_type(x.m_a)); + + llvm_utils->set_api->set_clear(pset, module.get(), set_type->m_type); + } + void visit_DictContains(const ASR::DictContains_t &x) { if (x.m_value) { this->visit_expr(*x.m_value); diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 9f20c98f78..2bcb3e355e 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -1,3 +1,5 @@ +#include "llvm_utils.h" +#include #include #include #include @@ -4384,6 +4386,25 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMDict::dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) { + llvm::Value* key_list = get_key_list(dict); + llvm::Value* value_list = get_value_list(dict); + llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict)); + llvm_utils->list_api->free_data(key_list, *module); + llvm_utils->list_api->free_data(value_list, *module); + LLVM::lfortran_free(context, *module, *builder, key_mask); + + std::string key_type_code = ASRUtils::get_type_code(key_asr_type); + std::string value_type_code = ASRUtils::get_type_code(value_asr_type); + dict_init(key_type_code, value_type_code, dict, module, 0); + } + + void LLVMDictSeparateChaining::dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) { + dict_init(ASRUtils::get_type_code(key_asr_type), + ASRUtils::get_type_code(value_asr_type), dict, module, 0); + } llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos, bool enable_bounds_checking, @@ -6880,6 +6901,22 @@ namespace LCompilers { llvm_utils->start_new_block(loopend); } + void LLVMSetLinearProbing::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) { + + llvm::Value* el_list = get_el_list(set); + + llvm_utils->list_api->free_data(el_list, *module); + LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); + + set_init(ASRUtils::get_type_code(el_asr_type), set, module, 0); + } + + void LLVMSetSeparateChaining::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) { + LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set))); + llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0)); + set_init_given_initial_capacity(ASRUtils::get_type_code(el_asr_type), set, module, llvm_zero); + } + llvm::Value* LLVMSetInterface::len(llvm::Value* set) { return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set)); } diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index 8e24438100..cff7b9f982 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -1,6 +1,7 @@ #ifndef LFORTRAN_LLVM_UTILS_H #define LFORTRAN_LLVM_UTILS_H +#include #include #include @@ -644,6 +645,9 @@ namespace LCompilers { virtual void set_is_dict_present(bool value); + virtual + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) = 0; virtual void get_elements_list(llvm::Value* dict, @@ -739,6 +743,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -889,6 +895,8 @@ namespace LCompilers { llvm::Value* len(llvm::Value* dict); + void dict_clear(llvm::Value *dict, llvm::Module *module, + ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type); void get_elements_list(llvm::Value* dict, llvm::Value* elements_list, ASR::ttype_t* key_asr_type, @@ -987,6 +995,9 @@ namespace LCompilers { virtual llvm::Value* len(llvm::Value* set); + virtual + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type) = 0; + virtual bool is_set_present(); @@ -1053,6 +1064,8 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type); + ~LLVMSetLinearProbing(); }; @@ -1134,6 +1147,8 @@ namespace LCompilers { ASR::Set_t* set_type, llvm::Module* module, std::map>& name2memidx); + void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type); + ~LLVMSetSeparateChaining(); }; diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index f8926a3eb8..ac9ad26a9c 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -34,10 +34,12 @@ struct AttributeHandler { {"set@add", &eval_set_add}, {"set@remove", &eval_set_remove}, {"set@discard", &eval_set_discard}, + {"set@clear", &eval_set_clear}, {"dict@get", &eval_dict_get}, {"dict@pop", &eval_dict_pop}, {"dict@keys", &eval_dict_keys}, - {"dict@values", &eval_dict_values} + {"dict@values", &eval_dict_values}, + {"dict@clear", &eval_dict_clear} }; modify_attr_set = {"list@append", "list@remove", @@ -356,6 +358,26 @@ struct AttributeHandler { return create_function(al, loc, args_with_set, diag); } + static ASR::asr_t* eval_set_clear(ASR::expr_t *s, Allocator &al, + const Location &loc, Vec &args, diag::Diagnostics & diag) { + if (ASRUtils::is_const(s)) { + throw SemanticError("cannot clear elements from a const set", loc); + } + if (args.size() != 0) { + diag.add(diag::Diagnostic( + "Incorrect number of arguments in 'clear', it accepts no argument", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("incorrect number of arguments in clear (found: " + + std::to_string(args.size()) + ", expected: 0)", + {loc}) + }) + ); + throw SemanticAbort(); + } + + return make_SetClear_t(al, loc, s); + } + static ASR::asr_t* eval_dict_get(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { ASR::expr_t *def = nullptr; @@ -448,6 +470,26 @@ struct AttributeHandler { return create_function(al, loc, args_with_dict, diag); } + static ASR::asr_t* eval_dict_clear(ASR::expr_t *s, Allocator &al, + const Location &loc, Vec &args, diag::Diagnostics & diag) { + if (ASRUtils::is_const(s)) { + throw SemanticError("cannot clear elements from a const dict", loc); + } + if (args.size() != 0) { + diag.add(diag::Diagnostic( + "Incorrect number of arguments in 'clear', it accepts no argument", + diag::Level::Error, diag::Stage::Semantic, { + diag::Label("incorrect number of arguments in clear (found: " + + std::to_string(args.size()) + ", expected: 0)", + {loc}) + }) + ); + throw SemanticAbort(); + } + + return make_DictClear_t(al, loc, s); + } + static ASR::asr_t* eval_symbolic_diff(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { Vec args_with_list;