Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5e5d69b

Browse files
committed
objstr: Make .join() support bytes.
1 parent 7e7940c commit 5e5d69b

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

py/objstr.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,8 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
357357
}
358358

359359
STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
360-
assert(MP_OBJ_IS_STR(self_in));
360+
assert(is_str_or_bytes(self_in));
361+
const mp_obj_type_t *self_type = mp_obj_get_type(self_in);
361362

362363
// get separation string
363364
GET_STR_DATA_LEN(self_in, sep_str, sep_len);
@@ -379,8 +380,9 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
379380
// count required length
380381
int required_len = 0;
381382
for (int i = 0; i < seq_len; i++) {
382-
if (!MP_OBJ_IS_STR(seq_items[i])) {
383-
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError, "join expected a list of str's"));
383+
if (mp_obj_get_type(seq_items[i]) != self_type) {
384+
nlr_raise(mp_obj_new_exception_msg(&mp_type_TypeError,
385+
"join expects a list of str/bytes objects consistent with self object"));
384386
}
385387
if (i > 0) {
386388
required_len += sep_len;
@@ -391,7 +393,7 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
391393

392394
// make joined string
393395
byte *data;
394-
mp_obj_t joined_str = mp_obj_str_builder_start(mp_obj_get_type(self_in), required_len, &data);
396+
mp_obj_t joined_str = mp_obj_str_builder_start(self_type, required_len, &data);
395397
for (int i = 0; i < seq_len; i++) {
396398
if (i > 0) {
397399
memcpy(data, sep_str, sep_len);

tests/basics/string-join.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,15 @@
1010
print(''.join('abc'))
1111
print(','.join('abc'))
1212
print(','.join('abc' for i in range(5)))
13+
14+
print(b','.join([b'abc', b'123']))
15+
16+
try:
17+
print(b','.join(['abc', b'123']))
18+
except TypeError:
19+
print("TypeError")
20+
21+
try:
22+
print(','.join([b'abc', b'123']))
23+
except TypeError:
24+
print("TypeError")

0 commit comments

Comments
 (0)