diff --git a/integration_tests/test_str_attributes.py b/integration_tests/test_str_attributes.py index d7b70730b4..f19c7a1712 100755 --- a/integration_tests/test_str_attributes.py +++ b/integration_tests/test_str_attributes.py @@ -71,6 +71,26 @@ def find(): assert s2.find("we") == -1 assert "".find("") == 0 +def count(): + s: str + sub: str + s = "ABC ABCDAB ABCDABCDABDE" + sub = "ABC" + assert s.count(sub) == 4 + assert s.count("ABC") == 4 + + sub = "AB" + assert s.count(sub) == 6 + assert s.count("AB") == 6 + + sub = "ABC" + assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 4 + assert "ABC ABCDAB ABCDABCDABDE".count("ABC") == 4 + + sub = "AB" + assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 6 + assert "ABC ABCDAB ABCDABCDABDE".count("AB") == 6 + def startswith(): s: str @@ -307,6 +327,7 @@ def check(): strip() swapcase() find() + count() startswith() endswith() partition() diff --git a/src/libasr/asr_utils.h b/src/libasr/asr_utils.h index 31e7d0dba1..948a402a2b 100644 --- a/src/libasr/asr_utils.h +++ b/src/libasr/asr_utils.h @@ -4242,6 +4242,44 @@ static inline int KMP_string_match(std::string &s_var, std::string &sub) { return res; } +static inline int KMP_string_match_count(std::string &s_var, std::string &sub) { + int str_len = s_var.size(); + int sub_len = sub.size(); + int count = 0; + std::vector lps(sub_len, 0); + if (sub_len == 0) { + count = str_len + 1; + } else { + for(int i = 1, len = 0; i < sub_len;) { + if (sub[i] == sub[len]) { + lps[i++] = ++len; + } else { + if (len != 0) { + len = lps[len - 1]; + } else { + lps[i++] = 0; + } + } + } + for (int i = 0, j = 0; (str_len - i) >= (sub_len - j);) { + if (sub[j] == s_var[i]) { + j++, i++; + } + if (j == sub_len) { + count++; + j = lps[j - 1]; + } else if (i < str_len && sub[j] != s_var[i]) { + if (j != 0) { + j = lps[j - 1]; + } else { + i = i + 1; + } + } + } + } + return count; +} + static inline void visit_expr_list(Allocator &al, Vec& exprs, Vec& exprs_vec) { LCOMPILERS_ASSERT(exprs_vec.reserve_called); diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index ff2ea35546..6c3fe7ea9d 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -6872,13 +6872,13 @@ class BodyVisitor : public CommonVisitor { } } else if (attr_name == "find") { if (args.size() != 1) { - throw SemanticError("str.find() takes one arguments", + throw SemanticError("str.find() takes one argument", loc); } ASR::expr_t *arg = args[0].m_value; ASR::ttype_t *type = ASRUtils::expr_type(arg); if (!ASRUtils::is_character(*type)) { - throw SemanticError("str.find() takes one arguments of type: str", + throw SemanticError("str.find() takes one argument of type: str", arg->base.loc); } if (ASRUtils::expr_value(arg) != nullptr) { @@ -6905,6 +6905,41 @@ class BodyVisitor : public CommonVisitor { tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_find", loc); } return; + } else if (attr_name == "count") { + if (args.size() != 1) { + throw SemanticError("str.count() takes one argument", + loc); + } + ASR::expr_t *arg = args[0].m_value; + ASR::ttype_t *type = ASRUtils::expr_type(arg); + if (!ASRUtils::is_character(*type)) { + throw SemanticError("str.count() takes one argument of type: str", + arg->base.loc); + } + if (ASRUtils::expr_value(arg) != nullptr) { + ASR::StringConstant_t* sub_str_con = ASR::down_cast(arg); + std::string sub = sub_str_con->m_s; + int res = ASRUtils::KMP_string_match_count(s_var, sub); + tmp = ASR::make_IntegerConstant_t(al, loc, res, + ASRUtils::TYPE(ASR::make_Integer_t(al,loc, 4))); + } else { + ASR::symbol_t *fn_div = resolve_intrinsic_function(loc, "_lpython_str_count"); + Vec args; + args.reserve(al, 1); + ASR::call_arg_t str_arg; + str_arg.loc = loc; + ASR::ttype_t *str_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc, + 1, s_var.size(), nullptr)); + str_arg.m_value = ASRUtils::EXPR( + ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type)); + ASR::call_arg_t sub_arg; + sub_arg.loc = loc; + sub_arg.m_value = arg; + args.push_back(al, str_arg); + args.push_back(al, sub_arg); + tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_count", loc); + } + return; } else if (attr_name == "rstrip") { if (args.size() != 0) { throw SemanticError("str.rstrip() takes no arguments", diff --git a/src/runtime/lpython_builtin.py b/src/runtime/lpython_builtin.py index 46b758c725..6bf35acbd3 100644 --- a/src/runtime/lpython_builtin.py +++ b/src/runtime/lpython_builtin.py @@ -637,17 +637,50 @@ def _lpython_str_capitalize(x: str) -> str: @overload -def _lpython_str_count(x: str, y: str) -> i32: - if(len(y) == 0): return len(x) + 1 +def _lpython_str_count(s: str, sub: str) -> i32: + s_len :i32; sub_len :i32; flag: bool; _len: i32; + count: i32; i: i32; + lps: list[i32] = [] + s_len = len(s) + sub_len = len(sub) - count: i32 = 0 - curr_char: str - i: i32 + if sub_len == 0: + return s_len + 1 + + count = 0 + + for i in range(sub_len): + lps.append(0) + + i = 1 + _len = 0 + while i < sub_len: + if sub[i] == sub[_len]: + _len += 1 + lps[i] = _len + i += 1 + else: + if _len != 0: + _len = lps[_len - 1] + else: + lps[i] = 0 + i += 1 - for i in range(len(x)): - curr_char = x[i] - if curr_char == y[0]: - count += i32(x[i:i+len(y)] == y) + j: i32 + j = 0 + i = 0 + while (s_len - i) >= (sub_len - j): + if sub[j] == s[i]: + i += 1 + j += 1 + if j == sub_len: + count += 1 + j = lps[j - 1] + elif i < s_len and sub[j] != s[i]: + if j != 0: + j = lps[j - 1] + else: + i = i + 1 return count