@@ -1298,86 +1298,68 @@ fn _check_kinds() {
12981298 is_sync :: < SslStream < TcpStream > > ( ) ;
12991299}
13001300
1301- #[ test]
13021301#[ cfg( ossl111) ]
1303- fn stateless ( ) {
1304- use super :: SslOptions ;
1305-
1306- #[ derive( Debug ) ]
1307- struct MemoryStream {
1308- incoming : io:: Cursor < Vec < u8 > > ,
1309- outgoing : Vec < u8 > ,
1310- }
1311-
1312- impl MemoryStream {
1313- pub fn new ( ) -> Self {
1314- Self {
1315- incoming : io:: Cursor :: new ( Vec :: new ( ) ) ,
1316- outgoing : Vec :: new ( ) ,
1317- }
1318- }
1319-
1320- pub fn extend_incoming ( & mut self , data : & [ u8 ] ) {
1321- self . incoming . get_mut ( ) . extend_from_slice ( data) ;
1322- }
1302+ #[ derive( Debug ) ]
1303+ struct MemoryStream {
1304+ incoming : io:: Cursor < Vec < u8 > > ,
1305+ outgoing : Vec < u8 > ,
1306+ }
13231307
1324- pub fn take_outgoing ( & mut self ) -> Outgoing < ' _ > {
1325- Outgoing ( & mut self . outgoing )
1308+ #[ cfg( ossl111) ]
1309+ impl MemoryStream {
1310+ fn new ( ) -> Self {
1311+ Self {
1312+ incoming : io:: Cursor :: new ( Vec :: new ( ) ) ,
1313+ outgoing : Vec :: new ( ) ,
13261314 }
13271315 }
13281316
1329- impl Read for MemoryStream {
1330- fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
1331- let n = self . incoming . read ( buf) ?;
1332- if self . incoming . position ( ) == self . incoming . get_ref ( ) . len ( ) as u64 {
1333- self . incoming . set_position ( 0 ) ;
1334- self . incoming . get_mut ( ) . clear ( ) ;
1335- }
1336- if n == 0 {
1337- return Err ( io:: Error :: new (
1338- io:: ErrorKind :: WouldBlock ,
1339- "no data available" ,
1340- ) ) ;
1341- }
1342- Ok ( n)
1343- }
1317+ fn extend_incoming ( & mut self , data : & [ u8 ] ) {
1318+ self . incoming . get_mut ( ) . extend_from_slice ( data) ;
13441319 }
13451320
1346- impl Write for MemoryStream {
1347- fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
1348- self . outgoing . write ( buf)
1349- }
1350-
1351- fn flush ( & mut self ) -> io:: Result < ( ) > {
1352- Ok ( ( ) )
1353- }
1321+ fn take_outgoing ( & mut self ) -> Vec < u8 > {
1322+ mem:: take ( & mut self . outgoing )
13541323 }
1324+ }
13551325
1356- pub struct Outgoing < ' a > ( & ' a mut Vec < u8 > ) ;
1357-
1358- impl Drop for Outgoing < ' _ > {
1359- fn drop ( & mut self ) {
1360- self . 0 . clear ( ) ;
1326+ #[ cfg( ossl111) ]
1327+ impl Read for MemoryStream {
1328+ fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
1329+ let n = self . incoming . read ( buf) ?;
1330+ if self . incoming . position ( ) == self . incoming . get_ref ( ) . len ( ) as u64 {
1331+ self . incoming . set_position ( 0 ) ;
1332+ self . incoming . get_mut ( ) . clear ( ) ;
13611333 }
1362- }
1363-
1364- impl :: std:: ops:: Deref for Outgoing < ' _ > {
1365- type Target = [ u8 ] ;
1366- fn deref ( & self ) -> & [ u8 ] {
1367- self . 0
1334+ if n == 0 {
1335+ return Err ( io:: Error :: new (
1336+ io:: ErrorKind :: WouldBlock ,
1337+ "no data available" ,
1338+ ) ) ;
13681339 }
1340+ Ok ( n)
13691341 }
1342+ }
13701343
1371- impl AsRef < [ u8 ] > for Outgoing < ' _ > {
1372- fn as_ref ( & self ) -> & [ u8 ] {
1373- self . 0
1374- }
1344+ # [ cfg ( ossl111 ) ]
1345+ impl Write for MemoryStream {
1346+ fn write ( & mut self , buf : & [ u8 ] ) -> io :: Result < usize > {
1347+ self . outgoing . write ( buf )
13751348 }
13761349
1377- fn send ( from : & mut MemoryStream , to : & mut MemoryStream ) {
1378- to . extend_incoming ( & from . take_outgoing ( ) ) ;
1350+ fn flush ( & mut self ) -> io :: Result < ( ) > {
1351+ Ok ( ( ) )
13791352 }
1353+ }
13801354
1355+ #[ cfg( ossl111) ]
1356+ fn send ( from : & mut MemoryStream , to : & mut MemoryStream ) {
1357+ to. extend_incoming ( & from. take_outgoing ( ) ) ;
1358+ }
1359+
1360+ #[ test]
1361+ #[ cfg( ossl111) ]
1362+ fn stateless ( ) {
13811363 //
13821364 // Setup
13831365 //
@@ -1467,6 +1449,149 @@ fn psk_ciphers() {
14671449 assert ! ( CLIENT_CALLED . load( Ordering :: SeqCst ) ) ;
14681450}
14691451
1452+ // Regression tests: the PSK/cookie trampolines used to forward the callback's
1453+ // returned `usize` to OpenSSL without checking it against the slice length.
1454+
1455+ #[ cfg( not( osslconf = "OPENSSL_NO_PSK" ) ) ]
1456+ #[ cfg( target_pointer_width = "64" ) ]
1457+ #[ test]
1458+ fn psk_client_cb_oversize_psk_len_rejected ( ) {
1459+ // Without the fix, `psk_len as u32` truncates the returned length; the low
1460+ // 32 bits match `PSK.len()` and slip past OpenSSL's `> PSK_MAX_PSK_LEN`
1461+ // check. (Rust's slice length equals `PSK_MAX_PSK_LEN`, so truncation is
1462+ // the only way to differentiate — hence the 64-bit guard.)
1463+ const CIPHER : & str = "PSK-AES256-CBC-SHA" ;
1464+ const PSK : & [ u8 ] = b"thisisaverysecurekey" ;
1465+ const CLIENT_IDENT : & [ u8 ] = b"thisisaclient" ;
1466+
1467+ let mut server = Server :: builder ( ) ;
1468+ server. ctx ( ) . set_cipher_list ( CIPHER ) . unwrap ( ) ;
1469+ server. ctx ( ) . set_psk_server_callback ( |_, _identity, psk| {
1470+ psk[ ..PSK . len ( ) ] . copy_from_slice ( PSK ) ;
1471+ Ok ( PSK . len ( ) )
1472+ } ) ;
1473+ server. should_error ( ) ;
1474+ let server = server. build ( ) ;
1475+
1476+ let mut client = server. client ( ) ;
1477+ #[ cfg( any( boringssl, ossl111, awslc) ) ]
1478+ client. ctx ( ) . set_options ( SslOptions :: NO_TLSV1_3 ) ;
1479+ client. ctx ( ) . set_cipher_list ( CIPHER ) . unwrap ( ) ;
1480+ client
1481+ . ctx ( )
1482+ . set_psk_client_callback ( move |_, _, identity, psk| {
1483+ identity[ ..CLIENT_IDENT . len ( ) ] . copy_from_slice ( CLIENT_IDENT ) ;
1484+ identity[ CLIENT_IDENT . len ( ) ] = 0 ;
1485+ psk[ ..PSK . len ( ) ] . copy_from_slice ( PSK ) ;
1486+ Ok ( ( u32:: MAX as usize ) + 1 + PSK . len ( ) )
1487+ } ) ;
1488+
1489+ client. connect_err ( ) ;
1490+ }
1491+
1492+ #[ cfg( not( osslconf = "OPENSSL_NO_PSK" ) ) ]
1493+ #[ cfg( target_pointer_width = "64" ) ]
1494+ #[ test]
1495+ fn psk_server_cb_oversize_psk_len_rejected ( ) {
1496+ // Server-side counterpart — same `as u32` truncation bypass.
1497+ const CIPHER : & str = "PSK-AES256-CBC-SHA" ;
1498+ const PSK : & [ u8 ] = b"thisisaverysecurekey" ;
1499+ const CLIENT_IDENT : & [ u8 ] = b"thisisaclient" ;
1500+
1501+ let mut server = Server :: builder ( ) ;
1502+ server. ctx ( ) . set_cipher_list ( CIPHER ) . unwrap ( ) ;
1503+ server. ctx ( ) . set_psk_server_callback ( |_, _identity, psk| {
1504+ psk[ ..PSK . len ( ) ] . copy_from_slice ( PSK ) ;
1505+ Ok ( ( u32:: MAX as usize ) + 1 + PSK . len ( ) )
1506+ } ) ;
1507+ server. should_error ( ) ;
1508+ let server = server. build ( ) ;
1509+
1510+ let mut client = server. client ( ) ;
1511+ #[ cfg( any( boringssl, ossl111, awslc) ) ]
1512+ client. ctx ( ) . set_options ( SslOptions :: NO_TLSV1_3 ) ;
1513+ client. ctx ( ) . set_cipher_list ( CIPHER ) . unwrap ( ) ;
1514+ client
1515+ . ctx ( )
1516+ . set_psk_client_callback ( move |_, _, identity, psk| {
1517+ identity[ ..CLIENT_IDENT . len ( ) ] . copy_from_slice ( CLIENT_IDENT ) ;
1518+ identity[ CLIENT_IDENT . len ( ) ] = 0 ;
1519+ psk[ ..PSK . len ( ) ] . copy_from_slice ( PSK ) ;
1520+ Ok ( PSK . len ( ) )
1521+ } ) ;
1522+
1523+ client. connect_err ( ) ;
1524+ }
1525+
1526+ #[ test]
1527+ #[ cfg( ossl111) ]
1528+ fn stateless_cookie_cb_oversize_length_rejected ( ) {
1529+ // Callback claims a length past the slice end. The fix makes the
1530+ // trampoline report failure so stateless() errors cleanly.
1531+ let mut client_ctx = SslContext :: builder ( SslMethod :: tls ( ) ) . unwrap ( ) ;
1532+ client_ctx. clear_options ( SslOptions :: ENABLE_MIDDLEBOX_COMPAT ) ;
1533+ let mut client_stream =
1534+ SslStream :: new ( Ssl :: new ( & client_ctx. build ( ) ) . unwrap ( ) , MemoryStream :: new ( ) ) . unwrap ( ) ;
1535+
1536+ let mut server_ctx = SslContext :: builder ( SslMethod :: tls ( ) ) . unwrap ( ) ;
1537+ server_ctx
1538+ . set_certificate_file ( Path :: new ( "test/cert.pem" ) , SslFiletype :: PEM )
1539+ . unwrap ( ) ;
1540+ server_ctx
1541+ . set_private_key_file ( Path :: new ( "test/key.pem" ) , SslFiletype :: PEM )
1542+ . unwrap ( ) ;
1543+ server_ctx. set_stateless_cookie_generate_cb ( |_, buf| Ok ( buf. len ( ) + 1 ) ) ;
1544+ server_ctx. set_stateless_cookie_verify_cb ( |_, _| true ) ;
1545+ let mut server_stream =
1546+ SslStream :: new ( Ssl :: new ( & server_ctx. build ( ) ) . unwrap ( ) , MemoryStream :: new ( ) ) . unwrap ( ) ;
1547+
1548+ client_stream. connect ( ) . unwrap_err ( ) ;
1549+ send ( client_stream. get_mut ( ) , server_stream. get_mut ( ) ) ;
1550+ assert ! ( server_stream. stateless( ) . is_err( ) ) ;
1551+ }
1552+
1553+ #[ test]
1554+ #[ cfg( not( any( boringssl, awslc) ) ) ]
1555+ fn dtls_cookie_generate_cb_oversize_length_rejected ( ) {
1556+ // Rust hands the callback `DTLS1_COOKIE_LENGTH - 1` bytes but OpenSSL's
1557+ // internal cookie buffer is `DTLS1_COOKIE_LENGTH`; returning `buf.len() + 1`
1558+ // passes OpenSSL's `cookie_leni > sizeof(s->d1->cookie)` check. Without the
1559+ // fix, the server sends a HelloVerifyRequest containing one unwritten byte
1560+ // and the verify callback fires on the client's echo.
1561+ static VERIFY_CALLED : AtomicBool = AtomicBool :: new ( false ) ;
1562+ VERIFY_CALLED . store ( false , Ordering :: SeqCst ) ;
1563+
1564+ let listener = TcpListener :: bind ( "127.0.0.1:0" ) . unwrap ( ) ;
1565+ let addr = listener. local_addr ( ) . unwrap ( ) ;
1566+
1567+ let server = thread:: spawn ( move || {
1568+ let stream = listener. accept ( ) . unwrap ( ) . 0 ;
1569+ let mut ctx = SslContext :: builder ( SslMethod :: dtls ( ) ) . unwrap ( ) ;
1570+ ctx. set_certificate_file ( Path :: new ( "test/cert.pem" ) , SslFiletype :: PEM )
1571+ . unwrap ( ) ;
1572+ ctx. set_private_key_file ( Path :: new ( "test/key.pem" ) , SslFiletype :: PEM )
1573+ . unwrap ( ) ;
1574+ ctx. set_options ( SslOptions :: COOKIE_EXCHANGE ) ;
1575+ ctx. set_cookie_generate_cb ( |_, buf| Ok ( buf. len ( ) + 1 ) ) ;
1576+ ctx. set_cookie_verify_cb ( |_, _| {
1577+ VERIFY_CALLED . store ( true , Ordering :: SeqCst ) ;
1578+ true
1579+ } ) ;
1580+ let mut ssl = Ssl :: new ( & ctx. build ( ) ) . unwrap ( ) ;
1581+ ssl. set_mtu ( 1500 ) . unwrap ( ) ;
1582+ let _ = ssl. accept ( stream) ;
1583+ } ) ;
1584+
1585+ let stream = TcpStream :: connect ( addr) . unwrap ( ) ;
1586+ let ctx = SslContext :: builder ( SslMethod :: dtls ( ) ) . unwrap ( ) ;
1587+ let mut ssl = Ssl :: new ( & ctx. build ( ) ) . unwrap ( ) ;
1588+ ssl. set_mtu ( 1500 ) . unwrap ( ) ;
1589+ let _ = ssl. connect ( stream) ;
1590+
1591+ server. join ( ) . unwrap ( ) ;
1592+ assert ! ( !VERIFY_CALLED . load( Ordering :: SeqCst ) ) ;
1593+ }
1594+
14701595#[ test]
14711596fn sni_callback_swapped_ctx ( ) {
14721597 static CALLED_BACK : AtomicBool = AtomicBool :: new ( false ) ;
0 commit comments