Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add membership checks in dictionaries and sets #2711

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion grammar/Python.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ module LPython
-- need sequences for compare to distinguish between
-- x < 4 < 3 and (x < 4) < 3
| Compare(expr left, cmpop ops, expr* comparators)
| Membership(expr left, membershipop op, expr right)
| Call(expr func, expr* args, keyword* keywords)
| FormattedValue(expr value, int conversion, expr? format_spec)
| JoinedStr(expr* values)
Expand Down Expand Up @@ -110,7 +111,9 @@ module LPython

unaryop = Invert | Not | UAdd | USub

cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot | In | NotIn
cmpop = Eq | NotEq | Lt | LtE | Gt | GtE | Is | IsNot

membershipop = In | NotIn

comprehension = (expr target, expr iter, expr* ifs, int is_async)

Expand Down
1 change: 1 addition & 0 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ RUN(NAME test_import_05 LABELS cpython llvm llvm_jit c wasm wasm_x86 wasm_x
RUN(NAME test_import_06 LABELS cpython llvm llvm_jit)
RUN(NAME test_import_07 LABELS cpython llvm llvm_jit c)
RUN(NAME test_math LABELS cpython llvm llvm_jit NOFAST)
RUN(NAME test_membership_01 LABELS cpython llvm)
RUN(NAME test_numpy_01 LABELS cpython llvm llvm_jit c)
RUN(NAME test_numpy_02 LABELS cpython llvm llvm_jit c)
RUN(NAME test_numpy_03 LABELS cpython llvm llvm_jit c)
Expand Down
36 changes: 36 additions & 0 deletions integration_tests/test_membership_01.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
def test_int_dict():
a: dict[i32, i32] = {1:2, 2:3, 3:4, 4:5}
i: i32
assert (1 in a)
assert (6 not in a)
i = 4
assert (i in a)

def test_str_dict():
a: dict[str, str] = {'a':'1', 'b':'2', 'c':'3'}
i: str
assert ('a' in a)
assert ('d' not in a)
i = 'c'
assert (i in a)

def test_int_set():
a: set[i32] = {1, 2, 3, 4}
i: i32
assert (1 in a)
assert (6 not in a)
i = 4
assert (i in a)

def test_str_set():
a: set[str] = {'a', 'b', 'c', 'e', 'f'}
i: str
assert ('a' in a)
# assert ('d' not in a)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try assert not ('d' in a). It should also fail.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be fixed in subsequent PR.

i = 'c'
assert (i in a)

test_int_dict()
test_str_dict()
test_int_set()
test_str_set()
5 changes: 5 additions & 0 deletions src/libasr/ASR.asdl
Original file line number Diff line number Diff line change
Expand Up @@ -118,19 +118,22 @@ expr
| ListConcat(expr left, expr right, ttype type, expr? value)
| ListCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| ListCount(expr arg, expr ele, ttype type, expr? value)
| ListContains(expr left, expr right, ttype type, expr? value)
| SetConstant(expr* elements, ttype type)
| SetLen(expr arg, ttype type, expr? value)
| TupleConstant(expr* elements, ttype type)
| TupleLen(expr arg, ttype type, expr value)
| TupleCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| TupleConcat(expr left, expr right, ttype type, expr? value)
| TupleContains(expr left, expr right, ttype type, expr? value)
| StringConstant(string s, ttype type)
| StringConcat(expr left, expr right, ttype type, expr? value)
| StringRepeat(expr left, expr right, ttype type, expr? value)
| StringLen(expr arg, ttype type, expr? value)
| StringItem(expr arg, expr idx, ttype type, expr? value)
| StringSection(expr arg, expr? start, expr? end, expr? step, ttype type, expr? value)
| StringCompare(expr left, cmpop op, expr right, ttype type, expr? value)
| StringContains(expr left, expr right, ttype type, expr? value)
| StringOrd(expr arg, ttype type, expr? value)
| StringChr(expr arg, ttype type, expr? value)
| StringFormat(expr fmt, expr* args, string_format_kind kind, ttype type, expr? value)
Expand Down Expand Up @@ -176,6 +179,8 @@ expr
| ListRepeat(expr left, expr right, ttype type, expr? value)
| DictPop(expr a, expr key, ttype type, expr? value)
| SetPop(expr a, ttype type, expr? value)
| SetContains(expr left, expr right, ttype type, expr? value)
| DictContains(expr left, expr right, ttype type, expr? value)
| IntegerBitLen(expr a, ttype type, expr? value)
| Ichar(expr arg, ttype type, expr? value)
| Iachar(expr arg, ttype type, expr? value)
Expand Down
45 changes: 45 additions & 0 deletions src/libasr/codegen/asr_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1637,6 +1637,51 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor<ASRToLLVMVisitor>
}
}

