Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b2d4fc0

Browse files
committed
objstr: Make *strip() accept bytes.
1 parent ce6c101 commit b2d4fc0

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

py/objstr.c

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,8 @@ enum { LSTRIP, RSTRIP, STRIP };
540540

541541
STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
542542
assert(1 <= n_args && n_args <= 2);
543-
assert(MP_OBJ_IS_STR(args[0]));
543+
assert(is_str_or_bytes(args[0]));
544+
const mp_obj_type_t *self_type = mp_obj_get_type(args[0]);
544545

545546
const byte *chars_to_del;
546547
uint chars_to_del_len;
@@ -550,7 +551,9 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
550551
chars_to_del = whitespace;
551552
chars_to_del_len = sizeof(whitespace);
552553
} else {
553-
assert(MP_OBJ_IS_STR(args[1]));
554+
if (mp_obj_get_type(args[1]) != self_type) {
555+
arg_type_mixup();
556+
}
554557
GET_STR_DATA_LEN(args[1], s, l);
555558
chars_to_del = s;
556559
chars_to_del_len = l;
@@ -594,7 +597,7 @@ STATIC mp_obj_t str_uni_strip(int type, uint n_args, const mp_obj_t *args) {
594597
assert(last_good_char_pos >= first_good_char_pos);
595598
//+1 to accomodate the last character
596599
machine_uint_t stripped_len = last_good_char_pos - first_good_char_pos + 1;
597-
return mp_obj_new_str(orig_str + first_good_char_pos, stripped_len, false);
600+
return str_new(self_type, orig_str + first_good_char_pos, stripped_len);
598601
}
599602

600603
STATIC mp_obj_t str_strip(uint n_args, const mp_obj_t *args) {

tests/basics/string_strip.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,13 @@
1010

1111
print(' spacious '.rstrip())
1212
print('mississippi'.rstrip('ipz'))
13+
14+
print(b'mississippi'.rstrip(b'ipz'))
15+
try:
16+
print(b'mississippi'.rstrip('ipz'))
17+
except TypeError:
18+
print("TypeError")
19+
try:
20+
print('mississippi'.rstrip(b'ipz'))
21+
except TypeError:
22+
print("TypeError")

0 commit comments

Comments
 (0)