diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 64635ae094..f7c33230ac 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -518,6 +518,7 @@ RUN(NAME test_list_pop2 LABELS cpython llvm NOFAST) # TODO: Remove NOFAST RUN(NAME test_list_pop3 LABELS cpython llvm) RUN(NAME test_list_compare LABELS cpython llvm) RUN(NAME test_list_concat LABELS cpython llvm c NOFAST) +RUN(NAME test_list_reserve LABELS cpython llvm) RUN(NAME test_tuple_01 LABELS cpython llvm c) RUN(NAME test_tuple_02 LABELS cpython llvm c NOFAST) RUN(NAME test_tuple_03 LABELS cpython llvm c) diff --git a/integration_tests/test_list_reserve.py b/integration_tests/test_list_reserve.py new file mode 100644 index 0000000000..9c074c351d --- /dev/null +++ b/integration_tests/test_list_reserve.py @@ -0,0 +1,30 @@ +from lpython import i32, f64, reserve + +def test_list_reserve(): + l1: list[i32] = [] + l2: list[list[tuple[f64, str, tuple[i32, f64]]]] = [] + i: i32 + + reserve(l1, 100) + for i in range(50): + l1.append(i) + assert len(l1) == i + 1 + + reserve(l1, 150) + + for i in range(50): + l1.pop(0) + assert len(l1) == 49 - i + + reserve(l2, 100) + for i in range(50): + l2.append([(f64(i * i), str(i), (i, f64(i + 1))), (f64(i), str(i), (i, f64(i)))]) + assert len(l2) == i + 1 + + reserve(l2, 150) + + for i in range(50): + l2.pop(0) + assert len(l2) == 49 - i + +test_list_reserve() diff --git a/src/libasr/codegen/asr_to_llvm.cpp b/src/libasr/codegen/asr_to_llvm.cpp index cbcaaf2e19..80a6190ddb 100644 --- a/src/libasr/codegen/asr_to_llvm.cpp +++ b/src/libasr/codegen/asr_to_llvm.cpp @@ -1687,6 +1687,21 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor tmp = list_api->pop_position(plist, pos, asr_el_type, module.get(), name2memidx); } + void generate_Reserve(ASR::expr_t* m_arg, ASR::expr_t* m_ele) { + // For now, this only handles lists + ASR::ttype_t* asr_el_type = ASRUtils::get_contained_type(ASRUtils::expr_type(m_arg)); + int64_t ptr_loads_copy = ptr_loads; + ptr_loads = 0; + this->visit_expr(*m_arg); + llvm::Value* plist = tmp; + + ptr_loads = 2; + this->visit_expr_wrapper(m_ele, true); + ptr_loads = ptr_loads_copy; + llvm::Value* n = tmp; + list_api->reserve(plist, n, asr_el_type, module.get()); + } + void generate_DictElems(ASR::expr_t* m_arg, bool key_or_value) { ASR::Dict_t* dict_type = ASR::down_cast( ASRUtils::expr_type(m_arg)); @@ -1807,6 +1822,10 @@ class ASRToLLVMVisitor : public ASR::BaseVisitor } break; } + case ASRUtils::IntrinsicFunctions::Reserve: { + generate_Reserve(x.m_args[0], x.m_args[1]); + break; + } case ASRUtils::IntrinsicFunctions::DictKeys: { generate_DictElems(x.m_args[0], 0); break; diff --git a/src/libasr/codegen/llvm_utils.cpp b/src/libasr/codegen/llvm_utils.cpp index 7fe6d28bf6..cefce251ab 100644 --- a/src/libasr/codegen/llvm_utils.cpp +++ b/src/libasr/codegen/llvm_utils.cpp @@ -4128,6 +4128,33 @@ namespace LCompilers { shift_end_point_by_one(list); } + void LLVMList::reserve(llvm::Value* list, llvm::Value* n, + ASR::ttype_t* asr_type, llvm::Module* module) { + /** + * C++ equivalent + * + * if( n > current_capacity ) { + * list_data = realloc(list_data, sizeof(el_type) * n); + * } + * + */ + llvm::Value* capacity = LLVM::CreateLoad(*builder, get_pointer_to_current_capacity(list)); + std::string type_code = ASRUtils::get_type_code(asr_type); + int type_size = std::get<1>(typecode2listtype[type_code]); + llvm::Type* el_type = std::get<2>(typecode2listtype[type_code]); + llvm_utils->create_if_else(builder->CreateICmpSGT(n, capacity), [&]() { + llvm::Value* arg_size = builder->CreateMul(llvm::ConstantInt::get(context, + llvm::APInt(32, type_size)), n); + llvm::Value* copy_data_ptr = get_pointer_to_list_data(list); + llvm::Value* copy_data = LLVM::CreateLoad(*builder, copy_data_ptr); + copy_data = LLVM::lfortran_realloc(context, *module, *builder, + copy_data, arg_size); + copy_data = builder->CreateBitCast(copy_data, el_type->getPointerTo()); + builder->CreateStore(copy_data, copy_data_ptr); + builder->CreateStore(n, get_pointer_to_current_capacity(list)); + }, []() {}); + } + void LLVMList::reverse(llvm::Value* list, llvm::Module& module) { /* Equivalent in C++: diff --git a/src/libasr/codegen/llvm_utils.h b/src/libasr/codegen/llvm_utils.h index d5d1264c8d..7efa781430 100644 --- a/src/libasr/codegen/llvm_utils.h +++ b/src/libasr/codegen/llvm_utils.h @@ -407,6 +407,9 @@ namespace LCompilers { llvm::Module* module, std::map>& name2memidx); + void reserve(llvm::Value* list, llvm::Value* n, + ASR::ttype_t* asr_type, llvm::Module* module); + void remove(llvm::Value* list, llvm::Value* item, ASR::ttype_t* item_type, llvm::Module& module); diff --git a/src/libasr/pass/intrinsic_function_registry.h b/src/libasr/pass/intrinsic_function_registry.h index 4d6b2b1691..d8fe1d173f 100644 --- a/src/libasr/pass/intrinsic_function_registry.h +++ b/src/libasr/pass/intrinsic_function_registry.h @@ -46,6 +46,7 @@ enum class IntrinsicFunctions : int64_t { Partition, ListReverse, ListPop, + Reserve, DictKeys, DictValues, SetAdd, @@ -102,6 +103,7 @@ inline std::string get_intrinsic_name(int x) { INTRINSIC_NAME_CASE(Partition) INTRINSIC_NAME_CASE(ListReverse) INTRINSIC_NAME_CASE(ListPop) + INTRINSIC_NAME_CASE(Reserve) INTRINSIC_NAME_CASE(DictKeys) INTRINSIC_NAME_CASE(DictValues) INTRINSIC_NAME_CASE(SetAdd) @@ -1262,6 +1264,55 @@ static inline ASR::asr_t* create_ListPop(Allocator& al, const Location& loc, } // namespace ListPop +namespace Reserve { + +static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { + ASRUtils::require_impl(x.n_args == 2, "Call to reserve must have exactly one argument", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[0])), + "First argument to reserve must be of list type", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(ASR::is_a(*ASRUtils::expr_type(x.m_args[1])), + "Second argument to reserve must be an integer", + x.base.base.loc, diagnostics); + ASRUtils::require_impl(x.m_type == nullptr, + "Return type of reserve must be empty", + x.base.base.loc, diagnostics); +} + +static inline ASR::expr_t *eval_reserve(Allocator &/*al*/, + const Location &/*loc*/, Vec& /*args*/) { + // TODO: To be implemented for ListConstant expression + return nullptr; +} + +static inline ASR::asr_t* create_Reserve(Allocator& al, const Location& loc, + Vec& args, + const std::function err) { + if (args.size() != 2) { + err("Call to reserve must have exactly two argument", loc); + } + if (!ASR::is_a(*ASRUtils::expr_type(args[0]))) { + err("First argument to reserve must be of list type", loc); + } + if (!ASR::is_a(*ASRUtils::expr_type(args[1]))) { + err("Second argument to reserve must be an integer", loc); + } + + Vec arg_values; + arg_values.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + arg_values.push_back(al, ASRUtils::expr_value(args[i])); + } + ASR::expr_t* compile_time_value = eval_reserve(al, loc, arg_values); + return ASR::make_Expr_t(al, loc, + ASRUtils::EXPR(ASRUtils::make_IntrinsicFunction_t_util(al, loc, + static_cast(ASRUtils::IntrinsicFunctions::Reserve), + args.p, args.size(), 0, nullptr, compile_time_value))); +} + +} // namespace Reserve + namespace DictKeys { static inline void verify_args(const ASR::IntrinsicFunction_t& x, diag::Diagnostics& diagnostics) { @@ -3124,6 +3175,8 @@ namespace IntrinsicFunctionRegistry { {nullptr, &DictValues::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::ListPop), {nullptr, &ListPop::verify_args}}, + {static_cast(ASRUtils::IntrinsicFunctions::Reserve), + {nullptr, &Reserve::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SetAdd), {nullptr, &SetAdd::verify_args}}, {static_cast(ASRUtils::IntrinsicFunctions::SetRemove), @@ -3206,6 +3259,8 @@ namespace IntrinsicFunctionRegistry { "list.reverse"}, {static_cast(ASRUtils::IntrinsicFunctions::ListPop), "list.pop"}, + {static_cast(ASRUtils::IntrinsicFunctions::Reserve), + "reserve"}, {static_cast(ASRUtils::IntrinsicFunctions::DictKeys), "dict.keys"}, {static_cast(ASRUtils::IntrinsicFunctions::DictValues), @@ -3290,6 +3345,7 @@ namespace IntrinsicFunctionRegistry { {"list.index", {&ListIndex::create_ListIndex, &ListIndex::eval_list_index}}, {"list.reverse", {&ListReverse::create_ListReverse, &ListReverse::eval_list_reverse}}, {"list.pop", {&ListPop::create_ListPop, &ListPop::eval_list_pop}}, + {"reserve", {&Reserve::create_Reserve, &Reserve::eval_reserve}}, {"dict.keys", {&DictKeys::create_DictKeys, &DictKeys::eval_dict_keys}}, {"dict.values", {&DictValues::create_DictValues, &DictValues::eval_dict_values}}, {"set.add", {&SetAdd::create_SetAdd, &SetAdd::eval_set_add}}, diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ce3cf6ca8d..2d41145d2c 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6511,6 +6511,9 @@ class BodyVisitor : public CommonVisitor { // Keyword arguments to be handled in make_call_helper args.reserve(al, c->n_args); visit_expr_list(c->m_args, c->n_args, args); + // TODO: Avoid overriding of user defined functions with same name as + // intrinsics like print, quit and reserve. Right now, user defined + // functions will never be considered. if (call_name == "print") { ASR::expr_t *fmt = nullptr; Vec args_expr = ASRUtils::call_arg2expr(al, args); @@ -6570,6 +6573,17 @@ class BodyVisitor : public CommonVisitor { } tmp = ASR::make_Stop_t(al, x.base.base.loc, code); return; + } else if( call_name == "reserve" ) { + ASRUtils::create_intrinsic_function create_func = + ASRUtils::IntrinsicFunctionRegistry::get_create_function("reserve"); + Vec args_exprs; args_exprs.reserve(al, args.size()); + for( size_t i = 0; i < args.size(); i++ ) { + args_exprs.push_back(al, args[i].m_value); + } + tmp = create_func(al, x.base.base.loc, args_exprs, + [&](const std::string &msg, const Location &loc) { + throw SemanticError(msg, loc); }); + return ; } ASR::symbol_t *s = current_scope->resolve_symbol(call_name); if (!s) { diff --git a/src/runtime/lpython/lpython.py b/src/runtime/lpython/lpython.py index 63bcc98c19..3ac9811c8c 100644 --- a/src/runtime/lpython/lpython.py +++ b/src/runtime/lpython/lpython.py @@ -760,6 +760,11 @@ def __lpython(*args, **kwargs): def bitnot(x, bitsize): return (~x) % (2 ** bitsize) +def reserve(data_structure, n): + if isinstance(data_structure, list): + data_structure = [None] * n + # no-op + bitnot_u8 = lambda x: bitnot(x, 8) bitnot_u16 = lambda x: bitnot(x, 16) bitnot_u32 = lambda x: bitnot(x, 32)