@@ -357,7 +357,8 @@ STATIC mp_obj_t str_subscr(mp_obj_t self_in, mp_obj_t index, mp_obj_t value) {
357
357
}
358
358
359
359
STATIC mp_obj_t str_join (mp_obj_t self_in , mp_obj_t arg ) {
360
- assert (MP_OBJ_IS_STR (self_in ));
360
+ assert (is_str_or_bytes (self_in ));
361
+ const mp_obj_type_t * self_type = mp_obj_get_type (self_in );
361
362
362
363
// get separation string
363
364
GET_STR_DATA_LEN (self_in , sep_str , sep_len );
@@ -379,8 +380,9 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
379
380
// count required length
380
381
int required_len = 0 ;
381
382
for (int i = 0 ; i < seq_len ; i ++ ) {
382
- if (!MP_OBJ_IS_STR (seq_items [i ])) {
383
- nlr_raise (mp_obj_new_exception_msg (& mp_type_TypeError , "join expected a list of str's" ));
383
+ if (mp_obj_get_type (seq_items [i ]) != self_type ) {
384
+ nlr_raise (mp_obj_new_exception_msg (& mp_type_TypeError ,
385
+ "join expects a list of str/bytes objects consistent with self object" ));
384
386
}
385
387
if (i > 0 ) {
386
388
required_len += sep_len ;
@@ -391,7 +393,7 @@ STATIC mp_obj_t str_join(mp_obj_t self_in, mp_obj_t arg) {
391
393
392
394
// make joined string
393
395
byte * data ;
394
- mp_obj_t joined_str = mp_obj_str_builder_start (mp_obj_get_type ( self_in ) , required_len , & data );
396
+ mp_obj_t joined_str = mp_obj_str_builder_start (self_type , required_len , & data );
395
397
for (int i = 0 ; i < seq_len ; i ++ ) {
396
398
if (i > 0 ) {
397
399
memcpy (data , sep_str , sep_len );
0 commit comments