Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add clear method to dictionary and set #2747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,12 @@ RUN(NAME test_dict_bool LABELS cpython llvm llvm_jit)
RUN(NAME test_dict_increment LABELS cpython llvm llvm_jit)
RUN(NAME test_dict_keys_values LABELS cpython llvm llvm_jit)
RUN(NAME test_dict_nested1 LABELS cpython llvm llvm_jit)
RUN(NAME test_dict_clear LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm llvm_jit)
RUN(NAME test_set_add LABELS cpython llvm llvm_jit)
RUN(NAME test_set_remove LABELS cpython llvm llvm_jit)
RUN(NAME test_set_discard LABELS cpython llvm llvm_jit)
RUN(NAME test_set_clear LABELS cpython llvm)
RUN(NAME test_global_set LABELS cpython llvm llvm_jit)
RUN(NAME test_for_loop LABELS cpython llvm llvm_jit c)
RUN(NAME modules_01 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x64)
Expand Down
18 changes: 18 additions & 0 deletions integration_tests/test_dict_clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
def test_clear():
a: dict[i32, i32] = {1:1, 2:2}

a.clear()
a[3] = 3

assert len(a) == 1
assert 3 in a

b: dict[str, str] = {'a':'a', 'b':'b'}

b.clear()
b['c'] = 'c'

assert len(b) == 1
assert 'c' in b

test_clear()
21 changes: 21 additions & 0 deletions integration_tests/test_set_clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def test_clear():
a: set[i32] = {1, 2}

a.clear()
a.add(3)

assert len(a) == 1
a.remove(3)
assert len(a) == 0

b: set[str] = {'a', 'b'}

b.clear()
b.add('c')

assert len(b) == 1
b.remove('c')
assert len(b) == 0


test_clear()
2 changes: 2 additions & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ stmt
| ListRemove(expr a, expr ele)
| ListClear(expr a)
| DictInsert(expr a, expr key, expr value)
| DictClear(expr a)
| SetClear(expr a)
| Expr(expr expression)

expr
Expand Down
23 changes: 23 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,29 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void visit_DictClear(const ASR::DictClear_t& x) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_a);
llvm::Value* pdict = tmp;
ptr_loads = ptr_loads_copy;
ASR::Dict_t* dict_type = ASR::down_cast<ASR::Dict_t>(ASRUtils::expr_type(x.m_a));

llvm_utils->dict_api->dict_clear(pdict, module.get(), dict_type->m_key_type, dict_type->m_value_type);
}

void visit_SetClear(const ASR::SetClear_t& x) {
int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_a);
llvm::Value* pset = tmp;
ptr_loads = ptr_loads_copy;
ASR::Set_t *set_type = ASR::down_cast<ASR::Set_t>(
ASRUtils::expr_type(x.m_a));

llvm_utils->set_api->set_clear(pset, module.get(), set_type->m_type);
}

void visit_DictContains(const ASR::DictContains_t &x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
Expand Down
37 changes: 37 additions & 0 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "llvm_utils.h"
#include <functional>
#include <libasr/assert.h>
#include <libasr/codegen/llvm_utils.h>
#include <libasr/codegen/llvm_array_utils.h>
Expand Down Expand Up @@ -4384,6 +4386,25 @@ namespace LCompilers {
llvm_utils->start_new_block(loopend);
}

void LLVMDict::dict_clear(llvm::Value *dict, llvm::Module *module,
ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) {
llvm::Value* key_list = get_key_list(dict);
llvm::Value* value_list = get_value_list(dict);
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
llvm_utils->list_api->free_data(key_list, *module);
llvm_utils->list_api->free_data(value_list, *module);
LLVM::lfortran_free(context, *module, *builder, key_mask);

std::string key_type_code = ASRUtils::get_type_code(key_asr_type);
std::string value_type_code = ASRUtils::get_type_code(value_asr_type);
dict_init(key_type_code, value_type_code, dict, module, 0);
}

void LLVMDictSeparateChaining::dict_clear(llvm::Value *dict, llvm::Module *module,
ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) {
dict_init(ASRUtils::get_type_code(key_asr_type),
ASRUtils::get_type_code(value_asr_type), dict, module, 0);
}

llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
bool enable_bounds_checking,
Expand Down Expand Up @@ -6880,6 +6901,22 @@ namespace LCompilers {
llvm_utils->start_new_block(loopend);
}

void LLVMSetLinearProbing::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) {

llvm::Value* el_list = get_el_list(set);

llvm_utils->list_api->free_data(el_list, *module);
LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set)));

set_init(ASRUtils::get_type_code(el_asr_type), set, module, 0);
}

void LLVMSetSeparateChaining::set_clear(llvm::Value* set, llvm::Module* module, ASR::ttype_t* el_asr_type) {
LLVM::lfortran_free(context, *module, *builder, LLVM::CreateLoad(*builder, get_pointer_to_mask(set)));
llvm::Value* llvm_zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), llvm::APInt(32, 0));
set_init_given_initial_capacity(ASRUtils::get_type_code(el_asr_type), set, module, llvm_zero);
}

