diff --git a/Lib/multiprocessing/connection.py b/Lib/multiprocessing/connection.py index 510e4b5aba..d0582e3cd5 100644 --- a/Lib/multiprocessing/connection.py +++ b/Lib/multiprocessing/connection.py @@ -9,6 +9,7 @@ __all__ = [ 'Client', 'Listener', 'Pipe', 'wait' ] +import errno import io import os import sys @@ -73,11 +74,6 @@ def arbitrary_address(family): if family == 'AF_INET': return ('localhost', 0) elif family == 'AF_UNIX': - # Prefer abstract sockets if possible to avoid problems with the address - # size. When coding portable applications, some implementations have - # sun_path as short as 92 bytes in the sockaddr_un struct. - if util.abstract_sockets_supported: - return f"\0listener-{os.getpid()}-{next(_mmap_counter)}" return tempfile.mktemp(prefix='listener-', dir=util.get_temp_dir()) elif family == 'AF_PIPE': return tempfile.mktemp(prefix=r'\\.\pipe\pyc-%d-%d-' % @@ -188,10 +184,9 @@ def send_bytes(self, buf, offset=0, size=None): self._check_closed() self._check_writable() m = memoryview(buf) - # HACK for byte-indexing of non-bytewise buffers (e.g. array.array) if m.itemsize > 1: - m = memoryview(bytes(m)) - n = len(m) + m = m.cast('B') + n = m.nbytes if offset < 0: raise ValueError("offset is negative") if n < offset: @@ -277,12 +272,22 @@ class PipeConnection(_ConnectionBase): with FILE_FLAG_OVERLAPPED. """ _got_empty_message = False + _send_ov = None def _close(self, _CloseHandle=_winapi.CloseHandle): + ov = self._send_ov + if ov is not None: + # Interrupt WaitForMultipleObjects() in _send_bytes() + ov.cancel() _CloseHandle(self._handle) def _send_bytes(self, buf): + if self._send_ov is not None: + # A connection should only be used by a single thread + raise ValueError("concurrent send_bytes() calls " + "are not supported") ov, err = _winapi.WriteFile(self._handle, buf, overlapped=True) + self._send_ov = ov try: if err == _winapi.ERROR_IO_PENDING: waitres = _winapi.WaitForMultipleObjects( @@ -292,7 +297,13 @@ def _send_bytes(self, buf): ov.cancel() raise finally: + self._send_ov = None nwritten, err = ov.GetOverlappedResult(True) + if err == _winapi.ERROR_OPERATION_ABORTED: + # close() was called by another thread while + # WaitForMultipleObjects() was waiting for the overlapped + # operation. + raise OSError(errno.EPIPE, "handle is closed") assert err == 0 assert nwritten == len(buf) @@ -465,8 +476,9 @@ def accept(self): ''' if self._listener is None: raise OSError('listener is closed') + c = self._listener.accept() - if self._authkey: + if self._authkey is not None: deliver_challenge(c, self._authkey) answer_challenge(c, self._authkey) return c @@ -728,39 +740,227 @@ def PipeClient(address): # Authentication stuff # -MESSAGE_LENGTH = 20 +MESSAGE_LENGTH = 40 # MUST be > 20 -CHALLENGE = b'#CHALLENGE#' -WELCOME = b'#WELCOME#' -FAILURE = b'#FAILURE#' +_CHALLENGE = b'#CHALLENGE#' +_WELCOME = b'#WELCOME#' +_FAILURE = b'#FAILURE#' -def deliver_challenge(connection, authkey): +# multiprocessing.connection Authentication Handshake Protocol Description +# (as documented for reference after reading the existing code) +# ============================================================================= +# +# On Windows: native pipes with "overlapped IO" are used to send the bytes, +# instead of the length prefix SIZE scheme described below. (ie: the OS deals +# with message sizes for us) +# +# Protocol error behaviors: +# +# On POSIX, any failure to receive the length prefix into SIZE, for SIZE greater +# than the requested maxsize to receive, or receiving fewer than SIZE bytes +# results in the connection being closed and auth to fail. +# +# On Windows, receiving too few bytes is never a low level _recv_bytes read +# error, receiving too many will trigger an error only if receive maxsize +# value was larger than 128 OR the if the data arrived in smaller pieces. +# +# Serving side Client side +# ------------------------------ --------------------------------------- +# 0. Open a connection on the pipe. +# 1. Accept connection. +# 2. Random 20+ bytes -> MESSAGE +# Modern servers always send +# more than 20 bytes and include +# a {digest} prefix on it with +# their preferred HMAC digest. +# Legacy ones send ==20 bytes. +# 3. send 4 byte length (net order) +# prefix followed by: +# b'#CHALLENGE#' + MESSAGE +# 4. Receive 4 bytes, parse as network byte +# order integer. If it is -1, receive an +# additional 8 bytes, parse that as network +# byte order. The result is the length of +# the data that follows -> SIZE. +# 5. Receive min(SIZE, 256) bytes -> M1 +# 6. Assert that M1 starts with: +# b'#CHALLENGE#' +# 7. Strip that prefix from M1 into -> M2 +# 7.1. Parse M2: if it is exactly 20 bytes in +# length this indicates a legacy server +# supporting only HMAC-MD5. Otherwise the +# 7.2. preferred digest is looked up from an +# expected "{digest}" prefix on M2. No prefix +# or unsupported digest? <- AuthenticationError +# 7.3. Put divined algorithm name in -> D_NAME +# 8. Compute HMAC-D_NAME of AUTHKEY, M2 -> C_DIGEST +# 9. Send 4 byte length prefix (net order) +# followed by C_DIGEST bytes. +# 10. Receive 4 or 4+8 byte length +# prefix (#4 dance) -> SIZE. +# 11. Receive min(SIZE, 256) -> C_D. +# 11.1. Parse C_D: legacy servers +# accept it as is, "md5" -> D_NAME +# 11.2. modern servers check the length +# of C_D, IF it is 16 bytes? +# 11.2.1. "md5" -> D_NAME +# and skip to step 12. +# 11.3. longer? expect and parse a "{digest}" +# prefix into -> D_NAME. +# Strip the prefix and store remaining +# bytes in -> C_D. +# 11.4. Don't like D_NAME? <- AuthenticationError +# 12. Compute HMAC-D_NAME of AUTHKEY, +# MESSAGE into -> M_DIGEST. +# 13. Compare M_DIGEST == C_D: +# 14a: Match? Send length prefix & +# b'#WELCOME#' +# <- RETURN +# 14b: Mismatch? Send len prefix & +# b'#FAILURE#' +# <- CLOSE & AuthenticationError +# 15. Receive 4 or 4+8 byte length prefix (net +# order) again as in #4 into -> SIZE. +# 16. Receive min(SIZE, 256) bytes -> M3. +# 17. Compare M3 == b'#WELCOME#': +# 17a. Match? <- RETURN +# 17b. Mismatch? <- CLOSE & AuthenticationError +# +# If this RETURNed, the connection remains open: it has been authenticated. +# +# Length prefixes are used consistently. Even on the legacy protocol, this +# was good fortune and allowed us to evolve the protocol by using the length +# of the opening challenge or length of the returned digest as a signal as +# to which protocol the other end supports. + +_ALLOWED_DIGESTS = frozenset( + {b'md5', b'sha256', b'sha384', b'sha3_256', b'sha3_384'}) +_MAX_DIGEST_LEN = max(len(_) for _ in _ALLOWED_DIGESTS) + +# Old hmac-md5 only server versions from Python <=3.11 sent a message of this +# length. It happens to not match the length of any supported digest so we can +# use a message of this length to indicate that we should work in backwards +# compatible md5-only mode without a {digest_name} prefix on our response. +_MD5ONLY_MESSAGE_LENGTH = 20 +_MD5_DIGEST_LEN = 16 +_LEGACY_LENGTHS = (_MD5ONLY_MESSAGE_LENGTH, _MD5_DIGEST_LEN) + + +def _get_digest_name_and_payload(message: bytes) -> (str, bytes): + """Returns a digest name and the payload for a response hash. + + If a legacy protocol is detected based on the message length + or contents the digest name returned will be empty to indicate + legacy mode where MD5 and no digest prefix should be sent. + """ + # modern message format: b"{digest}payload" longer than 20 bytes + # legacy message format: 16 or 20 byte b"payload" + if len(message) in _LEGACY_LENGTHS: + # Either this was a legacy server challenge, or we're processing + # a reply from a legacy client that sent an unprefixed 16-byte + # HMAC-MD5 response. All messages using the modern protocol will + # be longer than either of these lengths. + return '', message + if (message.startswith(b'{') and + (curly := message.find(b'}', 1, _MAX_DIGEST_LEN+2)) > 0): + digest = message[1:curly] + if digest in _ALLOWED_DIGESTS: + payload = message[curly+1:] + return digest.decode('ascii'), payload + raise AuthenticationError( + 'unsupported message length, missing digest prefix, ' + f'or unsupported digest: {message=}') + + +def _create_response(authkey, message): + """Create a MAC based on authkey and message + + The MAC algorithm defaults to HMAC-MD5, unless MD5 is not available or + the message has a '{digest_name}' prefix. For legacy HMAC-MD5, the response + is the raw MAC, otherwise the response is prefixed with '{digest_name}', + e.g. b'{sha256}abcdefg...' + + Note: The MAC protects the entire message including the digest_name prefix. + """ import hmac + digest_name = _get_digest_name_and_payload(message)[0] + # The MAC protects the entire message: digest header and payload. + if not digest_name: + # Legacy server without a {digest} prefix on message. + # Generate a legacy non-prefixed HMAC-MD5 reply. + try: + return hmac.new(authkey, message, 'md5').digest() + except ValueError: + # HMAC-MD5 is not available (FIPS mode?), fall back to + # HMAC-SHA2-256 modern protocol. The legacy server probably + # doesn't support it and will reject us anyways. :shrug: + digest_name = 'sha256' + # Modern protocol, indicate the digest used in the reply. + response = hmac.new(authkey, message, digest_name).digest() + return b'{%s}%s' % (digest_name.encode('ascii'), response) + + +def _verify_challenge(authkey, message, response): + """Verify MAC challenge + + If our message did not include a digest_name prefix, the client is allowed + to select a stronger digest_name from _ALLOWED_DIGESTS. + + In case our message is prefixed, a client cannot downgrade to a weaker + algorithm, because the MAC is calculated over the entire message + including the '{digest_name}' prefix. + """ + import hmac + response_digest, response_mac = _get_digest_name_and_payload(response) + response_digest = response_digest or 'md5' + try: + expected = hmac.new(authkey, message, response_digest).digest() + except ValueError: + raise AuthenticationError(f'{response_digest=} unsupported') + if len(expected) != len(response_mac): + raise AuthenticationError( + f'expected {response_digest!r} of length {len(expected)} ' + f'got {len(response_mac)}') + if not hmac.compare_digest(expected, response_mac): + raise AuthenticationError('digest received was wrong') + + +def deliver_challenge(connection, authkey: bytes, digest_name='sha256'): if not isinstance(authkey, bytes): raise ValueError( "Authkey must be bytes, not {0!s}".format(type(authkey))) + assert MESSAGE_LENGTH > _MD5ONLY_MESSAGE_LENGTH, "protocol constraint" message = os.urandom(MESSAGE_LENGTH) - connection.send_bytes(CHALLENGE + message) - digest = hmac.new(authkey, message, 'md5').digest() + message = b'{%s}%s' % (digest_name.encode('ascii'), message) + # Even when sending a challenge to a legacy client that does not support + # digest prefixes, they'll take the entire thing as a challenge and + # respond to it with a raw HMAC-MD5. + connection.send_bytes(_CHALLENGE + message) response = connection.recv_bytes(256) # reject large message - if response == digest: - connection.send_bytes(WELCOME) + try: + _verify_challenge(authkey, message, response) + except AuthenticationError: + connection.send_bytes(_FAILURE) + raise else: - connection.send_bytes(FAILURE) - raise AuthenticationError('digest received was wrong') + connection.send_bytes(_WELCOME) -def answer_challenge(connection, authkey): - import hmac + +def answer_challenge(connection, authkey: bytes): if not isinstance(authkey, bytes): raise ValueError( "Authkey must be bytes, not {0!s}".format(type(authkey))) message = connection.recv_bytes(256) # reject large message - assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message - message = message[len(CHALLENGE):] - digest = hmac.new(authkey, message, 'md5').digest() + if not message.startswith(_CHALLENGE): + raise AuthenticationError( + f'Protocol error, expected challenge: {message=}') + message = message[len(_CHALLENGE):] + if len(message) < _MD5ONLY_MESSAGE_LENGTH: + raise AuthenticationError('challenge too short: {len(message)} bytes') + digest = _create_response(authkey, message) connection.send_bytes(digest) response = connection.recv_bytes(256) # reject large message - if response != WELCOME: + if response != _WELCOME: raise AuthenticationError('digest sent was rejected') # @@ -943,7 +1143,7 @@ def wait(object_list, timeout=None): return ready # -# Make connection and socket objects sharable if possible +# Make connection and socket objects shareable if possible # if sys.platform == 'win32': diff --git a/Lib/multiprocessing/context.py b/Lib/multiprocessing/context.py index 8d0525d5d6..de8a264829 100644 --- a/Lib/multiprocessing/context.py +++ b/Lib/multiprocessing/context.py @@ -223,6 +223,10 @@ class Process(process.BaseProcess): def _Popen(process_obj): return _default_context.get_context().Process._Popen(process_obj) + @staticmethod + def _after_fork(): + return _default_context.get_context().Process._after_fork() + class DefaultContext(BaseContext): Process = Process @@ -254,6 +258,7 @@ def get_start_method(self, allow_none=False): return self._actual_context._name def get_all_start_methods(self): + """Returns a list of the supported start methods, default first.""" if sys.platform == 'win32': return ['spawn'] else: @@ -283,6 +288,11 @@ def _Popen(process_obj): from .popen_spawn_posix import Popen return Popen(process_obj) + @staticmethod + def _after_fork(): + # process is spawned, nothing to do + pass + class ForkServerProcess(process.BaseProcess): _start_method = 'forkserver' @staticmethod @@ -326,6 +336,11 @@ def _Popen(process_obj): from .popen_spawn_win32 import Popen return Popen(process_obj) + @staticmethod + def _after_fork(): + # process is spawned, nothing to do + pass + class SpawnContext(BaseContext): _name = 'spawn' Process = SpawnProcess diff --git a/Lib/multiprocessing/forkserver.py b/Lib/multiprocessing/forkserver.py index 22a911a7a2..4642707dae 100644 --- a/Lib/multiprocessing/forkserver.py +++ b/Lib/multiprocessing/forkserver.py @@ -61,7 +61,7 @@ def _stop_unlocked(self): def set_forkserver_preload(self, modules_names): '''Set list of module names to try to load in forkserver process.''' - if not all(type(mod) is str for mod in self._preload_modules): + if not all(type(mod) is str for mod in modules_names): raise TypeError('module_names must be a list of strings') self._preload_modules = modules_names diff --git a/Lib/multiprocessing/managers.py b/Lib/multiprocessing/managers.py index 22292c78b7..75d9c18c20 100644 --- a/Lib/multiprocessing/managers.py +++ b/Lib/multiprocessing/managers.py @@ -49,11 +49,11 @@ def reduce_array(a): reduction.register(array.array, reduce_array) view_types = [type(getattr({}, name)()) for name in ('items','keys','values')] -if view_types[0] is not list: # only needed in Py3.0 - def rebuild_as_list(obj): - return list, (list(obj),) - for view_type in view_types: - reduction.register(view_type, rebuild_as_list) +def rebuild_as_list(obj): + return list, (list(obj),) +for view_type in view_types: + reduction.register(view_type, rebuild_as_list) +del view_type, view_types # # Type for identifying shared objects @@ -153,7 +153,7 @@ def __init__(self, registry, address, authkey, serializer): Listener, Client = listener_client[serializer] # do authentication later - self.listener = Listener(address=address, backlog=16) + self.listener = Listener(address=address, backlog=128) self.address = self.listener.address self.id_to_obj = {'0': (None, ())} @@ -433,7 +433,6 @@ def incref(self, c, ident): self.id_to_refcount[ident] = 1 self.id_to_obj[ident] = \ self.id_to_local_proxy_obj[ident] - obj, exposed, gettypeid = self.id_to_obj[ident] util.debug('Server re-enabled tracking & INCREF %r', ident) else: raise ke @@ -497,7 +496,7 @@ class BaseManager(object): _Server = Server def __init__(self, address=None, authkey=None, serializer='pickle', - ctx=None): + ctx=None, *, shutdown_timeout=1.0): if authkey is None: authkey = process.current_process().authkey self._address = address # XXX not final address if eg ('', 0) @@ -507,6 +506,7 @@ def __init__(self, address=None, authkey=None, serializer='pickle', self._serializer = serializer self._Listener, self._Client = listener_client[serializer] self._ctx = ctx or get_context() + self._shutdown_timeout = shutdown_timeout def get_server(self): ''' @@ -570,8 +570,8 @@ def start(self, initializer=None, initargs=()): self._state.value = State.STARTED self.shutdown = util.Finalize( self, type(self)._finalize_manager, - args=(self._process, self._address, self._authkey, - self._state, self._Client), + args=(self._process, self._address, self._authkey, self._state, + self._Client, self._shutdown_timeout), exitpriority=0 ) @@ -656,7 +656,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() @staticmethod - def _finalize_manager(process, address, authkey, state, _Client): + def _finalize_manager(process, address, authkey, state, _Client, + shutdown_timeout): ''' Shutdown the manager process; will be registered as a finalizer ''' @@ -671,15 +672,17 @@ def _finalize_manager(process, address, authkey, state, _Client): except Exception: pass - process.join(timeout=1.0) + process.join(timeout=shutdown_timeout) if process.is_alive(): util.info('manager still alive') if hasattr(process, 'terminate'): util.info('trying to `terminate()` manager process') process.terminate() - process.join(timeout=1.0) + process.join(timeout=shutdown_timeout) if process.is_alive(): util.info('manager still alive after terminate') + process.kill() + process.join() state.value = State.SHUTDOWN try: @@ -1338,7 +1341,6 @@ def __init__(self, *args, **kwargs): def __del__(self): util.debug(f"{self.__class__.__name__}.__del__ by pid {getpid()}") - pass def get_server(self): 'Better than monkeypatching for now; merge into Server ultimately' diff --git a/Lib/multiprocessing/pool.py b/Lib/multiprocessing/pool.py index bbe05a550c..4f5d88cb97 100644 --- a/Lib/multiprocessing/pool.py +++ b/Lib/multiprocessing/pool.py @@ -203,6 +203,9 @@ def __init__(self, processes=None, initializer=None, initargs=(), processes = os.cpu_count() or 1 if processes < 1: raise ValueError("Number of processes must be at least 1") + if maxtasksperchild is not None: + if not isinstance(maxtasksperchild, int) or maxtasksperchild <= 0: + raise ValueError("maxtasksperchild must be a positive int or None") if initializer is not None and not callable(initializer): raise TypeError('initializer must be a callable') @@ -693,7 +696,7 @@ def _terminate_pool(cls, taskqueue, inqueue, outqueue, pool, change_notifier, if (not result_handler.is_alive()) and (len(cache) != 0): raise AssertionError( - "Cannot have cache with result_hander not alive") + "Cannot have cache with result_handler not alive") result_handler._state = TERMINATE change_notifier.put(None) diff --git a/Lib/multiprocessing/popen_spawn_win32.py b/Lib/multiprocessing/popen_spawn_win32.py index 9c4098d0fa..49d4c7eea2 100644 --- a/Lib/multiprocessing/popen_spawn_win32.py +++ b/Lib/multiprocessing/popen_spawn_win32.py @@ -14,6 +14,7 @@ # # +# Exit code used by Popen.terminate() TERMINATE = 0x10000 WINEXE = (sys.platform == 'win32' and getattr(sys, 'frozen', False)) WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") @@ -54,19 +55,20 @@ def __init__(self, process_obj): wfd = msvcrt.open_osfhandle(whandle, 0) cmd = spawn.get_command_line(parent_pid=os.getpid(), pipe_handle=rhandle) - cmd = ' '.join('"%s"' % x for x in cmd) python_exe = spawn.get_executable() # bpo-35797: When running in a venv, we bypass the redirect # executor and launch our base Python. if WINENV and _path_eq(python_exe, sys.executable): - python_exe = sys._base_executable + cmd[0] = python_exe = sys._base_executable env = os.environ.copy() env["__PYVENV_LAUNCHER__"] = sys.executable else: env = None + cmd = ' '.join('"%s"' % x for x in cmd) + with open(wfd, 'wb', closefd=True) as to_child: # start process try: @@ -99,18 +101,20 @@ def duplicate_for_child(self, handle): return reduction.duplicate(handle, self.sentinel) def wait(self, timeout=None): - if self.returncode is None: - if timeout is None: - msecs = _winapi.INFINITE - else: - msecs = max(0, int(timeout * 1000 + 0.5)) - - res = _winapi.WaitForSingleObject(int(self._handle), msecs) - if res == _winapi.WAIT_OBJECT_0: - code = _winapi.GetExitCodeProcess(self._handle) - if code == TERMINATE: - code = -signal.SIGTERM - self.returncode = code + if self.returncode is not None: + return self.returncode + + if timeout is None: + msecs = _winapi.INFINITE + else: + msecs = max(0, int(timeout * 1000 + 0.5)) + + res = _winapi.WaitForSingleObject(int(self._handle), msecs) + if res == _winapi.WAIT_OBJECT_0: + code = _winapi.GetExitCodeProcess(self._handle) + if code == TERMINATE: + code = -signal.SIGTERM + self.returncode = code return self.returncode @@ -118,12 +122,22 @@ def poll(self): return self.wait(timeout=0) def terminate(self): - if self.returncode is None: - try: - _winapi.TerminateProcess(int(self._handle), TERMINATE) - except OSError: - if self.wait(timeout=1.0) is None: - raise + if self.returncode is not None: + return + + try: + _winapi.TerminateProcess(int(self._handle), TERMINATE) + except PermissionError: + # ERROR_ACCESS_DENIED (winerror 5) is received when the + # process already died. + code = _winapi.GetExitCodeProcess(int(self._handle)) + if code == _winapi.STILL_ACTIVE: + raise + + # gh-113009: Don't set self.returncode. Even if GetExitCodeProcess() + # returns an exit code different than STILL_ACTIVE, the process can + # still be running. Only set self.returncode once WaitForSingleObject() + # returns WAIT_OBJECT_0 in wait(). kill = terminate diff --git a/Lib/multiprocessing/process.py b/Lib/multiprocessing/process.py index 0b2e0b45b2..271ba3fd32 100644 --- a/Lib/multiprocessing/process.py +++ b/Lib/multiprocessing/process.py @@ -61,7 +61,7 @@ def parent_process(): def _cleanup(): # check for processes which have finished for p in list(_children): - if p._popen.poll() is not None: + if (child_popen := p._popen) and child_popen.poll() is not None: _children.discard(p) # @@ -304,8 +304,7 @@ def _bootstrap(self, parent_sentinel=None): if threading._HAVE_THREAD_NATIVE_ID: threading.main_thread()._set_native_id() try: - util._finalizer_registry.clear() - util._run_after_forkers() + self._after_fork() finally: # delay finalization of the old process object until after # _run_after_forkers() is executed @@ -336,6 +335,13 @@ def _bootstrap(self, parent_sentinel=None): return exitcode + @staticmethod + def _after_fork(): + from . import util + util._finalizer_registry.clear() + util._run_after_forkers() + + # # We subclass bytes to avoid accidental transmission of auth keys over network # @@ -427,6 +433,7 @@ def close(self): for name, signum in list(signal.__dict__.items()): if name[:3]=='SIG' and '_' not in name: _exitcode_to_name[-signum] = f'-{name}' +del name, signum # For debug and leak testing _dangling = WeakSet() diff --git a/Lib/multiprocessing/queues.py b/Lib/multiprocessing/queues.py index f37f114a96..852ae87b27 100644 --- a/Lib/multiprocessing/queues.py +++ b/Lib/multiprocessing/queues.py @@ -158,6 +158,20 @@ def cancel_join_thread(self): except AttributeError: pass + def _terminate_broken(self): + # Close a Queue on error. + + # gh-94777: Prevent queue writing to a pipe which is no longer read. + self._reader.close() + + # gh-107219: Close the connection writer which can unblock + # Queue._feed() if it was stuck in send_bytes(). + if sys.platform == 'win32': + self._writer.close() + + self.close() + self.join_thread() + def _start_thread(self): debug('Queue._start_thread()') @@ -169,13 +183,19 @@ def _start_thread(self): self._wlock, self._reader.close, self._writer.close, self._ignore_epipe, self._on_queue_feeder_error, self._sem), - name='QueueFeederThread' + name='QueueFeederThread', + daemon=True, ) - self._thread.daemon = True - debug('doing self._thread.start()') - self._thread.start() - debug('... done self._thread.start()') + try: + debug('doing self._thread.start()') + self._thread.start() + debug('... done self._thread.start()') + except: + # gh-109047: During Python finalization, creating a thread + # can fail with RuntimeError. + self._thread = None + raise if not self._joincancelled: self._jointhread = Finalize( @@ -280,6 +300,8 @@ def _on_queue_feeder_error(e, obj): import traceback traceback.print_exc() + __class_getitem__ = classmethod(types.GenericAlias) + _sentinel = object() diff --git a/Lib/multiprocessing/resource_sharer.py b/Lib/multiprocessing/resource_sharer.py index 66076509a1..b8afb0fbed 100644 --- a/Lib/multiprocessing/resource_sharer.py +++ b/Lib/multiprocessing/resource_sharer.py @@ -123,7 +123,7 @@ def _start(self): from .connection import Listener assert self._listener is None, "Already have Listener" util.debug('starting listener and thread for sending handles') - self._listener = Listener(authkey=process.current_process().authkey) + self._listener = Listener(authkey=process.current_process().authkey, backlog=128) self._address = self._listener.address t = threading.Thread(target=self._serve) t.daemon = True diff --git a/Lib/multiprocessing/resource_tracker.py b/Lib/multiprocessing/resource_tracker.py index cc42dbdda0..79e96ecf32 100644 --- a/Lib/multiprocessing/resource_tracker.py +++ b/Lib/multiprocessing/resource_tracker.py @@ -51,15 +51,31 @@ }) +class ReentrantCallError(RuntimeError): + pass + + class ResourceTracker(object): def __init__(self): - self._lock = threading.Lock() + self._lock = threading.RLock() self._fd = None self._pid = None + def _reentrant_call_error(self): + # gh-109629: this happens if an explicit call to the ResourceTracker + # gets interrupted by a garbage collection, invoking a finalizer (*) + # that itself calls back into ResourceTracker. + # (*) for example the SemLock finalizer + raise ReentrantCallError( + "Reentrant call into the multiprocessing resource tracker") + def _stop(self): with self._lock: + # This should not happen (_stop() isn't called by a finalizer) + # but we check for it anyway. + if self._lock._recursion_count() > 1: + return self._reentrant_call_error() if self._fd is None: # not running return @@ -81,6 +97,9 @@ def ensure_running(self): This can be run from any process. Usually a child process will use the resource created by its parent.''' with self._lock: + if self._lock._recursion_count() > 1: + # The code below is certainly not reentrant-safe, so bail out + return self._reentrant_call_error() if self._fd is not None: # resource tracker was launched before, is it still running? if self._check_alive(): @@ -159,12 +178,22 @@ def unregister(self, name, rtype): self._send('UNREGISTER', name, rtype) def _send(self, cmd, name, rtype): - self.ensure_running() + try: + self.ensure_running() + except ReentrantCallError: + # The code below might or might not work, depending on whether + # the resource tracker was already running and still alive. + # Better warn the user. + # (XXX is warnings.warn itself reentrant-safe? :-) + warnings.warn( + f"ResourceTracker called reentrantly for resource cleanup, " + f"which is unsupported. " + f"The {rtype} object {name!r} might leak.") msg = '{0}:{1}:{2}\n'.format(cmd, name, rtype).encode('ascii') - if len(name) > 512: + if len(msg) > 512: # posix guarantees that writes to a pipe of less than PIPE_BUF # bytes are atomic, and that PIPE_BUF >= 512 - raise ValueError('name too long') + raise ValueError('msg too long') nbytes = os.write(self._fd, msg) assert nbytes == len(msg), "nbytes {0:n} but len(msg) {1:n}".format( nbytes, len(msg)) @@ -176,6 +205,7 @@ def _send(self, cmd, name, rtype): unregister = _resource_tracker.unregister getfd = _resource_tracker.getfd + def main(fd): '''Run resource tracker.''' # protect the process from ^C and "killall python" etc diff --git a/Lib/multiprocessing/shared_memory.py b/Lib/multiprocessing/shared_memory.py index 122b3fcebf..9a1e5aa17b 100644 --- a/Lib/multiprocessing/shared_memory.py +++ b/Lib/multiprocessing/shared_memory.py @@ -23,6 +23,7 @@ import _posixshmem _USE_POSIX = True +from . import resource_tracker _O_CREX = os.O_CREAT | os.O_EXCL @@ -116,8 +117,7 @@ def __init__(self, name=None, create=False, size=0): self.unlink() raise - from .resource_tracker import register - register(self._name, "shared_memory") + resource_tracker.register(self._name, "shared_memory") else: @@ -173,7 +173,10 @@ def __init__(self, name=None, create=False, size=0): ) finally: _winapi.CloseHandle(h_map) - size = _winapi.VirtualQuerySize(p_buf) + try: + size = _winapi.VirtualQuerySize(p_buf) + finally: + _winapi.UnmapViewOfFile(p_buf) self._mmap = mmap.mmap(-1, size, tagname=name) self._size = size @@ -237,9 +240,8 @@ def unlink(self): called once (and only once) across all processes which have access to the shared memory block.""" if _USE_POSIX and self._name: - from .resource_tracker import unregister _posixshmem.shm_unlink(self._name) - unregister(self._name, "shared_memory") + resource_tracker.unregister(self._name, "shared_memory") _encoding = "utf8" diff --git a/Lib/multiprocessing/spawn.py b/Lib/multiprocessing/spawn.py index 7cc129e261..daac1ecc34 100644 --- a/Lib/multiprocessing/spawn.py +++ b/Lib/multiprocessing/spawn.py @@ -31,20 +31,25 @@ WINSERVICE = False else: WINEXE = getattr(sys, 'frozen', False) - WINSERVICE = sys.executable.lower().endswith("pythonservice.exe") - -if WINSERVICE: - _python_exe = os.path.join(sys.exec_prefix, 'python.exe') -else: - _python_exe = sys.executable + WINSERVICE = sys.executable and sys.executable.lower().endswith("pythonservice.exe") def set_executable(exe): global _python_exe - _python_exe = exe + if exe is None: + _python_exe = exe + elif sys.platform == 'win32': + _python_exe = os.fsdecode(exe) + else: + _python_exe = os.fsencode(exe) def get_executable(): return _python_exe +if WINSERVICE: + set_executable(os.path.join(sys.exec_prefix, 'python.exe')) +else: + set_executable(sys.executable) + # # # @@ -86,7 +91,8 @@ def get_command_line(**kwds): prog = 'from multiprocessing.spawn import spawn_main; spawn_main(%s)' prog %= ', '.join('%s=%r' % item for item in kwds.items()) opts = util._args_from_interpreter_flags() - return [_python_exe] + opts + ['-c', prog, '--multiprocessing-fork'] + exe = get_executable() + return [exe] + opts + ['-c', prog, '--multiprocessing-fork'] def spawn_main(pipe_handle, parent_pid=None, tracker_fd=None): @@ -144,7 +150,11 @@ def _check_not_importing_main(): ... The "freeze_support()" line can be omitted if the program - is not going to be frozen to produce an executable.''') + is not going to be frozen to produce an executable. + + To fix this issue, refer to the "Safe importing of main module" + section in https://docs.python.org/3/library/multiprocessing.html + ''') def get_preparation_data(name): diff --git a/Lib/multiprocessing/synchronize.py b/Lib/multiprocessing/synchronize.py index d0be48f1fd..3ccbfe311c 100644 --- a/Lib/multiprocessing/synchronize.py +++ b/Lib/multiprocessing/synchronize.py @@ -50,8 +50,8 @@ class SemLock(object): def __init__(self, kind, value, maxvalue, *, ctx): if ctx is None: ctx = context._default_context.get_context() - name = ctx.get_start_method() - unlink_now = sys.platform == 'win32' or name == 'fork' + self._is_fork_ctx = ctx.get_start_method() == 'fork' + unlink_now = sys.platform == 'win32' or self._is_fork_ctx for i in range(100): try: sl = self._semlock = _multiprocessing.SemLock( @@ -103,6 +103,11 @@ def __getstate__(self): if sys.platform == 'win32': h = context.get_spawning_popen().duplicate_for_child(sl.handle) else: + if self._is_fork_ctx: + raise RuntimeError('A SemLock created in a fork context is being ' + 'shared with a process in a spawn context. This is ' + 'not supported. Please use the same context to create ' + 'multiprocessing objects and Process.') h = sl.handle return (h, sl.kind, sl.maxvalue, sl.name) @@ -110,6 +115,8 @@ def __setstate__(self, state): self._semlock = _multiprocessing.SemLock._rebuild(*state) util.debug('recreated blocker with handle %r' % state[0]) self._make_methods() + # Ensure that deserialized SemLock can be serialized again (gh-108520). + self._is_fork_ctx = False @staticmethod def _make_name(): @@ -353,6 +360,9 @@ def wait(self, timeout=None): return True return False + def __repr__(self) -> str: + set_status = 'set' if self.is_set() else 'unset' + return f"<{type(self).__qualname__} at {id(self):#x} {set_status}>" # # Barrier # diff --git a/Lib/multiprocessing/util.py b/Lib/multiprocessing/util.py index 9e07a4e93e..79559823fb 100644 --- a/Lib/multiprocessing/util.py +++ b/Lib/multiprocessing/util.py @@ -43,19 +43,19 @@ def sub_debug(msg, *args): if _logger: - _logger.log(SUBDEBUG, msg, *args) + _logger.log(SUBDEBUG, msg, *args, stacklevel=2) def debug(msg, *args): if _logger: - _logger.log(DEBUG, msg, *args) + _logger.log(DEBUG, msg, *args, stacklevel=2) def info(msg, *args): if _logger: - _logger.log(INFO, msg, *args) + _logger.log(INFO, msg, *args, stacklevel=2) def sub_warning(msg, *args): if _logger: - _logger.log(SUBWARNING, msg, *args) + _logger.log(SUBWARNING, msg, *args, stacklevel=2) def get_logger(): ''' @@ -130,7 +130,10 @@ def is_abstract_socket_namespace(address): # def _remove_temp_dir(rmtree, tempdir): - rmtree(tempdir) + def onerror(func, path, err_info): + if not issubclass(err_info[0], FileNotFoundError): + raise + rmtree(tempdir, onerror=onerror) current_process = process.current_process() # current_process() can be None if the finalizer is called @@ -446,13 +449,15 @@ def _flush_std_streams(): def spawnv_passfds(path, args, passfds): import _posixsubprocess + import subprocess passfds = tuple(sorted(map(int, passfds))) errpipe_read, errpipe_write = os.pipe() try: return _posixsubprocess.fork_exec( - args, [os.fsencode(path)], True, passfds, None, None, + args, [path], True, passfds, None, None, -1, -1, -1, -1, -1, -1, errpipe_read, errpipe_write, - False, False, None, None, None, -1, None) + False, False, -1, None, None, None, -1, None, + subprocess._USE_VFORK) finally: os.close(errpipe_read) os.close(errpipe_write) diff --git a/Lib/test/_test_multiprocessing.py b/Lib/test/_test_multiprocessing.py new file mode 100644 index 0000000000..9e688efb1e --- /dev/null +++ b/Lib/test/_test_multiprocessing.py @@ -0,0 +1,6307 @@ +# +# Unit tests for the multiprocessing package +# + +import unittest +import unittest.mock +import queue as pyqueue +import textwrap +import time +import io +import itertools +import sys +import os +import gc +import errno +import functools +import signal +import array +import socket +import random +import logging +import subprocess +import struct +import operator +import pathlib +import pickle +import weakref +import warnings +import test.support +import test.support.script_helper +from test import support +from test.support import hashlib_helper +from test.support import import_helper +from test.support import os_helper +from test.support import script_helper +from test.support import socket_helper +from test.support import threading_helper +from test.support import warnings_helper + + +# Skip tests if _multiprocessing wasn't built. +_multiprocessing = import_helper.import_module('_multiprocessing') +# Skip tests if sem_open implementation is broken. +support.skip_if_broken_multiprocessing_synchronize() +import threading + +import multiprocessing.connection +import multiprocessing.dummy +import multiprocessing.heap +import multiprocessing.managers +import multiprocessing.pool +import multiprocessing.queues +from multiprocessing.connection import wait, AuthenticationError + +from multiprocessing import util + +try: + from multiprocessing import reduction + HAS_REDUCTION = reduction.HAVE_SEND_HANDLE +except ImportError: + HAS_REDUCTION = False + +try: + from multiprocessing.sharedctypes import Value, copy + HAS_SHAREDCTYPES = True +except ImportError: + HAS_SHAREDCTYPES = False + +try: + from multiprocessing import shared_memory + HAS_SHMEM = True +except ImportError: + HAS_SHMEM = False + +try: + import msvcrt +except ImportError: + msvcrt = None + + +if support.HAVE_ASAN_FORK_BUG: + # gh-89363: Skip multiprocessing tests if Python is built with ASAN to + # work around a libasan race condition: dead lock in pthread_create(). + raise unittest.SkipTest("libasan has a pthread_create() dead lock related to thread+fork") + + +# gh-110666: Tolerate a difference of 100 ms when comparing timings +# (clock resolution) +CLOCK_RES = 0.100 + + +def latin(s): + return s.encode('latin') + + +def close_queue(queue): + if isinstance(queue, multiprocessing.queues.Queue): + queue.close() + queue.join_thread() + + +def join_process(process): + # Since multiprocessing.Process has the same API than threading.Thread + # (join() and is_alive(), the support function can be reused + threading_helper.join_thread(process) + + +if os.name == "posix": + from multiprocessing import resource_tracker + + def _resource_unlink(name, rtype): + resource_tracker._CLEANUP_FUNCS[rtype](name) + + +# +# Constants +# + +LOG_LEVEL = util.SUBWARNING +#LOG_LEVEL = logging.DEBUG + +DELTA = 0.1 +CHECK_TIMINGS = False # making true makes tests take a lot longer + # and can sometimes cause some non-serious + # failures because some calls block a bit + # longer than expected +if CHECK_TIMINGS: + TIMEOUT1, TIMEOUT2, TIMEOUT3 = 0.82, 0.35, 1.4 +else: + TIMEOUT1, TIMEOUT2, TIMEOUT3 = 0.1, 0.1, 0.1 + +# BaseManager.shutdown_timeout +SHUTDOWN_TIMEOUT = support.SHORT_TIMEOUT + +WAIT_ACTIVE_CHILDREN_TIMEOUT = 5.0 + +HAVE_GETVALUE = not getattr(_multiprocessing, + 'HAVE_BROKEN_SEM_GETVALUE', False) + +WIN32 = (sys.platform == "win32") + +def wait_for_handle(handle, timeout): + if timeout is not None and timeout < 0.0: + timeout = None + return wait([handle], timeout) + +try: + MAXFD = os.sysconf("SC_OPEN_MAX") +except: + MAXFD = 256 + +# To speed up tests when using the forkserver, we can preload these: +PRELOAD = ['__main__', 'test.test_multiprocessing_forkserver'] + +# +# Some tests require ctypes +# + +try: + from ctypes import Structure, c_int, c_double, c_longlong +except ImportError: + Structure = object + c_int = c_double = c_longlong = None + + +def check_enough_semaphores(): + """Check that the system supports enough semaphores to run the test.""" + # minimum number of semaphores available according to POSIX + nsems_min = 256 + try: + nsems = os.sysconf("SC_SEM_NSEMS_MAX") + except (AttributeError, ValueError): + # sysconf not available or setting not available + return + if nsems == -1 or nsems >= nsems_min: + return + raise unittest.SkipTest("The OS doesn't support enough semaphores " + "to run the test (required: %d)." % nsems_min) + + +def only_run_in_spawn_testsuite(reason): + """Returns a decorator: raises SkipTest when SM != spawn at test time. + + This can be useful to save overall Python test suite execution time. + "spawn" is the universal mode available on all platforms so this limits the + decorated test to only execute within test_multiprocessing_spawn. + + This would not be necessary if we refactored our test suite to split things + into other test files when they are not start method specific to be rerun + under all start methods. + """ + + def decorator(test_item): + + @functools.wraps(test_item) + def spawn_check_wrapper(*args, **kwargs): + if (start_method := multiprocessing.get_start_method()) != "spawn": + raise unittest.SkipTest(f"{start_method=}, not 'spawn'; {reason}") + return test_item(*args, **kwargs) + + return spawn_check_wrapper + + return decorator + + +class TestInternalDecorators(unittest.TestCase): + """Logic within a test suite that could errantly skip tests? Test it!""" + + @unittest.skipIf(sys.platform == "win32", "test requires that fork exists.") + def test_only_run_in_spawn_testsuite(self): + if multiprocessing.get_start_method() != "spawn": + raise unittest.SkipTest("only run in test_multiprocessing_spawn.") + + try: + @only_run_in_spawn_testsuite("testing this decorator") + def return_four_if_spawn(): + return 4 + except Exception as err: + self.fail(f"expected decorated `def` not to raise; caught {err}") + + orig_start_method = multiprocessing.get_start_method(allow_none=True) + try: + multiprocessing.set_start_method("spawn", force=True) + self.assertEqual(return_four_if_spawn(), 4) + multiprocessing.set_start_method("fork", force=True) + with self.assertRaises(unittest.SkipTest) as ctx: + return_four_if_spawn() + self.assertIn("testing this decorator", str(ctx.exception)) + self.assertIn("start_method=", str(ctx.exception)) + finally: + multiprocessing.set_start_method(orig_start_method, force=True) + + +# +# Creates a wrapper for a function which records the time it takes to finish +# + +class TimingWrapper(object): + + def __init__(self, func): + self.func = func + self.elapsed = None + + def __call__(self, *args, **kwds): + t = time.monotonic() + try: + return self.func(*args, **kwds) + finally: + self.elapsed = time.monotonic() - t + +# +# Base class for test cases +# + +class BaseTestCase(object): + + ALLOWED_TYPES = ('processes', 'manager', 'threads') + + def assertTimingAlmostEqual(self, a, b): + if CHECK_TIMINGS: + self.assertAlmostEqual(a, b, 1) + + def assertReturnsIfImplemented(self, value, func, *args): + try: + res = func(*args) + except NotImplementedError: + pass + else: + return self.assertEqual(value, res) + + # For the sanity of Windows users, rather than crashing or freezing in + # multiple ways. + def __reduce__(self, *args): + raise NotImplementedError("shouldn't try to pickle a test case") + + __reduce_ex__ = __reduce__ + +# +# Return the value of a semaphore +# + +def get_value(self): + try: + return self.get_value() + except AttributeError: + try: + return self._Semaphore__value + except AttributeError: + try: + return self._value + except AttributeError: + raise NotImplementedError + +# +# Testcases +# + +class DummyCallable: + def __call__(self, q, c): + assert isinstance(c, DummyCallable) + q.put(5) + + +class _TestProcess(BaseTestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + def test_current(self): + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + current = self.current_process() + authkey = current.authkey + + self.assertTrue(current.is_alive()) + self.assertTrue(not current.daemon) + self.assertIsInstance(authkey, bytes) + self.assertTrue(len(authkey) > 0) + self.assertEqual(current.ident, os.getpid()) + self.assertEqual(current.exitcode, None) + + def test_set_executable(self): + if self.TYPE == 'threads': + self.skipTest(f'test not appropriate for {self.TYPE}') + paths = [ + sys.executable, # str + sys.executable.encode(), # bytes + pathlib.Path(sys.executable) # os.PathLike + ] + for path in paths: + self.set_executable(path) + p = self.Process() + p.start() + p.join() + self.assertEqual(p.exitcode, 0) + + @support.requires_resource('cpu') + def test_args_argument(self): + # bpo-45735: Using list or tuple as *args* in constructor could + # achieve the same effect. + args_cases = (1, "str", [1], (1,)) + args_types = (list, tuple) + + test_cases = itertools.product(args_cases, args_types) + + for args, args_type in test_cases: + with self.subTest(args=args, args_type=args_type): + q = self.Queue(1) + # pass a tuple or list as args + p = self.Process(target=self._test_args, args=args_type((q, args))) + p.daemon = True + p.start() + child_args = q.get() + self.assertEqual(child_args, args) + p.join() + close_queue(q) + + @classmethod + def _test_args(cls, q, arg): + q.put(arg) + + def test_daemon_argument(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + # By default uses the current process's daemon flag. + proc0 = self.Process(target=self._test) + self.assertEqual(proc0.daemon, self.current_process().daemon) + proc1 = self.Process(target=self._test, daemon=True) + self.assertTrue(proc1.daemon) + proc2 = self.Process(target=self._test, daemon=False) + self.assertFalse(proc2.daemon) + + @classmethod + def _test(cls, q, *args, **kwds): + current = cls.current_process() + q.put(args) + q.put(kwds) + q.put(current.name) + if cls.TYPE != 'threads': + q.put(bytes(current.authkey)) + q.put(current.pid) + + def test_parent_process_attributes(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + self.assertIsNone(self.parent_process()) + + rconn, wconn = self.Pipe(duplex=False) + p = self.Process(target=self._test_send_parent_process, args=(wconn,)) + p.start() + p.join() + parent_pid, parent_name = rconn.recv() + self.assertEqual(parent_pid, self.current_process().pid) + self.assertEqual(parent_pid, os.getpid()) + self.assertEqual(parent_name, self.current_process().name) + + @classmethod + def _test_send_parent_process(cls, wconn): + from multiprocessing.process import parent_process + wconn.send([parent_process().pid, parent_process().name]) + + def test_parent_process(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + # Launch a child process. Make it launch a grandchild process. Kill the + # child process and make sure that the grandchild notices the death of + # its parent (a.k.a the child process). + rconn, wconn = self.Pipe(duplex=False) + p = self.Process( + target=self._test_create_grandchild_process, args=(wconn, )) + p.start() + + if not rconn.poll(timeout=support.LONG_TIMEOUT): + raise AssertionError("Could not communicate with child process") + parent_process_status = rconn.recv() + self.assertEqual(parent_process_status, "alive") + + p.terminate() + p.join() + + if not rconn.poll(timeout=support.LONG_TIMEOUT): + raise AssertionError("Could not communicate with child process") + parent_process_status = rconn.recv() + self.assertEqual(parent_process_status, "not alive") + + @classmethod + def _test_create_grandchild_process(cls, wconn): + p = cls.Process(target=cls._test_report_parent_status, args=(wconn, )) + p.start() + time.sleep(300) + + @classmethod + def _test_report_parent_status(cls, wconn): + from multiprocessing.process import parent_process + wconn.send("alive" if parent_process().is_alive() else "not alive") + parent_process().join(timeout=support.SHORT_TIMEOUT) + wconn.send("alive" if parent_process().is_alive() else "not alive") + + def test_process(self): + q = self.Queue(1) + e = self.Event() + args = (q, 1, 2) + kwargs = {'hello':23, 'bye':2.54} + name = 'SomeProcess' + p = self.Process( + target=self._test, args=args, kwargs=kwargs, name=name + ) + p.daemon = True + current = self.current_process() + + if self.TYPE != 'threads': + self.assertEqual(p.authkey, current.authkey) + self.assertEqual(p.is_alive(), False) + self.assertEqual(p.daemon, True) + self.assertNotIn(p, self.active_children()) + self.assertTrue(type(self.active_children()) is list) + self.assertEqual(p.exitcode, None) + + p.start() + + self.assertEqual(p.exitcode, None) + self.assertEqual(p.is_alive(), True) + self.assertIn(p, self.active_children()) + + self.assertEqual(q.get(), args[1:]) + self.assertEqual(q.get(), kwargs) + self.assertEqual(q.get(), p.name) + if self.TYPE != 'threads': + self.assertEqual(q.get(), current.authkey) + self.assertEqual(q.get(), p.pid) + + p.join() + + self.assertEqual(p.exitcode, 0) + self.assertEqual(p.is_alive(), False) + self.assertNotIn(p, self.active_children()) + close_queue(q) + + @unittest.skipUnless(threading._HAVE_THREAD_NATIVE_ID, "needs native_id") + def test_process_mainthread_native_id(self): + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + current_mainthread_native_id = threading.main_thread().native_id + + q = self.Queue(1) + p = self.Process(target=self._test_process_mainthread_native_id, args=(q,)) + p.start() + + child_mainthread_native_id = q.get() + p.join() + close_queue(q) + + self.assertNotEqual(current_mainthread_native_id, child_mainthread_native_id) + + @classmethod + def _test_process_mainthread_native_id(cls, q): + mainthread_native_id = threading.main_thread().native_id + q.put(mainthread_native_id) + + @classmethod + def _sleep_some(cls): + time.sleep(100) + + @classmethod + def _test_sleep(cls, delay): + time.sleep(delay) + + def _kill_process(self, meth): + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + p = self.Process(target=self._sleep_some) + p.daemon = True + p.start() + + self.assertEqual(p.is_alive(), True) + self.assertIn(p, self.active_children()) + self.assertEqual(p.exitcode, None) + + join = TimingWrapper(p.join) + + self.assertEqual(join(0), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + + self.assertEqual(join(-1), None) + self.assertTimingAlmostEqual(join.elapsed, 0.0) + self.assertEqual(p.is_alive(), True) + + # XXX maybe terminating too soon causes the problems on Gentoo... + time.sleep(1) + + meth(p) + + if hasattr(signal, 'alarm'): + # On the Gentoo buildbot waitpid() often seems to block forever. + # We use alarm() to interrupt it if it blocks for too long. + def handler(*args): + raise RuntimeError('join took too long: %s' % p) + old_handler = signal.signal(signal.SIGALRM, handler) + try: + signal.alarm(10) + self.assertEqual(join(), None) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, old_handler) + else: + self.assertEqual(join(), None) + + self.assertTimingAlmostEqual(join.elapsed, 0.0) + + self.assertEqual(p.is_alive(), False) + self.assertNotIn(p, self.active_children()) + + p.join() + + return p.exitcode + + def test_terminate(self): + exitcode = self._kill_process(multiprocessing.Process.terminate) + self.assertEqual(exitcode, -signal.SIGTERM) + + def test_kill(self): + exitcode = self._kill_process(multiprocessing.Process.kill) + if os.name != 'nt': + self.assertEqual(exitcode, -signal.SIGKILL) + else: + self.assertEqual(exitcode, -signal.SIGTERM) + + def test_cpu_count(self): + try: + cpus = multiprocessing.cpu_count() + except NotImplementedError: + cpus = 1 + self.assertTrue(type(cpus) is int) + self.assertTrue(cpus >= 1) + + def test_active_children(self): + self.assertEqual(type(self.active_children()), list) + + p = self.Process(target=time.sleep, args=(DELTA,)) + self.assertNotIn(p, self.active_children()) + + p.daemon = True + p.start() + self.assertIn(p, self.active_children()) + + p.join() + self.assertNotIn(p, self.active_children()) + + @classmethod + def _test_recursion(cls, wconn, id): + wconn.send(id) + if len(id) < 2: + for i in range(2): + p = cls.Process( + target=cls._test_recursion, args=(wconn, id+[i]) + ) + p.start() + p.join() + + def test_recursion(self): + rconn, wconn = self.Pipe(duplex=False) + self._test_recursion(wconn, []) + + time.sleep(DELTA) + result = [] + while rconn.poll(): + result.append(rconn.recv()) + + expected = [ + [], + [0], + [0, 0], + [0, 1], + [1], + [1, 0], + [1, 1] + ] + self.assertEqual(result, expected) + + @classmethod + def _test_sentinel(cls, event): + event.wait(10.0) + + def test_sentinel(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + event = self.Event() + p = self.Process(target=self._test_sentinel, args=(event,)) + with self.assertRaises(ValueError): + p.sentinel + p.start() + self.addCleanup(p.join) + sentinel = p.sentinel + self.assertIsInstance(sentinel, int) + self.assertFalse(wait_for_handle(sentinel, timeout=0.0)) + event.set() + p.join() + self.assertTrue(wait_for_handle(sentinel, timeout=1)) + + @classmethod + def _test_close(cls, rc=0, q=None): + if q is not None: + q.get() + sys.exit(rc) + + def test_close(self): + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + q = self.Queue() + p = self.Process(target=self._test_close, kwargs={'q': q}) + p.daemon = True + p.start() + self.assertEqual(p.is_alive(), True) + # Child is still alive, cannot close + with self.assertRaises(ValueError): + p.close() + + q.put(None) + p.join() + self.assertEqual(p.is_alive(), False) + self.assertEqual(p.exitcode, 0) + p.close() + with self.assertRaises(ValueError): + p.is_alive() + with self.assertRaises(ValueError): + p.join() + with self.assertRaises(ValueError): + p.terminate() + p.close() + + wr = weakref.ref(p) + del p + gc.collect() + self.assertIs(wr(), None) + + close_queue(q) + + @support.requires_resource('walltime') + def test_many_processes(self): + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + sm = multiprocessing.get_start_method() + N = 5 if sm == 'spawn' else 100 + + # Try to overwhelm the forkserver loop with events + procs = [self.Process(target=self._test_sleep, args=(0.01,)) + for i in range(N)] + for p in procs: + p.start() + for p in procs: + join_process(p) + for p in procs: + self.assertEqual(p.exitcode, 0) + + procs = [self.Process(target=self._sleep_some) + for i in range(N)] + for p in procs: + p.start() + time.sleep(0.001) # let the children start... + for p in procs: + p.terminate() + for p in procs: + join_process(p) + if os.name != 'nt': + exitcodes = [-signal.SIGTERM] + if sys.platform == 'darwin': + # bpo-31510: On macOS, killing a freshly started process with + # SIGTERM sometimes kills the process with SIGKILL. + exitcodes.append(-signal.SIGKILL) + for p in procs: + self.assertIn(p.exitcode, exitcodes) + + def test_lose_target_ref(self): + c = DummyCallable() + wr = weakref.ref(c) + q = self.Queue() + p = self.Process(target=c, args=(q, c)) + del c + p.start() + p.join() + gc.collect() # For PyPy or other GCs. + self.assertIs(wr(), None) + self.assertEqual(q.get(), 5) + close_queue(q) + + @classmethod + def _test_child_fd_inflation(self, evt, q): + q.put(os_helper.fd_count()) + evt.wait() + + def test_child_fd_inflation(self): + # Number of fds in child processes should not grow with the + # number of running children. + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + sm = multiprocessing.get_start_method() + if sm == 'fork': + # The fork method by design inherits all fds from the parent, + # trying to go against it is a lost battle + self.skipTest('test not appropriate for {}'.format(sm)) + + N = 5 + evt = self.Event() + q = self.Queue() + + procs = [self.Process(target=self._test_child_fd_inflation, args=(evt, q)) + for i in range(N)] + for p in procs: + p.start() + + try: + fd_counts = [q.get() for i in range(N)] + self.assertEqual(len(set(fd_counts)), 1, fd_counts) + + finally: + evt.set() + for p in procs: + p.join() + close_queue(q) + + @classmethod + def _test_wait_for_threads(self, evt): + def func1(): + time.sleep(0.5) + evt.set() + + def func2(): + time.sleep(20) + evt.clear() + + threading.Thread(target=func1).start() + threading.Thread(target=func2, daemon=True).start() + + def test_wait_for_threads(self): + # A child process should wait for non-daemonic threads to end + # before exiting + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + evt = self.Event() + proc = self.Process(target=self._test_wait_for_threads, args=(evt,)) + proc.start() + proc.join() + self.assertTrue(evt.is_set()) + + @classmethod + def _test_error_on_stdio_flush(self, evt, break_std_streams={}): + for stream_name, action in break_std_streams.items(): + if action == 'close': + stream = io.StringIO() + stream.close() + else: + assert action == 'remove' + stream = None + setattr(sys, stream_name, None) + evt.set() + + def test_error_on_stdio_flush_1(self): + # Check that Process works with broken standard streams + streams = [io.StringIO(), None] + streams[0].close() + for stream_name in ('stdout', 'stderr'): + for stream in streams: + old_stream = getattr(sys, stream_name) + setattr(sys, stream_name, stream) + try: + evt = self.Event() + proc = self.Process(target=self._test_error_on_stdio_flush, + args=(evt,)) + proc.start() + proc.join() + self.assertTrue(evt.is_set()) + self.assertEqual(proc.exitcode, 0) + finally: + setattr(sys, stream_name, old_stream) + + def test_error_on_stdio_flush_2(self): + # Same as test_error_on_stdio_flush_1(), but standard streams are + # broken by the child process + for stream_name in ('stdout', 'stderr'): + for action in ('close', 'remove'): + old_stream = getattr(sys, stream_name) + try: + evt = self.Event() + proc = self.Process(target=self._test_error_on_stdio_flush, + args=(evt, {stream_name: action})) + proc.start() + proc.join() + self.assertTrue(evt.is_set()) + self.assertEqual(proc.exitcode, 0) + finally: + setattr(sys, stream_name, old_stream) + + @classmethod + def _sleep_and_set_event(self, evt, delay=0.0): + time.sleep(delay) + evt.set() + + def check_forkserver_death(self, signum): + # bpo-31308: if the forkserver process has died, we should still + # be able to create and run new Process instances (the forkserver + # is implicitly restarted). + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + sm = multiprocessing.get_start_method() + if sm != 'forkserver': + # The fork method by design inherits all fds from the parent, + # trying to go against it is a lost battle + self.skipTest('test not appropriate for {}'.format(sm)) + + from multiprocessing.forkserver import _forkserver + _forkserver.ensure_running() + + # First process sleeps 500 ms + delay = 0.5 + + evt = self.Event() + proc = self.Process(target=self._sleep_and_set_event, args=(evt, delay)) + proc.start() + + pid = _forkserver._forkserver_pid + os.kill(pid, signum) + # give time to the fork server to die and time to proc to complete + time.sleep(delay * 2.0) + + evt2 = self.Event() + proc2 = self.Process(target=self._sleep_and_set_event, args=(evt2,)) + proc2.start() + proc2.join() + self.assertTrue(evt2.is_set()) + self.assertEqual(proc2.exitcode, 0) + + proc.join() + self.assertTrue(evt.is_set()) + self.assertIn(proc.exitcode, (0, 255)) + + def test_forkserver_sigint(self): + # Catchable signal + self.check_forkserver_death(signal.SIGINT) + + def test_forkserver_sigkill(self): + # Uncatchable signal + if os.name != 'nt': + self.check_forkserver_death(signal.SIGKILL) + + +# +# +# + +class _UpperCaser(multiprocessing.Process): + + def __init__(self): + multiprocessing.Process.__init__(self) + self.child_conn, self.parent_conn = multiprocessing.Pipe() + + def run(self): + self.parent_conn.close() + for s in iter(self.child_conn.recv, None): + self.child_conn.send(s.upper()) + self.child_conn.close() + + def submit(self, s): + assert type(s) is str + self.parent_conn.send(s) + return self.parent_conn.recv() + + def stop(self): + self.parent_conn.send(None) + self.parent_conn.close() + self.child_conn.close() + +class _TestSubclassingProcess(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def test_subclassing(self): + uppercaser = _UpperCaser() + uppercaser.daemon = True + uppercaser.start() + self.assertEqual(uppercaser.submit('hello'), 'HELLO') + self.assertEqual(uppercaser.submit('world'), 'WORLD') + uppercaser.stop() + uppercaser.join() + + def test_stderr_flush(self): + # sys.stderr is flushed at process shutdown (issue #13812) + if self.TYPE == "threads": + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + testfn = os_helper.TESTFN + self.addCleanup(os_helper.unlink, testfn) + proc = self.Process(target=self._test_stderr_flush, args=(testfn,)) + proc.start() + proc.join() + with open(testfn, encoding="utf-8") as f: + err = f.read() + # The whole traceback was printed + self.assertIn("ZeroDivisionError", err) + self.assertIn("test_multiprocessing.py", err) + self.assertIn("1/0 # MARKER", err) + + @classmethod + def _test_stderr_flush(cls, testfn): + fd = os.open(testfn, os.O_WRONLY | os.O_CREAT | os.O_EXCL) + sys.stderr = open(fd, 'w', encoding="utf-8", closefd=False) + 1/0 # MARKER + + + @classmethod + def _test_sys_exit(cls, reason, testfn): + fd = os.open(testfn, os.O_WRONLY | os.O_CREAT | os.O_EXCL) + sys.stderr = open(fd, 'w', encoding="utf-8", closefd=False) + sys.exit(reason) + + def test_sys_exit(self): + # See Issue 13854 + if self.TYPE == 'threads': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + testfn = os_helper.TESTFN + self.addCleanup(os_helper.unlink, testfn) + + for reason in ( + [1, 2, 3], + 'ignore this', + ): + p = self.Process(target=self._test_sys_exit, args=(reason, testfn)) + p.daemon = True + p.start() + join_process(p) + self.assertEqual(p.exitcode, 1) + + with open(testfn, encoding="utf-8") as f: + content = f.read() + self.assertEqual(content.rstrip(), str(reason)) + + os.unlink(testfn) + + cases = [ + ((True,), 1), + ((False,), 0), + ((8,), 8), + ((None,), 0), + ((), 0), + ] + + for args, expected in cases: + with self.subTest(args=args): + p = self.Process(target=sys.exit, args=args) + p.daemon = True + p.start() + join_process(p) + self.assertEqual(p.exitcode, expected) + +# +# +# + +def queue_empty(q): + if hasattr(q, 'empty'): + return q.empty() + else: + return q.qsize() == 0 + +def queue_full(q, maxsize): + if hasattr(q, 'full'): + return q.full() + else: + return q.qsize() == maxsize + + +class _TestQueue(BaseTestCase): + + + @classmethod + def _test_put(cls, queue, child_can_start, parent_can_continue): + child_can_start.wait() + for i in range(6): + queue.get() + parent_can_continue.set() + + def test_put(self): + MAXSIZE = 6 + queue = self.Queue(maxsize=MAXSIZE) + child_can_start = self.Event() + parent_can_continue = self.Event() + + proc = self.Process( + target=self._test_put, + args=(queue, child_can_start, parent_can_continue) + ) + proc.daemon = True + proc.start() + + self.assertEqual(queue_empty(queue), True) + self.assertEqual(queue_full(queue, MAXSIZE), False) + + queue.put(1) + queue.put(2, True) + queue.put(3, True, None) + queue.put(4, False) + queue.put(5, False, None) + queue.put_nowait(6) + + # the values may be in buffer but not yet in pipe so sleep a bit + time.sleep(DELTA) + + self.assertEqual(queue_empty(queue), False) + self.assertEqual(queue_full(queue, MAXSIZE), True) + + put = TimingWrapper(queue.put) + put_nowait = TimingWrapper(queue.put_nowait) + + self.assertRaises(pyqueue.Full, put, 7, False) + self.assertTimingAlmostEqual(put.elapsed, 0) + + self.assertRaises(pyqueue.Full, put, 7, False, None) + self.assertTimingAlmostEqual(put.elapsed, 0) + + self.assertRaises(pyqueue.Full, put_nowait, 7) + self.assertTimingAlmostEqual(put_nowait.elapsed, 0) + + self.assertRaises(pyqueue.Full, put, 7, True, TIMEOUT1) + self.assertTimingAlmostEqual(put.elapsed, TIMEOUT1) + + self.assertRaises(pyqueue.Full, put, 7, False, TIMEOUT2) + self.assertTimingAlmostEqual(put.elapsed, 0) + + self.assertRaises(pyqueue.Full, put, 7, True, timeout=TIMEOUT3) + self.assertTimingAlmostEqual(put.elapsed, TIMEOUT3) + + child_can_start.set() + parent_can_continue.wait() + + self.assertEqual(queue_empty(queue), True) + self.assertEqual(queue_full(queue, MAXSIZE), False) + + proc.join() + close_queue(queue) + + @classmethod + def _test_get(cls, queue, child_can_start, parent_can_continue): + child_can_start.wait() + #queue.put(1) + queue.put(2) + queue.put(3) + queue.put(4) + queue.put(5) + parent_can_continue.set() + + def test_get(self): + queue = self.Queue() + child_can_start = self.Event() + parent_can_continue = self.Event() + + proc = self.Process( + target=self._test_get, + args=(queue, child_can_start, parent_can_continue) + ) + proc.daemon = True + proc.start() + + self.assertEqual(queue_empty(queue), True) + + child_can_start.set() + parent_can_continue.wait() + + time.sleep(DELTA) + self.assertEqual(queue_empty(queue), False) + + # Hangs unexpectedly, remove for now + #self.assertEqual(queue.get(), 1) + self.assertEqual(queue.get(True, None), 2) + self.assertEqual(queue.get(True), 3) + self.assertEqual(queue.get(timeout=1), 4) + self.assertEqual(queue.get_nowait(), 5) + + self.assertEqual(queue_empty(queue), True) + + get = TimingWrapper(queue.get) + get_nowait = TimingWrapper(queue.get_nowait) + + self.assertRaises(pyqueue.Empty, get, False) + self.assertTimingAlmostEqual(get.elapsed, 0) + + self.assertRaises(pyqueue.Empty, get, False, None) + self.assertTimingAlmostEqual(get.elapsed, 0) + + self.assertRaises(pyqueue.Empty, get_nowait) + self.assertTimingAlmostEqual(get_nowait.elapsed, 0) + + self.assertRaises(pyqueue.Empty, get, True, TIMEOUT1) + self.assertTimingAlmostEqual(get.elapsed, TIMEOUT1) + + self.assertRaises(pyqueue.Empty, get, False, TIMEOUT2) + self.assertTimingAlmostEqual(get.elapsed, 0) + + self.assertRaises(pyqueue.Empty, get, timeout=TIMEOUT3) + self.assertTimingAlmostEqual(get.elapsed, TIMEOUT3) + + proc.join() + close_queue(queue) + + @classmethod + def _test_fork(cls, queue): + for i in range(10, 20): + queue.put(i) + # note that at this point the items may only be buffered, so the + # process cannot shutdown until the feeder thread has finished + # pushing items onto the pipe. + + def test_fork(self): + # Old versions of Queue would fail to create a new feeder + # thread for a forked process if the original process had its + # own feeder thread. This test checks that this no longer + # happens. + + queue = self.Queue() + + # put items on queue so that main process starts a feeder thread + for i in range(10): + queue.put(i) + + # wait to make sure thread starts before we fork a new process + time.sleep(DELTA) + + # fork process + p = self.Process(target=self._test_fork, args=(queue,)) + p.daemon = True + p.start() + + # check that all expected items are in the queue + for i in range(20): + self.assertEqual(queue.get(), i) + self.assertRaises(pyqueue.Empty, queue.get, False) + + p.join() + close_queue(queue) + + def test_qsize(self): + q = self.Queue() + try: + self.assertEqual(q.qsize(), 0) + except NotImplementedError: + self.skipTest('qsize method not implemented') + q.put(1) + self.assertEqual(q.qsize(), 1) + q.put(5) + self.assertEqual(q.qsize(), 2) + q.get() + self.assertEqual(q.qsize(), 1) + q.get() + self.assertEqual(q.qsize(), 0) + close_queue(q) + + @classmethod + def _test_task_done(cls, q): + for obj in iter(q.get, None): + time.sleep(DELTA) + q.task_done() + + def test_task_done(self): + queue = self.JoinableQueue() + + workers = [self.Process(target=self._test_task_done, args=(queue,)) + for i in range(4)] + + for p in workers: + p.daemon = True + p.start() + + for i in range(10): + queue.put(i) + + queue.join() + + for p in workers: + queue.put(None) + + for p in workers: + p.join() + close_queue(queue) + + def test_no_import_lock_contention(self): + with os_helper.temp_cwd(): + module_name = 'imported_by_an_imported_module' + with open(module_name + '.py', 'w', encoding="utf-8") as f: + f.write("""if 1: + import multiprocessing + + q = multiprocessing.Queue() + q.put('knock knock') + q.get(timeout=3) + q.close() + del q + """) + + with import_helper.DirsOnSysPath(os.getcwd()): + try: + __import__(module_name) + except pyqueue.Empty: + self.fail("Probable regression on import lock contention;" + " see Issue #22853") + + def test_timeout(self): + q = multiprocessing.Queue() + start = time.monotonic() + self.assertRaises(pyqueue.Empty, q.get, True, 0.200) + delta = time.monotonic() - start + # bpo-30317: Tolerate a delta of 100 ms because of the bad clock + # resolution on Windows (usually 15.6 ms). x86 Windows7 3.x once + # failed because the delta was only 135.8 ms. + self.assertGreaterEqual(delta, 0.100) + close_queue(q) + + def test_queue_feeder_donot_stop_onexc(self): + # bpo-30414: verify feeder handles exceptions correctly + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + class NotSerializable(object): + def __reduce__(self): + raise AttributeError + with test.support.captured_stderr(): + q = self.Queue() + q.put(NotSerializable()) + q.put(True) + self.assertTrue(q.get(timeout=support.SHORT_TIMEOUT)) + close_queue(q) + + with test.support.captured_stderr(): + # bpo-33078: verify that the queue size is correctly handled + # on errors. + q = self.Queue(maxsize=1) + q.put(NotSerializable()) + q.put(True) + try: + self.assertEqual(q.qsize(), 1) + except NotImplementedError: + # qsize is not available on all platform as it + # relies on sem_getvalue + pass + self.assertTrue(q.get(timeout=support.SHORT_TIMEOUT)) + # Check that the size of the queue is correct + self.assertTrue(q.empty()) + close_queue(q) + + def test_queue_feeder_on_queue_feeder_error(self): + # bpo-30006: verify feeder handles exceptions using the + # _on_queue_feeder_error hook. + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + class NotSerializable(object): + """Mock unserializable object""" + def __init__(self): + self.reduce_was_called = False + self.on_queue_feeder_error_was_called = False + + def __reduce__(self): + self.reduce_was_called = True + raise AttributeError + + class SafeQueue(multiprocessing.queues.Queue): + """Queue with overloaded _on_queue_feeder_error hook""" + @staticmethod + def _on_queue_feeder_error(e, obj): + if (isinstance(e, AttributeError) and + isinstance(obj, NotSerializable)): + obj.on_queue_feeder_error_was_called = True + + not_serializable_obj = NotSerializable() + # The captured_stderr reduces the noise in the test report + with test.support.captured_stderr(): + q = SafeQueue(ctx=multiprocessing.get_context()) + q.put(not_serializable_obj) + + # Verify that q is still functioning correctly + q.put(True) + self.assertTrue(q.get(timeout=support.SHORT_TIMEOUT)) + + # Assert that the serialization and the hook have been called correctly + self.assertTrue(not_serializable_obj.reduce_was_called) + self.assertTrue(not_serializable_obj.on_queue_feeder_error_was_called) + + def test_closed_queue_put_get_exceptions(self): + for q in multiprocessing.Queue(), multiprocessing.JoinableQueue(): + q.close() + with self.assertRaisesRegex(ValueError, 'is closed'): + q.put('foo') + with self.assertRaisesRegex(ValueError, 'is closed'): + q.get() +# +# +# + +class _TestLock(BaseTestCase): + + def test_lock(self): + lock = self.Lock() + self.assertEqual(lock.acquire(), True) + self.assertEqual(lock.acquire(False), False) + self.assertEqual(lock.release(), None) + self.assertRaises((ValueError, threading.ThreadError), lock.release) + + def test_rlock(self): + lock = self.RLock() + self.assertEqual(lock.acquire(), True) + self.assertEqual(lock.acquire(), True) + self.assertEqual(lock.acquire(), True) + self.assertEqual(lock.release(), None) + self.assertEqual(lock.release(), None) + self.assertEqual(lock.release(), None) + self.assertRaises((AssertionError, RuntimeError), lock.release) + + def test_lock_context(self): + with self.Lock(): + pass + + +class _TestSemaphore(BaseTestCase): + + def _test_semaphore(self, sem): + self.assertReturnsIfImplemented(2, get_value, sem) + self.assertEqual(sem.acquire(), True) + self.assertReturnsIfImplemented(1, get_value, sem) + self.assertEqual(sem.acquire(), True) + self.assertReturnsIfImplemented(0, get_value, sem) + self.assertEqual(sem.acquire(False), False) + self.assertReturnsIfImplemented(0, get_value, sem) + self.assertEqual(sem.release(), None) + self.assertReturnsIfImplemented(1, get_value, sem) + self.assertEqual(sem.release(), None) + self.assertReturnsIfImplemented(2, get_value, sem) + + def test_semaphore(self): + sem = self.Semaphore(2) + self._test_semaphore(sem) + self.assertEqual(sem.release(), None) + self.assertReturnsIfImplemented(3, get_value, sem) + self.assertEqual(sem.release(), None) + self.assertReturnsIfImplemented(4, get_value, sem) + + def test_bounded_semaphore(self): + sem = self.BoundedSemaphore(2) + self._test_semaphore(sem) + # Currently fails on OS/X + #if HAVE_GETVALUE: + # self.assertRaises(ValueError, sem.release) + # self.assertReturnsIfImplemented(2, get_value, sem) + + def test_timeout(self): + if self.TYPE != 'processes': + self.skipTest('test not appropriate for {}'.format(self.TYPE)) + + sem = self.Semaphore(0) + acquire = TimingWrapper(sem.acquire) + + self.assertEqual(acquire(False), False) + self.assertTimingAlmostEqual(acquire.elapsed, 0.0) + + self.assertEqual(acquire(False, None), False) + self.assertTimingAlmostEqual(acquire.elapsed, 0.0) + + self.assertEqual(acquire(False, TIMEOUT1), False) + self.assertTimingAlmostEqual(acquire.elapsed, 0) + + self.assertEqual(acquire(True, TIMEOUT2), False) + self.assertTimingAlmostEqual(acquire.elapsed, TIMEOUT2) + + self.assertEqual(acquire(timeout=TIMEOUT3), False) + self.assertTimingAlmostEqual(acquire.elapsed, TIMEOUT3) + + +class _TestCondition(BaseTestCase): + + @classmethod + def f(cls, cond, sleeping, woken, timeout=None): + cond.acquire() + sleeping.release() + cond.wait(timeout) + woken.release() + cond.release() + + def assertReachesEventually(self, func, value): + for i in range(10): + try: + if func() == value: + break + except NotImplementedError: + break + time.sleep(DELTA) + time.sleep(DELTA) + self.assertReturnsIfImplemented(value, func) + + def check_invariant(self, cond): + # this is only supposed to succeed when there are no sleepers + if self.TYPE == 'processes': + try: + sleepers = (cond._sleeping_count.get_value() - + cond._woken_count.get_value()) + self.assertEqual(sleepers, 0) + self.assertEqual(cond._wait_semaphore.get_value(), 0) + except NotImplementedError: + pass + + def test_notify(self): + cond = self.Condition() + sleeping = self.Semaphore(0) + woken = self.Semaphore(0) + + p = self.Process(target=self.f, args=(cond, sleeping, woken)) + p.daemon = True + p.start() + self.addCleanup(p.join) + + p = threading.Thread(target=self.f, args=(cond, sleeping, woken)) + p.daemon = True + p.start() + self.addCleanup(p.join) + + # wait for both children to start sleeping + sleeping.acquire() + sleeping.acquire() + + # check no process/thread has woken up + time.sleep(DELTA) + self.assertReturnsIfImplemented(0, get_value, woken) + + # wake up one process/thread + cond.acquire() + cond.notify() + cond.release() + + # check one process/thread has woken up + time.sleep(DELTA) + self.assertReturnsIfImplemented(1, get_value, woken) + + # wake up another + cond.acquire() + cond.notify() + cond.release() + + # check other has woken up + time.sleep(DELTA) + self.assertReturnsIfImplemented(2, get_value, woken) + + # check state is not mucked up + self.check_invariant(cond) + p.join() + + def test_notify_all(self): + cond = self.Condition() + sleeping = self.Semaphore(0) + woken = self.Semaphore(0) + + # start some threads/processes which will timeout + for i in range(3): + p = self.Process(target=self.f, + args=(cond, sleeping, woken, TIMEOUT1)) + p.daemon = True + p.start() + self.addCleanup(p.join) + + t = threading.Thread(target=self.f, + args=(cond, sleeping, woken, TIMEOUT1)) + t.daemon = True + t.start() + self.addCleanup(t.join) + + # wait for them all to sleep + for i in range(6): + sleeping.acquire() + + # check they have all timed out + for i in range(6): + woken.acquire() + self.assertReturnsIfImplemented(0, get_value, woken) + + # check state is not mucked up + self.check_invariant(cond) + + # start some more threads/processes + for i in range(3): + p = self.Process(target=self.f, args=(cond, sleeping, woken)) + p.daemon = True + p.start() + self.addCleanup(p.join) + + t = threading.Thread(target=self.f, args=(cond, sleeping, woken)) + t.daemon = True + t.start() + self.addCleanup(t.join) + + # wait for them to all sleep + for i in range(6): + sleeping.acquire() + + # check no process/thread has woken up + time.sleep(DELTA) + self.assertReturnsIfImplemented(0, get_value, woken) + + # wake them all up + cond.acquire() + cond.notify_all() + cond.release() + + # check they have all woken + self.assertReachesEventually(lambda: get_value(woken), 6) + + # check state is not mucked up + self.check_invariant(cond) + + def test_notify_n(self): + cond = self.Condition() + sleeping = self.Semaphore(0) + woken = self.Semaphore(0) + + # start some threads/processes + for i in range(3): + p = self.Process(target=self.f, args=(cond, sleeping, woken)) + p.daemon = True + p.start() + self.addCleanup(p.join) + + t = threading.Thread(target=self.f, args=(cond, sleeping, woken)) + t.daemon = True + t.start() + self.addCleanup(t.join) + + # wait for them to all sleep + for i in range(6): + sleeping.acquire() + + # check no process/thread has woken up + time.sleep(DELTA) + self.assertReturnsIfImplemented(0, get_value, woken) + + # wake some of them up + cond.acquire() + cond.notify(n=2) + cond.release() + + # check 2 have woken + self.assertReachesEventually(lambda: get_value(woken), 2) + + # wake the rest of them + cond.acquire() + cond.notify(n=4) + cond.release() + + self.assertReachesEventually(lambda: get_value(woken), 6) + + # doesn't do anything more + cond.acquire() + cond.notify(n=3) + cond.release() + + self.assertReturnsIfImplemented(6, get_value, woken) + + # check state is not mucked up + self.check_invariant(cond) + + def test_timeout(self): + cond = self.Condition() + wait = TimingWrapper(cond.wait) + cond.acquire() + res = wait(TIMEOUT1) + cond.release() + self.assertEqual(res, False) + self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) + + @classmethod + def _test_waitfor_f(cls, cond, state): + with cond: + state.value = 0 + cond.notify() + result = cond.wait_for(lambda : state.value==4) + if not result or state.value != 4: + sys.exit(1) + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', -1) + + p = self.Process(target=self._test_waitfor_f, args=(cond, state)) + p.daemon = True + p.start() + + with cond: + result = cond.wait_for(lambda : state.value==0) + self.assertTrue(result) + self.assertEqual(state.value, 0) + + for i in range(4): + time.sleep(0.01) + with cond: + state.value += 1 + cond.notify() + + join_process(p) + self.assertEqual(p.exitcode, 0) + + @classmethod + def _test_waitfor_timeout_f(cls, cond, state, success, sem): + sem.release() + with cond: + expected = 0.100 + dt = time.monotonic() + result = cond.wait_for(lambda : state.value==4, timeout=expected) + dt = time.monotonic() - dt + if not result and (expected - CLOCK_RES) <= dt: + success.value = True + + @unittest.skipUnless(HAS_SHAREDCTYPES, 'needs sharedctypes') + def test_waitfor_timeout(self): + # based on test in test/lock_tests.py + cond = self.Condition() + state = self.Value('i', 0) + success = self.Value('i', False) + sem = self.Semaphore(0) + + p = self.Process(target=self._test_waitfor_timeout_f, + args=(cond, state, success, sem)) + p.daemon = True + p.start() + self.assertTrue(sem.acquire(timeout=support.LONG_TIMEOUT)) + + # Only increment 3 times, so state == 4 is never reached. + for i in range(3): + time.sleep(0.010) + with cond: + state.value += 1 + cond.notify() + + join_process(p) + self.assertTrue(success.value) + + @classmethod + def _test_wait_result(cls, c, pid): + with c: + c.notify() + time.sleep(1) + if pid is not None: + os.kill(pid, signal.SIGINT) + + def test_wait_result(self): + if isinstance(self, ProcessesMixin) and sys.platform != 'win32': + pid = os.getpid() + else: + pid = None + + c = self.Condition() + with c: + self.assertFalse(c.wait(0)) + self.assertFalse(c.wait(0.1)) + + p = self.Process(target=self._test_wait_result, args=(c, pid)) + p.start() + + self.assertTrue(c.wait(60)) + if pid is not None: + self.assertRaises(KeyboardInterrupt, c.wait, 60) + + p.join() + + +class _TestEvent(BaseTestCase): + + @classmethod + def _test_event(cls, event): + time.sleep(TIMEOUT2) + event.set() + + def test_event(self): + event = self.Event() + wait = TimingWrapper(event.wait) + + # Removed temporarily, due to API shear, this does not + # work with threading._Event objects. is_set == isSet + self.assertEqual(event.is_set(), False) + + # Removed, threading.Event.wait() will return the value of the __flag + # instead of None. API Shear with the semaphore backed mp.Event + self.assertEqual(wait(0.0), False) + self.assertTimingAlmostEqual(wait.elapsed, 0.0) + self.assertEqual(wait(TIMEOUT1), False) + self.assertTimingAlmostEqual(wait.elapsed, TIMEOUT1) + + event.set() + + # See note above on the API differences + self.assertEqual(event.is_set(), True) + self.assertEqual(wait(), True) + self.assertTimingAlmostEqual(wait.elapsed, 0.0) + self.assertEqual(wait(TIMEOUT1), True) + self.assertTimingAlmostEqual(wait.elapsed, 0.0) + # self.assertEqual(event.is_set(), True) + + event.clear() + + #self.assertEqual(event.is_set(), False) + + p = self.Process(target=self._test_event, args=(event,)) + p.daemon = True + p.start() + self.assertEqual(wait(), True) + p.join() + + def test_repr(self) -> None: + event = self.Event() + if self.TYPE == 'processes': + self.assertRegex(repr(event), r"") + event.set() + self.assertRegex(repr(event), r"") + event.clear() + self.assertRegex(repr(event), r"") + elif self.TYPE == 'manager': + self.assertRegex(repr(event), r" 256 (issue #11657) + if self.TYPE != 'processes': + self.skipTest("only makes sense with processes") + conn, child_conn = self.Pipe(duplex=True) + + p = self.Process(target=self._writefd, args=(child_conn, b"bar", True)) + p.daemon = True + p.start() + self.addCleanup(os_helper.unlink, os_helper.TESTFN) + with open(os_helper.TESTFN, "wb") as f: + fd = f.fileno() + for newfd in range(256, MAXFD): + if not self._is_fd_assigned(newfd): + break + else: + self.fail("could not find an unassigned large file descriptor") + os.dup2(fd, newfd) + try: + reduction.send_handle(conn, newfd, p.pid) + finally: + os.close(newfd) + p.join() + with open(os_helper.TESTFN, "rb") as f: + self.assertEqual(f.read(), b"bar") + + @classmethod + def _send_data_without_fd(self, conn): + os.write(conn.fileno(), b"\0") + + @unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction") + @unittest.skipIf(sys.platform == "win32", "doesn't make sense on Windows") + def test_missing_fd_transfer(self): + # Check that exception is raised when received data is not + # accompanied by a file descriptor in ancillary data. + if self.TYPE != 'processes': + self.skipTest("only makes sense with processes") + conn, child_conn = self.Pipe(duplex=True) + + p = self.Process(target=self._send_data_without_fd, args=(child_conn,)) + p.daemon = True + p.start() + self.assertRaises(RuntimeError, reduction.recv_handle, conn) + p.join() + + def test_context(self): + a, b = self.Pipe() + + with a, b: + a.send(1729) + self.assertEqual(b.recv(), 1729) + if self.TYPE == 'processes': + self.assertFalse(a.closed) + self.assertFalse(b.closed) + + if self.TYPE == 'processes': + self.assertTrue(a.closed) + self.assertTrue(b.closed) + self.assertRaises(OSError, a.recv) + self.assertRaises(OSError, b.recv) + +class _TestListener(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def test_multiple_bind(self): + for family in self.connection.families: + l = self.connection.Listener(family=family) + self.addCleanup(l.close) + self.assertRaises(OSError, self.connection.Listener, + l.address, family) + + def test_context(self): + with self.connection.Listener() as l: + with self.connection.Client(l.address) as c: + with l.accept() as d: + c.send(1729) + self.assertEqual(d.recv(), 1729) + + if self.TYPE == 'processes': + self.assertRaises(OSError, l.accept) + + def test_empty_authkey(self): + # bpo-43952: allow empty bytes as authkey + def handler(*args): + raise RuntimeError('Connection took too long...') + + def run(addr, authkey): + client = self.connection.Client(addr, authkey=authkey) + client.send(1729) + + key = b'' + + with self.connection.Listener(authkey=key) as listener: + thread = threading.Thread(target=run, args=(listener.address, key)) + thread.start() + try: + with listener.accept() as d: + self.assertEqual(d.recv(), 1729) + finally: + thread.join() + + if self.TYPE == 'processes': + with self.assertRaises(OSError): + listener.accept() + + @unittest.skipUnless(util.abstract_sockets_supported, + "test needs abstract socket support") + def test_abstract_socket(self): + with self.connection.Listener("\0something") as listener: + with self.connection.Client(listener.address) as client: + with listener.accept() as d: + client.send(1729) + self.assertEqual(d.recv(), 1729) + + if self.TYPE == 'processes': + self.assertRaises(OSError, listener.accept) + + +class _TestListenerClient(BaseTestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + @classmethod + def _test(cls, address): + conn = cls.connection.Client(address) + conn.send('hello') + conn.close() + + def test_listener_client(self): + for family in self.connection.families: + l = self.connection.Listener(family=family) + p = self.Process(target=self._test, args=(l.address,)) + p.daemon = True + p.start() + conn = l.accept() + self.assertEqual(conn.recv(), 'hello') + p.join() + l.close() + + def test_issue14725(self): + l = self.connection.Listener() + p = self.Process(target=self._test, args=(l.address,)) + p.daemon = True + p.start() + time.sleep(1) + # On Windows the client process should by now have connected, + # written data and closed the pipe handle by now. This causes + # ConnectNamdedPipe() to fail with ERROR_NO_DATA. See Issue + # 14725. + conn = l.accept() + self.assertEqual(conn.recv(), 'hello') + conn.close() + p.join() + l.close() + + def test_issue16955(self): + for fam in self.connection.families: + l = self.connection.Listener(family=fam) + c = self.connection.Client(l.address) + a = l.accept() + a.send_bytes(b"hello") + self.assertTrue(c.poll(1)) + a.close() + c.close() + l.close() + +class _TestPoll(BaseTestCase): + + ALLOWED_TYPES = ('processes', 'threads') + + def test_empty_string(self): + a, b = self.Pipe() + self.assertEqual(a.poll(), False) + b.send_bytes(b'') + self.assertEqual(a.poll(), True) + self.assertEqual(a.poll(), True) + + @classmethod + def _child_strings(cls, conn, strings): + for s in strings: + time.sleep(0.1) + conn.send_bytes(s) + conn.close() + + def test_strings(self): + strings = (b'hello', b'', b'a', b'b', b'', b'bye', b'', b'lop') + a, b = self.Pipe() + p = self.Process(target=self._child_strings, args=(b, strings)) + p.start() + + for s in strings: + for i in range(200): + if a.poll(0.01): + break + x = a.recv_bytes() + self.assertEqual(s, x) + + p.join() + + @classmethod + def _child_boundaries(cls, r): + # Polling may "pull" a message in to the child process, but we + # don't want it to pull only part of a message, as that would + # corrupt the pipe for any other processes which might later + # read from it. + r.poll(5) + + def test_boundaries(self): + r, w = self.Pipe(False) + p = self.Process(target=self._child_boundaries, args=(r,)) + p.start() + time.sleep(2) + L = [b"first", b"second"] + for obj in L: + w.send_bytes(obj) + w.close() + p.join() + self.assertIn(r.recv_bytes(), L) + + @classmethod + def _child_dont_merge(cls, b): + b.send_bytes(b'a') + b.send_bytes(b'b') + b.send_bytes(b'cd') + + def test_dont_merge(self): + a, b = self.Pipe() + self.assertEqual(a.poll(0.0), False) + self.assertEqual(a.poll(0.1), False) + + p = self.Process(target=self._child_dont_merge, args=(b,)) + p.start() + + self.assertEqual(a.recv_bytes(), b'a') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.recv_bytes(), b'b') + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(1.0), True) + self.assertEqual(a.poll(0.0), True) + self.assertEqual(a.recv_bytes(), b'cd') + + p.join() + +# +# Test of sending connection and socket objects between processes +# + +@unittest.skipUnless(HAS_REDUCTION, "test needs multiprocessing.reduction") +@hashlib_helper.requires_hashdigest('sha256') +class _TestPicklingConnections(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + @classmethod + def tearDownClass(cls): + from multiprocessing import resource_sharer + resource_sharer.stop(timeout=support.LONG_TIMEOUT) + + @classmethod + def _listener(cls, conn, families): + for fam in families: + l = cls.connection.Listener(family=fam) + conn.send(l.address) + new_conn = l.accept() + conn.send(new_conn) + new_conn.close() + l.close() + + l = socket.create_server((socket_helper.HOST, 0)) + conn.send(l.getsockname()) + new_conn, addr = l.accept() + conn.send(new_conn) + new_conn.close() + l.close() + + conn.recv() + + @classmethod + def _remote(cls, conn): + for (address, msg) in iter(conn.recv, None): + client = cls.connection.Client(address) + client.send(msg.upper()) + client.close() + + address, msg = conn.recv() + client = socket.socket() + client.connect(address) + client.sendall(msg.upper()) + client.close() + + conn.close() + + def test_pickling(self): + families = self.connection.families + + lconn, lconn0 = self.Pipe() + lp = self.Process(target=self._listener, args=(lconn0, families)) + lp.daemon = True + lp.start() + lconn0.close() + + rconn, rconn0 = self.Pipe() + rp = self.Process(target=self._remote, args=(rconn0,)) + rp.daemon = True + rp.start() + rconn0.close() + + for fam in families: + msg = ('This connection uses family %s' % fam).encode('ascii') + address = lconn.recv() + rconn.send((address, msg)) + new_conn = lconn.recv() + self.assertEqual(new_conn.recv(), msg.upper()) + + rconn.send(None) + + msg = latin('This connection uses a normal socket') + address = lconn.recv() + rconn.send((address, msg)) + new_conn = lconn.recv() + buf = [] + while True: + s = new_conn.recv(100) + if not s: + break + buf.append(s) + buf = b''.join(buf) + self.assertEqual(buf, msg.upper()) + new_conn.close() + + lconn.send(None) + + rconn.close() + lconn.close() + + lp.join() + rp.join() + + @classmethod + def child_access(cls, conn): + w = conn.recv() + w.send('all is well') + w.close() + + r = conn.recv() + msg = r.recv() + conn.send(msg*2) + + conn.close() + + def test_access(self): + # On Windows, if we do not specify a destination pid when + # using DupHandle then we need to be careful to use the + # correct access flags for DuplicateHandle(), or else + # DupHandle.detach() will raise PermissionError. For example, + # for a read only pipe handle we should use + # access=FILE_GENERIC_READ. (Unfortunately + # DUPLICATE_SAME_ACCESS does not work.) + conn, child_conn = self.Pipe() + p = self.Process(target=self.child_access, args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + + r, w = self.Pipe(duplex=False) + conn.send(w) + w.close() + self.assertEqual(r.recv(), 'all is well') + r.close() + + r, w = self.Pipe(duplex=False) + conn.send(r) + r.close() + w.send('foobar') + w.close() + self.assertEqual(conn.recv(), 'foobar'*2) + + p.join() + +# +# +# + +class _TestHeap(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def setUp(self): + super().setUp() + # Make pristine heap for these tests + self.old_heap = multiprocessing.heap.BufferWrapper._heap + multiprocessing.heap.BufferWrapper._heap = multiprocessing.heap.Heap() + + def tearDown(self): + multiprocessing.heap.BufferWrapper._heap = self.old_heap + super().tearDown() + + def test_heap(self): + iterations = 5000 + maxblocks = 50 + blocks = [] + + # get the heap object + heap = multiprocessing.heap.BufferWrapper._heap + heap._DISCARD_FREE_SPACE_LARGER_THAN = 0 + + # create and destroy lots of blocks of different sizes + for i in range(iterations): + size = int(random.lognormvariate(0, 1) * 1000) + b = multiprocessing.heap.BufferWrapper(size) + blocks.append(b) + if len(blocks) > maxblocks: + i = random.randrange(maxblocks) + del blocks[i] + del b + + # verify the state of the heap + with heap._lock: + all = [] + free = 0 + occupied = 0 + for L in list(heap._len_to_seq.values()): + # count all free blocks in arenas + for arena, start, stop in L: + all.append((heap._arenas.index(arena), start, stop, + stop-start, 'free')) + free += (stop-start) + for arena, arena_blocks in heap._allocated_blocks.items(): + # count all allocated blocks in arenas + for start, stop in arena_blocks: + all.append((heap._arenas.index(arena), start, stop, + stop-start, 'occupied')) + occupied += (stop-start) + + self.assertEqual(free + occupied, + sum(arena.size for arena in heap._arenas)) + + all.sort() + + for i in range(len(all)-1): + (arena, start, stop) = all[i][:3] + (narena, nstart, nstop) = all[i+1][:3] + if arena != narena: + # Two different arenas + self.assertEqual(stop, heap._arenas[arena].size) # last block + self.assertEqual(nstart, 0) # first block + else: + # Same arena: two adjacent blocks + self.assertEqual(stop, nstart) + + # test free'ing all blocks + random.shuffle(blocks) + while blocks: + blocks.pop() + + self.assertEqual(heap._n_frees, heap._n_mallocs) + self.assertEqual(len(heap._pending_free_blocks), 0) + self.assertEqual(len(heap._arenas), 0) + self.assertEqual(len(heap._allocated_blocks), 0, heap._allocated_blocks) + self.assertEqual(len(heap._len_to_seq), 0) + + def test_free_from_gc(self): + # Check that freeing of blocks by the garbage collector doesn't deadlock + # (issue #12352). + # Make sure the GC is enabled, and set lower collection thresholds to + # make collections more frequent (and increase the probability of + # deadlock). + if not gc.isenabled(): + gc.enable() + self.addCleanup(gc.disable) + thresholds = gc.get_threshold() + self.addCleanup(gc.set_threshold, *thresholds) + gc.set_threshold(10) + + # perform numerous block allocations, with cyclic references to make + # sure objects are collected asynchronously by the gc + for i in range(5000): + a = multiprocessing.heap.BufferWrapper(1) + b = multiprocessing.heap.BufferWrapper(1) + # circular references + a.buddy = b + b.buddy = a + +# +# +# + +class _Foo(Structure): + _fields_ = [ + ('x', c_int), + ('y', c_double), + ('z', c_longlong,) + ] + +class _TestSharedCTypes(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def setUp(self): + if not HAS_SHAREDCTYPES: + self.skipTest("requires multiprocessing.sharedctypes") + + @classmethod + def _double(cls, x, y, z, foo, arr, string): + x.value *= 2 + y.value *= 2 + z.value *= 2 + foo.x *= 2 + foo.y *= 2 + string.value *= 2 + for i in range(len(arr)): + arr[i] *= 2 + + def test_sharedctypes(self, lock=False): + x = Value('i', 7, lock=lock) + y = Value(c_double, 1.0/3.0, lock=lock) + z = Value(c_longlong, 2 ** 33, lock=lock) + foo = Value(_Foo, 3, 2, lock=lock) + arr = self.Array('d', list(range(10)), lock=lock) + string = self.Array('c', 20, lock=lock) + string.value = latin('hello') + + p = self.Process(target=self._double, args=(x, y, z, foo, arr, string)) + p.daemon = True + p.start() + p.join() + + self.assertEqual(x.value, 14) + self.assertAlmostEqual(y.value, 2.0/3.0) + self.assertEqual(z.value, 2 ** 34) + self.assertEqual(foo.x, 6) + self.assertAlmostEqual(foo.y, 4.0) + for i in range(10): + self.assertAlmostEqual(arr[i], i*2) + self.assertEqual(string.value, latin('hellohello')) + + def test_synchronize(self): + self.test_sharedctypes(lock=True) + + def test_copy(self): + foo = _Foo(2, 5.0, 2 ** 33) + bar = copy(foo) + foo.x = 0 + foo.y = 0 + foo.z = 0 + self.assertEqual(bar.x, 2) + self.assertAlmostEqual(bar.y, 5.0) + self.assertEqual(bar.z, 2 ** 33) + + +@unittest.skipUnless(HAS_SHMEM, "requires multiprocessing.shared_memory") +@hashlib_helper.requires_hashdigest('sha256') +class _TestSharedMemory(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + @staticmethod + def _attach_existing_shmem_then_write(shmem_name_or_obj, binary_data): + if isinstance(shmem_name_or_obj, str): + local_sms = shared_memory.SharedMemory(shmem_name_or_obj) + else: + local_sms = shmem_name_or_obj + local_sms.buf[:len(binary_data)] = binary_data + local_sms.close() + + def _new_shm_name(self, prefix): + # Add a PID to the name of a POSIX shared memory object to allow + # running multiprocessing tests (test_multiprocessing_fork, + # test_multiprocessing_spawn, etc) in parallel. + return prefix + str(os.getpid()) + + def test_shared_memory_name_with_embedded_null(self): + name_tsmb = self._new_shm_name('test01_null') + sms = shared_memory.SharedMemory(name_tsmb, create=True, size=512) + self.addCleanup(sms.unlink) + with self.assertRaises(ValueError): + shared_memory.SharedMemory(name_tsmb + '\0a', create=False, size=512) + if shared_memory._USE_POSIX: + orig_name = sms._name + try: + sms._name = orig_name + '\0a' + with self.assertRaises(ValueError): + sms.unlink() + finally: + sms._name = orig_name + + def test_shared_memory_basics(self): + name_tsmb = self._new_shm_name('test01_tsmb') + sms = shared_memory.SharedMemory(name_tsmb, create=True, size=512) + self.addCleanup(sms.unlink) + + # Verify attributes are readable. + self.assertEqual(sms.name, name_tsmb) + self.assertGreaterEqual(sms.size, 512) + self.assertGreaterEqual(len(sms.buf), sms.size) + + # Verify __repr__ + self.assertIn(sms.name, str(sms)) + self.assertIn(str(sms.size), str(sms)) + + # Modify contents of shared memory segment through memoryview. + sms.buf[0] = 42 + self.assertEqual(sms.buf[0], 42) + + # Attach to existing shared memory segment. + also_sms = shared_memory.SharedMemory(name_tsmb) + self.assertEqual(also_sms.buf[0], 42) + also_sms.close() + + # Attach to existing shared memory segment but specify a new size. + same_sms = shared_memory.SharedMemory(name_tsmb, size=20*sms.size) + self.assertLess(same_sms.size, 20*sms.size) # Size was ignored. + same_sms.close() + + # Creating Shared Memory Segment with -ve size + with self.assertRaises(ValueError): + shared_memory.SharedMemory(create=True, size=-2) + + # Attaching Shared Memory Segment without a name + with self.assertRaises(ValueError): + shared_memory.SharedMemory(create=False) + + # Test if shared memory segment is created properly, + # when _make_filename returns an existing shared memory segment name + with unittest.mock.patch( + 'multiprocessing.shared_memory._make_filename') as mock_make_filename: + + NAME_PREFIX = shared_memory._SHM_NAME_PREFIX + names = [self._new_shm_name('test01_fn'), self._new_shm_name('test02_fn')] + # Prepend NAME_PREFIX which can be '/psm_' or 'wnsm_', necessary + # because some POSIX compliant systems require name to start with / + names = [NAME_PREFIX + name for name in names] + + mock_make_filename.side_effect = names + shm1 = shared_memory.SharedMemory(create=True, size=1) + self.addCleanup(shm1.unlink) + self.assertEqual(shm1._name, names[0]) + + mock_make_filename.side_effect = names + shm2 = shared_memory.SharedMemory(create=True, size=1) + self.addCleanup(shm2.unlink) + self.assertEqual(shm2._name, names[1]) + + if shared_memory._USE_POSIX: + # Posix Shared Memory can only be unlinked once. Here we + # test an implementation detail that is not observed across + # all supported platforms (since WindowsNamedSharedMemory + # manages unlinking on its own and unlink() does nothing). + # True release of shared memory segment does not necessarily + # happen until process exits, depending on the OS platform. + name_dblunlink = self._new_shm_name('test01_dblunlink') + sms_uno = shared_memory.SharedMemory( + name_dblunlink, + create=True, + size=5000 + ) + with self.assertRaises(FileNotFoundError): + try: + self.assertGreaterEqual(sms_uno.size, 5000) + + sms_duo = shared_memory.SharedMemory(name_dblunlink) + sms_duo.unlink() # First shm_unlink() call. + sms_duo.close() + sms_uno.close() + + finally: + sms_uno.unlink() # A second shm_unlink() call is bad. + + with self.assertRaises(FileExistsError): + # Attempting to create a new shared memory segment with a + # name that is already in use triggers an exception. + there_can_only_be_one_sms = shared_memory.SharedMemory( + name_tsmb, + create=True, + size=512 + ) + + if shared_memory._USE_POSIX: + # Requesting creation of a shared memory segment with the option + # to attach to an existing segment, if that name is currently in + # use, should not trigger an exception. + # Note: Using a smaller size could possibly cause truncation of + # the existing segment but is OS platform dependent. In the + # case of MacOS/darwin, requesting a smaller size is disallowed. + class OptionalAttachSharedMemory(shared_memory.SharedMemory): + _flags = os.O_CREAT | os.O_RDWR + ok_if_exists_sms = OptionalAttachSharedMemory(name_tsmb) + self.assertEqual(ok_if_exists_sms.size, sms.size) + ok_if_exists_sms.close() + + # Attempting to attach to an existing shared memory segment when + # no segment exists with the supplied name triggers an exception. + with self.assertRaises(FileNotFoundError): + nonexisting_sms = shared_memory.SharedMemory('test01_notthere') + nonexisting_sms.unlink() # Error should occur on prior line. + + sms.close() + + def test_shared_memory_recreate(self): + # Test if shared memory segment is created properly, + # when _make_filename returns an existing shared memory segment name + with unittest.mock.patch( + 'multiprocessing.shared_memory._make_filename') as mock_make_filename: + + NAME_PREFIX = shared_memory._SHM_NAME_PREFIX + names = [self._new_shm_name('test03_fn'), self._new_shm_name('test04_fn')] + # Prepend NAME_PREFIX which can be '/psm_' or 'wnsm_', necessary + # because some POSIX compliant systems require name to start with / + names = [NAME_PREFIX + name for name in names] + + mock_make_filename.side_effect = names + shm1 = shared_memory.SharedMemory(create=True, size=1) + self.addCleanup(shm1.unlink) + self.assertEqual(shm1._name, names[0]) + + mock_make_filename.side_effect = names + shm2 = shared_memory.SharedMemory(create=True, size=1) + self.addCleanup(shm2.unlink) + self.assertEqual(shm2._name, names[1]) + + def test_invalid_shared_memory_creation(self): + # Test creating a shared memory segment with negative size + with self.assertRaises(ValueError): + sms_invalid = shared_memory.SharedMemory(create=True, size=-1) + + # Test creating a shared memory segment with size 0 + with self.assertRaises(ValueError): + sms_invalid = shared_memory.SharedMemory(create=True, size=0) + + # Test creating a shared memory segment without size argument + with self.assertRaises(ValueError): + sms_invalid = shared_memory.SharedMemory(create=True) + + def test_shared_memory_pickle_unpickle(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + sms = shared_memory.SharedMemory(create=True, size=512) + self.addCleanup(sms.unlink) + sms.buf[0:6] = b'pickle' + + # Test pickling + pickled_sms = pickle.dumps(sms, protocol=proto) + + # Test unpickling + sms2 = pickle.loads(pickled_sms) + self.assertIsInstance(sms2, shared_memory.SharedMemory) + self.assertEqual(sms.name, sms2.name) + self.assertEqual(bytes(sms.buf[0:6]), b'pickle') + self.assertEqual(bytes(sms2.buf[0:6]), b'pickle') + + # Test that unpickled version is still the same SharedMemory + sms.buf[0:6] = b'newval' + self.assertEqual(bytes(sms.buf[0:6]), b'newval') + self.assertEqual(bytes(sms2.buf[0:6]), b'newval') + + sms2.buf[0:6] = b'oldval' + self.assertEqual(bytes(sms.buf[0:6]), b'oldval') + self.assertEqual(bytes(sms2.buf[0:6]), b'oldval') + + def test_shared_memory_pickle_unpickle_dead_object(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + sms = shared_memory.SharedMemory(create=True, size=512) + sms.buf[0:6] = b'pickle' + pickled_sms = pickle.dumps(sms, protocol=proto) + + # Now, we are going to kill the original object. + # So, unpickled one won't be able to attach to it. + sms.close() + sms.unlink() + + with self.assertRaises(FileNotFoundError): + pickle.loads(pickled_sms) + + def test_shared_memory_across_processes(self): + # bpo-40135: don't define shared memory block's name in case of + # the failure when we run multiprocessing tests in parallel. + sms = shared_memory.SharedMemory(create=True, size=512) + self.addCleanup(sms.unlink) + + # Verify remote attachment to existing block by name is working. + p = self.Process( + target=self._attach_existing_shmem_then_write, + args=(sms.name, b'howdy') + ) + p.daemon = True + p.start() + p.join() + self.assertEqual(bytes(sms.buf[:5]), b'howdy') + + # Verify pickling of SharedMemory instance also works. + p = self.Process( + target=self._attach_existing_shmem_then_write, + args=(sms, b'HELLO') + ) + p.daemon = True + p.start() + p.join() + self.assertEqual(bytes(sms.buf[:5]), b'HELLO') + + sms.close() + + @unittest.skipIf(os.name != "posix", "not feasible in non-posix platforms") + def test_shared_memory_SharedMemoryServer_ignores_sigint(self): + # bpo-36368: protect SharedMemoryManager server process from + # KeyboardInterrupt signals. + smm = multiprocessing.managers.SharedMemoryManager() + smm.start() + + # make sure the manager works properly at the beginning + sl = smm.ShareableList(range(10)) + + # the manager's server should ignore KeyboardInterrupt signals, and + # maintain its connection with the current process, and success when + # asked to deliver memory segments. + os.kill(smm._process.pid, signal.SIGINT) + + sl2 = smm.ShareableList(range(10)) + + # test that the custom signal handler registered in the Manager does + # not affect signal handling in the parent process. + with self.assertRaises(KeyboardInterrupt): + os.kill(os.getpid(), signal.SIGINT) + + smm.shutdown() + + @unittest.skipIf(os.name != "posix", "resource_tracker is posix only") + def test_shared_memory_SharedMemoryManager_reuses_resource_tracker(self): + # bpo-36867: test that a SharedMemoryManager uses the + # same resource_tracker process as its parent. + cmd = '''if 1: + from multiprocessing.managers import SharedMemoryManager + + + smm = SharedMemoryManager() + smm.start() + sl = smm.ShareableList(range(10)) + smm.shutdown() + ''' + rc, out, err = test.support.script_helper.assert_python_ok('-c', cmd) + + # Before bpo-36867 was fixed, a SharedMemoryManager not using the same + # resource_tracker process as its parent would make the parent's + # tracker complain about sl being leaked even though smm.shutdown() + # properly released sl. + self.assertFalse(err) + + def test_shared_memory_SharedMemoryManager_basics(self): + smm1 = multiprocessing.managers.SharedMemoryManager() + with self.assertRaises(ValueError): + smm1.SharedMemory(size=9) # Fails if SharedMemoryServer not started + smm1.start() + lol = [ smm1.ShareableList(range(i)) for i in range(5, 10) ] + lom = [ smm1.SharedMemory(size=j) for j in range(32, 128, 16) ] + doppleganger_list0 = shared_memory.ShareableList(name=lol[0].shm.name) + self.assertEqual(len(doppleganger_list0), 5) + doppleganger_shm0 = shared_memory.SharedMemory(name=lom[0].name) + self.assertGreaterEqual(len(doppleganger_shm0.buf), 32) + held_name = lom[0].name + smm1.shutdown() + if sys.platform != "win32": + # Calls to unlink() have no effect on Windows platform; shared + # memory will only be released once final process exits. + with self.assertRaises(FileNotFoundError): + # No longer there to be attached to again. + absent_shm = shared_memory.SharedMemory(name=held_name) + + with multiprocessing.managers.SharedMemoryManager() as smm2: + sl = smm2.ShareableList("howdy") + shm = smm2.SharedMemory(size=128) + held_name = sl.shm.name + if sys.platform != "win32": + with self.assertRaises(FileNotFoundError): + # No longer there to be attached to again. + absent_sl = shared_memory.ShareableList(name=held_name) + + + def test_shared_memory_ShareableList_basics(self): + sl = shared_memory.ShareableList( + ['howdy', b'HoWdY', -273.154, 100, None, True, 42] + ) + self.addCleanup(sl.shm.unlink) + + # Verify __repr__ + self.assertIn(sl.shm.name, str(sl)) + self.assertIn(str(list(sl)), str(sl)) + + # Index Out of Range (get) + with self.assertRaises(IndexError): + sl[7] + + # Index Out of Range (set) + with self.assertRaises(IndexError): + sl[7] = 2 + + # Assign value without format change (str -> str) + current_format = sl._get_packing_format(0) + sl[0] = 'howdy' + self.assertEqual(current_format, sl._get_packing_format(0)) + + # Verify attributes are readable. + self.assertEqual(sl.format, '8s8sdqxxxxxx?xxxxxxxx?q') + + # Exercise len(). + self.assertEqual(len(sl), 7) + + # Exercise index(). + with warnings.catch_warnings(): + # Suppress BytesWarning when comparing against b'HoWdY'. + warnings.simplefilter('ignore') + with self.assertRaises(ValueError): + sl.index('100') + self.assertEqual(sl.index(100), 3) + + # Exercise retrieving individual values. + self.assertEqual(sl[0], 'howdy') + self.assertEqual(sl[-2], True) + + # Exercise iterability. + self.assertEqual( + tuple(sl), + ('howdy', b'HoWdY', -273.154, 100, None, True, 42) + ) + + # Exercise modifying individual values. + sl[3] = 42 + self.assertEqual(sl[3], 42) + sl[4] = 'some' # Change type at a given position. + self.assertEqual(sl[4], 'some') + self.assertEqual(sl.format, '8s8sdq8sxxxxxxx?q') + with self.assertRaisesRegex(ValueError, + "exceeds available storage"): + sl[4] = 'far too many' + self.assertEqual(sl[4], 'some') + sl[0] = 'encodés' # Exactly 8 bytes of UTF-8 data + self.assertEqual(sl[0], 'encodés') + self.assertEqual(sl[1], b'HoWdY') # no spillage + with self.assertRaisesRegex(ValueError, + "exceeds available storage"): + sl[0] = 'encodées' # Exactly 9 bytes of UTF-8 data + self.assertEqual(sl[1], b'HoWdY') + with self.assertRaisesRegex(ValueError, + "exceeds available storage"): + sl[1] = b'123456789' + self.assertEqual(sl[1], b'HoWdY') + + # Exercise count(). + with warnings.catch_warnings(): + # Suppress BytesWarning when comparing against b'HoWdY'. + warnings.simplefilter('ignore') + self.assertEqual(sl.count(42), 2) + self.assertEqual(sl.count(b'HoWdY'), 1) + self.assertEqual(sl.count(b'adios'), 0) + + # Exercise creating a duplicate. + name_duplicate = self._new_shm_name('test03_duplicate') + sl_copy = shared_memory.ShareableList(sl, name=name_duplicate) + try: + self.assertNotEqual(sl.shm.name, sl_copy.shm.name) + self.assertEqual(name_duplicate, sl_copy.shm.name) + self.assertEqual(list(sl), list(sl_copy)) + self.assertEqual(sl.format, sl_copy.format) + sl_copy[-1] = 77 + self.assertEqual(sl_copy[-1], 77) + self.assertNotEqual(sl[-1], 77) + sl_copy.shm.close() + finally: + sl_copy.shm.unlink() + + # Obtain a second handle on the same ShareableList. + sl_tethered = shared_memory.ShareableList(name=sl.shm.name) + self.assertEqual(sl.shm.name, sl_tethered.shm.name) + sl_tethered[-1] = 880 + self.assertEqual(sl[-1], 880) + sl_tethered.shm.close() + + sl.shm.close() + + # Exercise creating an empty ShareableList. + empty_sl = shared_memory.ShareableList() + try: + self.assertEqual(len(empty_sl), 0) + self.assertEqual(empty_sl.format, '') + self.assertEqual(empty_sl.count('any'), 0) + with self.assertRaises(ValueError): + empty_sl.index(None) + empty_sl.shm.close() + finally: + empty_sl.shm.unlink() + + def test_shared_memory_ShareableList_pickling(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + sl = shared_memory.ShareableList(range(10)) + self.addCleanup(sl.shm.unlink) + + serialized_sl = pickle.dumps(sl, protocol=proto) + deserialized_sl = pickle.loads(serialized_sl) + self.assertIsInstance( + deserialized_sl, shared_memory.ShareableList) + self.assertEqual(deserialized_sl[-1], 9) + self.assertIsNot(sl, deserialized_sl) + + deserialized_sl[4] = "changed" + self.assertEqual(sl[4], "changed") + sl[3] = "newvalue" + self.assertEqual(deserialized_sl[3], "newvalue") + + larger_sl = shared_memory.ShareableList(range(400)) + self.addCleanup(larger_sl.shm.unlink) + serialized_larger_sl = pickle.dumps(larger_sl, protocol=proto) + self.assertEqual(len(serialized_sl), len(serialized_larger_sl)) + larger_sl.shm.close() + + deserialized_sl.shm.close() + sl.shm.close() + + def test_shared_memory_ShareableList_pickling_dead_object(self): + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto): + sl = shared_memory.ShareableList(range(10)) + serialized_sl = pickle.dumps(sl, protocol=proto) + + # Now, we are going to kill the original object. + # So, unpickled one won't be able to attach to it. + sl.shm.close() + sl.shm.unlink() + + with self.assertRaises(FileNotFoundError): + pickle.loads(serialized_sl) + + def test_shared_memory_cleaned_after_process_termination(self): + cmd = '''if 1: + import os, time, sys + from multiprocessing import shared_memory + + # Create a shared_memory segment, and send the segment name + sm = shared_memory.SharedMemory(create=True, size=10) + sys.stdout.write(sm.name + '\\n') + sys.stdout.flush() + time.sleep(100) + ''' + with subprocess.Popen([sys.executable, '-E', '-c', cmd], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) as p: + name = p.stdout.readline().strip().decode() + + # killing abruptly processes holding reference to a shared memory + # segment should not leak the given memory segment. + p.terminate() + p.wait() + + err_msg = ("A SharedMemory segment was leaked after " + "a process was abruptly terminated") + for _ in support.sleeping_retry(support.LONG_TIMEOUT, err_msg): + try: + smm = shared_memory.SharedMemory(name, create=False) + except FileNotFoundError: + break + + if os.name == 'posix': + # Without this line it was raising warnings like: + # UserWarning: resource_tracker: + # There appear to be 1 leaked shared_memory + # objects to clean up at shutdown + # See: https://bugs.python.org/issue45209 + resource_tracker.unregister(f"/{name}", "shared_memory") + + # A warning was emitted by the subprocess' own + # resource_tracker (on Windows, shared memory segments + # are released automatically by the OS). + err = p.stderr.read().decode() + self.assertIn( + "resource_tracker: There appear to be 1 leaked " + "shared_memory objects to clean up at shutdown", err) + +# +# Test to verify that `Finalize` works. +# + +class _TestFinalize(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def setUp(self): + self.registry_backup = util._finalizer_registry.copy() + util._finalizer_registry.clear() + + def tearDown(self): + gc.collect() # For PyPy or other GCs. + self.assertFalse(util._finalizer_registry) + util._finalizer_registry.update(self.registry_backup) + + @classmethod + def _test_finalize(cls, conn): + class Foo(object): + pass + + a = Foo() + util.Finalize(a, conn.send, args=('a',)) + del a # triggers callback for a + gc.collect() # For PyPy or other GCs. + + b = Foo() + close_b = util.Finalize(b, conn.send, args=('b',)) + close_b() # triggers callback for b + close_b() # does nothing because callback has already been called + del b # does nothing because callback has already been called + gc.collect() # For PyPy or other GCs. + + c = Foo() + util.Finalize(c, conn.send, args=('c',)) + + d10 = Foo() + util.Finalize(d10, conn.send, args=('d10',), exitpriority=1) + + d01 = Foo() + util.Finalize(d01, conn.send, args=('d01',), exitpriority=0) + d02 = Foo() + util.Finalize(d02, conn.send, args=('d02',), exitpriority=0) + d03 = Foo() + util.Finalize(d03, conn.send, args=('d03',), exitpriority=0) + + util.Finalize(None, conn.send, args=('e',), exitpriority=-10) + + util.Finalize(None, conn.send, args=('STOP',), exitpriority=-100) + + # call multiprocessing's cleanup function then exit process without + # garbage collecting locals + util._exit_function() + conn.close() + os._exit(0) + + def test_finalize(self): + conn, child_conn = self.Pipe() + + p = self.Process(target=self._test_finalize, args=(child_conn,)) + p.daemon = True + p.start() + p.join() + + result = [obj for obj in iter(conn.recv, 'STOP')] + self.assertEqual(result, ['a', 'b', 'd10', 'd03', 'd02', 'd01', 'e']) + + @support.requires_resource('cpu') + def test_thread_safety(self): + # bpo-24484: _run_finalizers() should be thread-safe + def cb(): + pass + + class Foo(object): + def __init__(self): + self.ref = self # create reference cycle + # insert finalizer at random key + util.Finalize(self, cb, exitpriority=random.randint(1, 100)) + + finish = False + exc = None + + def run_finalizers(): + nonlocal exc + while not finish: + time.sleep(random.random() * 1e-1) + try: + # A GC run will eventually happen during this, + # collecting stale Foo's and mutating the registry + util._run_finalizers() + except Exception as e: + exc = e + + def make_finalizers(): + nonlocal exc + d = {} + while not finish: + try: + # Old Foo's get gradually replaced and later + # collected by the GC (because of the cyclic ref) + d[random.getrandbits(5)] = {Foo() for i in range(10)} + except Exception as e: + exc = e + d.clear() + + old_interval = sys.getswitchinterval() + old_threshold = gc.get_threshold() + try: + sys.setswitchinterval(1e-6) + gc.set_threshold(5, 5, 5) + threads = [threading.Thread(target=run_finalizers), + threading.Thread(target=make_finalizers)] + with threading_helper.start_threads(threads): + time.sleep(4.0) # Wait a bit to trigger race condition + finish = True + if exc is not None: + raise exc + finally: + sys.setswitchinterval(old_interval) + gc.set_threshold(*old_threshold) + gc.collect() # Collect remaining Foo's + + +# +# Test that from ... import * works for each module +# + +class _TestImportStar(unittest.TestCase): + + def get_module_names(self): + import glob + folder = os.path.dirname(multiprocessing.__file__) + pattern = os.path.join(glob.escape(folder), '*.py') + files = glob.glob(pattern) + modules = [os.path.splitext(os.path.split(f)[1])[0] for f in files] + modules = ['multiprocessing.' + m for m in modules] + modules.remove('multiprocessing.__init__') + modules.append('multiprocessing') + return modules + + def test_import(self): + modules = self.get_module_names() + if sys.platform == 'win32': + modules.remove('multiprocessing.popen_fork') + modules.remove('multiprocessing.popen_forkserver') + modules.remove('multiprocessing.popen_spawn_posix') + else: + modules.remove('multiprocessing.popen_spawn_win32') + if not HAS_REDUCTION: + modules.remove('multiprocessing.popen_forkserver') + + if c_int is None: + # This module requires _ctypes + modules.remove('multiprocessing.sharedctypes') + + for name in modules: + __import__(name) + mod = sys.modules[name] + self.assertTrue(hasattr(mod, '__all__'), name) + + for attr in mod.__all__: + self.assertTrue( + hasattr(mod, attr), + '%r does not have attribute %r' % (mod, attr) + ) + +# +# Quick test that logging works -- does not test logging output +# + +class _TestLogging(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + def test_enable_logging(self): + logger = multiprocessing.get_logger() + logger.setLevel(util.SUBWARNING) + self.assertTrue(logger is not None) + logger.debug('this will not be printed') + logger.info('nor will this') + logger.setLevel(LOG_LEVEL) + + @classmethod + def _test_level(cls, conn): + logger = multiprocessing.get_logger() + conn.send(logger.getEffectiveLevel()) + + def test_level(self): + LEVEL1 = 32 + LEVEL2 = 37 + + logger = multiprocessing.get_logger() + root_logger = logging.getLogger() + root_level = root_logger.level + + reader, writer = multiprocessing.Pipe(duplex=False) + + logger.setLevel(LEVEL1) + p = self.Process(target=self._test_level, args=(writer,)) + p.start() + self.assertEqual(LEVEL1, reader.recv()) + p.join() + p.close() + + logger.setLevel(logging.NOTSET) + root_logger.setLevel(LEVEL2) + p = self.Process(target=self._test_level, args=(writer,)) + p.start() + self.assertEqual(LEVEL2, reader.recv()) + p.join() + p.close() + + root_logger.setLevel(root_level) + logger.setLevel(level=LOG_LEVEL) + + def test_filename(self): + logger = multiprocessing.get_logger() + original_level = logger.level + try: + logger.setLevel(util.DEBUG) + stream = io.StringIO() + handler = logging.StreamHandler(stream) + logging_format = '[%(levelname)s] [%(filename)s] %(message)s' + handler.setFormatter(logging.Formatter(logging_format)) + logger.addHandler(handler) + logger.info('1') + util.info('2') + logger.debug('3') + filename = os.path.basename(__file__) + log_record = stream.getvalue() + self.assertIn(f'[INFO] [{filename}] 1', log_record) + self.assertIn(f'[INFO] [{filename}] 2', log_record) + self.assertIn(f'[DEBUG] [{filename}] 3', log_record) + finally: + logger.setLevel(original_level) + logger.removeHandler(handler) + handler.close() + + +# class _TestLoggingProcessName(BaseTestCase): +# +# def handle(self, record): +# assert record.processName == multiprocessing.current_process().name +# self.__handled = True +# +# def test_logging(self): +# handler = logging.Handler() +# handler.handle = self.handle +# self.__handled = False +# # Bypass getLogger() and side-effects +# logger = logging.getLoggerClass()( +# 'multiprocessing.test.TestLoggingProcessName') +# logger.addHandler(handler) +# logger.propagate = False +# +# logger.warn('foo') +# assert self.__handled + +# +# Check that Process.join() retries if os.waitpid() fails with EINTR +# + +class _TestPollEintr(BaseTestCase): + + ALLOWED_TYPES = ('processes',) + + @classmethod + def _killer(cls, pid): + time.sleep(0.1) + os.kill(pid, signal.SIGUSR1) + + @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') + def test_poll_eintr(self): + got_signal = [False] + def record(*args): + got_signal[0] = True + pid = os.getpid() + oldhandler = signal.signal(signal.SIGUSR1, record) + try: + killer = self.Process(target=self._killer, args=(pid,)) + killer.start() + try: + p = self.Process(target=time.sleep, args=(2,)) + p.start() + p.join() + finally: + killer.join() + self.assertTrue(got_signal[0]) + self.assertEqual(p.exitcode, 0) + finally: + signal.signal(signal.SIGUSR1, oldhandler) + +# +# Test to verify handle verification, see issue 3321 +# + +class TestInvalidHandle(unittest.TestCase): + + @unittest.skipIf(WIN32, "skipped on Windows") + def test_invalid_handles(self): + conn = multiprocessing.connection.Connection(44977608) + # check that poll() doesn't crash + try: + conn.poll() + except (ValueError, OSError): + pass + finally: + # Hack private attribute _handle to avoid printing an error + # in conn.__del__ + conn._handle = None + self.assertRaises((ValueError, OSError), + multiprocessing.connection.Connection, -1) + + + +@hashlib_helper.requires_hashdigest('sha256') +class OtherTest(unittest.TestCase): + # TODO: add more tests for deliver/answer challenge. + def test_deliver_challenge_auth_failure(self): + class _FakeConnection(object): + def recv_bytes(self, size): + return b'something bogus' + def send_bytes(self, data): + pass + self.assertRaises(multiprocessing.AuthenticationError, + multiprocessing.connection.deliver_challenge, + _FakeConnection(), b'abc') + + def test_answer_challenge_auth_failure(self): + class _FakeConnection(object): + def __init__(self): + self.count = 0 + def recv_bytes(self, size): + self.count += 1 + if self.count == 1: + return multiprocessing.connection._CHALLENGE + elif self.count == 2: + return b'something bogus' + return b'' + def send_bytes(self, data): + pass + self.assertRaises(multiprocessing.AuthenticationError, + multiprocessing.connection.answer_challenge, + _FakeConnection(), b'abc') + + +@hashlib_helper.requires_hashdigest('md5') +@hashlib_helper.requires_hashdigest('sha256') +class ChallengeResponseTest(unittest.TestCase): + authkey = b'supadupasecretkey' + + def create_response(self, message): + return multiprocessing.connection._create_response( + self.authkey, message + ) + + def verify_challenge(self, message, response): + return multiprocessing.connection._verify_challenge( + self.authkey, message, response + ) + + def test_challengeresponse(self): + for algo in [None, "md5", "sha256"]: + with self.subTest(f"{algo=}"): + msg = b'is-twenty-bytes-long' # The length of a legacy message. + if algo: + prefix = b'{%s}' % algo.encode("ascii") + else: + prefix = b'' + msg = prefix + msg + response = self.create_response(msg) + if not response.startswith(prefix): + self.fail(response) + self.verify_challenge(msg, response) + + # TODO(gpshead): We need integration tests for handshakes between modern + # deliver_challenge() and verify_response() code and connections running a + # test-local copy of the legacy Python <=3.11 implementations. + + # TODO(gpshead): properly annotate tests for requires_hashdigest rather than + # only running these on a platform supporting everything. otherwise logic + # issues preventing it from working on FIPS mode setups will be hidden. + +# +# Test Manager.start()/Pool.__init__() initializer feature - see issue 5585 +# + +def initializer(ns): + ns.test += 1 + +@hashlib_helper.requires_hashdigest('sha256') +class TestInitializers(unittest.TestCase): + def setUp(self): + self.mgr = multiprocessing.Manager() + self.ns = self.mgr.Namespace() + self.ns.test = 0 + + def tearDown(self): + self.mgr.shutdown() + self.mgr.join() + + def test_manager_initializer(self): + m = multiprocessing.managers.SyncManager() + self.assertRaises(TypeError, m.start, 1) + m.start(initializer, (self.ns,)) + self.assertEqual(self.ns.test, 1) + m.shutdown() + m.join() + + def test_pool_initializer(self): + self.assertRaises(TypeError, multiprocessing.Pool, initializer=1) + p = multiprocessing.Pool(1, initializer, (self.ns,)) + p.close() + p.join() + self.assertEqual(self.ns.test, 1) + +# +# Issue 5155, 5313, 5331: Test process in processes +# Verifies os.close(sys.stdin.fileno) vs. sys.stdin.close() behavior +# + +def _this_sub_process(q): + try: + item = q.get(block=False) + except pyqueue.Empty: + pass + +def _test_process(): + queue = multiprocessing.Queue() + subProc = multiprocessing.Process(target=_this_sub_process, args=(queue,)) + subProc.daemon = True + subProc.start() + subProc.join() + +def _afunc(x): + return x*x + +def pool_in_process(): + pool = multiprocessing.Pool(processes=4) + x = pool.map(_afunc, [1, 2, 3, 4, 5, 6, 7]) + pool.close() + pool.join() + +class _file_like(object): + def __init__(self, delegate): + self._delegate = delegate + self._pid = None + + @property + def cache(self): + pid = os.getpid() + # There are no race conditions since fork keeps only the running thread + if pid != self._pid: + self._pid = pid + self._cache = [] + return self._cache + + def write(self, data): + self.cache.append(data) + + def flush(self): + self._delegate.write(''.join(self.cache)) + self._cache = [] + +class TestStdinBadfiledescriptor(unittest.TestCase): + + def test_queue_in_process(self): + proc = multiprocessing.Process(target=_test_process) + proc.start() + proc.join() + + def test_pool_in_process(self): + p = multiprocessing.Process(target=pool_in_process) + p.start() + p.join() + + def test_flushing(self): + sio = io.StringIO() + flike = _file_like(sio) + flike.write('foo') + proc = multiprocessing.Process(target=lambda: flike.flush()) + flike.flush() + assert sio.getvalue() == 'foo' + + +class TestWait(unittest.TestCase): + + @classmethod + def _child_test_wait(cls, w, slow): + for i in range(10): + if slow: + time.sleep(random.random() * 0.100) + w.send((i, os.getpid())) + w.close() + + def test_wait(self, slow=False): + from multiprocessing.connection import wait + readers = [] + procs = [] + messages = [] + + for i in range(4): + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=self._child_test_wait, args=(w, slow)) + p.daemon = True + p.start() + w.close() + readers.append(r) + procs.append(p) + self.addCleanup(p.join) + + while readers: + for r in wait(readers): + try: + msg = r.recv() + except EOFError: + readers.remove(r) + r.close() + else: + messages.append(msg) + + messages.sort() + expected = sorted((i, p.pid) for i in range(10) for p in procs) + self.assertEqual(messages, expected) + + @classmethod + def _child_test_wait_socket(cls, address, slow): + s = socket.socket() + s.connect(address) + for i in range(10): + if slow: + time.sleep(random.random() * 0.100) + s.sendall(('%s\n' % i).encode('ascii')) + s.close() + + def test_wait_socket(self, slow=False): + from multiprocessing.connection import wait + l = socket.create_server((socket_helper.HOST, 0)) + addr = l.getsockname() + readers = [] + procs = [] + dic = {} + + for i in range(4): + p = multiprocessing.Process(target=self._child_test_wait_socket, + args=(addr, slow)) + p.daemon = True + p.start() + procs.append(p) + self.addCleanup(p.join) + + for i in range(4): + r, _ = l.accept() + readers.append(r) + dic[r] = [] + l.close() + + while readers: + for r in wait(readers): + msg = r.recv(32) + if not msg: + readers.remove(r) + r.close() + else: + dic[r].append(msg) + + expected = ''.join('%s\n' % i for i in range(10)).encode('ascii') + for v in dic.values(): + self.assertEqual(b''.join(v), expected) + + def test_wait_slow(self): + self.test_wait(True) + + def test_wait_socket_slow(self): + self.test_wait_socket(True) + + @support.requires_resource('walltime') + def test_wait_timeout(self): + from multiprocessing.connection import wait + + timeout = 5.0 # seconds + a, b = multiprocessing.Pipe() + + start = time.monotonic() + res = wait([a, b], timeout) + delta = time.monotonic() - start + + self.assertEqual(res, []) + self.assertGreater(delta, timeout - CLOCK_RES) + + b.send(None) + res = wait([a, b], 20) + self.assertEqual(res, [a]) + + @classmethod + def signal_and_sleep(cls, sem, period): + sem.release() + time.sleep(period) + + @support.requires_resource('walltime') + def test_wait_integer(self): + from multiprocessing.connection import wait + + expected = 3 + sorted_ = lambda l: sorted(l, key=lambda x: id(x)) + sem = multiprocessing.Semaphore(0) + a, b = multiprocessing.Pipe() + p = multiprocessing.Process(target=self.signal_and_sleep, + args=(sem, expected)) + + p.start() + self.assertIsInstance(p.sentinel, int) + self.assertTrue(sem.acquire(timeout=20)) + + start = time.monotonic() + res = wait([a, p.sentinel, b], expected + 20) + delta = time.monotonic() - start + + self.assertEqual(res, [p.sentinel]) + self.assertLess(delta, expected + 2) + self.assertGreater(delta, expected - 2) + + a.send(None) + + start = time.monotonic() + res = wait([a, p.sentinel, b], 20) + delta = time.monotonic() - start + + self.assertEqual(sorted_(res), sorted_([p.sentinel, b])) + self.assertLess(delta, 0.4) + + b.send(None) + + start = time.monotonic() + res = wait([a, p.sentinel, b], 20) + delta = time.monotonic() - start + + self.assertEqual(sorted_(res), sorted_([a, p.sentinel, b])) + self.assertLess(delta, 0.4) + + p.terminate() + p.join() + + def test_neg_timeout(self): + from multiprocessing.connection import wait + a, b = multiprocessing.Pipe() + t = time.monotonic() + res = wait([a], timeout=-1) + t = time.monotonic() - t + self.assertEqual(res, []) + self.assertLess(t, 1) + a.close() + b.close() + +# +# Issue 14151: Test invalid family on invalid environment +# + +class TestInvalidFamily(unittest.TestCase): + + @unittest.skipIf(WIN32, "skipped on Windows") + def test_invalid_family(self): + with self.assertRaises(ValueError): + multiprocessing.connection.Listener(r'\\.\test') + + @unittest.skipUnless(WIN32, "skipped on non-Windows platforms") + def test_invalid_family_win32(self): + with self.assertRaises(ValueError): + multiprocessing.connection.Listener('/var/test.pipe') + +# +# Issue 12098: check sys.flags of child matches that for parent +# + +class TestFlags(unittest.TestCase): + @classmethod + def run_in_grandchild(cls, conn): + conn.send(tuple(sys.flags)) + + @classmethod + def run_in_child(cls, start_method): + import json + mp = multiprocessing.get_context(start_method) + r, w = mp.Pipe(duplex=False) + p = mp.Process(target=cls.run_in_grandchild, args=(w,)) + with warnings.catch_warnings(category=DeprecationWarning): + p.start() + grandchild_flags = r.recv() + p.join() + r.close() + w.close() + flags = (tuple(sys.flags), grandchild_flags) + print(json.dumps(flags)) + + def test_flags(self): + import json + # start child process using unusual flags + prog = ( + 'from test._test_multiprocessing import TestFlags; ' + f'TestFlags.run_in_child({multiprocessing.get_start_method()!r})' + ) + data = subprocess.check_output( + [sys.executable, '-E', '-S', '-O', '-c', prog]) + child_flags, grandchild_flags = json.loads(data.decode('ascii')) + self.assertEqual(child_flags, grandchild_flags) + +# +# Test interaction with socket timeouts - see Issue #6056 +# + +class TestTimeouts(unittest.TestCase): + @classmethod + def _test_timeout(cls, child, address): + time.sleep(1) + child.send(123) + child.close() + conn = multiprocessing.connection.Client(address) + conn.send(456) + conn.close() + + def test_timeout(self): + old_timeout = socket.getdefaulttimeout() + try: + socket.setdefaulttimeout(0.1) + parent, child = multiprocessing.Pipe(duplex=True) + l = multiprocessing.connection.Listener(family='AF_INET') + p = multiprocessing.Process(target=self._test_timeout, + args=(child, l.address)) + p.start() + child.close() + self.assertEqual(parent.recv(), 123) + parent.close() + conn = l.accept() + self.assertEqual(conn.recv(), 456) + conn.close() + l.close() + join_process(p) + finally: + socket.setdefaulttimeout(old_timeout) + +# +# Test what happens with no "if __name__ == '__main__'" +# + +class TestNoForkBomb(unittest.TestCase): + def test_noforkbomb(self): + sm = multiprocessing.get_start_method() + name = os.path.join(os.path.dirname(__file__), 'mp_fork_bomb.py') + if sm != 'fork': + rc, out, err = test.support.script_helper.assert_python_failure(name, sm) + self.assertEqual(out, b'') + self.assertIn(b'RuntimeError', err) + else: + rc, out, err = test.support.script_helper.assert_python_ok(name, sm) + self.assertEqual(out.rstrip(), b'123') + self.assertEqual(err, b'') + +# +# Issue #17555: ForkAwareThreadLock +# + +class TestForkAwareThreadLock(unittest.TestCase): + # We recursively start processes. Issue #17555 meant that the + # after fork registry would get duplicate entries for the same + # lock. The size of the registry at generation n was ~2**n. + + @classmethod + def child(cls, n, conn): + if n > 1: + p = multiprocessing.Process(target=cls.child, args=(n-1, conn)) + p.start() + conn.close() + join_process(p) + else: + conn.send(len(util._afterfork_registry)) + conn.close() + + def test_lock(self): + r, w = multiprocessing.Pipe(False) + l = util.ForkAwareThreadLock() + old_size = len(util._afterfork_registry) + p = multiprocessing.Process(target=self.child, args=(5, w)) + p.start() + w.close() + new_size = r.recv() + join_process(p) + self.assertLessEqual(new_size, old_size) + +# +# Check that non-forked child processes do not inherit unneeded fds/handles +# + +class TestCloseFds(unittest.TestCase): + + def get_high_socket_fd(self): + if WIN32: + # The child process will not have any socket handles, so + # calling socket.fromfd() should produce WSAENOTSOCK even + # if there is a handle of the same number. + return socket.socket().detach() + else: + # We want to produce a socket with an fd high enough that a + # freshly created child process will not have any fds as high. + fd = socket.socket().detach() + to_close = [] + while fd < 50: + to_close.append(fd) + fd = os.dup(fd) + for x in to_close: + os.close(x) + return fd + + def close(self, fd): + if WIN32: + socket.socket(socket.AF_INET, socket.SOCK_STREAM, fileno=fd).close() + else: + os.close(fd) + + @classmethod + def _test_closefds(cls, conn, fd): + try: + s = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM) + except Exception as e: + conn.send(e) + else: + s.close() + conn.send(None) + + def test_closefd(self): + if not HAS_REDUCTION: + raise unittest.SkipTest('requires fd pickling') + + reader, writer = multiprocessing.Pipe() + fd = self.get_high_socket_fd() + try: + p = multiprocessing.Process(target=self._test_closefds, + args=(writer, fd)) + p.start() + writer.close() + e = reader.recv() + join_process(p) + finally: + self.close(fd) + writer.close() + reader.close() + + if multiprocessing.get_start_method() == 'fork': + self.assertIs(e, None) + else: + WSAENOTSOCK = 10038 + self.assertIsInstance(e, OSError) + self.assertTrue(e.errno == errno.EBADF or + e.winerror == WSAENOTSOCK, e) + +# +# Issue #17097: EINTR should be ignored by recv(), send(), accept() etc +# + +class TestIgnoreEINTR(unittest.TestCase): + + # Sending CONN_MAX_SIZE bytes into a multiprocessing pipe must block + CONN_MAX_SIZE = max(support.PIPE_MAX_SIZE, support.SOCK_MAX_SIZE) + + @classmethod + def _test_ignore(cls, conn): + def handler(signum, frame): + pass + signal.signal(signal.SIGUSR1, handler) + conn.send('ready') + x = conn.recv() + conn.send(x) + conn.send_bytes(b'x' * cls.CONN_MAX_SIZE) + + @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') + def test_ignore(self): + conn, child_conn = multiprocessing.Pipe() + try: + p = multiprocessing.Process(target=self._test_ignore, + args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + self.assertEqual(conn.recv(), 'ready') + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + time.sleep(0.1) + conn.send(1234) + self.assertEqual(conn.recv(), 1234) + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + self.assertEqual(conn.recv_bytes(), b'x' * self.CONN_MAX_SIZE) + time.sleep(0.1) + p.join() + finally: + conn.close() + + @classmethod + def _test_ignore_listener(cls, conn): + def handler(signum, frame): + pass + signal.signal(signal.SIGUSR1, handler) + with multiprocessing.connection.Listener() as l: + conn.send(l.address) + a = l.accept() + a.send('welcome') + + @unittest.skipUnless(hasattr(signal, 'SIGUSR1'), 'requires SIGUSR1') + def test_ignore_listener(self): + conn, child_conn = multiprocessing.Pipe() + try: + p = multiprocessing.Process(target=self._test_ignore_listener, + args=(child_conn,)) + p.daemon = True + p.start() + child_conn.close() + address = conn.recv() + time.sleep(0.1) + os.kill(p.pid, signal.SIGUSR1) + time.sleep(0.1) + client = multiprocessing.connection.Client(address) + self.assertEqual(client.recv(), 'welcome') + p.join() + finally: + conn.close() + +class TestStartMethod(unittest.TestCase): + @classmethod + def _check_context(cls, conn): + conn.send(multiprocessing.get_start_method()) + + def check_context(self, ctx): + r, w = ctx.Pipe(duplex=False) + p = ctx.Process(target=self._check_context, args=(w,)) + p.start() + w.close() + child_method = r.recv() + r.close() + p.join() + self.assertEqual(child_method, ctx.get_start_method()) + + def test_context(self): + for method in ('fork', 'spawn', 'forkserver'): + try: + ctx = multiprocessing.get_context(method) + except ValueError: + continue + self.assertEqual(ctx.get_start_method(), method) + self.assertIs(ctx.get_context(), ctx) + self.assertRaises(ValueError, ctx.set_start_method, 'spawn') + self.assertRaises(ValueError, ctx.set_start_method, None) + self.check_context(ctx) + + def test_context_check_module_types(self): + try: + ctx = multiprocessing.get_context('forkserver') + except ValueError: + raise unittest.SkipTest('forkserver should be available') + with self.assertRaisesRegex(TypeError, 'module_names must be a list of strings'): + ctx.set_forkserver_preload([1, 2, 3]) + + def test_set_get(self): + multiprocessing.set_forkserver_preload(PRELOAD) + count = 0 + old_method = multiprocessing.get_start_method() + try: + for method in ('fork', 'spawn', 'forkserver'): + try: + multiprocessing.set_start_method(method, force=True) + except ValueError: + continue + self.assertEqual(multiprocessing.get_start_method(), method) + ctx = multiprocessing.get_context() + self.assertEqual(ctx.get_start_method(), method) + self.assertTrue(type(ctx).__name__.lower().startswith(method)) + self.assertTrue( + ctx.Process.__name__.lower().startswith(method)) + self.check_context(multiprocessing) + count += 1 + finally: + multiprocessing.set_start_method(old_method, force=True) + self.assertGreaterEqual(count, 1) + + def test_get_all(self): + methods = multiprocessing.get_all_start_methods() + if sys.platform == 'win32': + self.assertEqual(methods, ['spawn']) + else: + self.assertTrue(methods == ['fork', 'spawn'] or + methods == ['spawn', 'fork'] or + methods == ['fork', 'spawn', 'forkserver'] or + methods == ['spawn', 'fork', 'forkserver']) + + def test_preload_resources(self): + if multiprocessing.get_start_method() != 'forkserver': + self.skipTest("test only relevant for 'forkserver' method") + name = os.path.join(os.path.dirname(__file__), 'mp_preload.py') + rc, out, err = test.support.script_helper.assert_python_ok(name) + out = out.decode() + err = err.decode() + if out.rstrip() != 'ok' or err != '': + print(out) + print(err) + self.fail("failed spawning forkserver or grandchild") + + @unittest.skipIf(sys.platform == "win32", + "Only Spawn on windows so no risk of mixing") + @only_run_in_spawn_testsuite("avoids redundant testing.") + def test_mixed_startmethod(self): + # Fork-based locks cannot be used with spawned process + for process_method in ["spawn", "forkserver"]: + queue = multiprocessing.get_context("fork").Queue() + process_ctx = multiprocessing.get_context(process_method) + p = process_ctx.Process(target=close_queue, args=(queue,)) + err_msg = "A SemLock created in a fork" + with self.assertRaisesRegex(RuntimeError, err_msg): + p.start() + + # non-fork-based locks can be used with all other start methods + for queue_method in ["spawn", "forkserver"]: + for process_method in multiprocessing.get_all_start_methods(): + queue = multiprocessing.get_context(queue_method).Queue() + process_ctx = multiprocessing.get_context(process_method) + p = process_ctx.Process(target=close_queue, args=(queue,)) + p.start() + p.join() + + @classmethod + def _put_one_in_queue(cls, queue): + queue.put(1) + + @classmethod + def _put_two_and_nest_once(cls, queue): + queue.put(2) + process = multiprocessing.Process(target=cls._put_one_in_queue, args=(queue,)) + process.start() + process.join() + + def test_nested_startmethod(self): + # gh-108520: Regression test to ensure that child process can send its + # arguments to another process + queue = multiprocessing.Queue() + + process = multiprocessing.Process(target=self._put_two_and_nest_once, args=(queue,)) + process.start() + process.join() + + results = [] + while not queue.empty(): + results.append(queue.get()) + + # gh-109706: queue.put(1) can write into the queue before queue.put(2), + # there is no synchronization in the test. + self.assertSetEqual(set(results), set([2, 1])) + + +@unittest.skipIf(sys.platform == "win32", + "test semantics don't make sense on Windows") +class TestResourceTracker(unittest.TestCase): + + def test_resource_tracker(self): + # + # Check that killing process does not leak named semaphores + # + cmd = '''if 1: + import time, os + import multiprocessing as mp + from multiprocessing import resource_tracker + from multiprocessing.shared_memory import SharedMemory + + mp.set_start_method("spawn") + + + def create_and_register_resource(rtype): + if rtype == "semaphore": + lock = mp.Lock() + return lock, lock._semlock.name + elif rtype == "shared_memory": + sm = SharedMemory(create=True, size=10) + return sm, sm._name + else: + raise ValueError( + "Resource type {{}} not understood".format(rtype)) + + + resource1, rname1 = create_and_register_resource("{rtype}") + resource2, rname2 = create_and_register_resource("{rtype}") + + os.write({w}, rname1.encode("ascii") + b"\\n") + os.write({w}, rname2.encode("ascii") + b"\\n") + + time.sleep(10) + ''' + for rtype in resource_tracker._CLEANUP_FUNCS: + with self.subTest(rtype=rtype): + if rtype == "noop": + # Artefact resource type used by the resource_tracker + continue + r, w = os.pipe() + p = subprocess.Popen([sys.executable, + '-E', '-c', cmd.format(w=w, rtype=rtype)], + pass_fds=[w], + stderr=subprocess.PIPE) + os.close(w) + with open(r, 'rb', closefd=True) as f: + name1 = f.readline().rstrip().decode('ascii') + name2 = f.readline().rstrip().decode('ascii') + _resource_unlink(name1, rtype) + p.terminate() + p.wait() + + err_msg = (f"A {rtype} resource was leaked after a process was " + f"abruptly terminated") + for _ in support.sleeping_retry(support.SHORT_TIMEOUT, + err_msg): + try: + _resource_unlink(name2, rtype) + except OSError as e: + # docs say it should be ENOENT, but OSX seems to give + # EINVAL + self.assertIn(e.errno, (errno.ENOENT, errno.EINVAL)) + break + + err = p.stderr.read().decode('utf-8') + p.stderr.close() + expected = ('resource_tracker: There appear to be 2 leaked {} ' + 'objects'.format( + rtype)) + self.assertRegex(err, expected) + self.assertRegex(err, r'resource_tracker: %r: \[Errno' % name1) + + def check_resource_tracker_death(self, signum, should_die): + # bpo-31310: if the semaphore tracker process has died, it should + # be restarted implicitly. + from multiprocessing.resource_tracker import _resource_tracker + pid = _resource_tracker._pid + if pid is not None: + os.kill(pid, signal.SIGKILL) + support.wait_process(pid, exitcode=-signal.SIGKILL) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _resource_tracker.ensure_running() + pid = _resource_tracker._pid + + os.kill(pid, signum) + time.sleep(1.0) # give it time to die + + ctx = multiprocessing.get_context("spawn") + with warnings.catch_warnings(record=True) as all_warn: + warnings.simplefilter("always") + sem = ctx.Semaphore() + sem.acquire() + sem.release() + wr = weakref.ref(sem) + # ensure `sem` gets collected, which triggers communication with + # the semaphore tracker + del sem + gc.collect() + self.assertIsNone(wr()) + if should_die: + self.assertEqual(len(all_warn), 1) + the_warn = all_warn[0] + self.assertTrue(issubclass(the_warn.category, UserWarning)) + self.assertTrue("resource_tracker: process died" + in str(the_warn.message)) + else: + self.assertEqual(len(all_warn), 0) + + def test_resource_tracker_sigint(self): + # Catchable signal (ignored by semaphore tracker) + self.check_resource_tracker_death(signal.SIGINT, False) + + def test_resource_tracker_sigterm(self): + # Catchable signal (ignored by semaphore tracker) + self.check_resource_tracker_death(signal.SIGTERM, False) + + def test_resource_tracker_sigkill(self): + # Uncatchable signal. + self.check_resource_tracker_death(signal.SIGKILL, True) + + @staticmethod + def _is_resource_tracker_reused(conn, pid): + from multiprocessing.resource_tracker import _resource_tracker + _resource_tracker.ensure_running() + # The pid should be None in the child process, expect for the fork + # context. It should not be a new value. + reused = _resource_tracker._pid in (None, pid) + reused &= _resource_tracker._check_alive() + conn.send(reused) + + def test_resource_tracker_reused(self): + from multiprocessing.resource_tracker import _resource_tracker + _resource_tracker.ensure_running() + pid = _resource_tracker._pid + + r, w = multiprocessing.Pipe(duplex=False) + p = multiprocessing.Process(target=self._is_resource_tracker_reused, + args=(w, pid)) + p.start() + is_resource_tracker_reused = r.recv() + + # Clean up + p.join() + w.close() + r.close() + + self.assertTrue(is_resource_tracker_reused) + + def test_too_long_name_resource(self): + # gh-96819: Resource names that will make the length of a write to a pipe + # greater than PIPE_BUF are not allowed + rtype = "shared_memory" + too_long_name_resource = "a" * (512 - len(rtype)) + with self.assertRaises(ValueError): + resource_tracker.register(too_long_name_resource, rtype) + + +class TestSimpleQueue(unittest.TestCase): + + @classmethod + def _test_empty(cls, queue, child_can_start, parent_can_continue): + child_can_start.wait() + # issue 30301, could fail under spawn and forkserver + try: + queue.put(queue.empty()) + queue.put(queue.empty()) + finally: + parent_can_continue.set() + + def test_empty(self): + queue = multiprocessing.SimpleQueue() + child_can_start = multiprocessing.Event() + parent_can_continue = multiprocessing.Event() + + proc = multiprocessing.Process( + target=self._test_empty, + args=(queue, child_can_start, parent_can_continue) + ) + proc.daemon = True + proc.start() + + self.assertTrue(queue.empty()) + + child_can_start.set() + parent_can_continue.wait() + + self.assertFalse(queue.empty()) + self.assertEqual(queue.get(), True) + self.assertEqual(queue.get(), False) + self.assertTrue(queue.empty()) + + proc.join() + + def test_close(self): + queue = multiprocessing.SimpleQueue() + queue.close() + # closing a queue twice should not fail + queue.close() + + # Test specific to CPython since it tests private attributes + @test.support.cpython_only + def test_closed(self): + queue = multiprocessing.SimpleQueue() + queue.close() + self.assertTrue(queue._reader.closed) + self.assertTrue(queue._writer.closed) + + +class TestPoolNotLeakOnFailure(unittest.TestCase): + + def test_release_unused_processes(self): + # Issue #19675: During pool creation, if we can't create a process, + # don't leak already created ones. + will_fail_in = 3 + forked_processes = [] + + class FailingForkProcess: + def __init__(self, **kwargs): + self.name = 'Fake Process' + self.exitcode = None + self.state = None + forked_processes.append(self) + + def start(self): + nonlocal will_fail_in + if will_fail_in <= 0: + raise OSError("Manually induced OSError") + will_fail_in -= 1 + self.state = 'started' + + def terminate(self): + self.state = 'stopping' + + def join(self): + if self.state == 'stopping': + self.state = 'stopped' + + def is_alive(self): + return self.state == 'started' or self.state == 'stopping' + + with self.assertRaisesRegex(OSError, 'Manually induced OSError'): + p = multiprocessing.pool.Pool(5, context=unittest.mock.MagicMock( + Process=FailingForkProcess)) + p.close() + p.join() + self.assertFalse( + any(process.is_alive() for process in forked_processes)) + + +@hashlib_helper.requires_hashdigest('sha256') +class TestSyncManagerTypes(unittest.TestCase): + """Test all the types which can be shared between a parent and a + child process by using a manager which acts as an intermediary + between them. + + In the following unit-tests the base type is created in the parent + process, the @classmethod represents the worker process and the + shared object is readable and editable between the two. + + # The child. + @classmethod + def _test_list(cls, obj): + assert obj[0] == 5 + assert obj.append(6) + + # The parent. + def test_list(self): + o = self.manager.list() + o.append(5) + self.run_worker(self._test_list, o) + assert o[1] == 6 + """ + manager_class = multiprocessing.managers.SyncManager + + def setUp(self): + self.manager = self.manager_class() + self.manager.start() + self.proc = None + + def tearDown(self): + if self.proc is not None and self.proc.is_alive(): + self.proc.terminate() + self.proc.join() + self.manager.shutdown() + self.manager = None + self.proc = None + + @classmethod + def setUpClass(cls): + support.reap_children() + + tearDownClass = setUpClass + + def wait_proc_exit(self): + # Only the manager process should be returned by active_children() + # but this can take a bit on slow machines, so wait a few seconds + # if there are other children too (see #17395). + join_process(self.proc) + + timeout = WAIT_ACTIVE_CHILDREN_TIMEOUT + start_time = time.monotonic() + for _ in support.sleeping_retry(timeout, error=False): + if len(multiprocessing.active_children()) <= 1: + break + else: + dt = time.monotonic() - start_time + support.environment_altered = True + support.print_warning(f"multiprocessing.Manager still has " + f"{multiprocessing.active_children()} " + f"active children after {dt:.1f} seconds") + + def run_worker(self, worker, obj): + self.proc = multiprocessing.Process(target=worker, args=(obj, )) + self.proc.daemon = True + self.proc.start() + self.wait_proc_exit() + self.assertEqual(self.proc.exitcode, 0) + + @classmethod + def _test_event(cls, obj): + assert obj.is_set() + obj.wait() + obj.clear() + obj.wait(0.001) + + def test_event(self): + o = self.manager.Event() + o.set() + self.run_worker(self._test_event, o) + assert not o.is_set() + o.wait(0.001) + + @classmethod + def _test_lock(cls, obj): + obj.acquire() + + def test_lock(self, lname="Lock"): + o = getattr(self.manager, lname)() + self.run_worker(self._test_lock, o) + o.release() + self.assertRaises(RuntimeError, o.release) # already released + + @classmethod + def _test_rlock(cls, obj): + obj.acquire() + obj.release() + + def test_rlock(self, lname="Lock"): + o = getattr(self.manager, lname)() + self.run_worker(self._test_rlock, o) + + @classmethod + def _test_semaphore(cls, obj): + obj.acquire() + + def test_semaphore(self, sname="Semaphore"): + o = getattr(self.manager, sname)() + self.run_worker(self._test_semaphore, o) + o.release() + + def test_bounded_semaphore(self): + self.test_semaphore(sname="BoundedSemaphore") + + @classmethod + def _test_condition(cls, obj): + obj.acquire() + obj.release() + + def test_condition(self): + o = self.manager.Condition() + self.run_worker(self._test_condition, o) + + @classmethod + def _test_barrier(cls, obj): + assert obj.parties == 5 + obj.reset() + + def test_barrier(self): + o = self.manager.Barrier(5) + self.run_worker(self._test_barrier, o) + + @classmethod + def _test_pool(cls, obj): + # TODO: fix https://bugs.python.org/issue35919 + with obj: + pass + + def test_pool(self): + o = self.manager.Pool(processes=4) + self.run_worker(self._test_pool, o) + + @classmethod + def _test_queue(cls, obj): + assert obj.qsize() == 2 + assert obj.full() + assert not obj.empty() + assert obj.get() == 5 + assert not obj.empty() + assert obj.get() == 6 + assert obj.empty() + + def test_queue(self, qname="Queue"): + o = getattr(self.manager, qname)(2) + o.put(5) + o.put(6) + self.run_worker(self._test_queue, o) + assert o.empty() + assert not o.full() + + def test_joinable_queue(self): + self.test_queue("JoinableQueue") + + @classmethod + def _test_list(cls, obj): + case = unittest.TestCase() + case.assertEqual(obj[0], 5) + case.assertEqual(obj.count(5), 1) + case.assertEqual(obj.index(5), 0) + obj.sort() + obj.reverse() + for x in obj: + pass + case.assertEqual(len(obj), 1) + case.assertEqual(obj.pop(0), 5) + + def test_list(self): + o = self.manager.list() + o.append(5) + self.run_worker(self._test_list, o) + self.assertIsNotNone(o) + self.assertEqual(len(o), 0) + + @classmethod + def _test_dict(cls, obj): + case = unittest.TestCase() + case.assertEqual(len(obj), 1) + case.assertEqual(obj['foo'], 5) + case.assertEqual(obj.get('foo'), 5) + case.assertListEqual(list(obj.items()), [('foo', 5)]) + case.assertListEqual(list(obj.keys()), ['foo']) + case.assertListEqual(list(obj.values()), [5]) + case.assertDictEqual(obj.copy(), {'foo': 5}) + case.assertTupleEqual(obj.popitem(), ('foo', 5)) + + def test_dict(self): + o = self.manager.dict() + o['foo'] = 5 + self.run_worker(self._test_dict, o) + self.assertIsNotNone(o) + self.assertEqual(len(o), 0) + + @classmethod + def _test_value(cls, obj): + case = unittest.TestCase() + case.assertEqual(obj.value, 1) + case.assertEqual(obj.get(), 1) + obj.set(2) + + def test_value(self): + o = self.manager.Value('i', 1) + self.run_worker(self._test_value, o) + self.assertEqual(o.value, 2) + self.assertEqual(o.get(), 2) + + @classmethod + def _test_array(cls, obj): + case = unittest.TestCase() + case.assertEqual(obj[0], 0) + case.assertEqual(obj[1], 1) + case.assertEqual(len(obj), 2) + case.assertListEqual(list(obj), [0, 1]) + + def test_array(self): + o = self.manager.Array('i', [0, 1]) + self.run_worker(self._test_array, o) + + @classmethod + def _test_namespace(cls, obj): + case = unittest.TestCase() + case.assertEqual(obj.x, 0) + case.assertEqual(obj.y, 1) + + def test_namespace(self): + o = self.manager.Namespace() + o.x = 0 + o.y = 1 + self.run_worker(self._test_namespace, o) + + +class TestNamedResource(unittest.TestCase): + @only_run_in_spawn_testsuite("spawn specific test.") + def test_global_named_resource_spawn(self): + # + # gh-90549: Check that global named resources in main module + # will not leak by a subprocess, in spawn context. + # + testfn = os_helper.TESTFN + self.addCleanup(os_helper.unlink, testfn) + with open(testfn, 'w', encoding='utf-8') as f: + f.write(textwrap.dedent('''\ + import multiprocessing as mp + ctx = mp.get_context('spawn') + global_resource = ctx.Semaphore() + def submain(): pass + if __name__ == '__main__': + p = ctx.Process(target=submain) + p.start() + p.join() + ''')) + rc, out, err = script_helper.assert_python_ok(testfn) + # on error, err = 'UserWarning: resource_tracker: There appear to + # be 1 leaked semaphore objects to clean up at shutdown' + self.assertFalse(err, msg=err.decode('utf-8')) + + +class MiscTestCase(unittest.TestCase): + def test__all__(self): + # Just make sure names in not_exported are excluded + support.check__all__(self, multiprocessing, extra=multiprocessing.__all__, + not_exported=['SUBDEBUG', 'SUBWARNING']) + + @only_run_in_spawn_testsuite("avoids redundant testing.") + def test_spawn_sys_executable_none_allows_import(self): + # Regression test for a bug introduced in + # https://github.com/python/cpython/issues/90876 that caused an + # ImportError in multiprocessing when sys.executable was None. + # This can be true in embedded environments. + rc, out, err = script_helper.assert_python_ok( + "-c", + """if 1: + import sys + sys.executable = None + assert "multiprocessing" not in sys.modules, "already imported!" + import multiprocessing + import multiprocessing.spawn # This should not fail\n""", + ) + self.assertEqual(rc, 0) + self.assertFalse(err, msg=err.decode('utf-8')) + + +# +# Mixins +# + +class BaseMixin(object): + @classmethod + def setUpClass(cls): + cls.dangling = (multiprocessing.process._dangling.copy(), + threading._dangling.copy()) + + @classmethod + def tearDownClass(cls): + # bpo-26762: Some multiprocessing objects like Pool create reference + # cycles. Trigger a garbage collection to break these cycles. + test.support.gc_collect() + + processes = set(multiprocessing.process._dangling) - set(cls.dangling[0]) + if processes: + test.support.environment_altered = True + support.print_warning(f'Dangling processes: {processes}') + processes = None + + threads = set(threading._dangling) - set(cls.dangling[1]) + if threads: + test.support.environment_altered = True + support.print_warning(f'Dangling threads: {threads}') + threads = None + + +class ProcessesMixin(BaseMixin): + TYPE = 'processes' + Process = multiprocessing.Process + connection = multiprocessing.connection + current_process = staticmethod(multiprocessing.current_process) + parent_process = staticmethod(multiprocessing.parent_process) + active_children = staticmethod(multiprocessing.active_children) + set_executable = staticmethod(multiprocessing.set_executable) + Pool = staticmethod(multiprocessing.Pool) + Pipe = staticmethod(multiprocessing.Pipe) + Queue = staticmethod(multiprocessing.Queue) + JoinableQueue = staticmethod(multiprocessing.JoinableQueue) + Lock = staticmethod(multiprocessing.Lock) + RLock = staticmethod(multiprocessing.RLock) + Semaphore = staticmethod(multiprocessing.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.BoundedSemaphore) + Condition = staticmethod(multiprocessing.Condition) + Event = staticmethod(multiprocessing.Event) + Barrier = staticmethod(multiprocessing.Barrier) + Value = staticmethod(multiprocessing.Value) + Array = staticmethod(multiprocessing.Array) + RawValue = staticmethod(multiprocessing.RawValue) + RawArray = staticmethod(multiprocessing.RawArray) + + +class ManagerMixin(BaseMixin): + TYPE = 'manager' + Process = multiprocessing.Process + Queue = property(operator.attrgetter('manager.Queue')) + JoinableQueue = property(operator.attrgetter('manager.JoinableQueue')) + Lock = property(operator.attrgetter('manager.Lock')) + RLock = property(operator.attrgetter('manager.RLock')) + Semaphore = property(operator.attrgetter('manager.Semaphore')) + BoundedSemaphore = property(operator.attrgetter('manager.BoundedSemaphore')) + Condition = property(operator.attrgetter('manager.Condition')) + Event = property(operator.attrgetter('manager.Event')) + Barrier = property(operator.attrgetter('manager.Barrier')) + Value = property(operator.attrgetter('manager.Value')) + Array = property(operator.attrgetter('manager.Array')) + list = property(operator.attrgetter('manager.list')) + dict = property(operator.attrgetter('manager.dict')) + Namespace = property(operator.attrgetter('manager.Namespace')) + + @classmethod + def Pool(cls, *args, **kwds): + return cls.manager.Pool(*args, **kwds) + + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.manager = multiprocessing.Manager() + + @classmethod + def tearDownClass(cls): + # only the manager process should be returned by active_children() + # but this can take a bit on slow machines, so wait a few seconds + # if there are other children too (see #17395) + timeout = WAIT_ACTIVE_CHILDREN_TIMEOUT + start_time = time.monotonic() + for _ in support.sleeping_retry(timeout, error=False): + if len(multiprocessing.active_children()) <= 1: + break + else: + dt = time.monotonic() - start_time + support.environment_altered = True + support.print_warning(f"multiprocessing.Manager still has " + f"{multiprocessing.active_children()} " + f"active children after {dt:.1f} seconds") + + gc.collect() # do garbage collection + if cls.manager._number_of_objects() != 0: + # This is not really an error since some tests do not + # ensure that all processes which hold a reference to a + # managed object have been joined. + test.support.environment_altered = True + support.print_warning('Shared objects which still exist ' + 'at manager shutdown:') + support.print_warning(cls.manager._debug_info()) + cls.manager.shutdown() + cls.manager.join() + cls.manager = None + + super().tearDownClass() + + +class ThreadsMixin(BaseMixin): + TYPE = 'threads' + Process = multiprocessing.dummy.Process + connection = multiprocessing.dummy.connection + current_process = staticmethod(multiprocessing.dummy.current_process) + active_children = staticmethod(multiprocessing.dummy.active_children) + Pool = staticmethod(multiprocessing.dummy.Pool) + Pipe = staticmethod(multiprocessing.dummy.Pipe) + Queue = staticmethod(multiprocessing.dummy.Queue) + JoinableQueue = staticmethod(multiprocessing.dummy.JoinableQueue) + Lock = staticmethod(multiprocessing.dummy.Lock) + RLock = staticmethod(multiprocessing.dummy.RLock) + Semaphore = staticmethod(multiprocessing.dummy.Semaphore) + BoundedSemaphore = staticmethod(multiprocessing.dummy.BoundedSemaphore) + Condition = staticmethod(multiprocessing.dummy.Condition) + Event = staticmethod(multiprocessing.dummy.Event) + Barrier = staticmethod(multiprocessing.dummy.Barrier) + Value = staticmethod(multiprocessing.dummy.Value) + Array = staticmethod(multiprocessing.dummy.Array) + +# +# Functions used to create test cases from the base ones in this module +# + +def install_tests_in_module_dict(remote_globs, start_method, + only_type=None, exclude_types=False): + __module__ = remote_globs['__name__'] + local_globs = globals() + ALL_TYPES = {'processes', 'threads', 'manager'} + + for name, base in local_globs.items(): + if not isinstance(base, type): + continue + if issubclass(base, BaseTestCase): + if base is BaseTestCase: + continue + assert set(base.ALLOWED_TYPES) <= ALL_TYPES, base.ALLOWED_TYPES + for type_ in base.ALLOWED_TYPES: + if only_type and type_ != only_type: + continue + if exclude_types: + continue + newname = 'With' + type_.capitalize() + name[1:] + Mixin = local_globs[type_.capitalize() + 'Mixin'] + class Temp(base, Mixin, unittest.TestCase): + pass + if type_ == 'manager': + Temp = hashlib_helper.requires_hashdigest('sha256')(Temp) + Temp.__name__ = Temp.__qualname__ = newname + Temp.__module__ = __module__ + remote_globs[newname] = Temp + elif issubclass(base, unittest.TestCase): + if only_type: + continue + + class Temp(base, object): + pass + Temp.__name__ = Temp.__qualname__ = name + Temp.__module__ = __module__ + remote_globs[name] = Temp + + dangling = [None, None] + old_start_method = [None] + + def setUpModule(): + multiprocessing.set_forkserver_preload(PRELOAD) + multiprocessing.process._cleanup() + dangling[0] = multiprocessing.process._dangling.copy() + dangling[1] = threading._dangling.copy() + old_start_method[0] = multiprocessing.get_start_method(allow_none=True) + try: + multiprocessing.set_start_method(start_method, force=True) + except ValueError: + raise unittest.SkipTest(start_method + + ' start method not supported') + + if sys.platform.startswith("linux"): + try: + lock = multiprocessing.RLock() + except OSError: + raise unittest.SkipTest("OSError raises on RLock creation, " + "see issue 3111!") + check_enough_semaphores() + util.get_temp_dir() # creates temp directory + multiprocessing.get_logger().setLevel(LOG_LEVEL) + + def tearDownModule(): + need_sleep = False + + # bpo-26762: Some multiprocessing objects like Pool create reference + # cycles. Trigger a garbage collection to break these cycles. + test.support.gc_collect() + + multiprocessing.set_start_method(old_start_method[0], force=True) + # pause a bit so we don't get warning about dangling threads/processes + processes = set(multiprocessing.process._dangling) - set(dangling[0]) + if processes: + need_sleep = True + test.support.environment_altered = True + support.print_warning(f'Dangling processes: {processes}') + processes = None + + threads = set(threading._dangling) - set(dangling[1]) + if threads: + need_sleep = True + test.support.environment_altered = True + support.print_warning(f'Dangling threads: {threads}') + threads = None + + # Sleep 500 ms to give time to child processes to complete. + if need_sleep: + time.sleep(0.5) + + multiprocessing.util._cleanup_tests() + + remote_globs['setUpModule'] = setUpModule + remote_globs['tearDownModule'] = tearDownModule + + +@unittest.skipIf(not hasattr(_multiprocessing, 'SemLock'), 'SemLock not available') +@unittest.skipIf(sys.platform != "linux", "Linux only") +class SemLockTests(unittest.TestCase): + + def test_semlock_subclass(self): + class SemLock(_multiprocessing.SemLock): + pass + name = f'test_semlock_subclass-{os.getpid()}' + s = SemLock(1, 0, 10, name, False) + _multiprocessing.sem_unlink(name) diff --git a/Lib/test/test_multiprocessing_fork/__init__.py b/Lib/test/test_multiprocessing_fork/__init__.py new file mode 100644 index 0000000000..aa1fff50b2 --- /dev/null +++ b/Lib/test/test_multiprocessing_fork/__init__.py @@ -0,0 +1,16 @@ +import os.path +import sys +import unittest +from test import support + +if support.PGO: + raise unittest.SkipTest("test is not helpful for PGO") + +if sys.platform == "win32": + raise unittest.SkipTest("fork is not available on Windows") + +if sys.platform == 'darwin': + raise unittest.SkipTest("test may crash on macOS (bpo-33725)") + +def load_tests(*args): + return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_multiprocessing_fork/test_manager.py b/Lib/test/test_multiprocessing_fork/test_manager.py new file mode 100644 index 0000000000..9efbb83bbb --- /dev/null +++ b/Lib/test/test_multiprocessing_fork/test_manager.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'fork', only_type="manager") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_fork/test_misc.py b/Lib/test/test_multiprocessing_fork/test_misc.py new file mode 100644 index 0000000000..891a494020 --- /dev/null +++ b/Lib/test/test_multiprocessing_fork/test_misc.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'fork', exclude_types=True) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_fork/test_processes.py b/Lib/test/test_multiprocessing_fork/test_processes.py new file mode 100644 index 0000000000..e64e9afc01 --- /dev/null +++ b/Lib/test/test_multiprocessing_fork/test_processes.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'fork', only_type="processes") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_fork/test_threads.py b/Lib/test/test_multiprocessing_fork/test_threads.py new file mode 100644 index 0000000000..1670e34cb1 --- /dev/null +++ b/Lib/test/test_multiprocessing_fork/test_threads.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'fork', only_type="threads") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_forkserver/__init__.py b/Lib/test/test_multiprocessing_forkserver/__init__.py new file mode 100644 index 0000000000..d91715a344 --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/__init__.py @@ -0,0 +1,13 @@ +import os.path +import sys +import unittest +from test import support + +if support.PGO: + raise unittest.SkipTest("test is not helpful for PGO") + +if sys.platform == "win32": + raise unittest.SkipTest("forkserver is not available on Windows") + +def load_tests(*args): + return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_multiprocessing_forkserver/test_manager.py b/Lib/test/test_multiprocessing_forkserver/test_manager.py new file mode 100644 index 0000000000..14f8f10dfb --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/test_manager.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'forkserver', only_type="manager") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_forkserver/test_misc.py b/Lib/test/test_multiprocessing_forkserver/test_misc.py new file mode 100644 index 0000000000..9cae1b50f7 --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/test_misc.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'forkserver', exclude_types=True) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_forkserver/test_processes.py b/Lib/test/test_multiprocessing_forkserver/test_processes.py new file mode 100644 index 0000000000..360967cf1a --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/test_processes.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'forkserver', only_type="processes") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_forkserver/test_threads.py b/Lib/test/test_multiprocessing_forkserver/test_threads.py new file mode 100644 index 0000000000..719c752aa0 --- /dev/null +++ b/Lib/test/test_multiprocessing_forkserver/test_threads.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'forkserver', only_type="threads") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_main_handling.py b/Lib/test/test_multiprocessing_main_handling.py new file mode 100644 index 0000000000..6b30a89316 --- /dev/null +++ b/Lib/test/test_multiprocessing_main_handling.py @@ -0,0 +1,300 @@ +# tests __main__ module handling in multiprocessing +from test import support +from test.support import import_helper +# Skip tests if _multiprocessing wasn't built. +import_helper.import_module('_multiprocessing') + +import importlib +import importlib.machinery +import unittest +import sys +import os +import os.path +import py_compile + +from test.support import os_helper +from test.support.script_helper import ( + make_pkg, make_script, make_zip_pkg, make_zip_script, + assert_python_ok) + +if support.PGO: + raise unittest.SkipTest("test is not helpful for PGO") + +# Look up which start methods are available to test +import multiprocessing +AVAILABLE_START_METHODS = set(multiprocessing.get_all_start_methods()) + +# Issue #22332: Skip tests if sem_open implementation is broken. +support.skip_if_broken_multiprocessing_synchronize() + +verbose = support.verbose + +test_source = """\ +# multiprocessing includes all sorts of shenanigans to make __main__ +# attributes accessible in the subprocess in a pickle compatible way. + +# We run the "doesn't work in the interactive interpreter" example from +# the docs to make sure it *does* work from an executed __main__, +# regardless of the invocation mechanism + +import sys +import time +from multiprocessing import Pool, set_start_method +from test import support + +# We use this __main__ defined function in the map call below in order to +# check that multiprocessing in correctly running the unguarded +# code in child processes and then making it available as __main__ +def f(x): + return x*x + +# Check explicit relative imports +if "check_sibling" in __file__: + # We're inside a package and not in a __main__.py file + # so make sure explicit relative imports work correctly + from . import sibling + +if __name__ == '__main__': + start_method = sys.argv[1] + set_start_method(start_method) + results = [] + with Pool(5) as pool: + pool.map_async(f, [1, 2, 3], callback=results.extend) + + # up to 1 min to report the results + for _ in support.sleeping_retry(support.LONG_TIMEOUT, + "Timed out waiting for results"): + if results: + break + + results.sort() + print(start_method, "->", results) + + pool.join() +""" + +test_source_main_skipped_in_children = """\ +# __main__.py files have an implied "if __name__ == '__main__'" so +# multiprocessing should always skip running them in child processes + +# This means we can't use __main__ defined functions in child processes, +# so we just use "int" as a passthrough operation below + +if __name__ != "__main__": + raise RuntimeError("Should only be called as __main__!") + +import sys +import time +from multiprocessing import Pool, set_start_method +from test import support + +start_method = sys.argv[1] +set_start_method(start_method) +results = [] +with Pool(5) as pool: + pool.map_async(int, [1, 4, 9], callback=results.extend) + # up to 1 min to report the results + for _ in support.sleeping_retry(support.LONG_TIMEOUT, + "Timed out waiting for results"): + if results: + break + +results.sort() +print(start_method, "->", results) + +pool.join() +""" + +# These helpers were copied from test_cmd_line_script & tweaked a bit... + +def _make_test_script(script_dir, script_basename, + source=test_source, omit_suffix=False): + to_return = make_script(script_dir, script_basename, + source, omit_suffix) + # Hack to check explicit relative imports + if script_basename == "check_sibling": + make_script(script_dir, "sibling", "") + importlib.invalidate_caches() + return to_return + +def _make_test_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename, + source=test_source, depth=1): + to_return = make_zip_pkg(zip_dir, zip_basename, pkg_name, script_basename, + source, depth) + importlib.invalidate_caches() + return to_return + +# There's no easy way to pass the script directory in to get +# -m to work (avoiding that is the whole point of making +# directories and zipfiles executable!) +# So we fake it for testing purposes with a custom launch script +launch_source = """\ +import sys, os.path, runpy +sys.path.insert(0, %s) +runpy._run_module_as_main(%r) +""" + +def _make_launch_script(script_dir, script_basename, module_name, path=None): + if path is None: + path = "os.path.dirname(__file__)" + else: + path = repr(path) + source = launch_source % (path, module_name) + to_return = make_script(script_dir, script_basename, source) + importlib.invalidate_caches() + return to_return + +class MultiProcessingCmdLineMixin(): + maxDiff = None # Show full tracebacks on subprocess failure + + def setUp(self): + if self.start_method not in AVAILABLE_START_METHODS: + self.skipTest("%r start method not available" % self.start_method) + + def _check_output(self, script_name, exit_code, out, err): + if verbose > 1: + print("Output from test script %r:" % script_name) + print(repr(out)) + self.assertEqual(exit_code, 0) + self.assertEqual(err.decode('utf-8'), '') + expected_results = "%s -> [1, 4, 9]" % self.start_method + self.assertEqual(out.decode('utf-8').strip(), expected_results) + + def _check_script(self, script_name, *cmd_line_switches): + if not __debug__: + cmd_line_switches += ('-' + 'O' * sys.flags.optimize,) + run_args = cmd_line_switches + (script_name, self.start_method) + rc, out, err = assert_python_ok(*run_args, __isolated=False) + self._check_output(script_name, rc, out, err) + + def test_basic_script(self): + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script') + self._check_script(script_name) + + def test_basic_script_no_suffix(self): + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script', + omit_suffix=True) + self._check_script(script_name) + + def test_ipython_workaround(self): + # Some versions of the IPython launch script are missing the + # __name__ = "__main__" guard, and multiprocessing has long had + # a workaround for that case + # See https://github.com/ipython/ipython/issues/4698 + source = test_source_main_skipped_in_children + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'ipython', + source=source) + self._check_script(script_name) + script_no_suffix = _make_test_script(script_dir, 'ipython', + source=source, + omit_suffix=True) + self._check_script(script_no_suffix) + + def test_script_compiled(self): + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, 'script') + py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = import_helper.make_legacy_pyc(script_name) + self._check_script(pyc_file) + + def test_directory(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + self._check_script(script_dir) + + def test_directory_compiled(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = import_helper.make_legacy_pyc(script_name) + self._check_script(script_dir) + + def test_zipfile(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + zip_name, run_name = make_zip_script(script_dir, 'test_zip', script_name) + self._check_script(zip_name) + + def test_zipfile_compiled(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + script_name = _make_test_script(script_dir, '__main__', + source=source) + compiled_name = py_compile.compile(script_name, doraise=True) + zip_name, run_name = make_zip_script(script_dir, 'test_zip', compiled_name) + self._check_script(zip_name) + + def test_module_in_package(self): + with os_helper.temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, 'check_sibling') + launch_name = _make_launch_script(script_dir, 'launch', + 'test_pkg.check_sibling') + self._check_script(launch_name) + + def test_module_in_package_in_zipfile(self): + with os_helper.temp_dir() as script_dir: + zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script') + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.script', zip_name) + self._check_script(launch_name) + + def test_module_in_subpackage_in_zipfile(self): + with os_helper.temp_dir() as script_dir: + zip_name, run_name = _make_test_zip_pkg(script_dir, 'test_zip', 'test_pkg', 'script', depth=2) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg.test_pkg.script', zip_name) + self._check_script(launch_name) + + def test_package(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, '__main__', + source=source) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg') + self._check_script(launch_name) + + def test_package_compiled(self): + source = self.main_in_children_source + with os_helper.temp_dir() as script_dir: + pkg_dir = os.path.join(script_dir, 'test_pkg') + make_pkg(pkg_dir) + script_name = _make_test_script(pkg_dir, '__main__', + source=source) + compiled_name = py_compile.compile(script_name, doraise=True) + os.remove(script_name) + pyc_file = import_helper.make_legacy_pyc(script_name) + launch_name = _make_launch_script(script_dir, 'launch', 'test_pkg') + self._check_script(launch_name) + +# Test all supported start methods (setupClass skips as appropriate) + +class SpawnCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'spawn' + main_in_children_source = test_source_main_skipped_in_children + +class ForkCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'fork' + main_in_children_source = test_source + +class ForkServerCmdLineTest(MultiProcessingCmdLineMixin, unittest.TestCase): + start_method = 'forkserver' + main_in_children_source = test_source_main_skipped_in_children + +def tearDownModule(): + support.reap_children() + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_spawn/__init__.py b/Lib/test/test_multiprocessing_spawn/__init__.py new file mode 100644 index 0000000000..3fd0f9b390 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn/__init__.py @@ -0,0 +1,9 @@ +import os.path +import unittest +from test import support + +if support.PGO: + raise unittest.SkipTest("test is not helpful for PGO") + +def load_tests(*args): + return support.load_package_tests(os.path.dirname(__file__), *args) diff --git a/Lib/test/test_multiprocessing_spawn/test_manager.py b/Lib/test/test_multiprocessing_spawn/test_manager.py new file mode 100644 index 0000000000..b40bea0bf6 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn/test_manager.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'spawn', only_type="manager") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_spawn/test_misc.py b/Lib/test/test_multiprocessing_spawn/test_misc.py new file mode 100644 index 0000000000..32f37c5cc8 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn/test_misc.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'spawn', exclude_types=True) + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_spawn/test_processes.py b/Lib/test/test_multiprocessing_spawn/test_processes.py new file mode 100644 index 0000000000..af764b0d84 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn/test_processes.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'spawn', only_type="processes") + +if __name__ == '__main__': + unittest.main() diff --git a/Lib/test/test_multiprocessing_spawn/test_threads.py b/Lib/test/test_multiprocessing_spawn/test_threads.py new file mode 100644 index 0000000000..c1257749b9 --- /dev/null +++ b/Lib/test/test_multiprocessing_spawn/test_threads.py @@ -0,0 +1,7 @@ +import unittest +from test._test_multiprocessing import install_tests_in_module_dict + +install_tests_in_module_dict(globals(), 'spawn', only_type="threads") + +if __name__ == '__main__': + unittest.main()