void visit_DictContains(const ASR::DictContains_t &x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return;
}

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_right);
llvm::Value *right = tmp;
ASR::Dict_t *dict_type = ASR::down_cast<ASR::Dict_t>(
ASRUtils::expr_type(x.m_right));
ptr_loads = !LLVM::is_llvm_struct(dict_type->m_key_type);
this->visit_expr(*x.m_left);
llvm::Value *left = tmp;
ptr_loads = ptr_loads_copy;
llvm::Value *capacity = LLVM::CreateLoad(*builder,
llvm_utils->dict_api->get_pointer_to_capacity(right));
llvm::Value *key_hash = llvm_utils->dict_api->get_key_hash(capacity, left, dict_type->m_key_type, *module);

tmp = llvm_utils->dict_api->resolve_collision_for_read_with_bound_check(right, key_hash, left, *module, dict_type->m_key_type, dict_type->m_value_type, true);
}

void visit_SetContains(const ASR::SetContains_t &x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
return;
}

int64_t ptr_loads_copy = ptr_loads;
ptr_loads = 0;
this->visit_expr(*x.m_right);
llvm::Value *right = tmp;
ASR::ttype_t *el_type = ASRUtils::expr_type(x.m_left);
ptr_loads = !LLVM::is_llvm_struct(el_type);
this->visit_expr(*x.m_left);
llvm::Value *left = tmp;
ptr_loads = ptr_loads_copy;
llvm::Value *capacity = LLVM::CreateLoad(*builder,
llvm_utils->set_api->get_pointer_to_capacity(right));
llvm::Value *el_hash = llvm_utils->set_api->get_el_hash(capacity, left, el_type, *module);

tmp = llvm_utils->set_api->resolve_collision_for_read_with_bound_check(right, el_hash, left, *module, el_type, false, true);
}