llvm::Value* LLVMSetInterface::len(llvm::Value* set) {
return LLVM::CreateLoad(*builder, get_pointer_to_occupancy(set));
}
Expand Down
15 changes: 15 additions & 0 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef LFORTRAN_LLVM_UTILS_H
#define LFORTRAN_LLVM_UTILS_H

#include <cstdlib>
#include <memory>

#include <llvm/IR/Value.h>
Expand Down Expand Up @@ -644,6 +645,9 @@ namespace LCompilers {
virtual
void set_is_dict_present(bool value);

virtual
void dict_clear(llvm::Value *dict, llvm::Module *module,
ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type) = 0;

virtual
void get_elements_list(llvm::Value* dict,
Expand Down Expand Up @@ -739,6 +743,8 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void dict_clear(llvm::Value *dict, llvm::Module *module,
ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
Expand Down Expand Up @@ -889,6 +895,8 @@ namespace LCompilers {

llvm::Value* len(llvm::Value* dict);

void dict_clear(llvm::Value *dict, llvm::Module *module,
ASR::ttype_t *key_asr_type, ASR::ttype_t* value_asr_type);

void get_elements_list(llvm::Value* dict,
llvm::Value* elements_list, ASR::ttype_t* key_asr_type,
Expand Down Expand Up @@ -987,6 +995,9 @@ namespace LCompilers {
virtual
llvm::Value* len(llvm::Value* set);

virtual
void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type) = 0;

virtual
bool is_set_present();

Expand Down Expand Up @@ -1053,6 +1064,8 @@ namespace LCompilers {
ASR::Set_t* set_type, llvm::Module* module,
std::map<std::string, std::map<std::string, int>>& name2memidx);

void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type);

~LLVMSetLinearProbing();
};

Expand Down Expand Up @@ -1134,6 +1147,8 @@ namespace LCompilers {
ASR::Set_t* set_type, llvm::Module* module,
std::map<std::string, std::map<std::string, int>>& name2memidx);

void set_clear(llvm::Value *set, llvm::Module *module, ASR::ttype_t *el_asr_type);

~LLVMSetSeparateChaining();
};

Expand Down
44 changes: 43 additions & 1 deletion src/lpython/semantics/python_attribute_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ struct AttributeHandler {
{"set@add", &eval_set_add},
{"set@remove", &eval_set_remove},
{"set@discard", &eval_set_discard},
{"set@clear", &eval_set_clear},
{"dict@get", &eval_dict_get},
{"dict@pop", &eval_dict_pop},
{"dict@keys", &eval_dict_keys},
{"dict@values", &eval_dict_values}
{"dict@values", &eval_dict_values},
{"dict@clear", &eval_dict_clear}
};

modify_attr_set = {"list@append", "list@remove",
Expand Down Expand Up @@ -356,6 +358,26 @@ struct AttributeHandler {
return create_function(al, loc, args_with_set, diag);
}

static ASR::asr_t* eval_set_clear(ASR::expr_t *s, Allocator &al,
const Location &loc, Vec<ASR::expr_t*> &args, diag::Diagnostics & diag) {
if (ASRUtils::is_const(s)) {
throw SemanticError("cannot clear elements from a const set", loc);
}
if (args.size() != 0) {
diag.add(diag::Diagnostic(
"Incorrect number of arguments in 'clear', it accepts no argument",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("incorrect number of arguments in clear (found: " +
std::to_string(args.size()) + ", expected: 0)",
{loc})
})
);
throw SemanticAbort();
}

return make_SetClear_t(al, loc, s);
}

static ASR::asr_t* eval_dict_get(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
ASR::expr_t *def = nullptr;
Expand Down Expand Up @@ -448,6 +470,26 @@ struct AttributeHandler {
return create_function(al, loc, args_with_dict, diag);
}

static ASR::asr_t* eval_dict_clear(ASR::expr_t *s, Allocator &al,
const Location &loc, Vec<ASR::expr_t*> &args, diag::Diagnostics & diag) {
if (ASRUtils::is_const(s)) {
throw SemanticError("cannot clear elements from a const dict", loc);
}
if (args.size() != 0) {
diag.add(diag::Diagnostic(
"Incorrect number of arguments in 'clear', it accepts no argument",
diag::Level::Error, diag::Stage::Semantic, {
diag::Label("incorrect number of arguments in clear (found: " +
std::to_string(args.size()) + ", expected: 0)",
{loc})
})
);
throw SemanticAbort();
}

return make_DictClear_t(al, loc, s);
}

static ASR::asr_t* eval_symbolic_diff(ASR::expr_t *s, Allocator &al, const Location &loc,
Vec<ASR::expr_t*> &args, diag::Diagnostics &diag) {
Vec<ASR::expr_t*> args_with_list;
Expand Down
Loading