Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Implement set.discard(elem) #2633

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,7 @@ RUN(NAME test_dict_nested1 LABELS cpython llvm)
RUN(NAME test_set_len LABELS cpython llvm)
RUN(NAME test_set_add LABELS cpython llvm)
RUN(NAME test_set_remove LABELS cpython llvm)
RUN(NAME test_set_discard LABELS cpython llvm)
RUN(NAME test_global_set LABELS cpython llvm)
RUN(NAME test_for_loop LABELS cpython llvm c)
RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64)
Expand Down
48 changes: 48 additions & 0 deletions integration_tests/test_set_discard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from lpython import i32

def test_set_discard():
s1: set[i32]
s2: set[tuple[i32, tuple[i32, i32], str]]
s3: set[str]
st1: str
i: i32
j: i32
k: i32

for k in range(2):
s1 = {0}
s2 = {(0, (1, 2), "a")}
for i in range(20):
j = i % 10
s1.add(j)
s2.add((j, (j + 1, j + 2), "a"))

for i in range(10):
s1.discard(i)
s2.discard((i, (i + 1, i + 2), "a"))
assert len(s1) == 10 - 1 - i
assert len(s1) == len(s2)

st1 = "a"
s3 = {st1}
for i in range(20):
s3.add(st1)
if i < 10:
if i > 0:
st1 += "a"

st1 = "a"
for i in range(10):
s3.discard(st1)
assert len(s3) == 10 - 1 - i
if i < 10:
st1 += "a"

for i in range(20):
s1.add(i)
if i % 2 == 0:
s1.discard(i)
assert len(s1) == (i + 1) // 2


test_set_discard()
1 change: 1 addition & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ stmt
| BlockCall(int label, symbol m)
| SetInsert(expr a, expr ele)
| SetRemove(expr a, expr ele)
| SetDiscard(expr a, expr ele)
| ListInsert(expr a, expr pos, expr ele)
| ListRemove(expr a, expr ele)
| ListClear(expr a)
Expand Down
10 changes: 7 additions & 3 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx);
}

void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) {
void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele, bool throw_key_error) {
ASR::Set_t* set_type = ASR::down_cast<ASR::Set_t>(
ASRUtils::expr_type(m_arg));
ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg));
Expand All @@ -1919,7 +1919,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
ptr_loads = ptr_loads_copy;
llvm::Value *el = tmp;
llvm_utils->set_set_api(set_type);
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type);
llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type, throw_key_error);
}