void visit_DictLen(const ASR::DictLen_t& x) {
if (x.m_value) {
this->visit_expr(*x.m_value);
Expand Down
97 changes: 76 additions & 21 deletions src/libasr/codegen/llvm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3177,7 +3177,7 @@ namespace LCompilers {
llvm::Value* LLVMDict::resolve_collision_for_read_with_bound_check(
llvm::Value* dict, llvm::Value* key_hash,
llvm::Value* key, llvm::Module& module,
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) {
llvm::Value* key_list = get_key_list(dict);
llvm::Value* value_list = get_value_list(dict);
llvm::Value* key_mask = LLVM::CreateLoad(*builder, get_pointer_to_keymask(dict));
Expand All @@ -3187,6 +3187,8 @@ namespace LCompilers {
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
llvm_utils->list_api->read_item(key_list, pos, false, module,
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
if (check_if_exists)
return is_key_matching;

llvm_utils->create_if_else(is_key_matching, [&]() {
}, [&]() {
Expand Down Expand Up @@ -3245,7 +3247,7 @@ namespace LCompilers {
llvm::Value* LLVMDictOptimizedLinearProbing::resolve_collision_for_read_with_bound_check(
llvm::Value* dict, llvm::Value* key_hash,
llvm::Value* key, llvm::Module& module,
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/) {
ASR::ttype_t* key_asr_type, ASR::ttype_t* /*value_asr_type*/, bool check_if_exists) {

/**
* C++ equivalent:
Expand Down Expand Up @@ -3287,6 +3289,9 @@ namespace LCompilers {
llvm_utils->create_ptr_gep(key_mask, key_hash));
llvm::Value* is_prob_not_neeeded = builder->CreateICmpEQ(key_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
builder->CreateCondBr(is_prob_not_neeeded, thenBB, elseBB);
builder->SetInsertPoint(thenBB);
{
Expand All @@ -3304,6 +3309,9 @@ namespace LCompilers {
llvm_utils->create_if_else(is_key_matching, [=]() {
LLVM::CreateStore(*builder, key_hash, pos_ptr);
}, [&]() {
if (check_if_exists) {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr);
} else {
std::string message = "The dict does not contain the specified key";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
llvm::Value *fmt_ptr2 = builder->CreateGlobalStringPtr(message);
Expand All @@ -3312,7 +3320,7 @@ namespace LCompilers {
llvm::Value *exit_code = llvm::ConstantInt::get(context,
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
});
}});
}
builder->CreateBr(mergeBB);
llvm_utils->start_new_block(elseBB);
Expand All @@ -3321,11 +3329,24 @@ namespace LCompilers {
module, key_asr_type, true);
}
llvm_utils->start_new_block(mergeBB);
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
// Check if the actual key is present or not
llvm::Value* is_key_matching = llvm_utils->is_equal_by_value(key,
llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr);
llvm::Value *pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::AllocaInst *is_key_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);

llvm_utils->create_if_else(flag, [&](){
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_key_matching_ptr);
}, [&](){
// Check if the actual element is present or not
LLVM::CreateStore(*builder, llvm_utils->is_equal_by_value(key,
llvm_utils->list_api->read_item(key_list, pos, false, module,
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type);
LLVM::is_llvm_struct(key_asr_type)), module, key_asr_type), is_key_matching_ptr);
});

llvm::Value *is_key_matching = LLVM::CreateLoad(*builder, is_key_matching_ptr);

if (check_if_exists) {
return is_key_matching;
}

llvm_utils->create_if_else(is_key_matching, [&]() {
}, [&]() {
Expand Down Expand Up @@ -3471,7 +3492,7 @@ namespace LCompilers {
llvm::Value* LLVMDictSeparateChaining::resolve_collision_for_read_with_bound_check(
llvm::Value* dict, llvm::Value* key_hash,
llvm::Value* key, llvm::Module& module,
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type) {
ASR::ttype_t* key_asr_type, ASR::ttype_t* value_asr_type, bool check_if_exists) {
/**
* C++ equivalent:
*
Expand Down Expand Up @@ -3506,6 +3527,10 @@ namespace LCompilers {
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
);

if (check_if_exists) {
return does_kv_exists;
}

llvm_utils->create_if_else(does_kv_exists, [&]() {
llvm::Value* kv_struct_i8 = LLVM::CreateLoad(*builder, chain_itr);
llvm::Value* kv_struct = builder->CreateBitCast(kv_struct_i8, kv_struct_type->getPointerTo());
Expand Down Expand Up @@ -4358,6 +4383,7 @@ namespace LCompilers {
// end
llvm_utils->start_new_block(loopend);
}


llvm::Value* LLVMList::read_item(llvm::Value* list, llvm::Value* pos,
bool enable_bounds_checking,
Expand Down Expand Up @@ -6393,9 +6419,9 @@ namespace LCompilers {
el_asr_type, name2memidx);
}

void LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
llvm::Value* LLVMSetLinearProbing::resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) {

/**
* C++ equivalent:
Expand Down Expand Up @@ -6423,18 +6449,22 @@ namespace LCompilers {
*/

get_builder0()
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
llvm::Value* el_list = get_el_list(set);
llvm::Value* el_mask = LLVM::CreateLoad(*builder, get_pointer_to_mask(set));
llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_capacity(set));
pos_ptr = builder0.CreateAlloca(llvm::Type::getInt32Ty(context), nullptr);
llvm::Function *fn = builder->GetInsertBlock()->getParent();
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then", fn);
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else");
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont");
std::string s = check_if_exists ? "qq" : "pp";
llvm::BasicBlock *thenBB = llvm::BasicBlock::Create(context, "then"+s, fn);
llvm::BasicBlock *elseBB = llvm::BasicBlock::Create(context, "else"+s);
llvm::BasicBlock *mergeBB = llvm::BasicBlock::Create(context, "ifcont"+s);
llvm::Value* el_mask_value = LLVM::CreateLoad(*builder,
llvm_utils->create_ptr_gep(el_mask, el_hash));
llvm::Value* is_prob_not_needed = builder->CreateICmpEQ(el_mask_value,
llvm::ConstantInt::get(llvm::Type::getInt8Ty(context), llvm::APInt(8, 1)));
llvm::AllocaInst *flag_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0), pos_ptr);
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), flag_ptr);
builder->CreateCondBr(is_prob_not_needed, thenBB, elseBB);
builder->SetInsertPoint(thenBB);
{
Expand All @@ -6447,6 +6477,9 @@ namespace LCompilers {
llvm_utils->create_if_else(is_el_matching, [=]() {
LLVM::CreateStore(*builder, el_hash, pos_ptr);
}, [&]() {
if (check_if_exists) {
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 1), flag_ptr);
} else {
if (throw_key_error) {
std::string message = "The set does not contain the specified element";
llvm::Value *fmt_ptr = builder->CreateGlobalStringPtr("KeyError: %s\n");
Expand All @@ -6457,7 +6490,7 @@ namespace LCompilers {
llvm::APInt(32, exit_code_int));
exit(context, module, *builder, exit_code);
}
});
}});
}
builder->CreateBr(mergeBB);
llvm_utils->start_new_block(elseBB);
Expand All @@ -6466,11 +6499,25 @@ namespace LCompilers {
module, el_asr_type, true);
}
llvm_utils->start_new_block(mergeBB);
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::Value *flag = LLVM::CreateLoad(*builder, flag_ptr);
llvm::AllocaInst *is_el_matching_ptr = builder->CreateAlloca(llvm::Type::getInt1Ty(context), nullptr);

llvm_utils->create_if_else(flag, [&](){
LLVM::CreateStore(*builder, llvm::ConstantInt::get(llvm::Type::getInt1Ty(context), 0), is_el_matching_ptr);
}, [&](){
// Check if the actual element is present or not
llvm::Value* is_el_matching = llvm_utils->is_equal_by_value(el,
llvm_utils->list_api->read_item(el_list, pos, false, module,
LLVM::is_llvm_struct(el_asr_type)), module, el_asr_type);
llvm::Value* pos = LLVM::CreateLoad(*builder, pos_ptr);
llvm::Value* item = llvm_utils->list_api->read_item(el_list, pos, false, module,
LLVM::is_llvm_struct(el_asr_type)) ;
llvm::Value *iseq =llvm_utils->is_equal_by_value(el,
item, module, el_asr_type) ;
LLVM::CreateStore(*builder, iseq, is_el_matching_ptr);
});

llvm::Value *is_el_matching = LLVM::CreateLoad(*builder, is_el_matching_ptr);
if (check_if_exists) {
return is_el_matching;
}

llvm_utils->create_if_else(is_el_matching, []() {}, [&]() {
if (throw_key_error) {
Expand All @@ -6484,11 +6531,13 @@ namespace LCompilers {
exit(context, module, *builder, exit_code);
}
});

return nullptr;
}

void LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
llvm::Value* LLVMSetSeparateChaining::resolve_collision_for_read_with_bound_check(
llvm::Value* set, llvm::Value* el_hash, llvm::Value* el,
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error) {
llvm::Module& module, ASR::ttype_t* el_asr_type, bool throw_key_error, bool check_if_exists) {
/**
* C++ equivalent:
*
Expand All @@ -6515,6 +6564,10 @@ namespace LCompilers {
llvm::ConstantPointerNull::get(llvm::Type::getInt8PtrTy(context)))
);

if (check_if_exists) {
return does_el_exist;
}

llvm_utils->create_if_else(does_el_exist, []() {}, [&]() {
if (throw_key_error) {
std::string message = "The set does not contain the specified element";
Expand All @@ -6527,6 +6580,8 @@ namespace LCompilers {
exit(context, module, *builder, exit_code);
}
});

return nullptr;
}

void LLVMSetLinearProbing::remove_item(
Expand Down
Loading
Loading