diff --git a/Lib/test/test_codecs.py b/Lib/test/test_codecs.py index a12e5893dc..f9b0a00019 100644 --- a/Lib/test/test_codecs.py +++ b/Lib/test/test_codecs.py @@ -878,8 +878,6 @@ class UTF16LETest(ReadTest, unittest.TestCase): encoding = "utf-16-le" ill_formed_sequence = b"\x80\xdc" - # TODO: RUSTPYTHON - @unittest.expectedFailure def test_partial(self): self.check_partial( "\x00\xff\u0100\uffff\U00010000", @@ -922,10 +920,6 @@ def test_nonbmp(self): self.assertEqual(b'\x00\xd8\x03\xde'.decode(self.encoding), "\U00010203") - # TODO: RUSTPYTHON - @unittest.expectedFailure - def test_incremental_surrogatepass(self): - super().test_incremental_surrogatepass() class UTF16BETest(ReadTest, unittest.TestCase): encoding = "utf-16-be" diff --git a/common/src/encodings.rs b/common/src/encodings.rs index 39ca266126..e0075ad1eb 100644 --- a/common/src/encodings.rs +++ b/common/src/encodings.rs @@ -633,3 +633,176 @@ pub mod ascii { ) } } + +pub mod utf16_le { + use super::*; + + pub const ENCODING_NAME: &str = "utf-16-le"; + + pub fn encode(mut ctx: Ctx, errors: &E) -> Result, Ctx::Error> + where + Ctx: EncodeContext, + E: EncodeErrorHandler, + { + let mut out = Vec::::new(); + loop { + let data = ctx.remaining_data(); + let error_info = { + let mut iter = iter_code_points(data); + iter.find(|(_, c)| c.to_u32() > 0x10FFFF) + }; + let Some((i, ch)) = error_info else { + break; + }; + + // Add valid part up to the error + for ch in data[..i.bytes].code_points() { + let ch_u32 = ch.to_u32(); + if ch_u32 <= 0xFFFF { + out.extend_from_slice(&(ch_u32 as u16).to_le_bytes()); + } else if ch_u32 <= 0x10FFFF { + let code = ch_u32 - 0x10000; + let high = 0xD800 + (code >> 10); + let low = 0xDC00 + (code & 0x3FF); + out.extend_from_slice(&(high as u16).to_le_bytes()); + out.extend_from_slice(&(low as u16).to_le_bytes()); + } + } + + let err_start = ctx.position() + i; + let err_end = StrSize { + bytes: i.bytes + ch.len_wtf8(), + chars: i.chars + 1, + }; + let err_end = ctx.position() + err_end; + let replace = + ctx.handle_error(errors, err_start..err_end, Some("surrogates not allowed"))?; + match replace { + EncodeReplace::Str(s) => { + // Re-encode the replacement string + for cp in s.as_ref().code_points() { + let cp_u32 = cp.to_u32(); + if cp_u32 <= 0xFFFF { + out.extend_from_slice(&(cp_u32 as u16).to_le_bytes()); + } else if cp_u32 <= 0x10FFFF { + let code = cp_u32 - 0x10000; + let high = 0xD800 + (code >> 10); + let low = 0xDC00 + (code & 0x3FF); + out.extend_from_slice(&(high as u16).to_le_bytes()); + out.extend_from_slice(&(low as u16).to_le_bytes()); + } + } + } + EncodeReplace::Bytes(b) => { + out.extend_from_slice(b.as_ref()); + } + } + } + + // Process all remaining data + for ch in ctx.remaining_data().code_points() { + let ch_u32 = ch.to_u32(); + if ch_u32 <= 0xFFFF { + out.extend_from_slice(&(ch_u32 as u16).to_le_bytes()); + } else if ch_u32 <= 0x10FFFF { + let code = ch_u32 - 0x10000; + let high = 0xD800 + (code >> 10); + let low = 0xDC00 + (code & 0x3FF); + out.extend_from_slice(&(high as u16).to_le_bytes()); + out.extend_from_slice(&(low as u16).to_le_bytes()); + } + } + Ok(out) + } + + pub fn decode>( + mut ctx: Ctx, + errors: &E, + final_decode: bool, + ) -> Result<(Wtf8Buf, usize), Ctx::Error> { + let mut out = Wtf8Buf::new(); + + while ctx.remaining_data().len() >= 2 { + let data = ctx.remaining_data(); + let ch = u16::from_le_bytes([data[0], data[1]]); + + if ch < 0xD800 || ch > 0xDFFF { + // BMP character + if let Some(c) = char::from_u32(ch as u32) { + out.push_str(&c.to_string()); + ctx.advance(2); + } else { + let pos = ctx.position(); + let replace = + ctx.handle_error(errors, pos..pos + 2, Some("invalid character"))?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } + } else if ch >= 0xD800 && ch <= 0xDBFF { + // High surrogate + if data.len() < 4 { + if final_decode { + let pos = ctx.position(); + let replace = + ctx.handle_error(errors, pos..pos + 2, Some("unexpected end of data"))?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } else { + // In partial mode, stop here and return what we have + break; + } + } else { + let ch2 = u16::from_le_bytes([data[2], data[3]]); + if ch2 >= 0xDC00 && ch2 <= 0xDFFF { + // Valid surrogate pair + let code = (((ch & 0x3FF) as u32) << 10) | ((ch2 & 0x3FF) as u32); + let code_point = code + 0x10000; + if let Some(c) = char::from_u32(code_point) { + out.push_str(&c.to_string()); + ctx.advance(4); + } else { + let pos = ctx.position(); + let replace = ctx.handle_error( + errors, + pos..pos + 4, + Some("invalid surrogate pair"), + )?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } + } else { + // Invalid surrogate pair + let pos = ctx.position(); + let replace = ctx.handle_error( + errors, + pos..pos + 2, + Some("illegal UTF-16 surrogate"), + )?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } + } + } else { + // Low surrogate without high surrogate + let pos = ctx.position(); + let replace = + ctx.handle_error(errors, pos..pos + 2, Some("illegal UTF-16 surrogate"))?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } + } + + // Handle remaining single byte + if ctx.remaining_data().len() == 1 { + if final_decode { + let pos = ctx.position(); + let replace = ctx.handle_error(errors, pos..pos + 1, Some("truncated data"))?; + out.push_wtf8(replace.as_ref()); + // Don't advance here, the error handler already positioned us + } + // In partial mode, just leave it for next call + } + + Ok((out, ctx.position())) + } +} diff --git a/vm/src/stdlib/codecs.rs b/vm/src/stdlib/codecs.rs index c0a091bcf8..a739886941 100644 --- a/vm/src/stdlib/codecs.rs +++ b/vm/src/stdlib/codecs.rs @@ -286,12 +286,12 @@ mod _codecs { delegate_pycodecs!(charmap_build, args, vm) } #[pyfunction] - fn utf_16_le_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(utf_16_le_encode, args, vm) + fn utf_16_le_encode(args: EncodeArgs, vm: &VirtualMachine) -> EncodeResult { + do_codec!(utf16_le::encode, args, vm) } #[pyfunction] - fn utf_16_le_decode(args: FuncArgs, vm: &VirtualMachine) -> PyResult { - delegate_pycodecs!(utf_16_le_decode, args, vm) + fn utf_16_le_decode(args: DecodeArgs, vm: &VirtualMachine) -> DecodeResult { + do_codec!(utf16_le::decode, args, vm) } #[pyfunction] fn utf_16_be_encode(args: FuncArgs, vm: &VirtualMachine) -> PyResult {