void visit_IntrinsicElementalFunction(const ASR::IntrinsicElementalFunction_t& x) {
Expand Down Expand Up @@ -1986,7 +1986,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
break;
}
case ASRUtils::IntrinsicElementalFunctions::SetRemove: {
generate_SetRemove(x.m_args[0], x.m_args[1]);
generate_SetRemove(x.m_args[0], x.m_args[1], true);
break;
}
case ASRUtils::IntrinsicElementalFunctions::SetDiscard: {
generate_SetRemove(x.m_args[0], x.m_args[1], false);
break;
}
case ASRUtils::IntrinsicElementalFunctions::Exp: {
Expand Down
66 changes: 36 additions & 30 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6415,7 +6415,7 @@ namespace LCompilers {

void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {

/**
* C++ equivalent:
Expand Down Expand Up @@ -6467,14 +6467,16 @@ namespace LCompilers {
llvm_utils->create_if_else(is_el_matching, [=]() {
LLVM::CreateStore(*builder, el_hash, pos_ptr);
}, [&]() {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
if (throw_key_error) {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
}
});
}
builder->CreateBr(mergeBB);
Expand All @@ -6491,20 +6493,22 @@ namespace LCompilers {
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);

llvm_utils->create_if_else(is_el_matching, []() {}, [&]() {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
if (throw_key_error) {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
}
});
}

void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
/**
* C++ equivalent:
*
Expand Down Expand Up @@ -6532,20 +6536,22 @@ namespace LCompilers {
);

llvm_utils->create_if_else(does_el_exist, []() {}, [&]() {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
if (throw_key_error) {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
print_error(context, module, *builder, {fmt_ptr, fmt_ptr2});
int exit_code_int = 1;
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
}
});
}

void LLVMSetLinearProbing::remove_item(
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
/**
* C++ equivalent:
*
Expand All @@ -6555,7 +6561,7 @@ namespace LCompilers {
*/
llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module);
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type);
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error);
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos);
Expand All @@ -6571,7 +6577,7 @@ namespace LCompilers {

void LLVMSetSeparateChaining::remove_item(
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
/**
* C++ equivalent:
*
Expand All @@ -6593,7 +6599,7 @@ namespace LCompilers {

llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module);
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type);
this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error);
llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev);
llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr);

Expand Down
12 changes: 6 additions & 6 deletions src/libasr/codegen/llvm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -967,12 +967,12 @@ namespace LCompilers {
virtual
void resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;

virtual
void remove_item(
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type) = 0;
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0;

virtual
void set_deepcopy(
Expand Down Expand Up @@ -1038,11 +1038,11 @@ namespace LCompilers {

void resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type);
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

void remove_item(
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type);
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

void set_deepcopy(
llvm::Value* src, llvm::Value* dest,
Expand Down Expand Up @@ -1119,11 +1119,11 @@ namespace LCompilers {

void resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type);
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

void remove_item(
llvm::Value* set, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type);
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error);

void set_deepcopy(
llvm::Value* src, llvm::Value* dest,
Expand Down
6 changes: 6 additions & 0 deletions src/libasr/pass/intrinsic_function_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ inline std::string get_intrinsic_name(int x) {
INTRINSIC_NAME_CASE(DictValues)
INTRINSIC_NAME_CASE(SetAdd)
INTRINSIC_NAME_CASE(SetRemove)
INTRINSIC_NAME_CASE(SetDiscard)
INTRINSIC_NAME_CASE(Max)
INTRINSIC_NAME_CASE(Min)
INTRINSIC_NAME_CASE(Sign)
Expand Down Expand Up @@ -343,6 +344,8 @@ namespace IntrinsicElementalFunctionRegistry {
{nullptr, &SetAdd::verify_args}},
{static_cast<int64_t>(IntrinsicElementalFunctions::SetRemove),
{nullptr, &SetRemove::verify_args}},
{static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
{nullptr, &SetDiscard::verify_args}},
{static_cast<int64_t>(IntrinsicElementalFunctions::Max),
{&Max::instantiate_Max, &Max::verify_args}},
{static_cast<int64_t>(IntrinsicElementalFunctions::Min),
Expand Down Expand Up @@ -630,6 +633,8 @@ namespace IntrinsicElementalFunctionRegistry {
"set.add"},
{static_cast<int64_t>(IntrinsicElementalFunctions::SetRemove),
"set.remove"},
{static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
"set.discard"},
{static_cast<int64_t>(IntrinsicElementalFunctions::Max),
"max"},
{static_cast<int64_t>(IntrinsicElementalFunctions::Min),
Expand Down Expand Up @@ -823,6 +828,7 @@ namespace IntrinsicElementalFunctionRegistry {
{"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}},
{"set.add", {&SetAdd::create_SetAdd, &SetAdd::eval_set_add}},
{"set.remove", {&SetRemove::create_SetRemove, &SetRemove::eval_set_remove}},
{"set.discard", {&SetDiscard::create_SetDiscard, &SetDiscard::eval_set_discard}},
{"max0", {&Max::create_Max, &Max::eval_Max}},
{"adjustl", {&Adjustl::create_Adjustl, &Adjustl::eval_Adjustl}},
{"adjustr", {&Adjustr::create_Adjustr, &Adjustr::eval_Adjustr}},
Expand Down
52 changes: 52 additions & 0 deletions src/libasr/pass/intrinsic_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ enum class IntrinsicElementalFunctions : int64_t {
DictValues,
SetAdd,
SetRemove,
SetDiscard,
Max,
Min,
Radix,
Expand Down Expand Up @@ -4916,6 +4917,57 @@ static inline ASR::asr_t* create_SetRemove(Allocator& al, const Location& loc,

} // namespace SetRemove

namespace SetDiscard {

static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) {
ASRUtils::require_impl(x.n_args == 2, "Call to set.discard must have exactly one argument",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASR::is_a<ASR::Set_t>(*ASRUtils::expr_type(x.m_args[0])),
"First argument to set.discard must be of set type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(ASRUtils::check_equal_type(ASRUtils::expr_type(x.m_args[1]),
ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]))),
"Second argument to set.discard must be of same type as set's element type",
x.base.base.loc, diagnostics);
ASRUtils::require_impl(x.m_type == nullptr,
"Return type of set.discard must be empty",
x.base.base.loc, diagnostics);
}

static inline ASR::expr_t *eval_set_discard(Allocator &/*al*/,
const Location &/*loc*/, ASR::ttype_t *, Vec<ASR::expr_t*>& /*args*/, diag::Diagnostics& /*diag*/) {
// TODO: To be implemented for SetConstant expression
return nullptr;
}

static inline ASR::asr_t* create_SetDiscard(Allocator& al, const Location& loc,
Vec<ASR::expr_t*>& args,
diag::Diagnostics& diag) {
if (args.size() != 2) {
append_error(diag, "Call to set.discard must have exactly one argument", loc);
return nullptr;
}
if (!ASRUtils::check_equal_type(ASRUtils::expr_type(args[1]),
ASRUtils::get_contained_type(ASRUtils::expr_type(args[0])))) {
append_error(diag, "Argument to set.discard must be of same type as set's "
"element type", loc);
return nullptr;
}

Vec<ASR::expr_t*> arg_values;
arg_values.reserve(al, args.size());
for( size_t i = 0; i < args.size(); i++ ) {
arg_values.push_back(al, ASRUtils::expr_value(args[i]));
}
ASR::expr_t* compile_time_value = eval_set_discard(al, loc, nullptr, arg_values, diag);
return ASR::make_Expr_t(al, loc,
ASRUtils::EXPR(ASR::make_IntrinsicElementalFunction_t(al, loc,
static_cast<int64_t>(IntrinsicElementalFunctions::SetDiscard),
args.p, args.size(), 0, nullptr, compile_time_value)));
}

} // namespace SetRemove

namespace Max {

static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) {
Expand Down
Loading