From b620f037288bdd93bf1e192d28ce1ce688cd2ff2 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 22 Apr 2024 10:39:54 +0900 Subject: [PATCH 1/2] Update socket from CPython 3.12.2 --- Lib/socket.py | 42 +++-- Lib/test/test_socket.py | 361 +++++++++++++++++++++++++++++++--------- 2 files changed, 310 insertions(+), 93 deletions(-) diff --git a/Lib/socket.py b/Lib/socket.py index 63ba0acc90..42ee130773 100644 --- a/Lib/socket.py +++ b/Lib/socket.py @@ -13,7 +13,7 @@ socketpair() -- create a pair of new socket objects [*] fromfd() -- create a socket object from an open file descriptor [*] send_fds() -- Send file descriptor to the socket. -recv_fds() -- Recieve file descriptors from the socket. +recv_fds() -- Receive file descriptors from the socket. fromshare() -- create a socket object from data received from socket.share() [*] gethostname() -- return the current hostname gethostbyname() -- map a hostname to its IP number @@ -28,6 +28,7 @@ socket.setdefaulttimeout() -- set the default timeout value create_connection() -- connects to an address, with an optional timeout and optional source address. +create_server() -- create a TCP socket and bind it to a specified address. [*] not available on all platforms! @@ -122,7 +123,7 @@ def _intenum_converter(value, enum_klass): errorTab[10014] = "A fault occurred on the network??" # WSAEFAULT errorTab[10022] = "An invalid operation was attempted." errorTab[10024] = "Too many open files." - errorTab[10035] = "The socket operation would block" + errorTab[10035] = "The socket operation would block." errorTab[10036] = "A blocking operation is already in progress." errorTab[10037] = "Operation already in progress." errorTab[10038] = "Socket operation on nonsocket." @@ -254,17 +255,18 @@ def __repr__(self): self.type, self.proto) if not closed: + # getsockname and getpeername may not be available on WASI. try: laddr = self.getsockname() if laddr: s += ", laddr=%s" % str(laddr) - except error: + except (error, AttributeError): pass try: raddr = self.getpeername() if raddr: s += ", raddr=%s" % str(raddr) - except error: + except (error, AttributeError): pass s += '>' return s @@ -380,7 +382,7 @@ def _sendfile_use_sendfile(self, file, offset=0, count=None): if timeout and not selector_select(timeout): raise TimeoutError('timed out') if count: - blocksize = count - total_sent + blocksize = min(count - total_sent, blocksize) if blocksize <= 0: break try: @@ -783,11 +785,11 @@ def getfqdn(name=''): First the hostname returned by gethostbyaddr() is checked, then possibly existing aliases. In case no FQDN is available and `name` - was given, it is returned unchanged. If `name` was empty or '0.0.0.0', + was given, it is returned unchanged. If `name` was empty, '0.0.0.0' or '::', hostname from gethostname() is returned. """ name = name.strip() - if not name or name == '0.0.0.0': + if not name or name in ('0.0.0.0', '::'): name = gethostname() try: hostname, aliases, ipaddrs = gethostbyaddr(name) @@ -806,7 +808,7 @@ def getfqdn(name=''): _GLOBAL_DEFAULT_TIMEOUT = object() def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, - source_address=None): + source_address=None, *, all_errors=False): """Connect to *address* and return the socket object. Convenience function. Connect to *address* (a 2-tuple ``(host, @@ -816,11 +818,13 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, global default timeout setting returned by :func:`getdefaulttimeout` is used. If *source_address* is set it must be a tuple of (host, port) for the socket to bind as a source address before making the connection. - A host of '' or port 0 tells the OS to use the default. + A host of '' or port 0 tells the OS to use the default. When a connection + cannot be created, raises the last error if *all_errors* is False, + and an ExceptionGroup of all errors if *all_errors* is True. """ host, port = address - err = None + exceptions = [] for res in getaddrinfo(host, port, 0, SOCK_STREAM): af, socktype, proto, canonname, sa = res sock = None @@ -832,20 +836,24 @@ def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT, sock.bind(source_address) sock.connect(sa) # Break explicitly a reference cycle - err = None + exceptions.clear() return sock - except error as _: - err = _ + except error as exc: + if not all_errors: + exceptions.clear() # raise only the last error + exceptions.append(exc) if sock is not None: sock.close() - if err is not None: + if len(exceptions): try: - raise err + if not all_errors: + raise exceptions[0] + raise ExceptionGroup("create_connection failed", exceptions) finally: # Break explicitly a reference cycle - err = None + exceptions.clear() else: raise error("getaddrinfo returns an empty list") @@ -902,7 +910,7 @@ def create_server(address, *, family=AF_INET, backlog=None, reuse_port=False, # address, effectively preventing this one from accepting # connections. Also, it may set the process in a state where # it'll no longer respond to any signals or graceful kills. - # See: msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx + # See: https://learn.microsoft.com/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse if os.name not in ('nt', 'cygwin') and \ hasattr(_socket, 'SO_REUSEADDR'): try: diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index 0e3eb08b82..d448101dcd 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -4,30 +4,31 @@ from test.support import socket_helper from test.support import threading_helper +import _thread as thread +import array +import contextlib import errno +import gc import io import itertools -import socket -import select -import tempfile -import time -import traceback -import queue -import sys -import os -import platform -import array -import contextlib -from weakref import proxy -import signal import math +import os import pickle -import struct +import platform +import queue import random -import shutil +import re +import select +import signal +import socket import string -import _thread as thread +import struct +import sys +import tempfile import threading +import time +import traceback +from weakref import proxy try: import multiprocessing except ImportError: @@ -37,12 +38,15 @@ except ImportError: fcntl = None +support.requires_working_socket(module=True) + HOST = socket_helper.HOST # test unicode string and carriage return MSG = 'Michael Gilfix was here\u1234\r\n'.encode('utf-8') VSOCKPORT = 1234 AIX = platform.system() == "AIX" +WSL = "microsoft-standard-WSL" in platform.release() try: import _socket @@ -141,6 +145,17 @@ def _have_socket_bluetooth(): return True +def _have_socket_hyperv(): + """Check whether AF_HYPERV sockets are supported on this host.""" + try: + s = socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) + except (AttributeError, OSError): + return False + else: + s.close() + return True + + @contextlib.contextmanager def socket_setdefaulttimeout(timeout): old_timeout = socket.getdefaulttimeout() @@ -169,6 +184,8 @@ def socket_setdefaulttimeout(timeout): HAVE_SOCKET_BLUETOOTH = _have_socket_bluetooth() +HAVE_SOCKET_HYPERV = _have_socket_hyperv() + # Size in bytes of the int type SIZEOF_INT = array.array("i").itemsize @@ -199,24 +216,6 @@ def setUp(self): self.serv = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDPLITE) self.port = socket_helper.bind_port(self.serv) -class ThreadSafeCleanupTestCase: - """Subclass of unittest.TestCase with thread-safe cleanup methods. - - This subclass protects the addCleanup() and doCleanups() methods - with a recursive lock. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._cleanup_lock = threading.RLock() - - def addCleanup(self, *args, **kwargs): - with self._cleanup_lock: - return super().addCleanup(*args, **kwargs) - - def doCleanups(self, *args, **kwargs): - with self._cleanup_lock: - return super().doCleanups(*args, **kwargs) class SocketCANTest(unittest.TestCase): @@ -336,9 +335,7 @@ def serverExplicitReady(self): self.server_ready.set() def _setUp(self): - self.wait_threads = threading_helper.wait_threads_exit() - self.wait_threads.__enter__() - self.addCleanup(self.wait_threads.__exit__, None, None, None) + self.enterContext(threading_helper.wait_threads_exit()) self.server_ready = threading.Event() self.client_ready = threading.Event() @@ -485,6 +482,7 @@ def clientTearDown(self): ThreadableTest.clientTearDown(self) @unittest.skipIf(fcntl is None, "need fcntl") +@unittest.skipIf(WSL, 'VSOCK does not work on Microsoft WSL') @unittest.skipUnless(HAVE_SOCKET_VSOCK, 'VSOCK sockets required for this test.') @unittest.skipUnless(get_cid() != 2, @@ -501,6 +499,7 @@ def setUp(self): self.serv.bind((socket.VMADDR_CID_ANY, VSOCKPORT)) self.serv.listen() self.serverExplicitReady() + self.serv.settimeout(support.LOOPBACK_TIMEOUT) self.conn, self.connaddr = self.serv.accept() self.addCleanup(self.conn.close) @@ -591,17 +590,18 @@ class SocketTestBase(unittest.TestCase): def setUp(self): self.serv = self.newSocket() + self.addCleanup(self.close_server) self.bindServer() + def close_server(self): + self.serv.close() + self.serv = None + def bindServer(self): """Bind server socket and set self.serv_addr to its address.""" self.bindSock(self.serv) self.serv_addr = self.serv.getsockname() - def tearDown(self): - self.serv.close() - self.serv = None - class SocketListeningTestMixin(SocketTestBase): """Mixin to listen on the server socket.""" @@ -611,8 +611,7 @@ def setUp(self): self.serv.listen() -class ThreadedSocketTestMixin(ThreadSafeCleanupTestCase, SocketTestBase, - ThreadableTest): +class ThreadedSocketTestMixin(SocketTestBase, ThreadableTest): """Mixin to add client socket and allow client/server tests. Client socket is self.cli and its address is self.cli_addr. See @@ -686,15 +685,10 @@ class UnixSocketTestBase(SocketTestBase): # can't send anything that might be problematic for a privileged # user running the tests. - def setUp(self): - self.dir_path = tempfile.mkdtemp() - self.addCleanup(os.rmdir, self.dir_path) - super().setUp() - def bindSock(self, sock): - path = tempfile.mktemp(dir=self.dir_path) - socket_helper.bind_unix_socket(sock, path) + path = socket_helper.create_unix_domain_name() self.addCleanup(os_helper.unlink, path) + socket_helper.bind_unix_socket(sock, path) class UnixStreamBase(UnixSocketTestBase): """Base class for Unix-domain SOCK_STREAM tests.""" @@ -827,6 +821,12 @@ def requireSocket(*args): class GeneralModuleTests(unittest.TestCase): + @unittest.skipUnless(_socket is not None, 'need _socket module') + def test_socket_type(self): + self.assertTrue(gc.is_tracked(_socket.socket)) + with self.assertRaisesRegex(TypeError, "immutable"): + _socket.socket.foo = 1 + def test_SocketType_is_socketobject(self): import _socket self.assertTrue(socket.SocketType is _socket.socket) @@ -958,6 +958,19 @@ def testWindowsSpecificConstants(self): socket.IPPROTO_L2TP socket.IPPROTO_SCTP + @unittest.skipIf(support.is_wasi, "WASI is missing these methods") + def test_socket_methods(self): + # socket methods that depend on a configure HAVE_ check. They should + # be present on all platforms except WASI. + names = [ + "_accept", "bind", "connect", "connect_ex", "getpeername", + "getsockname", "listen", "recvfrom", "recvfrom_into", "sendto", + "setsockopt", "shutdown" + ] + for name in names: + if not hasattr(socket.socket, name): + self.fail(f"socket method {name} is missing") + # TODO: RUSTPYTHON @unittest.expectedFailure @unittest.skipUnless(sys.platform == 'darwin', 'macOS specific test') @@ -1021,8 +1034,10 @@ def test_host_resolution(self): def test_host_resolution_bad_address(self): # These are all malformed IP addresses and expected not to resolve to - # any result. But some ISPs, e.g. AWS, may successfully resolve these - # IPs. + # any result. But some ISPs, e.g. AWS and AT&T, may successfully + # resolve these IPs. In particular, AT&T's DNS Error Assist service + # will break this test. See https://bugs.python.org/issue42092 for a + # workaround. explanation = ( "resolving an invalid IP address did not raise OSError; " "can be caused by a broken DNS server" @@ -1074,7 +1089,20 @@ def testInterfaceNameIndex(self): 'socket.if_indextoname() not available.') def testInvalidInterfaceIndexToName(self): self.assertRaises(OSError, socket.if_indextoname, 0) + self.assertRaises(OverflowError, socket.if_indextoname, -1) + self.assertRaises(OverflowError, socket.if_indextoname, 2**1000) self.assertRaises(TypeError, socket.if_indextoname, '_DEADBEEF') + if hasattr(socket, 'if_nameindex'): + indices = dict(socket.if_nameindex()) + for index in indices: + index2 = index + 2**32 + if index2 not in indices: + with self.assertRaises((OverflowError, OSError)): + socket.if_indextoname(index2) + for index in 2**32-1, 2**64-1: + if index not in indices: + with self.assertRaises((OverflowError, OSError)): + socket.if_indextoname(index) @unittest.skipUnless(hasattr(socket, 'if_nametoindex'), 'socket.if_nametoindex() not available.') @@ -1378,10 +1406,21 @@ def testStringToIPv6(self): def testSockName(self): # Testing getsockname() - port = socket_helper.find_unused_port() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.addCleanup(sock.close) - sock.bind(("0.0.0.0", port)) + + # Since find_unused_port() is inherently subject to race conditions, we + # call it a couple times if necessary. + for i in itertools.count(): + port = socket_helper.find_unused_port() + try: + sock.bind(("0.0.0.0", port)) + except OSError as e: + if e.errno != errno.EADDRINUSE or i == 5: + raise + else: + break + name = sock.getsockname() # XXX(nnorwitz): http://tinyurl.com/os5jz seems to indicate # it reasonable to get the host's addr in addition to 0.0.0.0. @@ -1525,9 +1564,11 @@ def testGetaddrinfo(self): infos = socket.getaddrinfo(HOST, 80, socket.AF_INET, socket.SOCK_STREAM) for family, type, _, _, _ in infos: self.assertEqual(family, socket.AF_INET) - self.assertEqual(str(family), 'AddressFamily.AF_INET') + self.assertEqual(repr(family), '' % family.value) + self.assertEqual(str(family), str(family.value)) self.assertEqual(type, socket.SOCK_STREAM) - self.assertEqual(str(type), 'SocketKind.SOCK_STREAM') + self.assertEqual(repr(type), '' % type.value) + self.assertEqual(str(type), str(type.value)) infos = socket.getaddrinfo(HOST, None, 0, socket.SOCK_STREAM) for _, socktype, _, _, _ in infos: self.assertEqual(socktype, socket.SOCK_STREAM) @@ -1574,6 +1615,54 @@ def testGetaddrinfo(self): except socket.gaierror: pass + def test_getaddrinfo_int_port_overflow(self): + # gh-74895: Test that getaddrinfo does not raise OverflowError on port. + # + # POSIX getaddrinfo() never specify the valid range for "service" + # decimal port number values. For IPv4 and IPv6 they are technically + # unsigned 16-bit values, but the API is protocol agnostic. Which values + # trigger an error from the C library function varies by platform as + # they do not all perform validation. + + # The key here is that we don't want to produce OverflowError as Python + # prior to 3.12 did for ints outside of a [LONG_MIN, LONG_MAX] range. + # Leave the error up to the underlying string based platform C API. + + from _testcapi import ULONG_MAX, LONG_MAX, LONG_MIN + try: + socket.getaddrinfo(None, ULONG_MAX + 1, type=socket.SOCK_STREAM) + except OverflowError: + # Platforms differ as to what values consitute a getaddrinfo() error + # return. Some fail for LONG_MAX+1, others ULONG_MAX+1, and Windows + # silently accepts such huge "port" aka "service" numeric values. + self.fail("Either no error or socket.gaierror expected.") + except socket.gaierror: + pass + + try: + socket.getaddrinfo(None, LONG_MAX + 1, type=socket.SOCK_STREAM) + except OverflowError: + self.fail("Either no error or socket.gaierror expected.") + except socket.gaierror: + pass + + try: + socket.getaddrinfo(None, LONG_MAX - 0xffff + 1, type=socket.SOCK_STREAM) + except OverflowError: + self.fail("Either no error or socket.gaierror expected.") + except socket.gaierror: + pass + + try: + socket.getaddrinfo(None, LONG_MIN - 1, type=socket.SOCK_STREAM) + except OverflowError: + self.fail("Either no error or socket.gaierror expected.") + except socket.gaierror: + pass + + socket.getaddrinfo(None, 0, type=socket.SOCK_STREAM) # No error expected. + socket.getaddrinfo(None, 0xffff, type=socket.SOCK_STREAM) # No error expected. + def test_getnameinfo(self): # only IP addresses are allowed self.assertRaises(OSError, socket.getnameinfo, ('mail.python.org',0), 0) @@ -1746,6 +1835,10 @@ def test_getaddrinfo_ipv6_basic(self): ) self.assertEqual(sockaddr, ('ff02::1de:c0:face:8d', 1234, 0, 0)) + def test_getfqdn_filter_localhost(self): + self.assertEqual(socket.getfqdn(), socket.getfqdn("0.0.0.0")) + self.assertEqual(socket.getfqdn(), socket.getfqdn("::")) + @unittest.skipUnless(socket_helper.IPV6_ENABLED, 'IPv6 required for this test.') @unittest.skipIf(sys.platform == 'win32', 'does not work on Windows') @unittest.skipIf(AIX, 'Symbolic scope id does not work') @@ -1807,8 +1900,10 @@ def test_str_for_enums(self): # Make sure that the AF_* and SOCK_* constants have enum-like string # reprs. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - self.assertEqual(str(s.family), 'AddressFamily.AF_INET') - self.assertEqual(str(s.type), 'SocketKind.SOCK_STREAM') + self.assertEqual(repr(s.family), '' % s.family.value) + self.assertEqual(repr(s.type), '' % s.type.value) + self.assertEqual(str(s.family), str(s.family.value)) + self.assertEqual(str(s.type), str(s.type.value)) @unittest.expectedFailureIf(sys.platform.startswith("linux"), "TODO: RUSTPYTHON, AssertionError: 526337 != ") def test_socket_consistent_sock_type(self): @@ -1902,17 +1997,18 @@ def test_socket_fileno(self): self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM) if hasattr(socket, "AF_UNIX"): - tmpdir = tempfile.mkdtemp() - self.addCleanup(shutil.rmtree, tmpdir) + unix_name = socket_helper.create_unix_domain_name() + self.addCleanup(os_helper.unlink, unix_name) + s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - self.addCleanup(s.close) - try: - s.bind(os.path.join(tmpdir, 'socket')) - except PermissionError: - pass - else: - self._test_socket_fileno(s, socket.AF_UNIX, - socket.SOCK_STREAM) + with s: + try: + s.bind(unix_name) + except PermissionError: + pass + else: + self._test_socket_fileno(s, socket.AF_UNIX, + socket.SOCK_STREAM) def test_socket_fileno_rejects_float(self): with self.assertRaises(TypeError): @@ -1956,6 +2052,41 @@ def test_socket_fileno_requires_socket_fd(self): fileno=afile.fileno()) self.assertEqual(cm.exception.errno, errno.ENOTSOCK) + def test_addressfamily_enum(self): + import _socket, enum + CheckedAddressFamily = enum._old_convert_( + enum.IntEnum, 'AddressFamily', 'socket', + lambda C: C.isupper() and C.startswith('AF_'), + source=_socket, + ) + enum._test_simple_enum(CheckedAddressFamily, socket.AddressFamily) + + def test_socketkind_enum(self): + import _socket, enum + CheckedSocketKind = enum._old_convert_( + enum.IntEnum, 'SocketKind', 'socket', + lambda C: C.isupper() and C.startswith('SOCK_'), + source=_socket, + ) + enum._test_simple_enum(CheckedSocketKind, socket.SocketKind) + + def test_msgflag_enum(self): + import _socket, enum + CheckedMsgFlag = enum._old_convert_( + enum.IntFlag, 'MsgFlag', 'socket', + lambda C: C.isupper() and C.startswith('MSG_'), + source=_socket, + ) + enum._test_simple_enum(CheckedMsgFlag, socket.MsgFlag) + + def test_addressinfo_enum(self): + import _socket, enum + CheckedAddressInfo = enum._old_convert_( + enum.IntFlag, 'AddressInfo', 'socket', + lambda C: C.isupper() and C.startswith('AI_'), + source=_socket) + enum._test_simple_enum(CheckedAddressInfo, socket.AddressInfo) + @unittest.skipUnless(HAVE_SOCKET_CAN, 'SocketCan required for this test.') class BasicCANTest(unittest.TestCase): @@ -2449,6 +2580,58 @@ def testCreateScoSocket(self): pass +@unittest.skipUnless(HAVE_SOCKET_HYPERV, + 'Hyper-V sockets required for this test.') +class BasicHyperVTest(unittest.TestCase): + + def testHyperVConstants(self): + socket.HVSOCKET_CONNECT_TIMEOUT + socket.HVSOCKET_CONNECT_TIMEOUT_MAX + socket.HVSOCKET_CONNECTED_SUSPEND + socket.HVSOCKET_ADDRESS_FLAG_PASSTHRU + socket.HV_GUID_ZERO + socket.HV_GUID_WILDCARD + socket.HV_GUID_BROADCAST + socket.HV_GUID_CHILDREN + socket.HV_GUID_LOOPBACK + socket.HV_GUID_PARENT + + def testCreateHyperVSocketWithUnknownProtoFailure(self): + expected = r"\[WinError 10041\]" + with self.assertRaisesRegex(OSError, expected): + socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM) + + def testCreateHyperVSocketAddrNotTupleFailure(self): + expected = "connect(): AF_HYPERV address must be tuple, not str" + with socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) as s: + with self.assertRaisesRegex(TypeError, re.escape(expected)): + s.connect(socket.HV_GUID_ZERO) + + def testCreateHyperVSocketAddrNotTupleOf2StrsFailure(self): + expected = "AF_HYPERV address must be a str tuple (vm_id, service_id)" + with socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) as s: + with self.assertRaisesRegex(TypeError, re.escape(expected)): + s.connect((socket.HV_GUID_ZERO,)) + + def testCreateHyperVSocketAddrNotTupleOfStrsFailure(self): + expected = "AF_HYPERV address must be a str tuple (vm_id, service_id)" + with socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) as s: + with self.assertRaisesRegex(TypeError, re.escape(expected)): + s.connect((1, 2)) + + def testCreateHyperVSocketAddrVmIdNotValidUUIDFailure(self): + expected = "connect(): AF_HYPERV address vm_id is not a valid UUID string" + with socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) as s: + with self.assertRaisesRegex(ValueError, re.escape(expected)): + s.connect(("00", socket.HV_GUID_ZERO)) + + def testCreateHyperVSocketAddrServiceIdNotValidUUIDFailure(self): + expected = "connect(): AF_HYPERV address service_id is not a valid UUID string" + with socket.socket(socket.AF_HYPERV, socket.SOCK_STREAM, socket.HV_PROTOCOL_RAW) as s: + with self.assertRaisesRegex(ValueError, re.escape(expected)): + s.connect((socket.HV_GUID_ZERO, "00")) + + class BasicTCPTest(SocketConnectedTest): def __init__(self, methodName='runTest'): @@ -2658,7 +2841,7 @@ def _testRecvFromNegative(self): # here assumes that datagram delivery on the local machine will be # reliable. -class SendrecvmsgBase(ThreadSafeCleanupTestCase): +class SendrecvmsgBase: # Base class for sendmsg()/recvmsg() tests. # Time in seconds to wait before considering a test failed, or @@ -4525,7 +4708,6 @@ def testInterruptedRecvmsgIntoTimeout(self): @unittest.skipUnless(hasattr(signal, "alarm") or hasattr(signal, "setitimer"), "Don't have signal.alarm or signal.setitimer") class InterruptedSendTimeoutTest(InterruptedTimeoutBase, - ThreadSafeCleanupTestCase, SocketListeningTestMixin, TCPTestBase): # Test interrupting the interruptible send*() methods with signals # when a timeout is set. @@ -5136,6 +5318,7 @@ def mocked_socket_module(self): finally: socket.socket = old_socket + @socket_helper.skip_if_tcp_blackhole def test_connect(self): port = socket_helper.find_unused_port() cli = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -5144,6 +5327,7 @@ def test_connect(self): cli.connect((HOST, port)) self.assertEqual(cm.exception.errno, errno.ECONNREFUSED) + @socket_helper.skip_if_tcp_blackhole def test_create_connection(self): # Issue #9792: errors raised by create_connection() should have # a proper errno attribute. @@ -5168,6 +5352,24 @@ def test_create_connection(self): expected_errnos = socket_helper.get_socket_conn_refused_errs() self.assertIn(cm.exception.errno, expected_errnos) + def test_create_connection_all_errors(self): + port = socket_helper.find_unused_port() + try: + socket.create_connection((HOST, port), all_errors=True) + except ExceptionGroup as e: + eg = e + else: + self.fail('expected connection to fail') + + self.assertIsInstance(eg, ExceptionGroup) + for e in eg.exceptions: + self.assertIsInstance(e, OSError) + + addresses = socket.getaddrinfo( + 'localhost', port, 0, socket.SOCK_STREAM) + # assert that we got an exception for each address + self.assertEqual(len(addresses), len(eg.exceptions)) + def test_create_connection_timeout(self): # Issue #9792: create_connection() should not recast timeout errors # as generic socket errors. @@ -5184,6 +5386,7 @@ def test_create_connection_timeout(self): class NetworkConnectionAttributesTest(SocketTCPTest, ThreadableTest): + cli = None def __init__(self, methodName='runTest'): SocketTCPTest.__init__(self, methodName=methodName) @@ -5193,7 +5396,8 @@ def clientSetUp(self): self.source_port = socket_helper.find_unused_port() def clientTearDown(self): - self.cli.close() + if self.cli is not None: + self.cli.close() self.cli = None ThreadableTest.clientTearDown(self) @@ -5328,10 +5532,10 @@ def alarm_handler(signal, frame): self.fail("caught timeout instead of Alarm") except Alarm: pass - except: + except BaseException as e: self.fail("caught other exception instead of Alarm:" " %s(%s):\n%s" % - (sys.exc_info()[:2] + (traceback.format_exc(),))) + (type(e), e, traceback.format_exc())) else: self.fail("nothing caught") finally: @@ -6232,6 +6436,7 @@ def _testWithTimeoutTriggeredSend(self): def testWithTimeoutTriggeredSend(self): conn = self.accept_conn() conn.recv(88192) + # bpo-45212: the wait here needs to be longer than the client-side timeout (0.01s) time.sleep(1) # errors @@ -6312,12 +6517,16 @@ def test_sha256(self): # TODO: RUSTPYTHON, OSError: bind(): bad family @unittest.expectedFailure def test_hmac_sha1(self): - expected = bytes.fromhex("effcdf6ae5eb2fa2d27416d5f184df9c259a7c79") + # gh-109396: In FIPS mode, Linux 6.5 requires a key + # of at least 112 bits. Use a key of 152 bits. + key = b"Python loves AF_ALG" + data = b"what do ya want for nothing?" + expected = bytes.fromhex("193dbb43c6297b47ea6277ec0ce67119a3f3aa66") with self.create_alg('hash', 'hmac(sha1)') as algo: - algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, b"Jefe") + algo.setsockopt(socket.SOL_ALG, socket.ALG_SET_KEY, key) op, _ = algo.accept() with op: - op.sendall(b"what do ya want for nothing?") + op.sendall(data) self.assertEqual(op.recv(512), expected) # Although it should work with 3.19 and newer the test blocks on From 9e2f6bd1875419fa3add740b89db3cdad988fe15 Mon Sep 17 00:00:00 2001 From: Jeong YunWon Date: Mon, 22 Apr 2024 11:20:38 +0900 Subject: [PATCH 2/2] mark failing tests of test_socket --- Lib/test/test_socket.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_socket.py b/Lib/test/test_socket.py index d448101dcd..ea544f6afa 100644 --- a/Lib/test/test_socket.py +++ b/Lib/test/test_socket.py @@ -821,6 +821,8 @@ def requireSocket(*args): class GeneralModuleTests(unittest.TestCase): + # TODO: RUSTPYTHON + @unittest.expectedFailure @unittest.skipUnless(_socket is not None, 'need _socket module') def test_socket_type(self): self.assertTrue(gc.is_tracked(_socket.socket)) @@ -1532,8 +1534,6 @@ def test_sio_loopback_fast_path(self): raise self.assertRaises(TypeError, s.ioctl, socket.SIO_LOOPBACK_FAST_PATH, None) - # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' - @unittest.expectedFailure def testGetaddrinfo(self): try: socket.getaddrinfo('localhost', 80) @@ -1615,6 +1615,8 @@ def testGetaddrinfo(self): except socket.gaierror: pass + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_getaddrinfo_int_port_overflow(self): # gh-74895: Test that getaddrinfo does not raise OverflowError on port. # @@ -1894,8 +1896,6 @@ def test_getnameinfo_ipv6_scopeid_numeric(self): nameinfo = socket.getnameinfo(sockaddr, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV) self.assertEqual(nameinfo, ('ff02::1de:c0:face:8d%' + str(ifindex), '1234')) - # TODO: RUSTPYTHON, AssertionError: '2' != 'AddressFamily.AF_INET' - @unittest.expectedFailure def test_str_for_enums(self): # Make sure that the AF_* and SOCK_* constants have enum-like string # reprs. @@ -2052,6 +2052,8 @@ def test_socket_fileno_requires_socket_fd(self): fileno=afile.fileno()) self.assertEqual(cm.exception.errno, errno.ENOTSOCK) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_addressfamily_enum(self): import _socket, enum CheckedAddressFamily = enum._old_convert_( @@ -2061,6 +2063,8 @@ def test_addressfamily_enum(self): ) enum._test_simple_enum(CheckedAddressFamily, socket.AddressFamily) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_socketkind_enum(self): import _socket, enum CheckedSocketKind = enum._old_convert_( @@ -2070,6 +2074,8 @@ def test_socketkind_enum(self): ) enum._test_simple_enum(CheckedSocketKind, socket.SocketKind) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_msgflag_enum(self): import _socket, enum CheckedMsgFlag = enum._old_convert_( @@ -2079,6 +2085,8 @@ def test_msgflag_enum(self): ) enum._test_simple_enum(CheckedMsgFlag, socket.MsgFlag) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_addressinfo_enum(self): import _socket, enum CheckedAddressInfo = enum._old_convert_( @@ -5352,6 +5360,8 @@ def test_create_connection(self): expected_errnos = socket_helper.get_socket_conn_refused_errs() self.assertIn(cm.exception.errno, expected_errnos) + # TODO: RUSTPYTHON + @unittest.expectedFailure def test_create_connection_all_errors(self): port = socket_helper.find_unused_port() try: