diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index d73ab6f5d3..e5f6ff3aa8 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -583,6 +583,7 @@ RUN(NAME test_dict_nested1 LABELS cpython llvm) RUN(NAME test_set_len LABELS cpython llvm) RUN(NAME test_set_add LABELS cpython llvm) RUN(NAME test_set_remove LABELS cpython llvm) +RUN(NAME test_set_discard LABELS cpython llvm) RUN(NAME test_global_set LABELS cpython llvm) RUN(NAME test_for_loop LABELS cpython llvm c) RUN(NAME modules_01 LABELS cpython llvm c wasm wasm_x86 wasm_x64) diff --git a/integration_tests/test_set_discard.py b/integration_tests/test_set_discard.py new file mode 100644 index 0000000000..730abaff7d --- /dev/null +++ b/integration_tests/test_set_discard.py @@ -0,0 +1,48 @@ +from lpython import i32 + +def test_set_discard(): + s1: set[i32] + s2: set[tuple[i32, tuple[i32, i32], str]] + s3: set[str] + st1: str + i: i32 + j: i32 + k: i32 + + for k in range(2): + s1 = {0} + s2 = {(0, (1, 2), "a")} + for i in range(20): + j = i % 10 + s1.add(j) + s2.add((j, (j + 1, j + 2), "a")) + + for i in range(10): + s1.discard(i) + s2.discard((i, (i + 1, i + 2), "a")) + assert len(s1) == 10 - 1 - i + assert len(s1) == len(s2) + + st1 = "a" + s3 = {st1} + for i in range(20): + s3.add(st1) + if i < 10: + if i > 0: + st1 += "a" + + st1 = "a" + for i in range(10): + s3.discard(st1) + assert len(s3) == 10 - 1 - i + if i < 10: + st1 += "a" + + for i in range(20): + s1.add(i) + if i % 2 == 0: + s1.discard(i) + assert len(s1) == (i + 1) // 2 + + +test_set_discard() diff --git a/src/libasr/ASR.asdl b/src/libasr/ASR.asdl index d31b2d95dd..eeda3ebb35 100644 --- a/src/libasr/ASR.asdl +++ b/src/libasr/ASR.asdl @@ -69,6 +69,7 @@ stmt | BlockCall(int label, symbol m) | SetInsert(expr a, expr ele) | SetRemove(expr a, expr ele) + | SetDiscard(expr a, expr ele) | ListInsert(expr a, expr pos, expr ele) | ListRemove(expr a, expr ele) | ListClear(expr a) diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index 4e2c35c5cf..1f973c0074 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1905,7 +1905,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor llvm_utils->set_api->write_item(pset, el, module.get(), asr_el_type, name2memidx); } - void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + void generate_SetRemove(ASR::expr_t* m_arg, ASR::expr_t* m_ele, bool throw_key_error) { ASR::Set_t* set_type = ASR::down_cast( ASRUtils::expr_type(m_arg)); ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); @@ -1919,7 +1919,7 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor ptr_loads = ptr_loads_copy; llvm::Value *el = tmp; llvm_utils->set_set_api(set_type); - llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type); + llvm_utils->set_api->remove_item(pset, el, *module, asr_el_type, throw_key_error); } void visit_IntrinsicElementalFunction(const ASR::IntrinsicElementalFunction_t& x) { @@ -1986,7 +1986,11 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor break; } case ASRUtils::IntrinsicElementalFunctions::SetRemove: { - generate_SetRemove(x.m_args[0], x.m_args[1]); + generate_SetRemove(x.m_args[0], x.m_args[1], true); + break; + } + case ASRUtils::IntrinsicElementalFunctions::SetDiscard: { + generate_SetRemove(x.m_args[0], x.m_args[1], false); break; } case ASRUtils::IntrinsicElementalFunctions::Exp: { diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 61e19a6286..0403e2ce61 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -6415,7 +6415,7 @@ namespace LCompilers { void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { /** * C++ equivalent: @@ -6467,14 +6467,16 @@ namespace LCompilers { llvm_utils->create_if_else(is_el_matching, [=]() { LLVM::CreateStore(*builder, el_hash, pos_ptr); }, [&]() { - std::string message = "The set does not contain the specified element"; - llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); - llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); - print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); - int exit_code_int = 1; - llvm::Value *exit_code = llvm::ConstantInt::get(context, - llvm::APInt(32, exit_code_int)); - exit(context, module, *builder, exit_code); + if (throw_key_error) { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + } }); } builder->CreateBr(mergeBB); @@ -6491,20 +6493,22 @@ namespace LCompilers { LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type); llvm_utils->create_if_else(is_el_matching, []() {}, [&]() { - std::string message = "The set does not contain the specified element"; - llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); - llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); - print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); - int exit_code_int = 1; - llvm::Value *exit_code = llvm::ConstantInt::get(context, - llvm::APInt(32, exit_code_int)); - exit(context, module, *builder, exit_code); + if (throw_key_error) { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + } }); } void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { /** * C++ equivalent: * @@ -6532,20 +6536,22 @@ namespace LCompilers { ); llvm_utils->create_if_else(does_el_exist, []() {}, [&]() { - std::string message = "The set does not contain the specified element"; - llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); - llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); - print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); - int exit_code_int = 1; - llvm::Value *exit_code = llvm::ConstantInt::get(context, - llvm::APInt(32, exit_code_int)); - exit(context, module, *builder, exit_code); + if (throw_key_error) { + std::string message = "The set does not contain the specified element"; + llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n"); + llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message); + print_error(context, module, *builder, {fmt_ptr, fmt_ptr2}); + int exit_code_int = 1; + llvm::Value *exit_code = llvm::ConstantInt::get(context, + llvm::APInt(32, exit_code_int)); + exit(context, module, *builder, exit_code); + } }); } void LLVMSetLinearProbing::remove_item( llvm::Value* set, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { /** * C++ equivalent: * @@ -6555,7 +6561,7 @@ namespace LCompilers { */ llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module); - this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type); + this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error); llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr); llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set)); llvm::Value* el_mask_i = llvm_utils->create_ptr_gep(el_mask, pos); @@ -6571,7 +6577,7 @@ namespace LCompilers { void LLVMSetSeparateChaining::remove_item( llvm::Value* set, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) { + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) { /** * C++ equivalent: * @@ -6593,7 +6599,7 @@ namespace LCompilers { llvm::Value* current_capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set)); llvm::Value* el_hash = get_el_hash(current_capacity, el, el_asr_type, module); - this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type); + this->resolve_collision_for_read_with_bound_check(set, el_hash, el, module, el_asr_type, throw_key_error); llvm::Value* prev = LLVM::CreateLoad(*builder, chain_itr_prev); llvm::Value* found = LLVM::CreateLoad(*builder, chain_itr); diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index a4fdedff84..869aef52e7 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -967,12 +967,12 @@ namespace LCompilers { virtual void resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) = 0; + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0; virtual void remove_item( llvm::Value* set, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type) = 0; + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) = 0; virtual void set_deepcopy( @@ -1038,11 +1038,11 @@ namespace LCompilers { void resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); void remove_item( llvm::Value* set, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); void set_deepcopy( llvm::Value* src, llvm::Value* dest, @@ -1119,11 +1119,11 @@ namespace LCompilers { void resolve_collision_for_read_with_bound_check( llvm::Value* set, llvm::Value* el_hash, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); void remove_item( llvm::Value* set, llvm::Value* el, - llvm::Module& module, ASR::ttype_t* el_asr_type); + llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error); void set_deepcopy( llvm::Value* src, llvm::Value* dest, diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 6b730fcea6..65437a6518 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -110,6 +110,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(DictValues) INTRINSIC_NAME_CASE(SetAdd) INTRINSIC_NAME_CASE(SetRemove) + INTRINSIC_NAME_CASE(SetDiscard) INTRINSIC_NAME_CASE(Max) INTRINSIC_NAME_CASE(Min) INTRINSIC_NAME_CASE(Sign) @@ -343,6 +344,8 @@ namespace IntrinsicElementalFunctionRegistry { {nullptr, &SetAdd::verify_args}}, {static_cast(IntrinsicElementalFunctions::SetRemove), {nullptr, &SetRemove::verify_args}}, + {static_cast(IntrinsicElementalFunctions::SetDiscard), + {nullptr, &SetDiscard::verify_args}}, {static_cast(IntrinsicElementalFunctions::Max), {&Max::instantiate_Max, &Max::verify_args}}, {static_cast(IntrinsicElementalFunctions::Min), @@ -630,6 +633,8 @@ namespace IntrinsicElementalFunctionRegistry { "set.add"}, {static_cast(IntrinsicElementalFunctions::SetRemove), "set.remove"}, + {static_cast(IntrinsicElementalFunctions::SetDiscard), + "set.discard"}, {static_cast(IntrinsicElementalFunctions::Max), "max"}, {static_cast(IntrinsicElementalFunctions::Min), @@ -823,6 +828,7 @@ namespace IntrinsicElementalFunctionRegistry { {"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}}, {"set.add", {&SetAdd::create_SetAdd, &SetAdd::eval_set_add}}, {"set.remove", {&SetRemove::create_SetRemove, &SetRemove::eval_set_remove}}, + {"set.discard", {&SetDiscard::create_SetDiscard, &SetDiscard::eval_set_discard}}, {"max0", {&Max::create_Max, &Max::eval_Max}}, {"adjustl", {&Adjustl::create_Adjustl, &Adjustl::eval_Adjustl}}, {"adjustr", {&Adjustr::create_Adjustr, &Adjustr::eval_Adjustr}}, diff --git a/src/libasr/pass/intrinsic_functions.h b/src/libasr/pass/intrinsic_functions.h index 8d299a5a58..d4495a91bc 100644 --- a/src/libasr/pass/intrinsic_functions.h +++ b/src/libasr/pass/intrinsic_functions.h @@ -108,6 +108,7 @@ enum class IntrinsicElementalFunctions : int64_t { DictValues, SetAdd, SetRemove, + SetDiscard, Max, Min, Radix, @@ -4916,6 +4917,57 @@ static inline ASR::asr_t* create_SetRemove(Allocator& al, const Location& loc, } // namespace SetRemove +namespace SetDiscard { + +static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Call to set.discard must have exactly one argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "First argument to set.discard must be of set type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASRUtils::check_equal_type(ASRUtils::expr_type(x.m_args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(x.m_args[0]))), + "Second argument to set.discard must be of same type as set's element type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(x.m_type == nullptr, + "Return type of set.discard must be empty", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_set_discard(Allocator &/*al*/, + const Location &/*loc*/, ASR::ttype_t *, Vec& /*args*/, diag::Diagnostics& /*diag*/) { + // TODO: To be implemented for SetConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_SetDiscard(Allocator& al, const Location& loc, + Vec& args, + diag::Diagnostics& diag) { + if (args.size() != 2) { + append_error(diag, "Call to set.discard must have exactly one argument", loc); + return nullptr; + } + if (!ASRUtils::check_equal_type(ASRUtils::expr_type(args[1]), + ASRUtils::get_contained_type(ASRUtils::expr_type(args[0])))) { + append_error(diag, "Argument to set.discard must be of same type as set's " + "element type", loc); + return nullptr; + } + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_set_discard(al, loc, nullptr, arg_values, diag); + return ASR::make_Expr_t(al, loc, + ASRUtils::EXPR(ASR::make_IntrinsicElementalFunction_t(al, loc, + static_cast(IntrinsicElementalFunctions::SetDiscard), + args.p, args.size(), 0, nullptr, compile_time_value))); +} + +} // namespace SetRemove + namespace Max { static inline void verify_args(const ASR::IntrinsicElementalFunction_t& x, diag::Diagnostics& diagnostics) { diff --git a/src/lpython/semantics/python_attribute_eval.h b/src/lpython/semantics/python_attribute_eval.h index 5e5aceb613..ab6a4c169a 100644 --- a/src/lpython/semantics/python_attribute_eval.h +++ b/src/lpython/semantics/python_attribute_eval.h @@ -33,6 +33,7 @@ struct AttributeHandler { {"set@pop", &eval_set_pop}, {"set@add", &eval_set_add}, {"set@remove", &eval_set_remove}, + {"set@discard", &eval_set_discard}, {"dict@get", &eval_dict_get}, {"dict@pop", &eval_dict_pop}, {"dict@keys", &eval_dict_keys}, @@ -41,7 +42,7 @@ struct AttributeHandler { modify_attr_set = {"list@append", "list@remove", "list@reverse", "list@clear", "list@insert", "list@pop", - "set@pop", "set@add", "set@remove", "dict@pop"}; + "set@pop", "set@add", "set@remove", "set@discard", "dict@pop"}; symbolic_attribute_map = { {"diff", &eval_symbolic_diff}, @@ -337,6 +338,19 @@ struct AttributeHandler { return create_function(al, loc, args_with_set, diag); } + static ASR::asr_t* eval_set_discard(ASR::expr_t *s, Allocator &al, const Location &loc, + Vec &args, diag::Diagnostics &diag) { + Vec args_with_set; + args_with_set.reserve(al, args.size() + 1); + args_with_set.push_back(al, s); + for(size_t i = 0; i < args.size(); i++) { + args_with_set.push_back(al, args[i]); + } + ASRUtils::create_intrinsic_function create_function = + ASRUtils::IntrinsicElementalFunctionRegistry::get_create_function("set.discard"); + return create_function(al, loc, args_with_set, diag); + } + static ASR::asr_t* eval_dict_get(ASR::expr_t *s, Allocator &al, const Location &loc, Vec &args, diag::Diagnostics &diag) { ASR::expr_t *def = nullptr;